From e5c507e1129e1c49175de1ae058e0963e16efa40 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Tue, 30 Jan 2024 14:43:56 -0800 Subject: [PATCH] Add IDGrouper for HNSW --- faiss/CMakeLists.txt | 3 + faiss/Index.h | 8 +- faiss/IndexHNSW.cpp | 13 +- faiss/IndexIDMap.cpp | 29 +++++ faiss/IndexIDMap.h | 22 ++++ faiss/impl/HNSW.cpp | 6 + faiss/impl/IDGrouper.cpp | 51 ++++++++ faiss/impl/IDGrouper.h | 51 ++++++++ faiss/impl/ResultHandler.h | 189 +++++++++++++++++++++++++++++ faiss/utils/GroupHeap.h | 182 ++++++++++++++++++++++++++++ tests/CMakeLists.txt | 2 + tests/test_group_heap.cpp | 98 +++++++++++++++ tests/test_id_grouper.cpp | 241 +++++++++++++++++++++++++++++++++++++ 13 files changed, 890 insertions(+), 5 deletions(-) create mode 100644 faiss/impl/IDGrouper.cpp create mode 100644 faiss/impl/IDGrouper.h create mode 100644 faiss/utils/GroupHeap.h create mode 100644 tests/test_group_heap.cpp create mode 100644 tests/test_id_grouper.cpp diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 33e1849568..5c2ffab816 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -54,6 +54,7 @@ set(FAISS_SRC impl/AuxIndexStructures.cpp impl/CodePacker.cpp impl/IDSelector.cpp + impl/IDGrouper.cpp impl/FaissException.cpp impl/HNSW.cpp impl/NSG.cpp @@ -149,6 +150,7 @@ set(FAISS_HEADERS impl/AuxIndexStructures.h impl/CodePacker.h impl/IDSelector.h + impl/IDGrouper.h impl/DistanceComputer.h impl/FaissAssert.h impl/FaissException.h @@ -183,6 +185,7 @@ set(FAISS_HEADERS invlists/InvertedLists.h invlists/InvertedListsIOHook.h utils/AlignedTable.h + utils/GroupHeap.h utils/Heap.h utils/WorkerThread.h utils/distances.h diff --git a/faiss/Index.h b/faiss/Index.h index 3d1bdb996a..a862285830 100644 --- a/faiss/Index.h +++ b/faiss/Index.h @@ -38,9 +38,10 @@ namespace faiss { -/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and -/// impl/DistanceComputer.h +/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h +/// ,impl/IDGrouper.h and impl/DistanceComputer.h struct IDSelector; +struct IDGrouper; struct RangeSearchResult; struct DistanceComputer; @@ -52,6 +53,9 @@ struct DistanceComputer; struct SearchParameters { /// if non-null, only these IDs will be considered during search. IDSelector* sel = nullptr; + /// if non-null, only best matched ID per group will be included in the + /// result. + IDGrouper* grp = nullptr; /// make sure we can dynamic_cast this virtual ~SearchParameters() {} }; diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 3325c8c0e1..5017c25dcc 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -352,10 +352,17 @@ void IndexHNSW::search( const SearchParameters* params_in) const { FAISS_THROW_IF_NOT(k > 0); - using RH = HeapBlockResultHandler; - RH bres(n, distances, labels, k); + if (params_in && params_in->grp) { + using RH = GroupedHeapBlockResultHandler; + RH bres(n, distances, labels, k, params_in->grp); - hnsw_search(this, n, x, bres, params_in); + hnsw_search(this, n, x, bres, params_in); + } else { + using RH = HeapBlockResultHandler; + RH bres(n, distances, labels, k); + + hnsw_search(this, n, x, bres, params_in); + } if (is_similarity_metric(this->metric_type)) { // we need to revert the negated distances diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp index dc84052b2f..3f375e7bb6 100644 --- a/faiss/IndexIDMap.cpp +++ b/faiss/IndexIDMap.cpp @@ -102,6 +102,23 @@ struct ScopedSelChange { } }; +/// RAII object to reset the IDGrouper in the params object +struct ScopedGrpChange { + SearchParameters* params = nullptr; + IDGrouper* old_grp = nullptr; + + void set(SearchParameters* params_2, IDGrouper* new_grp) { + this->params = params_2; + old_grp = params_2->grp; + params_2->grp = new_grp; + } + ~ScopedGrpChange() { + if (params) { + params->grp = old_grp; + } + } +}; + } // namespace template @@ -114,6 +131,8 @@ void IndexIDMapTemplate::search( const SearchParameters* params) const { IDSelectorTranslated this_idtrans(this->id_map, nullptr); ScopedSelChange sel_change; + IDGrouperTranslated this_idgrptrans(this->id_map, nullptr); + ScopedGrpChange grp_change; if (params && params->sel) { auto idtrans = dynamic_cast(params->sel); @@ -131,6 +150,16 @@ void IndexIDMapTemplate::search( sel_change.set(params_non_const, &this_idtrans); } } + + if (params && params->grp) { + auto idtrans = dynamic_cast(params->grp); + + if (!idtrans) { + auto params_non_const = const_cast(params); + this_idgrptrans.grp = params->grp; + grp_change.set(params_non_const, &this_idgrptrans); + } + } index->search(n, x, k, distances, labels, params); idx_t* li = labels; #pragma omp parallel for diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h index 2d16412301..a68887bdef 100644 --- a/faiss/IndexIDMap.h +++ b/faiss/IndexIDMap.h @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -124,4 +125,25 @@ struct IDSelectorTranslated : IDSelector { } }; +// IDGrouper that translates the ids using an IDMap +struct IDGrouperTranslated : IDGrouper { + const std::vector& id_map; + const IDGrouper* grp; + + IDGrouperTranslated( + const std::vector& id_map, + const IDGrouper* grp) + : id_map(id_map), grp(grp) {} + + IDGrouperTranslated(IndexBinaryIDMap& index_idmap, const IDGrouper* grp) + : id_map(index_idmap.id_map), grp(grp) {} + + IDGrouperTranslated(IndexIDMap& index_idmap, const IDGrouper* grp) + : id_map(index_idmap.id_map), grp(grp) {} + + idx_t get_group(idx_t id) const override { + return grp->get_group(id_map[id]); + } +}; + } // namespace faiss diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index d8c8225968..35f9fb5d58 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler& res) { if (auto hres = dynamic_cast(&res)) { return hres->k; } + + if (auto hres = dynamic_cast< + GroupedHeapBlockResultHandler::SingleResultHandler*>(&res)) { + return hres->k; + } + return 1; } diff --git a/faiss/impl/IDGrouper.cpp b/faiss/impl/IDGrouper.cpp new file mode 100644 index 0000000000..ca9f5fda7a --- /dev/null +++ b/faiss/impl/IDGrouper.cpp @@ -0,0 +1,51 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace faiss { + +/*********************************************************************** + * IDGrouperBitmap + ***********************************************************************/ + +IDGrouperBitmap::IDGrouperBitmap(size_t n, uint64_t* bitmap) + : n(n), bitmap(bitmap) {} + +idx_t IDGrouperBitmap::get_group(idx_t id) const { + assert(id >= 0 && "id shouldn't be less than zero"); + assert(id < this->n * 64 && "is should be less than total number of bits"); + + idx_t index = id >> 6; // div by 64 + uint64_t block = this->bitmap[index] >> + (id & 63); // Equivalent of words[i] >> (index % 64) + // block is non zero after right shift, it means, next set bit is in current + // block The index of set bit is "given index" + "trailing zero in the right + // shifted word" + if (block != 0) { + return id + __builtin_ctzll(block); + } + + while (++index < this->n) { + block = this->bitmap[index]; + if (block != 0) { + return (index << 6) + __builtin_ctzll(block); + } + } + + return NO_MORE_DOCS; +} + +void IDGrouperBitmap::set_group(idx_t group_id) { + idx_t index = group_id >> 6; + this->bitmap[index] |= 1ULL + << (group_id & 63); // Equivalent of 1ULL << (value % 64) +} + +} // namespace faiss diff --git a/faiss/impl/IDGrouper.h b/faiss/impl/IDGrouper.h new file mode 100644 index 0000000000..d56113d971 --- /dev/null +++ b/faiss/impl/IDGrouper.h @@ -0,0 +1,51 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +/** IDGrouper is intended to define a group of vectors to include only + * the nearest vector of each group during search */ + +namespace faiss { + +/** Encapsulates a group id of ids */ +struct IDGrouper { + const idx_t NO_MORE_DOCS = std::numeric_limits::max(); + virtual idx_t get_group(idx_t id) const = 0; + virtual ~IDGrouper() {} +}; + +/** One bit per element. Constructed with a bitmap, size ceil(n / 8). + */ +struct IDGrouperBitmap : IDGrouper { + // length of the bitmap array + size_t n; + + // Array of uint64_t holding the bits + // Using uint64_t to leverage function __builtin_ctzll which is defined in + // faiss/impl/platform_macros.h Group id of a given id is next set bit in + // the bitmap + uint64_t* bitmap; + + /** Construct with a binary mask + * + * @param n size of the bitmap array + * @param bitmap group id of a given id is next set bit in the bitmap + */ + IDGrouperBitmap(size_t n, uint64_t* bitmap); + idx_t get_group(idx_t id) const final; + void set_group(idx_t group_id); + ~IDGrouperBitmap() override {} +}; + +} // namespace faiss diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index 713fe8e49f..d307fd70b0 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -13,6 +13,8 @@ #include #include +#include +#include #include #include #include @@ -267,6 +269,193 @@ struct HeapBlockResultHandler : BlockResultHandler { } }; +/***************************************************************** + * Heap based result handler with grouping + *****************************************************************/ + +template +struct GroupedHeapBlockResultHandler : BlockResultHandler { + using T = typename C::T; + using TI = typename C::TI; + using BlockResultHandler::i0; + using BlockResultHandler::i1; + + T* heap_dis_tab; + TI* heap_ids_tab; + int64_t k; // number of results to keep + + IDGrouper* id_grouper; + TI* heap_group_ids_tab; + std::unordered_map* group_id_to_index_in_heap_tab; + + GroupedHeapBlockResultHandler( + size_t nq, + T* heap_dis_tab, + TI* heap_ids_tab, + size_t k, + IDGrouper* id_grouper) + : BlockResultHandler(nq), + heap_dis_tab(heap_dis_tab), + heap_ids_tab(heap_ids_tab), + k(k), + id_grouper(id_grouper) {} + + /****************************************************** + * API for 1 result at a time (each SingleResultHandler is + * called from 1 thread) + */ + + struct SingleResultHandler : ResultHandler { + GroupedHeapBlockResultHandler& hr; + using ResultHandler::threshold; + size_t k; + + T* heap_dis; + TI* heap_ids; + TI* heap_group_ids; + std::unordered_map group_id_to_index_in_heap; + + explicit SingleResultHandler(GroupedHeapBlockResultHandler& hr) + : hr(hr), k(hr.k) {} + + /// begin results for query # i + void begin(size_t i) { + heap_dis = hr.heap_dis_tab + i * k; + heap_ids = hr.heap_ids_tab + i * k; + heap_heapify(k, heap_dis, heap_ids); + threshold = heap_dis[0]; + heap_group_ids = new TI[hr.k]; + for (size_t i = 0; i < hr.k; i++) { + heap_group_ids[i] = -1; + } + } + + /// add one result for query i + bool add_result(T dis, TI idx) final { + if (!C::cmp(threshold, dis)) { + return false; + } + + idx_t group_id = hr.id_grouper->get_group(idx); + typename std::unordered_map::const_iterator it_pos = + group_id_to_index_in_heap.find(group_id); + if (it_pos == group_id_to_index_in_heap.end()) { + group_heap_replace_top( + k, + heap_dis, + heap_ids, + heap_group_ids, + dis, + idx, + group_id, + &group_id_to_index_in_heap); + threshold = heap_dis[0]; + return true; + } else { + size_t pos = it_pos->second; + if (!C::cmp(heap_dis[pos], dis)) { + return false; + } + group_heap_replace_at( + pos, + k, + heap_dis, + heap_ids, + heap_group_ids, + dis, + idx, + group_id, + &group_id_to_index_in_heap); + threshold = heap_dis[0]; + return true; + } + } + + /// series of results for query i is done + void end() { + heap_reorder(k, heap_dis, heap_ids); + delete heap_group_ids; + } + }; + + /****************************************************** + * API for multiple results (called from 1 thread) + */ + + /// begin + void begin_multiple(size_t i0_2, size_t i1_2) final { + this->i0 = i0_2; + this->i1 = i1_2; + for (size_t i = i0; i < i1; i++) { + heap_heapify(k, heap_dis_tab + i * k, heap_ids_tab + i * k); + } + size_t size = (i1 - i0) * k; + heap_group_ids_tab = new TI[size]; + for (size_t i = 0; i < size; i++) { + heap_group_ids_tab[i] = -1; + } + group_id_to_index_in_heap_tab = + new std::unordered_map[i1 - i0]; + } + + /// add results for query i0..i1 and j0..j1 + void add_results(size_t j0, size_t j1, const T* dis_tab) final { +#pragma omp parallel for + for (int64_t i = i0; i < i1; i++) { + T* heap_dis = heap_dis_tab + i * k; + TI* heap_ids = heap_ids_tab + i * k; + const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; + T thresh = heap_dis[0]; // NOLINT(*-use-default-none) + for (size_t j = j0; j < j1; j++) { + T dis = dis_tab_i[j]; + if (C::cmp(thresh, dis)) { + idx_t group_id = id_grouper->get_group(j); + typename std::unordered_map::const_iterator + it_pos = group_id_to_index_in_heap_tab[i - i0].find( + group_id); + if (it_pos == group_id_to_index_in_heap_tab[i - i0].end()) { + group_heap_replace_top( + k, + heap_dis, + heap_ids, + heap_group_ids_tab + ((i - i0) * k), + dis, + j, + group_id, + &group_id_to_index_in_heap_tab[i - i0]); + thresh = heap_dis[0]; + } else { + size_t pos = it_pos->first; + if (C::cmp(heap_dis[pos], dis)) { + group_heap_replace_at( + pos, + k, + heap_dis, + heap_ids, + heap_group_ids_tab + ((i - i0) * k), + dis, + j, + group_id, + &group_id_to_index_in_heap_tab[i - i0]); + thresh = heap_dis[0]; + } + } + } + } + } + } + + /// series of results for queries i0..i1 is done + void end_multiple() final { + // maybe parallel for + for (size_t i = i0; i < i1; i++) { + heap_reorder(k, heap_dis_tab + i * k, heap_ids_tab + i * k); + } + delete group_id_to_index_in_heap_tab; + delete heap_group_ids_tab; + } +}; + /***************************************************************** * Reservoir result handler * diff --git a/faiss/utils/GroupHeap.h b/faiss/utils/GroupHeap.h new file mode 100644 index 0000000000..3b7078da6e --- /dev/null +++ b/faiss/utils/GroupHeap.h @@ -0,0 +1,182 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include + +namespace faiss { + +/** + * From start_index, it compare its value with parent node's and swap if needed. + * Continue until either there is no swap or it reaches the top node. + */ +template +static inline void group_up_heap( + typename C::T* heap_dis, + typename C::TI* heap_ids, + typename C::TI* heap_group_ids, + std::unordered_map* group_id_to_index_in_heap, + size_t start_index) { + heap_dis--; /* Use 1-based indexing for easier node->child translation */ + heap_ids--; + heap_group_ids--; + size_t i = start_index + 1, i_father; + typename C::T target_dis = heap_dis[i]; + typename C::TI target_id = heap_ids[i]; + typename C::TI target_group_id = heap_group_ids[i]; + + while (i > 1) { + i_father = i >> 1; + if (!C::cmp2( + target_dis, + heap_dis[i_father], + target_id, + heap_ids[i_father])) { + /* the heap structure is ok */ + break; + } + heap_dis[i] = heap_dis[i_father]; + heap_ids[i] = heap_ids[i_father]; + heap_group_ids[i] = heap_group_ids[i_father]; + (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; + i = i_father; + } + heap_dis[i] = target_dis; + heap_ids[i] = target_id; + heap_group_ids[i] = target_group_id; + (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; +} + +/** + * From start_index, it compare its value with child node's and swap if needed. + * Continue until either there is no swap or it reaches the leaf node. + */ +template +static inline void group_down_heap( + size_t k, + typename C::T* heap_dis, + typename C::TI* heap_ids, + typename C::TI* heap_group_ids, + std::unordered_map* group_id_to_index_in_heap, + size_t start_index) { + heap_dis--; /* Use 1-based indexing for easier node->child translation */ + heap_ids--; + heap_group_ids--; + size_t i = start_index + 1, i1, i2; + typename C::T target_dis = heap_dis[i]; + typename C::TI target_id = heap_ids[i]; + typename C::TI target_group_id = heap_group_ids[i]; + + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) { + break; + } + + // Note that C::cmp2() is a bool function answering + // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max + // heap and same with the `<` sign for min heap. + if ((i2 == k + 1) || + C::cmp2(heap_dis[i1], heap_dis[i2], heap_ids[i1], heap_ids[i2])) { + if (C::cmp2(target_dis, heap_dis[i1], target_id, heap_ids[i1])) { + break; + } + heap_dis[i] = heap_dis[i1]; + heap_ids[i] = heap_ids[i1]; + heap_group_ids[i] = heap_group_ids[i1]; + (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; + i = i1; + } else { + if (C::cmp2(target_dis, heap_dis[i2], target_id, heap_ids[i2])) { + break; + } + heap_dis[i] = heap_dis[i2]; + heap_ids[i] = heap_ids[i2]; + heap_group_ids[i] = heap_group_ids[i2]; + (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; + i = i2; + } + } + heap_dis[i] = target_dis; + heap_ids[i] = target_id; + heap_group_ids[i] = target_group_id; + (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; +} + +template +static inline void group_heap_replace_top( + size_t k, + typename C::T* heap_dis, + typename C::TI* heap_ids, + typename C::TI* heap_group_ids, + typename C::T dis, + typename C::TI id, + typename C::TI group_id, + std::unordered_map* group_id_to_index_in_heap) { + assert(group_id_to_index_in_heap->find(group_id) == + group_id_to_index_in_heap->end() && + "group id should not exist in the binary heap"); + + group_id_to_index_in_heap->erase(heap_group_ids[0]); + heap_group_ids[0] = group_id; + heap_dis[0] = dis; + heap_ids[0] = id; + (*group_id_to_index_in_heap)[group_id] = 0; + group_down_heap( + k, + heap_dis, + heap_ids, + heap_group_ids, + group_id_to_index_in_heap, + 0); +} + +template +static inline void group_heap_replace_at( + size_t pos, + size_t k, + typename C::T* heap_dis, + typename C::TI* heap_ids, + typename C::TI* heap_group_ids, + typename C::T dis, + typename C::TI id, + typename C::TI group_id, + std::unordered_map* group_id_to_index_in_heap) { + assert(group_id_to_index_in_heap->find(group_id) != + group_id_to_index_in_heap->end() && + "group id should exist in the binary heap"); + assert(group_id_to_index_in_heap->find(group_id)->second == pos && + "index of group id in the heap should be same as pos"); + + heap_dis[pos] = dis; + heap_ids[pos] = id; + group_up_heap( + heap_dis, heap_ids, heap_group_ids, group_id_to_index_in_heap, pos); + group_down_heap( + k, + heap_dis, + heap_ids, + heap_group_ids, + group_id_to_index_in_heap, + pos); +} + +} // namespace faiss \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3980d7dd7c..c888a5a60d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -27,6 +27,8 @@ set(FAISS_TEST_SRC test_approx_topk.cpp test_RCQ_cropping.cpp test_distances_simd.cpp + test_id_grouper.cpp + test_group_heap.cpp test_heap.cpp test_code_distance.cpp test_hnsw.cpp diff --git a/tests/test_group_heap.cpp b/tests/test_group_heap.cpp new file mode 100644 index 0000000000..0e8fe7a7e0 --- /dev/null +++ b/tests/test_group_heap.cpp @@ -0,0 +1,98 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include + +using namespace faiss; + +TEST(GroupHeap, group_heap_replace_top) { + using C = CMax; + const int k = 100; + float binary_heap_values[k]; + int64_t binary_heap_ids[k]; + heap_heapify(k, binary_heap_values, binary_heap_ids); + int64_t binary_heap_group_ids[k]; + for (size_t i = 0; i < k; i++) { + binary_heap_group_ids[i] = -1; + } + std::unordered_map group_id_to_index_in_heap; + for (int i = 1000; i > 0; i--) { + group_heap_replace_top( + k, + binary_heap_values, + binary_heap_ids, + binary_heap_group_ids, + i * 10.0, + i, + i, + &group_id_to_index_in_heap); + } + + heap_reorder(k, binary_heap_values, binary_heap_ids); + + for (int i = 0; i < k; i++) { + ASSERT_EQ((i + 1) * 10.0, binary_heap_values[i]); + ASSERT_EQ(i + 1, binary_heap_ids[i]); + } +} + +TEST(GroupHeap, group_heap_replace_at) { + using C = CMax; + const int k = 10; + float binary_heap_values[k]; + int64_t binary_heap_ids[k]; + heap_heapify(k, binary_heap_values, binary_heap_ids); + int64_t binary_heap_group_ids[k]; + for (size_t i = 0; i < k; i++) { + binary_heap_group_ids[i] = -1; + } + std::unordered_map group_id_to_index_in_heap; + + std::unordered_map group_id_to_id; + for (int i = 1000; i > 0; i--) { + int64_t group_id = rand() % 100; + group_id_to_id[group_id] = i; + if (group_id_to_index_in_heap.find(group_id) == + group_id_to_index_in_heap.end()) { + group_heap_replace_top( + k, + binary_heap_values, + binary_heap_ids, + binary_heap_group_ids, + i * 10.0, + i, + group_id, + &group_id_to_index_in_heap); + } else { + group_heap_replace_at( + group_id_to_index_in_heap.at(group_id), + k, + binary_heap_values, + binary_heap_ids, + binary_heap_group_ids, + i * 10.0, + i, + group_id, + &group_id_to_index_in_heap); + } + } + + heap_reorder(k, binary_heap_values, binary_heap_ids); + + std::vector sorted_ids; + for (const auto& pair : group_id_to_id) { + sorted_ids.push_back(pair.second); + } + std::sort(sorted_ids.begin(), sorted_ids.end()); + + for (int i = 0; i < k && binary_heap_ids[i] != -1; i++) { + ASSERT_EQ(sorted_ids[i] * 10.0, binary_heap_values[i]); + ASSERT_EQ(sorted_ids[i], binary_heap_ids[i]); + } +} diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp new file mode 100644 index 0000000000..6601795b8b --- /dev/null +++ b/tests/test_id_grouper.cpp @@ -0,0 +1,241 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// 64-bit int +using idx_t = faiss::idx_t; + +using namespace faiss; + +TEST(IdGrouper, get_group) { + uint64_t ids1[1] = {0b1000100010001000}; + IDGrouperBitmap bitmap(1, ids1); + + ASSERT_EQ(3, bitmap.get_group(0)); + ASSERT_EQ(3, bitmap.get_group(1)); + ASSERT_EQ(3, bitmap.get_group(2)); + ASSERT_EQ(3, bitmap.get_group(3)); + ASSERT_EQ(7, bitmap.get_group(4)); + ASSERT_EQ(7, bitmap.get_group(5)); + ASSERT_EQ(7, bitmap.get_group(6)); + ASSERT_EQ(7, bitmap.get_group(7)); + ASSERT_EQ(11, bitmap.get_group(8)); + ASSERT_EQ(11, bitmap.get_group(9)); + ASSERT_EQ(11, bitmap.get_group(10)); + ASSERT_EQ(11, bitmap.get_group(11)); + ASSERT_EQ(15, bitmap.get_group(12)); + ASSERT_EQ(15, bitmap.get_group(13)); + ASSERT_EQ(15, bitmap.get_group(14)); + ASSERT_EQ(15, bitmap.get_group(15)); + ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(16)); +} + +TEST(IdGrouper, set_group) { + idx_t group_ids[] = {64, 127, 128, 1022}; + uint64_t ids[16] = {}; // 1023 / 64 + 1 + IDGrouperBitmap bitmap(16, ids); + + for (int i = 0; i < 4; i++) { + bitmap.set_group(group_ids[i]); + } + + int group_id_index = 0; + for (int i = 0; i <= group_ids[3]; i++) { + ASSERT_EQ(group_ids[group_id_index], bitmap.get_group(i)); + if (group_ids[group_id_index] == i) { + group_id_index++; + } + } + ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1)); +} + +TEST(IdGrouper, sanity_test) { + int d = 1; // dimension + int nb = 10; // database size + + std::mt19937 rng; + std::uniform_real_distribution<> distrib; + + float* xb = new float[d * nb]; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) + xb[d * i + j] = distrib(rng); + xb[d * i] += i / 1000.; + } + + uint64_t bitmap[1] = {}; + faiss::IDGrouperBitmap id_grouper(1, bitmap); + for (int i = 0; i < nb; i++) { + id_grouper.set_group(i); + } + + int k = 5; + int m = 8; + faiss::Index* index = + new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); + index->add(nb, xb); // add vectors to the index + + // search + auto pSearchParameters = new faiss::SearchParametersHNSW(); + + idx_t* expectedI = new idx_t[k]; + float* expectedD = new float[k]; + index->search(1, xb, k, expectedD, expectedI, pSearchParameters); + + idx_t* I = new idx_t[k]; + float* D = new float[k]; + pSearchParameters->grp = &id_grouper; + index->search(1, xb, k, D, I, pSearchParameters); + + // compare + for (int j = 0; j < k; j++) { + ASSERT_EQ(expectedI[j], I[j]); + ASSERT_EQ(expectedD[j], D[j]); + } + + delete[] expectedI; + delete[] expectedD; + delete[] I; + delete[] D; + delete[] xb; +} + +TEST(IdGrouper, bitmap_with_hnsw) { + int d = 1; // dimension + int nb = 10; // database size + + std::mt19937 rng; + std::uniform_real_distribution<> distrib; + + float* xb = new float[d * nb]; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) + xb[d * i + j] = distrib(rng); + xb[d * i] += i / 1000.; + } + + uint64_t bitmap[1] = {}; + faiss::IDGrouperBitmap id_grouper(1, bitmap); + for (int i = 0; i < nb; i++) { + if (i % 2 == 1) { + id_grouper.set_group(i); + } + } + + int k = 10; + int m = 8; + faiss::Index* index = + new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); + index->add(nb, xb); // add vectors to the index + + // search + idx_t* I = new idx_t[k]; + float* D = new float[k]; + + auto pSearchParameters = new faiss::SearchParametersHNSW(); + pSearchParameters->grp = &id_grouper; + + index->search(1, xb, k, D, I, pSearchParameters); + + std::unordered_set group_ids; + ASSERT_EQ(0, I[0]); + ASSERT_EQ(0, D[0]); + group_ids.insert(id_grouper.get_group(I[0])); + for (int j = 1; j < 5; j++) { + ASSERT_NE(-1, I[j]); + ASSERT_NE(std::numeric_limits::max(), D[j]); + group_ids.insert(id_grouper.get_group(I[j])); + } + for (int j = 5; j < k; j++) { + ASSERT_EQ(-1, I[j]); + ASSERT_EQ(std::numeric_limits::max(), D[j]); + } + ASSERT_EQ(5, group_ids.size()); + + delete[] I; + delete[] D; + delete[] xb; +} + +TEST(IdGrouper, bitmap_with_hnswn_idmap) { + int d = 1; // dimension + int nb = 10; // database size + + std::mt19937 rng; + std::uniform_real_distribution<> distrib; + + float* xb = new float[d * nb]; + idx_t* xids = new idx_t[d * nb]; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) + xb[d * i + j] = distrib(rng); + xb[d * i] += i / 1000.; + } + + uint64_t bitmap[1] = {}; + faiss::IDGrouperBitmap id_grouper(1, bitmap); + int num_grp = 0; + int grp_size = 2; + int id_in_grp = 0; + for (int i = 0; i < nb; i++) { + xids[i] = i + num_grp; + id_in_grp++; + if (id_in_grp == grp_size) { + id_grouper.set_group(i + num_grp + 1); + num_grp++; + id_in_grp = 0; + } + } + + int k = 10; + int m = 8; + faiss::Index* index = + new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); + faiss::IndexIDMap id_map = + faiss::IndexIDMap(index); // add vectors to the index + id_map.add_with_ids(nb, xb, xids); + + // search + idx_t* I = new idx_t[k]; + float* D = new float[k]; + + auto pSearchParameters = new faiss::SearchParametersHNSW(); + pSearchParameters->grp = &id_grouper; + + id_map.search(1, xb, k, D, I, pSearchParameters); + + std::unordered_set group_ids; + ASSERT_EQ(0, I[0]); + ASSERT_EQ(0, D[0]); + group_ids.insert(id_grouper.get_group(I[0])); + for (int j = 1; j < 5; j++) { + ASSERT_NE(-1, I[j]); + ASSERT_NE(std::numeric_limits::max(), D[j]); + group_ids.insert(id_grouper.get_group(I[j])); + } + for (int j = 5; j < k; j++) { + ASSERT_EQ(-1, I[j]); + ASSERT_EQ(std::numeric_limits::max(), D[j]); + } + ASSERT_EQ(5, group_ids.size()); + + delete[] I; + delete[] D; + delete[] xb; +}