diff --git a/faiss/IndexIVF.cpp b/faiss/IndexIVF.cpp index a1fa8cd16b..65d017aa5a 100644 --- a/faiss/IndexIVF.cpp +++ b/faiss/IndexIVF.cpp @@ -203,7 +203,8 @@ void IndexIVF::add_core( idx_t n, const float* x, const idx_t* xids, - const idx_t* coarse_idx) { + const idx_t* coarse_idx, + void* inverted_list_context) { // do some blocking to avoid excessive allocs idx_t bs = 65536; if (n > bs) { @@ -218,7 +219,8 @@ void IndexIVF::add_core( i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr, - coarse_idx + i0); + coarse_idx + i0, + inverted_list_context); } return; } @@ -249,7 +251,10 @@ void IndexIVF::add_core( if (list_no >= 0 && list_no % nt == rank) { idx_t id = xids ? xids[i] : ntotal + i; size_t ofs = invlists->add_entry( - list_no, id, flat_codes.get() + i * code_size); + list_no, + id, + flat_codes.get() + i * code_size, + inverted_list_context); dm_adder.add(i, list_no, ofs); @@ -445,6 +450,9 @@ void IndexIVF::search_preassigned( : pmode == 1 ? nprobe > 1 : nprobe * n > 1); + void* inverted_list_context = + params ? params->inverted_list_context : nullptr; + #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap) { std::unique_ptr scanner( @@ -507,7 +515,7 @@ void IndexIVF::search_preassigned( nlist); // don't waste time on empty lists - if (invlists->is_empty(key)) { + if (invlists->is_empty(key, inverted_list_context)) { return (size_t)0; } @@ -520,7 +528,7 @@ void IndexIVF::search_preassigned( size_t list_size = 0; std::unique_ptr it( - invlists->get_iterator(key)); + invlists->get_iterator(key, inverted_list_context)); nheap += scanner->iterate_codes( it.get(), simi, idxi, k, list_size); @@ -783,6 +791,9 @@ void IndexIVF::range_search_preassigned( : pmode == 1 ? nprobe > 1 : nprobe * nx > 1); + void* inverted_list_context = + params ? params->inverted_list_context : nullptr; + #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis) { RangeSearchPartialResult pres(result); @@ -804,7 +815,7 @@ void IndexIVF::range_search_preassigned( ik, nlist); - if (invlists->is_empty(key)) { + if (invlists->is_empty(key, inverted_list_context)) { return; } @@ -813,7 +824,7 @@ void IndexIVF::range_search_preassigned( scanner->set_list(key, coarse_dis[i * nprobe + ik]); if (invlists->use_iterator) { std::unique_ptr it( - invlists->get_iterator(key)); + invlists->get_iterator(key, inverted_list_context)); scanner->iterate_codes_range( it.get(), radius, qres, list_size); diff --git a/faiss/IndexIVF.h b/faiss/IndexIVF.h index d0981caa42..45c65ef839 100644 --- a/faiss/IndexIVF.h +++ b/faiss/IndexIVF.h @@ -72,6 +72,8 @@ struct SearchParametersIVF : SearchParameters { size_t nprobe = 1; ///< number of probes at query time size_t max_codes = 0; ///< max nb of codes to visit to do a query SearchParameters* quantizer_params = nullptr; + /// context object to pass to InvertedLists + void* inverted_list_context = nullptr; virtual ~SearchParametersIVF() {} }; @@ -232,7 +234,8 @@ struct IndexIVF : Index, IndexIVFInterface { idx_t n, const float* x, const idx_t* xids, - const idx_t* precomputed_idx); + const idx_t* precomputed_idx, + void* inverted_list_context = nullptr); /** Encodes a set of vectors as they would appear in the inverted lists * diff --git a/faiss/IndexIVFFlat.cpp b/faiss/IndexIVFFlat.cpp index e985683eba..1b36fea379 100644 --- a/faiss/IndexIVFFlat.cpp +++ b/faiss/IndexIVFFlat.cpp @@ -47,7 +47,8 @@ void IndexIVFFlat::add_core( idx_t n, const float* x, const idx_t* xids, - const idx_t* coarse_idx) { + const idx_t* coarse_idx, + void* inverted_list_context) { FAISS_THROW_IF_NOT(is_trained); FAISS_THROW_IF_NOT(coarse_idx); FAISS_THROW_IF_NOT(!by_residual); @@ -70,8 +71,8 @@ void IndexIVFFlat::add_core( if (list_no >= 0 && list_no % nt == rank) { idx_t id = xids ? xids[i] : ntotal + i; const float* xi = x + i * d; - size_t offset = - invlists->add_entry(list_no, id, (const uint8_t*)xi); + size_t offset = invlists->add_entry( + list_no, id, (const uint8_t*)xi, inverted_list_context); dm_adder.add(i, list_no, offset); n_add++; } else if (rank == 0 && list_no == -1) { diff --git a/faiss/IndexIVFFlat.h b/faiss/IndexIVFFlat.h index a0233052fa..8e47547e02 100644 --- a/faiss/IndexIVFFlat.h +++ b/faiss/IndexIVFFlat.h @@ -32,7 +32,8 @@ struct IndexIVFFlat : IndexIVF { idx_t n, const float* x, const idx_t* xids, - const idx_t* precomputed_idx) override; + const idx_t* precomputed_idx, + void* inverted_list_context = nullptr) override; void encode_vectors( idx_t n, diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index 6de78b9539..5d02c5ee0e 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -135,8 +135,9 @@ void IndexIVFPQ::add_core( idx_t n, const float* x, const idx_t* xids, - const idx_t* coarse_idx) { - add_core_o(n, x, xids, nullptr, coarse_idx); + const idx_t* coarse_idx, + void* inverted_list_context) { + add_core_o(n, x, xids, nullptr, coarse_idx, inverted_list_context); } static std::unique_ptr compute_residuals( @@ -212,7 +213,8 @@ void IndexIVFPQ::add_core_o( const float* x, const idx_t* xids, float* residuals_2, - const idx_t* precomputed_idx) { + const idx_t* precomputed_idx, + void* inverted_list_context) { idx_t bs = index_ivfpq_add_core_o_bs; if (n > bs) { for (idx_t i0 = 0; i0 < n; i0 += bs) { @@ -229,7 +231,8 @@ void IndexIVFPQ::add_core_o( x + i0 * d, xids ? xids + i0 : nullptr, residuals_2 ? residuals_2 + i0 * d : nullptr, - precomputed_idx ? precomputed_idx + i0 : nullptr); + precomputed_idx ? precomputed_idx + i0 : nullptr, + inverted_list_context); } return; } @@ -281,7 +284,8 @@ void IndexIVFPQ::add_core_o( } uint8_t* code = xcodes.get() + i * code_size; - size_t offset = invlists->add_entry(key, id, code); + size_t offset = + invlists->add_entry(key, id, code, inverted_list_context); if (residuals_2) { float* res2 = residuals_2 + i * d; diff --git a/faiss/IndexIVFPQ.h b/faiss/IndexIVFPQ.h index ab49f1e549..d5d21da49d 100644 --- a/faiss/IndexIVFPQ.h +++ b/faiss/IndexIVFPQ.h @@ -71,7 +71,8 @@ struct IndexIVFPQ : IndexIVF { idx_t n, const float* x, const idx_t* xids, - const idx_t* precomputed_idx) override; + const idx_t* precomputed_idx, + void* inverted_list_context = nullptr) override; /// same as add_core, also: /// - output 2nd level residuals if residuals_2 != NULL @@ -81,7 +82,8 @@ struct IndexIVFPQ : IndexIVF { const float* x, const idx_t* xids, float* residuals_2, - const idx_t* precomputed_idx = nullptr); + const idx_t* precomputed_idx = nullptr, + void* inverted_list_context = nullptr); /// trains the product quantizer void train_encoder(idx_t n, const float* x, const idx_t* assign) override; diff --git a/faiss/IndexIVFPQR.cpp b/faiss/IndexIVFPQR.cpp index 2dd967e829..f55332cddf 100644 --- a/faiss/IndexIVFPQR.cpp +++ b/faiss/IndexIVFPQR.cpp @@ -91,7 +91,8 @@ void IndexIVFPQR::add_core( idx_t n, const float* x, const idx_t* xids, - const idx_t* precomputed_idx) { + const idx_t* precomputed_idx, + void* /*inverted_list_context*/) { std::unique_ptr residual_2(new float[n * d]); idx_t n0 = ntotal; diff --git a/faiss/IndexIVFPQR.h b/faiss/IndexIVFPQR.h index 73502879f2..7642d2f232 100644 --- a/faiss/IndexIVFPQR.h +++ b/faiss/IndexIVFPQR.h @@ -48,7 +48,8 @@ struct IndexIVFPQR : IndexIVFPQ { idx_t n, const float* x, const idx_t* xids, - const idx_t* precomputed_idx) override; + const idx_t* precomputed_idx, + void* inverted_list_context = nullptr) override; void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons) const override; diff --git a/faiss/IndexScalarQuantizer.cpp b/faiss/IndexScalarQuantizer.cpp index b7199df39d..9203a98932 100644 --- a/faiss/IndexScalarQuantizer.cpp +++ b/faiss/IndexScalarQuantizer.cpp @@ -207,7 +207,8 @@ void IndexIVFScalarQuantizer::add_core( idx_t n, const float* x, const idx_t* xids, - const idx_t* coarse_idx) { + const idx_t* coarse_idx, + void* inverted_list_context) { FAISS_THROW_IF_NOT(is_trained); std::unique_ptr squant(sq.select_quantizer()); @@ -236,7 +237,8 @@ void IndexIVFScalarQuantizer::add_core( memset(one_code.data(), 0, code_size); squant->encode_vector(xi, one_code.data()); - size_t ofs = invlists->add_entry(list_no, id, one_code.data()); + size_t ofs = invlists->add_entry( + list_no, id, one_code.data(), inverted_list_context); dm_add.add(i, list_no, ofs); diff --git a/faiss/IndexScalarQuantizer.h b/faiss/IndexScalarQuantizer.h index c064bbeeb3..27332500c1 100644 --- a/faiss/IndexScalarQuantizer.h +++ b/faiss/IndexScalarQuantizer.h @@ -91,7 +91,8 @@ struct IndexIVFScalarQuantizer : IndexIVF { idx_t n, const float* x, const idx_t* xids, - const idx_t* precomputed_idx) override; + const idx_t* precomputed_idx, + void* inverted_list_context = nullptr) override; InvertedListScanner* get_InvertedListScanner( bool store_pairs, diff --git a/faiss/invlists/InvertedLists.cpp b/faiss/invlists/InvertedLists.cpp index ca87ae00ea..cc337d004b 100644 --- a/faiss/invlists/InvertedLists.cpp +++ b/faiss/invlists/InvertedLists.cpp @@ -28,11 +28,12 @@ InvertedLists::InvertedLists(size_t nlist, size_t code_size) InvertedLists::~InvertedLists() {} -bool InvertedLists::is_empty(size_t list_no) const { - return use_iterator - ? !std::unique_ptr(get_iterator(list_no)) - ->is_available() - : list_size(list_no) == 0; +bool InvertedLists::is_empty(size_t list_no, void* inverted_list_context) + const { + return use_iterator ? !std::unique_ptr( + get_iterator(list_no, inverted_list_context)) + ->is_available() + : list_size(list_no) == 0; } idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const { @@ -58,7 +59,8 @@ const uint8_t* InvertedLists::get_single_code(size_t list_no, size_t offset) size_t InvertedLists::add_entry( size_t list_no, idx_t theid, - const uint8_t* code) { + const uint8_t* code, + void* /*inverted_list_context*/) { return add_entries(list_no, 1, &theid, code); } @@ -76,7 +78,9 @@ void InvertedLists::reset() { } } -InvertedListsIterator* InvertedLists::get_iterator(size_t /*list_no*/) const { +InvertedListsIterator* InvertedLists::get_iterator( + size_t /*list_no*/, + void* /*inverted_list_context*/) const { FAISS_THROW_MSG("get_iterator is not supported"); } diff --git a/faiss/invlists/InvertedLists.h b/faiss/invlists/InvertedLists.h index c4d681452b..90a9d65411 100644 --- a/faiss/invlists/InvertedLists.h +++ b/faiss/invlists/InvertedLists.h @@ -51,13 +51,15 @@ struct InvertedLists { * Read only functions */ // check if the list is empty - bool is_empty(size_t list_no) const; + bool is_empty(size_t list_no, void* inverted_list_context) const; /// get the size of a list virtual size_t list_size(size_t list_no) const = 0; /// get iterable for lists that use_iterator - virtual InvertedListsIterator* get_iterator(size_t list_no) const; + virtual InvertedListsIterator* get_iterator( + size_t list_no, + void* inverted_list_context) const; /** get the codes for an inverted list * must be released by release_codes @@ -94,7 +96,11 @@ struct InvertedLists { * writing functions */ /// add one entry to an inverted list - virtual size_t add_entry(size_t list_no, idx_t theid, const uint8_t* code); + virtual size_t add_entry( + size_t list_no, + idx_t theid, + const uint8_t* code, + void* inverted_list_context = nullptr); virtual size_t add_entries( size_t list_no, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cc0a4f4cfd..a1c3266961 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,6 +10,7 @@ set(FAISS_TEST_SRC test_ivfpq_codec.cpp test_ivfpq_indexing.cpp test_lowlevel_ivf.cpp + test_ivf_index.cpp test_merge.cpp test_omp_threads.cpp test_ondisk_ivf.cpp diff --git a/tests/test_ivf_index.cpp b/tests/test_ivf_index.cpp new file mode 100644 index 0000000000..54cb7945f9 --- /dev/null +++ b/tests/test_ivf_index.cpp @@ -0,0 +1,242 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace { + +// stores all ivf lists, used to verify the context +// object is passed to the iterator +class TestContext { + public: + TestContext() {} + + void save_code(size_t list_no, const uint8_t* code, size_t code_size) { + list_nos.emplace(id, list_no); + codes.emplace(id, std::vector(code_size)); + for (size_t i = 0; i < code_size; i++) { + codes[id][i] = code[i]; + } + id++; + } + + // id to codes map + std::unordered_map> codes; + // id to list_no map + std::unordered_map list_nos; + faiss::idx_t id = 0; + std::set lists_probed; +}; + +// the iterator that iterates over the codes stored in context object +class TestInvertedListIterator : public faiss::InvertedListsIterator { + public: + TestInvertedListIterator(size_t list_no, TestContext* context) + : list_no{list_no}, context{context} { + it = context->codes.cbegin(); + seek_next(); + } + ~TestInvertedListIterator() override {} + + // move the cursor to the first valid entry + void seek_next() { + while (it != context->codes.cend() && + context->list_nos[it->first] != list_no) { + it++; + } + } + + virtual bool is_available() const override { + return it != context->codes.cend(); + } + + virtual void next() override { + it++; + seek_next(); + } + + virtual std::pair get_id_and_codes() + override { + if (it == context->codes.cend()) { + FAISS_THROW_MSG("invalid state"); + } + return std::make_pair(it->first, it->second.data()); + } + + private: + size_t list_no; + TestContext* context; + decltype(context->codes.cbegin()) it; +}; + +class TestInvertedLists : public faiss::InvertedLists { + public: + TestInvertedLists(size_t nlist, size_t code_size) + : faiss::InvertedLists(nlist, code_size) { + use_iterator = true; + } + + ~TestInvertedLists() override {} + size_t list_size(size_t /*list_no*/) const override { + FAISS_THROW_MSG("unexpected call"); + } + + faiss::InvertedListsIterator* get_iterator(size_t list_no, void* context) + const override { + auto testContext = (TestContext*)context; + testContext->lists_probed.insert(list_no); + return new TestInvertedListIterator(list_no, testContext); + } + + const uint8_t* get_codes(size_t /* list_no */) const override { + FAISS_THROW_MSG("unexpected call"); + } + + const faiss::idx_t* get_ids(size_t /* list_no */) const override { + FAISS_THROW_MSG("unexpected call"); + } + + // store the codes in context object + size_t add_entry( + size_t list_no, + faiss::idx_t /*theid*/, + const uint8_t* code, + void* context) override { + auto testContext = (TestContext*)context; + testContext->save_code(list_no, code, code_size); + return 0; + } + + size_t add_entries( + size_t /*list_no*/, + size_t /*n_entry*/, + const faiss::idx_t* /*ids*/, + const uint8_t* /*code*/) override { + FAISS_THROW_MSG("unexpected call"); + } + + void update_entries( + size_t /*list_no*/, + size_t /*offset*/, + size_t /*n_entry*/, + const faiss::idx_t* /*ids*/, + const uint8_t* /*code*/) override { + FAISS_THROW_MSG("unexpected call"); + } + + void resize(size_t /*list_no*/, size_t /*new_size*/) override { + FAISS_THROW_MSG("unexpected call"); + } +}; +} // namespace + +TEST(IVF, list_context) { + // this test verifies that the context object is passed + // to the InvertedListsIterator and InvertedLists::add_entry. + // the test InvertedLists and InvertedListsIterator reads/writes + // to the test context object. + // the test verifies the context object is modified as expected. + + constexpr int d = 32; // dimension + constexpr int nb = 100000; // database size + constexpr int nlist = 100; + + std::mt19937 rng; + std::uniform_real_distribution<> distrib; + + // disable parallism, or we need to make Context object + // thread-safe + omp_set_num_threads(1); + + faiss::IndexFlatL2 quantizer(d); // the other index + faiss::IndexIVFFlat index(&quantizer, d, nlist); + TestInvertedLists inverted_lists(nlist, index.code_size); + index.replace_invlists(&inverted_lists); + { + // training + constexpr size_t nt = 1500; // nb of training vectors + std::vector trainvecs(nt * d); + for (size_t i = 0; i < nt * d; i++) { + trainvecs[i] = distrib(rng); + } + index.verbose = true; + index.train(nt, trainvecs.data()); + } + TestContext context; + std::vector query_vector; + constexpr faiss::idx_t query_vector_id = 100; + { + // populating the database + std::vector database(nb * d); + for (size_t i = 0; i < nb * d; i++) { + database[i] = distrib(rng); + // populate the query vector + if (i >= query_vector_id * d && i < query_vector_id * d + d) { + query_vector.push_back(database[i]); + } + } + std::vector coarse_idx(nb); + index.quantizer->assign(nb, database.data(), coarse_idx.data()); + // pass dummy ids, the acutal ids are assigned in TextContext object + std::vector xids(nb, 42); + index.add_core( + nb, database.data(), xids.data(), coarse_idx.data(), &context); + + // check the context object get updated + EXPECT_EQ(nb, context.id) << "should have added all ids"; + EXPECT_EQ(nb, context.codes.size()) + << "should have correct number of codes"; + EXPECT_EQ(nb, context.list_nos.size()) + << "should have correct number of list numbers"; + } + { + constexpr faiss::idx_t k = 100; + constexpr size_t nprobe = 10; + std::vector distances(k); + std::vector labels(k); + faiss::SearchParametersIVF params; + params.inverted_list_context = &context; + params.nprobe = nprobe; + index.search( + 1, + query_vector.data(), + k, + distances.data(), + labels.data(), + ¶ms); + EXPECT_EQ(nprobe, context.lists_probed.size()) + << "should probe nprobe lists"; + + // check the result contains the query vector, the probablity of + // this fail should be low + auto query_vector_listno = context.list_nos[query_vector_id]; + auto& lists_probed = context.lists_probed; + EXPECT_TRUE( + std::find( + lists_probed.cbegin(), + lists_probed.cend(), + query_vector_listno) != lists_probed.cend()) + << "should probe the list of the query vector"; + EXPECT_TRUE( + std::find(labels.cbegin(), labels.cend(), query_vector_id) != + labels.cend()) + << "should return the query vector"; + } +}