diff --git a/contrib/ivf_tools.py b/contrib/ivf_tools.py index 26ada886a1..1c10eb0386 100644 --- a/contrib/ivf_tools.py +++ b/contrib/ivf_tools.py @@ -32,6 +32,11 @@ def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None): Supports indexes with pretransforms (as opposed to the IndexIVF.search_preassigned, that cannot be applied with pretransform). """ + if isinstance(index_ivf, faiss.IndexPreTransform): + assert index_ivf.chain.size() == 1, "chain must have only one component" + transform = faiss.downcast_VectorTransform(index_ivf.chain.at(0)) + xq = transform.apply(xq) + index_ivf = faiss.downcast_index(index_ivf.index) n, d = xq.shape if isinstance(index_ivf, faiss.IndexBinaryIVF): d *= 8 diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 36c17792ce..f3411163de 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -306,6 +306,26 @@ def test_PR_multiple(self): class TestPreassigned(unittest.TestCase): + def test_index_pretransformed(self): + + ds = datasets.SyntheticDataset(128, 2000, 2000, 200) + xt = ds.get_train() + xq = ds.get_queries() + xb = ds.get_database() + index = faiss.index_factory(128, 'PCA64,IVF64,PQ4np') + index.train(xt) + index.add(xb) + index_downcasted = faiss.extract_index_ivf(index) + index_downcasted.nprobe = 10 + xq_trans = index.chain.at(0).apply_py(xq) + D_ref, I_ref = index.search(xq, 4) + + quantizer = index_downcasted.quantizer + Dq, Iq = quantizer.search(xq_trans, index_downcasted.nprobe) + D, I = ivf_tools.search_preassigned(index, xq, 4, Iq, Dq) + np.testing.assert_array_equal(D_ref, D) + np.testing.assert_array_equal(I_ref, I) + def test_float(self): ds = datasets.SyntheticDataset(128, 2000, 2000, 200)