-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathannoy.py
34 lines (24 loc) · 987 Bytes
/
annoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from .base import BaseANN
import annoy
class AnnoyANN(BaseANN):
def __init__(self):
self.n_trees, self.index = None, None
def set_index_param(self, param):
self.n_trees = param["n_trees"]
def has_train(self):
return False
def add(self, vecs):
self.index = annoy.AnnoyIndex(f=vecs.shape[1], metric="euclidean")
for n, vec in enumerate(vecs):
self.index.add_item(n, vec.tolist())
self.index.build(self.n_trees)
def query(self, vecs, topk, param):
return [self.index.get_nns_by_vector(vector=vec.tolist(), n=topk, search_k=param["search_k"]) for vec in vecs]
def write(self, path):
self.index.save(path)
def read(self, path, D):
self.index = annoy.AnnoyIndex(f=D, metric="euclidean")
self.index.load(path, prefault=True)
self.n_trees = self.index.get_n_trees()
def stringify_index_param(self, param):
return f"ntrees{param['n_trees']}.bin"