Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add different metric type and initialization methods to nanopq #24

Merged
merged 2 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion nanopq/convert_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from .opq import OPQ
from .pq import PQ

faiss_metric_map = {
'l2': faiss.METRIC_L2,
'dot': faiss.METRIC_INNER_PRODUCT,
'angular': faiss.METRIC_INNER_PRODUCT
}

def nanopq_to_faiss(pq_nanopq):
"""Convert a :class:`nanopq.PQ` instance to `faiss.IndexPQ </~https://github.com/facebookresearch/faiss/blob/master/IndexPQ.h>`_.
Expand All @@ -31,7 +36,7 @@ def nanopq_to_faiss(pq_nanopq):
D = pq_nanopq.Ds * pq_nanopq.M
nbits = {np.uint8: 8, np.uint16: 16, np.uint32: 32}[pq_nanopq.code_dtype]

pq_faiss = faiss.IndexPQ(D, pq_nanopq.M, nbits)
pq_faiss = faiss.IndexPQ(D, pq_nanopq.M, nbits, faiss_metric_map[pq_nanopq.metric])

for m in range(pq_nanopq.M):
# Prepare std::vector<float>
Expand Down
4 changes: 2 additions & 2 deletions nanopq/opq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class OPQ(object):

"""

def __init__(self, M, Ks=256, verbose=True):
self.pq = PQ(M, Ks, verbose)
def __init__(self, M, Ks=256, metric='l2', minit='random', verbose=True):
self.pq = PQ(M, Ks, metric=metric, minit=minit, verbose=verbose)
self.R = None

def __eq__(self, other):
Expand Down
56 changes: 45 additions & 11 deletions nanopq/pq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
import warnings
import numpy as np
from scipy.cluster.vq import kmeans2, vq


def dist_l2(q, x):
return np.linalg.norm(q - x, ord=2, axis=1) ** 2


def dist_ip(q, x):
return np.matmul(x, q[None, :].T).sum(axis=-1)


def dist_angular(q, x):
return dist_ip(q, x)


metric_function_map = {
'l2': dist_l2,
'angular': dist_angular,
'dot': dist_ip
}


class PQ(object):
"""Pure python implementation of Product Quantization (PQ) [Jegou11]_.

Expand All @@ -19,12 +39,14 @@ class PQ(object):
M (int): The number of sub-space
Ks (int): The number of codewords for each subspace
(typically 256, so that each sub-vector is quantized
into 8 bits = 1 byte = uint8)
into 256 bits = 1 byte = uint8)
metric (str): Type of metric used among vectors
verbose (bool): Verbose flag

Attributes:
M (int): The number of sub-space
Ks (int): The number of codewords for each subspace
metric (str): Type of metric used among vectors
verbose (bool): Verbose flag
code_dtype (object): dtype of PQ-code. Either np.uint{8, 16, 32}
codewords (np.ndarray): shape=(M, Ks, Ds) with dtype=np.float32.
Expand All @@ -33,17 +55,22 @@ class PQ(object):

"""

def __init__(self, M, Ks=256, verbose=True):
def __init__(self, M, Ks=256, metric='l2', minit='random', verbose=True):
assert 0 < Ks <= 2 ** 32
self.M, self.Ks, self.verbose = M, Ks, verbose
assert metric in ['l2', 'dot', 'angular']
assert minit in ['random', '++', 'points', 'matrix']
self.M, self.Ks, self.verbose, self.metric = M, Ks, verbose, metric
self.code_dtype = (
np.uint8 if Ks <= 2 ** 8 else (np.uint16 if Ks <= 2 ** 16 else np.uint32)
)
self.codewords = None
self.Ds = None
self.metric = metric
self.minit = minit

if verbose:
print("M: {}, Ks: {}, code_dtype: {}".format(M, Ks, self.code_dtype))
print("M: {}, Ks: {}, metric : {}, code_dtype: {} minit: {}".format(
M, Ks, self.code_dtype, metric, minit))

def __eq__(self, other):
if isinstance(other, PQ):
Expand Down Expand Up @@ -88,9 +115,9 @@ def fit(self, vecs, iter=20, seed=123):
for m in range(self.M):
if self.verbose:
print("Training the subspace: {} / {}".format(m, self.M))
vecs_sub = vecs[:, m * self.Ds : (m + 1) * self.Ds]
self.codewords[m], _ = kmeans2(vecs_sub, self.Ks, iter=iter, minit="points")

