From f28434bd57c74154099d36bd42f2a0b6d11d111f Mon Sep 17 00:00:00 2001 From: Maria Lomeli Date: Mon, 29 Jan 2024 10:23:06 -0800 Subject: [PATCH] Index pretransform support in search_preassigned (#3225) Summary: This diff fixes issue [#3113](https://github.com/facebookresearch/faiss/issues/3113), e.g. introduces support for index pretransform in `search_preassigned`. Reviewed By: mdouze Differential Revision: D53188584 --- contrib/ivf_tools.py | 5 +++++ tests/test_contrib.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/contrib/ivf_tools.py b/contrib/ivf_tools.py index 26ada886a1..1c10eb0386 100644 --- a/contrib/ivf_tools.py +++ b/contrib/ivf_tools.py @@ -32,6 +32,11 @@ def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None): Supports indexes with pretransforms (as opposed to the IndexIVF.search_preassigned, that cannot be applied with pretransform). """ + if isinstance(index_ivf, faiss.IndexPreTransform): + assert index_ivf.chain.size() == 1, "chain must have only one component" + transform = faiss.downcast_VectorTransform(index_ivf.chain.at(0)) + xq = transform.apply(xq) + index_ivf = faiss.downcast_index(index_ivf.index) n, d = xq.shape if isinstance(index_ivf, faiss.IndexBinaryIVF): d *= 8 diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 36c17792ce..f3411163de 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -306,6 +306,26 @@ def test_PR_multiple(self): class TestPreassigned(unittest.TestCase): + def test_index_pretransformed(self): + + ds = datasets.SyntheticDataset(128, 2000, 2000, 200) + xt = ds.get_train() + xq = ds.get_queries() + xb = ds.get_database() + index = faiss.index_factory(128, 'PCA64,IVF64,PQ4np') + index.train(xt) + index.add(xb) + index_downcasted = faiss.extract_index_ivf(index) + index_downcasted.nprobe = 10 + xq_trans = index.chain.at(0).apply_py(xq) + D_ref, I_ref = index.search(xq, 4) + + quantizer = index_downcasted.quantizer + Dq, Iq = quantizer.search(xq_trans, index_downcasted.nprobe) + D, I = ivf_tools.search_preassigned(index, xq, 4, Iq, Dq) + np.testing.assert_array_equal(D_ref, D) + np.testing.assert_array_equal(I_ref, I) + def test_float(self): ds = datasets.SyntheticDataset(128, 2000, 2000, 200)