Skip to content

Commit

Permalink
Index pretransform support in search_preassigned (facebookresearch#3225)
Browse files Browse the repository at this point in the history
Summary:

This diff fixes issue [facebookresearch#3113](facebookresearch#3113), e.g. introduces support for index pretransform in `search_preassigned`.

Reviewed By: mdouze

Differential Revision: D53188584
  • Loading branch information
mlomeli1 authored and facebook-github-bot committed Jan 30, 2024
1 parent 2817344 commit 640e86d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
5 changes: 5 additions & 0 deletions contrib/ivf_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
Supports indexes with pretransforms (as opposed to the
IndexIVF.search_preassigned, that cannot be applied with pretransform).
"""
if isinstance(index_ivf, faiss.IndexPreTransform):
assert index_ivf.chain.size() == 1, "chain must have only one component"
transform = faiss.downcast_VectorTransform(index_ivf.chain.at(0))
xq = transform.apply(xq)
index_ivf = faiss.downcast_index(index_ivf.index)
n, d = xq.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
Expand Down
20 changes: 20 additions & 0 deletions tests/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,26 @@ def test_PR_multiple(self):

class TestPreassigned(unittest.TestCase):

def test_index_pretransformed(self):

ds = datasets.SyntheticDataset(128, 2000, 2000, 200)
xt = ds.get_train()
xq = ds.get_queries()
xb = ds.get_database()
index = faiss.index_factory(128, 'PCA64,IVF64,PQ4np')
index.train(xt)
index.add(xb)
index_downcasted = faiss.extract_index_ivf(index)
index_downcasted.nprobe = 10
xq_trans = index.chain.at(0).apply_py(xq)
D_ref, I_ref = index.search(xq, 4)

quantizer = index_downcasted.quantizer
Dq, Iq = quantizer.search(xq_trans, index_downcasted.nprobe)
D, I = ivf_tools.search_preassigned(index, xq, 4, Iq, Dq)
np.testing.assert_array_equal(D_ref, D)
np.testing.assert_array_equal(I_ref, I)

def test_float(self):
ds = datasets.SyntheticDataset(128, 2000, 2000, 200)

Expand Down

0 comments on commit 640e86d

Please sign in to comment.