dak ブログ

python、rubyなどのプログラミング、MySQL、サーバーの設定などの備忘録。レゴの写真も。

nmslib のパラメータ

2021-03-19 23:53:43 | python
ベクトル検索の nmslib のパラメータについてのメモ。

ベクトル間の cos 類似度が高い順に検索するには、以下のように space='cosinesimil' を指定して初期化します。
index = nmslib.init(space='cosinesimil') 


■サンプルプログラム(cos 類似度版)
import numpy as np
import nmslib

index = nmslib.init(method='hnsw', space='cosinesimil')

vs = [
    [1.0, 2.0],
    [2.0, 2.0],
    [3.0, 3.0],
]

# インデックス生成
vs = np.array(vs, dtype=np.float32)
index.addDataPointBatch(vs)
index.createIndex({}, print_progress=False)

# 検索
v = np.array([1.0, 1.0], dtype=np.float32)
ids, dists = index.knnQuery(v, 10)

print('ids: %s' % (ids))
print('dists: %s' % (sims))

実行結果
ids: [2 1 0]
dists: [0.0000000e+00 5.9604645e-08 5.1316738e-02]

※cos類似度=1の場合に距離=0

一方、ベクトル間の距離が短い順に検索するには、space='l2' を指定して初期化します。
index = nmslib.init(space='l2')


■サンプルプログラム(距離版)
import numpy as np
import nmslib

index = nmslib.init(method='hnsw', space='l2')

vs = [
    [1.0, 2.0],
    [2.0, 2.0],
    [3.0, 3.0],
]

# インデックス生成
vs = np.array(vs, dtype=np.float32)
index.addDataPointBatch(vs)
index.createIndex({}, print_progress=False)

# 検索
v = np.array([1.0, 1.0], dtype=np.float32)
ids, dists = index.knnQuery(v, 10)

print('ids: %s' % (ids))
print('dists: %s' % (dists))

実行結果
ids: [0 1 2]
dists: [1. 2. 8.]

※距離は2乗和となっている

その他に、method='hnsw' の場合、createIndex() メソッドの第1引数のパラメータに
'post' を指定すると後処理の制御ができます。
0: 後処理なし、1, 2: 後処理あり(2 の方が強力)