vecs_sub = vecs[:, m * self.Ds: (m + 1) * self.Ds]
self.codewords[m], _ = kmeans2(
vecs_sub, self.Ks, iter=iter, minit=self.minit)
return self

def encode(self, vecs):
Expand Down Expand Up @@ -167,10 +194,11 @@ def dtable(self, query):
# dtable[m][ks] : distance between m-th subvec and ks-th codeword of m-th codewords
dtable = np.empty((self.M, self.Ks), dtype=np.float32)
for m in range(self.M):
query_sub = query[m * self.Ds : (m + 1) * self.Ds]
dtable[m, :] = np.linalg.norm(self.codewords[m] - query_sub, axis=1) ** 2
query_sub = query[m * self.Ds: (m + 1) * self.Ds]
dtable[m, :] = metric_function_map[self.metric](
query_sub, self.codewords[m])

return DistanceTable(dtable)
return DistanceTable(dtable, D=D, metric=self.metric)


class DistanceTable(object):
Expand All @@ -183,6 +211,7 @@ class DistanceTable(object):
Args:
dtable (np.ndarray): Distance table with shape=(M, Ks) and dtype=np.float32
computed by :func:`PQ.dtable` or :func:`OPQ.dtable`
metric (str): metric type to calculate distance

Attributes:
dtable (np.ndarray): Distance table with shape=(M, Ks) and dtype=np.float32.
Expand All @@ -191,10 +220,13 @@ class DistanceTable(object):

"""

def __init__(self, dtable):
def __init__(self, dtable, D, metric='l2'):
assert dtable.ndim == 2
assert dtable.dtype == np.float32
assert metric in ['l2', 'dot', 'angular']
self.dtable = dtable
self.metric = metric
self.D = D

def adist(self, codes):
"""Given PQ-codes, compute Asymmetric Distances between the query (self.dtable)
Expand All @@ -215,6 +247,8 @@ def adist(self, codes):

# Fetch distance values using codes. The following codes are
dists = np.sum(self.dtable[range(M), codes], axis=1)
if self.metric == 'angular':
dists = 1 - dists

# The above line is equivalent to the followings:
# dists = np.zeros((N, )).astype(np.float32)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_pq.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,40 @@ def test_pickle(self):
)
self.assertTrue(np.allclose(pq.codewords, pq2.codewords))
self.assertTrue(pq == pq2)

def test_ip(self):
N, D, M, Ks = 100, 12, 4, 10
X = np.random.random((N, D)).astype(np.float32)
pq = nanopq.PQ(M=M, Ks=Ks, metric='dot')
pq.fit(X)
X_ = pq.encode(X)
q = X[13]
dist1 = pq.dtable(q).adist(X_)
dtable = np.empty((pq.M, pq.Ks), dtype=np.float32)
for m in range(pq.M):
query_sub = q[m * pq.Ds : (m + 1) * pq.Ds]
dtable[m, :] = np.matmul(pq.codewords[m], query_sub[None, :].T).sum(axis=-1)
dist2 = np.sum(dtable[range(M), X_], axis=1)
self.assertTrue((dist1 == dist2).all())
self.assertTrue(abs(np.mean(np.matmul(X, q[:, None]).squeeze() - dist1)) < 1e-7)

def test_angular(self):
N, D, M, Ks = 100, 12, 4, 10
X = np.random.random((N, D)).astype(np.float32)
X[np.linalg.norm(X, axis=1) == 0] = 1.0 / np.sqrt(X.shape[1])
X /= np.linalg.norm(X, ord=2, axis=-1)[:, None]
pq = nanopq.PQ(M=M, Ks=Ks, metric='angular')
pq.fit(X)
X_ = pq.encode(X)
q = X[13]
dist1 = pq.dtable(q).adist(X_)
dtable = np.empty((pq.M, pq.Ks), dtype=np.float32)
for m in range(pq.M):
query_sub = q[m * pq.Ds : (m + 1) * pq.Ds]
dtable[m, :] = np.matmul(pq.codewords[m], query_sub[None, :].T).sum(axis=-1)
dist2 = 1 - np.sum(dtable[range(M), X_], axis=1)
self.assertTrue((dist1 == dist2).all())
self.assertTrue(abs(np.mean((1-np.matmul(X, q[:, None]) / (np.linalg.norm(q) * np.linalg.norm(X, ord=2, axis=-1))) - dist1)) < 1e-7)


if __name__ == "__main__":
Expand Down