diff --git a/CHANGELOG.md b/CHANGELOG.md index ffd99a595..b55634b53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,4 +19,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Infrastructure ### Documentation ### Maintenance +* Bump faiss lib commit to 32f0e8cf92cd2275b60364517bb1cce67aa29a55 [#1443](https://github.com/opensearch-project/k-NN/pull/1443) ### Refactoring diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 0bec4b945..0f0b58738 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -181,17 +181,15 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S ${TARGET_LIB_FAISS} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/utils/BitSet.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollector.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_util.cpp + ) target_link_libraries(${TARGET_LIB_FAISS} ${TARGET_LINK_FAISS_LIB} ${TARGET_LIB_COMMON} OpenMP::OpenMP_CXX) target_include_directories(${TARGET_LIB_FAISS} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss - ${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss/utils $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE} - ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss) + ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss + ) set_target_properties(${TARGET_LIB_FAISS} PROPERTIES SUFFIX ${LIB_EXT}) set_target_properties(${TARGET_LIB_FAISS} PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -228,12 +226,9 @@ if ("${WIN32}" STREQUAL "") add_executable( jni_test tests/faiss_wrapper_test.cpp + tests/faiss_util_test.cpp tests/nmslib_wrapper_test.cpp tests/test_util.cpp - tests/knn_extension/faiss/utils/BitSetTest.cpp - tests/knn_extension/faiss/utils/HeapTest.cpp - tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp - tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp ) target_link_libraries( diff --git a/jni/external/faiss b/jni/external/faiss index 0013c702f..32f0e8cf9 160000 --- a/jni/external/faiss +++ b/jni/external/faiss @@ -1 +1 @@ -Subproject commit 0013c702f47bedbf6159ac356e61f378ccd12ac8 +Subproject commit 32f0e8cf92cd2275b60364517bb1cce67aa29a55 diff --git a/jni/include/faiss_util.h b/jni/include/faiss_util.h new file mode 100644 index 000000000..f23540aef --- /dev/null +++ b/jni/include/faiss_util.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +/** + * This file contains util methods which are free of JNI to be used in faiss_wrapper.cpp + */ + +#ifndef OPENSEARCH_KNN_FAISS_UTIL_H +#define OPENSEARCH_KNN_FAISS_UTIL_H + +#include "faiss/impl/IDGrouper.h" +#include + +namespace faiss_util { + std::unique_ptr buildIDGrouperBitmap(int *parentIdsArray, int parentIdsLength, std::vector* bitmap); +}; + + +#endif //OPENSEARCH_KNN_FAISS_UTIL_H diff --git a/jni/include/knn_extension/faiss/MultiVectorResultCollector.h b/jni/include/knn_extension/faiss/MultiVectorResultCollector.h deleted file mode 100644 index a11a278d9..000000000 --- a/jni/include/knn_extension/faiss/MultiVectorResultCollector.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include -#include -#include "knn_extension/faiss/utils/BitSet.h" -#include - -namespace os_faiss { - -using idx_t = faiss::idx_t; -/** - * Implementation of ResultCollector to support multi vector - * - * Only supports HNSW algorithm - * - * Example: - * When there is two lucene document with two nested fields, the parent_bit_set value of 100100 is provided where - * parent doc ids are 2, and 5. Doc id for nested fields of parent document 2 are 0, and 1. Doc id for nested fields - * of parent document 5 are 3, and 4. For faiss, only nested fields are stored. Therefore corresponding doc ids for - * nested fields 0, 1, 3, 4 is 0, 1, 2, 3 in faiss. This mapping data is stored in id_map parameter. - * - * When collect method is called - * 1. It switches from faiss id to lucene id and look for its parent id. - * 2. See if the parent id already exist in heap using either parent_id_to_id or parent_id_to_index. - * 3. If it does not exist, add the parent id and distance value in the heap(bh_ids, bh_val) and update parent_id_to_id, and parent_id_to_index. - * 4. If it does exist, update the distance value(bh_val), parent_id_to_id, and parent_id_to_index. - * - * When post_process method is called - * 1. Convert lucene parent ID to faiss doc ID using parent_id_to_id - */ -struct MultiVectorResultCollector:faiss::ResultCollector { - // BitSet of lucene parent doc ID - const BitSet* parent_bit_set; - - // Mapping data from Faiss doc ID to Lucene doc ID - const std::vector* id_map; - - // Lucene parent doc ID to to Faiss doc ID - // Lucene parent doc ID to index in heap(bh_val, bh_ids) - std::unordered_map parent_id_to_id; - std::unordered_map parent_id_to_index; - MultiVectorResultCollector(const BitSet* parent_bit_set, const std::vector* id_map); - - /** - * - * @param k max size of bh_val, and bh_ids - * @param nres number of results in bh_val, and bh_ids - * @param bh_val binary heap storing values (For this case distance from query to result) - * @param bh_ids binary heap storing document IDs - * @param val a new value to add in bh_val - * @param ids a new doc id to add in bh_ids - */ - void collect( - int k, - int& nres, - float* bh_val, - int64_t* bh_ids, - float val, - int64_t ids) override; - void post_process(int64_t nres, int64_t* bh_ids) override; -}; - -} // namespace os_faiss - diff --git a/jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h b/jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h deleted file mode 100644 index 45c0338b3..000000000 --- a/jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include -#include "knn_extension/faiss/utils/BitSet.h" - -namespace os_faiss { -/** - * Create MultiVectorResultCollector for single query request - * - * Creating new collector is required because MultiVectorResultCollector has instance variables - * which should be isolated for each query. - */ -struct MultiVectorResultCollectorFactory:faiss::ResultCollectorFactory { - BitSet* parent_bit_set; - - MultiVectorResultCollectorFactory(BitSet* parent_bit_set); - faiss::ResultCollector* new_collector() override; - void delete_collector(faiss::ResultCollector* resultCollector) override; -}; - -} // namespace os_faiss diff --git a/jni/include/knn_extension/faiss/utils/BitSet.h b/jni/include/knn_extension/faiss/utils/BitSet.h deleted file mode 100644 index 0c8079d37..000000000 --- a/jni/include/knn_extension/faiss/utils/BitSet.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include -#include -#include - -using idx_t = faiss::idx_t; - -struct BitSet { - const int NO_MORE_DOCS = std::numeric_limits::max(); - /** - * Returns the index of the first set bit starting at the index specified. - * NO_MORE_DOCS is returned if there are no more set bits. - */ - virtual idx_t next_set_bit(idx_t index) const = 0; - virtual ~BitSet() = default; -}; - - -/** - * BitSet of fixed length (numBits), implemented using an array of unit64. - * See https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java - * - * Here a block is 64 bit. However, for simplicity let's assume its size is 8 bits. - * Then, if have an array of 3, 7, and 10, it will be represented in bitmap as follow. - * [0] [1] - * bitmap: 10001000 00000100 - * - * for next_set_bit call with 4 - * 1. it looks for words[0] - * 2. words[0] >> 4 - * 3. count trailing zero of the result from step 2 which is 3 - * 4. return 4(current index) + 3(result from step 3) - */ -struct FixedBitSet : public BitSet { - // The number of bits in use - idx_t num_bits; - - // The exact number of longs needed to hold num_bits - size_t num_words; - - // Array of uint64_t holding the bits - // Using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h - uint64_t* words; - - FixedBitSet(const int* int_array, const int length); - idx_t next_set_bit(idx_t index) const; - ~FixedBitSet(); -}; diff --git a/jni/include/knn_extension/faiss/utils/Heap.h b/jni/include/knn_extension/faiss/utils/Heap.h deleted file mode 100644 index 2aa19da52..000000000 --- a/jni/include/knn_extension/faiss/utils/Heap.h +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include - -#include -#include -#include - -// Collection of heap operations with parent id to dedupe -namespace os_faiss { - -/** - * From start_index, it compare its value with parent node's and swap if needed. - * Continue until either there is no swap or it reaches the top node. - * - * @param bh_val binary heap storing values - * @param bh_ids binary heap storing parent ids - * @param val new value to add - * @param id new id to add - * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h - * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h - * @parent_id parent id of given id - * @start_index an index to start up-heap from in the binary heap(bh_val, and bh_ids) - */ -template -static inline void up_heap( - typename C::T* bh_val, - typename C::TI* bh_ids, - typename C::T val, - typename C::TI id, - std::unordered_map* parent_id_to_id, - std::unordered_map* parent_id_to_index, - typename C::TI parent_id, - size_t start_index) { - bh_val--; /* Use 1-based indexing for easier node->child translation */ - bh_ids--; - size_t i = start_index + 1, i_father; - - while (i > 1) { - i_father = i >> 1; - if (!C::cmp2(val, bh_val[i_father], parent_id, bh_ids[i_father])) { - /* the heap structure is ok */ - break; - } - bh_val[i] = bh_val[i_father]; - bh_ids[i] = bh_ids[i_father]; - (*parent_id_to_index)[bh_ids[i]] = i - 1; - i = i_father; - } - bh_val[i] = val; - bh_ids[i] = parent_id; - (*parent_id_to_id)[parent_id] = id; - (*parent_id_to_index)[parent_id] = i - 1; -} - -/** - * From start_index, it compare its value with child node's and swap if needed. - * Continue until either there is no swap or it reaches the leaf node. - * - * @param nres number of values in the binary heap(bh_val, and bh_ids) - * @param bh_val binary heap storing values - * @param bh_ids binary heap storing parent ids - * @param val new value to add - * @param id new id to add - * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h - * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h - * @parent_id parent id of given id - * @start_index an index to start up-heap from in the binary heap(bh_val, and bh_ids) - */ -template -static inline void down_heap( - int nres, - typename C::T* bh_val, - typename C::TI* bh_ids, - typename C::T val, - typename C::TI id, - std::unordered_map* parent_id_to_id, - std::unordered_map* parent_id_to_index, - typename C::TI parent_id, - size_t start_index) { - bh_val--; /* Use 1-based indexing for easier node->child translation */ - bh_ids--; - size_t i = start_index + 1, i1, i2; - - while (1) { - i1 = i << 1; - i2 = i1 + 1; - if (i1 > nres) { - break; - } - - // Note that C::cmp2() is a bool function answering - // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max - // heap and same with the `<` sign for min heap. - if ((i2 == nres + 1) || - C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) { - if (C::cmp2(val, bh_val[i1], parent_id, bh_ids[i1])) { - break; - } - bh_val[i] = bh_val[i1]; - bh_ids[i] = bh_ids[i1]; - (*parent_id_to_index)[bh_ids[i]] = i - 1; - i = i1; - } else { - if (C::cmp2(val, bh_val[i2], parent_id, bh_ids[i2])) { - break; - } - bh_val[i] = bh_val[i2]; - bh_ids[i] = bh_ids[i2]; - (*parent_id_to_index)[bh_ids[i]] = i - 1; - i = i2; - } - } - bh_val[i] = val; - bh_ids[i] = parent_id; - (*parent_id_to_id)[parent_id] = id; - (*parent_id_to_index)[parent_id] = i - 1; -} - -/** - * Push the value to the max heap - * As the heap contains only one value per group id, pushing a value of existing group id - * will break the data integrity. For existing group id, use maxheap_update instead. - * The parent_id should not exist in in bh_ids, parent_id_to_id, and parent_id_to_index. - * - * @param nres number of values in the binary heap(bh_val, and bh_ids) - * @param bh_val binary heap storing values - * @param bh_ids binary heap storing parent ids - * @param val new value to add - * @param id new id to add - * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h - * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h - * @parent_id parent id of given id - */ -template -inline void maxheap_push( - int nres, - T* bh_val, - int64_t* bh_ids, - T val, - int64_t id, - std::unordered_map* parent_id_to_id, - std::unordered_map* parent_id_to_index, - int64_t parent_id) { - - assert(parent_id_to_index->find(parent_id) == parent_id_to_index->end() && "parent id should not exist in the binary heap"); - - up_heap>( - bh_val, - bh_ids, - val, - id, - parent_id_to_id, - parent_id_to_index, - parent_id, - nres); -} - -/** - * Update the top node with given value - * The parent_id should not exist in in bh_ids, parent_id_to_id, and parent_id_to_index. - * - * @param nres number of values in the binary heap(bh_val, and bh_ids) - * @param bh_val binary heap storing values - * @param bh_ids binary heap storing parent ids - * @param val new value to add - * @param id new id to add - * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h - * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h - * @parent_id parent id of given id - */ -template -inline void maxheap_replace_top( - int nres, - T* bh_val, - int64_t* bh_ids, - T val, - int64_t id, - std::unordered_map* parent_id_to_id, - std::unordered_map* parent_id_to_index, - int64_t parent_id) { - - assert(parent_id_to_index->find(parent_id) == parent_id_to_index->end() && "parent id should not exist in the binary heap"); - - parent_id_to_id->erase(bh_ids[0]); - parent_id_to_index->erase(bh_ids[0]); - down_heap>( - nres, - bh_val, - bh_ids, - val, - id, - parent_id_to_id, - parent_id_to_index, - parent_id, - 0); -} - -/** - * Update value of the parent_id in the binary heap and id of the parent_id in parent_id_to_id - * The parent_id should exist in bh_ids, parent_id_to_id, and parent_id_to_index. - * - * @param nres number of values in the binary heap(bh_val, and bh_ids) - * @param bh_val binary heap storing values - * @param bh_ids binary heap storing parent ids - * @param val new value to update - * @param id new id to update - * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h - * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h - * @parent_id parent id of given id - */ -template -inline void maxheap_update( - int nres, - T* bh_val, - int64_t* bh_ids, - T val, - int64_t id, - std::unordered_map* parent_id_to_id, - std::unordered_map* parent_id_to_index, - int64_t parent_id) { - size_t target_index = parent_id_to_index->at(parent_id); - up_heap>( - bh_val, - bh_ids, - val, - id, - parent_id_to_id, - parent_id_to_index, - parent_id, - target_index); - down_heap>( - nres, - bh_val, - bh_ids, - val, - id, - parent_id_to_id, - parent_id_to_index, - parent_id, - target_index); -} - -} // namespace os_faiss diff --git a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch b/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch index b07620c7e..a22e28130 100644 --- a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch +++ b/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch @@ -1,90 +1,131 @@ -From baa7e23c54637d68adac45f09633939b402405a9 Mon Sep 17 00:00:00 2001 +From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001 From: Heemin Kim -Date: Wed, 6 Dec 2023 16:33:52 -0800 -Subject: [PATCH] Custom patch to support multi-vector +Date: Tue, 30 Jan 2024 14:43:56 -0800 +Subject: [PATCH] Add IDGrouper for HNSW Signed-off-by: Heemin Kim --- - faiss/CMakeLists.txt | 2 + - faiss/Index.h | 6 ++- - faiss/IndexIDMap.cpp | 24 ++++++++++ - faiss/impl/HNSW.cpp | 27 +++++++---- - faiss/impl/ResultCollector.h | 74 +++++++++++++++++++++++++++++ - faiss/impl/ResultCollectorFactory.h | 33 +++++++++++++ - 6 files changed, 154 insertions(+), 12 deletions(-) - create mode 100644 faiss/impl/ResultCollector.h - create mode 100644 faiss/impl/ResultCollectorFactory.h + faiss/CMakeLists.txt | 3 + + faiss/Index.h | 8 +- + faiss/IndexHNSW.cpp | 13 ++- + faiss/IndexIDMap.cpp | 29 ++++++ + faiss/IndexIDMap.h | 22 +++++ + faiss/impl/HNSW.cpp | 10 +- + faiss/impl/IDGrouper.cpp | 51 ++++++++++ + faiss/impl/IDGrouper.h | 51 ++++++++++ + faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++ + faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++ + tests/CMakeLists.txt | 2 + + tests/test_group_heap.cpp | 98 +++++++++++++++++++ + tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++ + 13 files changed, 838 insertions(+), 7 deletions(-) + create mode 100644 faiss/impl/IDGrouper.cpp + create mode 100644 faiss/impl/IDGrouper.h + create mode 100644 faiss/utils/GroupHeap.h + create mode 100644 tests/test_group_heap.cpp + create mode 100644 tests/test_id_grouper.cpp diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt -index 27701586..af682a05 100644 +index a890a46f..137e68d4 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt -@@ -162,6 +162,8 @@ set(FAISS_HEADERS - impl/ProductQuantizer.h - impl/Quantizer.h - impl/ResidualQuantizer.h -+ impl/ResultCollector.h -+ impl/ResultCollectorFactory.h - impl/ResultHandler.h - impl/ScalarQuantizer.h - impl/ThreadedIndex-inl.h +@@ -54,6 +54,7 @@ set(FAISS_SRC + impl/AuxIndexStructures.cpp + impl/CodePacker.cpp + impl/IDSelector.cpp ++ impl/IDGrouper.cpp + impl/FaissException.cpp + impl/HNSW.cpp + impl/NSG.cpp +@@ -149,6 +150,7 @@ set(FAISS_HEADERS + impl/AuxIndexStructures.h + impl/CodePacker.h + impl/IDSelector.h ++ impl/IDGrouper.h + impl/DistanceComputer.h + impl/FaissAssert.h + impl/FaissException.h +@@ -183,6 +185,7 @@ set(FAISS_HEADERS + invlists/InvertedLists.h + invlists/InvertedListsIOHook.h + utils/AlignedTable.h ++ utils/GroupHeap.h + utils/Heap.h + utils/WorkerThread.h + utils/distances.h diff --git a/faiss/Index.h b/faiss/Index.h -index 4b4b302b..13eab0c0 100644 +index 4b4b302b..3b673d1e 100644 --- a/faiss/Index.h +++ b/faiss/Index.h -@@ -38,11 +38,12 @@ +@@ -38,9 +38,10 @@ namespace faiss { -/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and -/// impl/DistanceComputer.h -+/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h, -+/// impl/DistanceComputer.h, and impl/ResultCollectorFactory.h ++/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h ++/// ,impl/IDGrouper.h and impl/DistanceComputer.h struct IDSelector; ++struct IDGrouper; struct RangeSearchResult; struct DistanceComputer; -+struct ResultCollectorFactory; - /** Parent class for the optional search paramenters. - * -@@ -52,6 +53,7 @@ struct DistanceComputer; +@@ -52,6 +53,9 @@ struct DistanceComputer; struct SearchParameters { /// if non-null, only these IDs will be considered during search. IDSelector* sel = nullptr; -+ ResultCollectorFactory* col = nullptr; ++ /// if non-null, only best matched ID per group will be included in the ++ /// result. ++ IDGrouper* grp = nullptr; /// make sure we can dynamic_cast this virtual ~SearchParameters() {} }; +diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp +index 9a67332d..a5e0fea0 100644 +--- a/faiss/IndexHNSW.cpp ++++ b/faiss/IndexHNSW.cpp +@@ -354,10 +354,17 @@ void IndexHNSW::search( + const SearchParameters* params_in) const { + FAISS_THROW_IF_NOT(k > 0); + +- using RH = HeapBlockResultHandler; +- RH bres(n, distances, labels, k); ++ if (params_in && params_in->grp) { ++ using RH = GroupedHeapBlockResultHandler; ++ RH bres(n, distances, labels, k, params_in->grp); + +- hnsw_search(this, n, x, bres, params_in); ++ hnsw_search(this, n, x, bres, params_in); ++ } else { ++ using RH = HeapBlockResultHandler; ++ RH bres(n, distances, labels, k); ++ ++ hnsw_search(this, n, x, bres, params_in); ++ } + + if (is_similarity_metric(this->metric_type)) { + // we need to revert the negated distances diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp -index 7972bec9..0f82a17c 100644 +index e093bbda..e24365d5 100644 --- a/faiss/IndexIDMap.cpp +++ b/faiss/IndexIDMap.cpp -@@ -18,6 +18,7 @@ - #include - #include - #include -+#include - - namespace faiss { - -@@ -102,6 +103,24 @@ struct ScopedSelChange { +@@ -102,6 +102,23 @@ struct ScopedSelChange { } }; -+// RAII object to reset the id_map parameter in ResultCollectorFactory object -+// This object make sure to reset the id_map parameter in ResultCollectorFactory -+// once the program exist current method scope. -+struct ScopedColChange { -+ ResultCollectorFactory* collector_factory = nullptr; -+ void set( -+ ResultCollectorFactory* collector_factory, -+ const std::vector* id_map) { -+ this->collector_factory = collector_factory; -+ collector_factory->id_map = id_map; ++/// RAII object to reset the IDGrouper in the params object ++struct ScopedGrpChange { ++ SearchParameters* params = nullptr; ++ IDGrouper* old_grp = nullptr; ++ ++ void set(SearchParameters* params_2, IDGrouper* new_grp) { ++ this->params = params_2; ++ old_grp = params_2->grp; ++ params_2->grp = new_grp; + } -+ ~ScopedColChange() { -+ if (collector_factory) { -+ collector_factory->id_map = nullptr; ++ ~ScopedGrpChange() { ++ if (params) { ++ params->grp = old_grp; + } + } +}; @@ -92,97 +133,161 @@ index 7972bec9..0f82a17c 100644 } // namespace template -@@ -114,6 +133,7 @@ void IndexIDMapTemplate::search( +@@ -114,6 +131,8 @@ void IndexIDMapTemplate::search( const SearchParameters* params) const { IDSelectorTranslated this_idtrans(this->id_map, nullptr); ScopedSelChange sel_change; -+ ScopedColChange col_change; ++ IDGrouperTranslated this_idgrptrans(this->id_map, nullptr); ++ ScopedGrpChange grp_change; if (params && params->sel) { auto idtrans = dynamic_cast(params->sel); -@@ -131,6 +151,10 @@ void IndexIDMapTemplate::search( +@@ -131,6 +150,16 @@ void IndexIDMapTemplate::search( sel_change.set(params_non_const, &this_idtrans); } } + -+ if (params && params->col && !params->col->id_map) { -+ col_change.set(params->col, &this->id_map); ++ if (params && params->grp) { ++ auto idtrans = dynamic_cast(params->grp); ++ ++ if (!idtrans) { ++ auto params_non_const = const_cast(params); ++ this_idgrptrans.grp = params->grp; ++ grp_change.set(params_non_const, &this_idgrptrans); ++ } + } index->search(n, x, k, distances, labels, params); idx_t* li = labels; #pragma omp parallel for -diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp -index 9fc201ea..5b5900d1 100644 ---- a/faiss/impl/HNSW.cpp -+++ b/faiss/impl/HNSW.cpp -@@ -14,6 +14,7 @@ - #include - #include - #include -+#include - #include +diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h +index 2d164123..a68887bd 100644 +--- a/faiss/IndexIDMap.h ++++ b/faiss/IndexIDMap.h +@@ -9,6 +9,7 @@ - #include -@@ -530,6 +531,15 @@ int search_from_candidates( - int level, - int nres_in = 0, - const SearchParametersHNSW* params = nullptr) { -+ ResultCollectorFactory defaultFactory; -+ ResultCollectorFactory* collectorFactory; -+ if (params == nullptr || params->col == nullptr) { -+ collectorFactory = &defaultFactory; -+ } else { -+ collectorFactory = params->col; -+ } -+ ResultCollector* collector = collectorFactory->new_collector(); -+ - int nres = nres_in; - int ndis = 0; + #include + #include ++#include + #include -@@ -544,11 +554,7 @@ int search_from_candidates( - float d = candidates.dis[i]; - FAISS_ASSERT(v1 >= 0); - if (!sel || sel->is_member(v1)) { -- if (nres < k) { -- faiss::maxheap_push(++nres, D, I, d, v1); -- } else if (d < D[0]) { -- faiss::maxheap_replace_top(nres, D, I, d, v1); -- } -+ collector->collect(k, nres, D, I, d, v1); - } - vt.set(v1); + #include +@@ -124,4 +125,25 @@ struct IDSelectorTranslated : IDSelector { } -@@ -612,11 +618,7 @@ int search_from_candidates( + }; - auto add_to_heap = [&](const size_t idx, const float dis) { - if (!sel || sel->is_member(idx)) { -- if (nres < k) { -- faiss::maxheap_push(++nres, D, I, dis, idx); -- } else if (dis < D[0]) { -- faiss::maxheap_replace_top(nres, D, I, dis, idx); -- } -+ collector->collect(k, nres, D, I, dis, idx); - } - candidates.push(idx, dis); - }; -@@ -660,6 +662,11 @@ int search_from_candidates( - } ++// IDGrouper that translates the ids using an IDMap ++struct IDGrouperTranslated : IDGrouper { ++ const std::vector& id_map; ++ const IDGrouper* grp; ++ ++ IDGrouperTranslated( ++ const std::vector& id_map, ++ const IDGrouper* grp) ++ : id_map(id_map), grp(grp) {} ++ ++ IDGrouperTranslated(IndexBinaryIDMap& index_idmap, const IDGrouper* grp) ++ : id_map(index_idmap.id_map), grp(grp) {} ++ ++ IDGrouperTranslated(IndexIDMap& index_idmap, const IDGrouper* grp) ++ : id_map(index_idmap.id_map), grp(grp) {} ++ ++ idx_t get_group(idx_t id) const override { ++ return grp->get_group(id_map[id]); ++ } ++}; ++ + } // namespace faiss +diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp +index fb4de678..b6f602a0 100644 +--- a/faiss/impl/HNSW.cpp ++++ b/faiss/impl/HNSW.cpp +@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const { + level, + nb_neighbors(level)); + size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; +-#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \ +- reduction(+: tot_reciprocal) reduction(+: n_node) ++#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \ ++ reduction(+ : tot_reciprocal) reduction(+ : n_node) + for (int i = 0; i < levels.size(); i++) { + if (levels[i] > level) { + n_node++; +@@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler& res) { + if (auto hres = dynamic_cast(&res)) { + return hres->k; } ++ ++ if (auto hres = dynamic_cast< ++ GroupedHeapBlockResultHandler::SingleResultHandler*>(&res)) { ++ return hres->k; ++ } ++ + return 1; + } -+ // Completed collection of result. Run post processor. -+ collector->post_process(nres, I); -+ // Collector completed its task. Release all resource of the collector. -+ collectorFactory->delete_collector(collector); -+ - if (level == 0) { - stats.n1++; - if (candidates.size() == 0) { -diff --git a/faiss/impl/ResultCollector.h b/faiss/impl/ResultCollector.h +diff --git a/faiss/impl/IDGrouper.cpp b/faiss/impl/IDGrouper.cpp +new file mode 100644 +index 00000000..ca9f5fda +--- /dev/null ++++ b/faiss/impl/IDGrouper.cpp +@@ -0,0 +1,51 @@ ++/** ++ * Copyright (c) Facebook, Inc. and its affiliates. ++ * ++ * This source code is licensed under the MIT license found in the ++ * LICENSE file in the root directory of this source tree. ++ */ ++ ++#include ++#include ++#include ++ ++namespace faiss { ++ ++/*********************************************************************** ++ * IDGrouperBitmap ++ ***********************************************************************/ ++ ++IDGrouperBitmap::IDGrouperBitmap(size_t n, uint64_t* bitmap) ++ : n(n), bitmap(bitmap) {} ++ ++idx_t IDGrouperBitmap::get_group(idx_t id) const { ++ assert(id >= 0 && "id shouldn't be less than zero"); ++ assert(id < this->n * 64 && "is should be less than total number of bits"); ++ ++ idx_t index = id >> 6; // div by 64 ++ uint64_t block = this->bitmap[index] >> ++ (id & 63); // Equivalent of words[i] >> (index % 64) ++ // block is non zero after right shift, it means, next set bit is in current ++ // block The index of set bit is "given index" + "trailing zero in the right ++ // shifted word" ++ if (block != 0) { ++ return id + __builtin_ctzll(block); ++ } ++ ++ while (++index < this->n) { ++ block = this->bitmap[index]; ++ if (block != 0) { ++ return (index << 6) + __builtin_ctzll(block); ++ } ++ } ++ ++ return NO_MORE_DOCS; ++} ++ ++void IDGrouperBitmap::set_group(idx_t group_id) { ++ idx_t index = group_id >> 6; ++ this->bitmap[index] |= 1ULL ++ << (group_id & 63); // Equivalent of 1ULL << (value % 64) ++} ++ ++} // namespace faiss +diff --git a/faiss/impl/IDGrouper.h b/faiss/impl/IDGrouper.h new file mode 100644 -index 00000000..a0489fd6 +index 00000000..d56113d9 --- /dev/null -+++ b/faiss/impl/ResultCollector.h -@@ -0,0 +1,74 @@ ++++ b/faiss/impl/IDGrouper.h +@@ -0,0 +1,51 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * @@ -192,77 +297,259 @@ index 00000000..a0489fd6 + +#pragma once + ++#include +#include +#include + +#include -+#include + -+/** -+ * ResultCollector is intended to define how to collect search result -+ * For each single search result, collect method will be called. -+ * After every results are collected, post_process method is called at the end. -+ */ ++/** IDGrouper is intended to define a group of vectors to include only ++ * the nearest vector of each group during search */ + +namespace faiss { + -+/** Encapsulates a set of ids to handle. */ -+struct ResultCollector { -+ /** -+ * For each result, collect method is called to store result -+ * @param k number of vectors to search -+ * @param nres number of results in queue -+ * @param bh_val search result, distances from query -+ * @param bh_ids search result, ids of vectors -+ * @param val distance from query for current vector -+ * @param ids id of current vector ++/** Encapsulates a group id of ids */ ++struct IDGrouper { ++ const idx_t NO_MORE_DOCS = std::numeric_limits::max(); ++ virtual idx_t get_group(idx_t id) const = 0; ++ virtual ~IDGrouper() {} ++}; ++ ++/** One bit per element. Constructed with a bitmap, size ceil(n / 8). ++ */ ++struct IDGrouperBitmap : IDGrouper { ++ // length of the bitmap array ++ size_t n; ++ ++ // Array of uint64_t holding the bits ++ // Using uint64_t to leverage function __builtin_ctzll which is defined in ++ // faiss/impl/platform_macros.h Group id of a given id is next set bit in ++ // the bitmap ++ uint64_t* bitmap; ++ ++ /** Construct with a binary mask ++ * ++ * @param n size of the bitmap array ++ * @param bitmap group id of a given id is next set bit in the bitmap + */ -+ virtual void collect( -+ int k, -+ int& nres, -+ float* bh_val, -+ idx_t* bh_ids, -+ float val, -+ idx_t ids) = 0; -+ -+ // This method is called after all result is collected -+ virtual void post_process(idx_t nres, idx_t* bh_ids) = 0; -+ virtual ~ResultCollector() {} ++ IDGrouperBitmap(size_t n, uint64_t* bitmap); ++ idx_t get_group(idx_t id) const final; ++ void set_group(idx_t group_id); ++ ~IDGrouperBitmap() override {} +}; + -+struct DefaultCollector : ResultCollector { -+ void collect( -+ int k, -+ int& nres, -+ float* bh_val, -+ idx_t* bh_ids, -+ float val, -+ idx_t ids) override { -+ if (nres < k) { -+ faiss::maxheap_push(++nres, bh_val, bh_ids, val, ids); -+ } else if (val < bh_val[0]) { -+ faiss::maxheap_replace_top(nres, bh_val, bh_ids, val, ids); ++} // namespace faiss +diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h +index 270de8dc..2f7f3e7f 100644 +--- a/faiss/impl/ResultHandler.h ++++ b/faiss/impl/ResultHandler.h +@@ -12,6 +12,8 @@ + #pragma once + + #include ++#include ++#include + #include + #include + +@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler { + } + }; + ++/***************************************************************** ++ * Heap based result handler with grouping ++ *****************************************************************/ ++ ++template ++struct GroupedHeapBlockResultHandler : BlockResultHandler { ++ using T = typename C::T; ++ using TI = typename C::TI; ++ using BlockResultHandler::i0; ++ using BlockResultHandler::i1; ++ ++ T* heap_dis_tab; ++ TI* heap_ids_tab; ++ int64_t k; // number of results to keep ++ ++ IDGrouper* id_grouper; ++ TI* heap_group_ids_tab; ++ std::unordered_map* group_id_to_index_in_heap_tab; ++ ++ GroupedHeapBlockResultHandler( ++ size_t nq, ++ T* heap_dis_tab, ++ TI* heap_ids_tab, ++ size_t k, ++ IDGrouper* id_grouper) ++ : BlockResultHandler(nq), ++ heap_dis_tab(heap_dis_tab), ++ heap_ids_tab(heap_ids_tab), ++ k(k), ++ id_grouper(id_grouper) {} ++ ++ /****************************************************** ++ * API for 1 result at a time (each SingleResultHandler is ++ * called from 1 thread) ++ */ ++ ++ struct SingleResultHandler : ResultHandler { ++ GroupedHeapBlockResultHandler& hr; ++ using ResultHandler::threshold; ++ size_t k; ++ ++ T* heap_dis; ++ TI* heap_ids; ++ TI* heap_group_ids; ++ std::unordered_map group_id_to_index_in_heap; ++ ++ explicit SingleResultHandler(GroupedHeapBlockResultHandler& hr) ++ : hr(hr), k(hr.k) {} ++ ++ /// begin results for query # i ++ void begin(size_t i) { ++ heap_dis = hr.heap_dis_tab + i * k; ++ heap_ids = hr.heap_ids_tab + i * k; ++ heap_heapify(k, heap_dis, heap_ids); ++ threshold = heap_dis[0]; ++ heap_group_ids = new TI[hr.k]; ++ for (size_t i = 0; i < hr.k; i++) { ++ heap_group_ids[i] = -1; ++ } + } ++ ++ /// add one result for query i ++ bool add_result(T dis, TI idx) final { ++ if (!C::cmp(threshold, dis)) { ++ return false; ++ } ++ ++ idx_t group_id = hr.id_grouper->get_group(idx); ++ typename std::unordered_map::const_iterator it_pos = ++ group_id_to_index_in_heap.find(group_id); ++ if (it_pos == group_id_to_index_in_heap.end()) { ++ group_heap_replace_top( ++ k, ++ heap_dis, ++ heap_ids, ++ heap_group_ids, ++ dis, ++ idx, ++ group_id, ++ &group_id_to_index_in_heap); ++ return true; ++ } else { ++ size_t pos = it_pos->second; ++ if (!C::cmp(heap_dis[pos], dis)) { ++ return false; ++ } ++ group_heap_replace_at( ++ pos, ++ k, ++ heap_dis, ++ heap_ids, ++ heap_group_ids, ++ dis, ++ idx, ++ group_id, ++ &group_id_to_index_in_heap); ++ return true; ++ } ++ } ++ ++ /// series of results for query i is done ++ void end() { ++ heap_reorder(k, heap_dis, heap_ids); ++ delete heap_group_ids; ++ } ++ }; ++ ++ /****************************************************** ++ * API for multiple results (called from 1 thread) ++ */ ++ ++ /// begin ++ void begin_multiple(size_t i0_2, size_t i1_2) final { ++ this->i0 = i0_2; ++ this->i1 = i1_2; ++ for (size_t i = i0; i < i1; i++) { ++ heap_heapify(k, heap_dis_tab + i * k, heap_ids_tab + i * k); ++ } ++ size_t size = (i1 - i0) * k; ++ heap_group_ids_tab = new TI[size]; ++ for (size_t i = 0; i < size; i++) { ++ heap_group_ids_tab[i] = -1; ++ } ++ group_id_to_index_in_heap_tab = ++ new std::unordered_map[i1 - i0]; + } + -+ // This method is called once all result is collected so that final post -+ // processing can be done For example, if the result is collected using -+ // group id, the group id can be converted back to its original id inside -+ // this method -+ void post_process(idx_t nres, idx_t* bh_ids) override { -+ // Do nothing ++ /// add results for query i0..i1 and j0..j1 ++ void add_results(size_t j0, size_t j1, const T* dis_tab) final { ++#pragma omp parallel for ++ for (int64_t i = i0; i < i1; i++) { ++ T* heap_dis = heap_dis_tab + i * k; ++ TI* heap_ids = heap_ids_tab + i * k; ++ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; ++ T thresh = heap_dis[0]; // NOLINT(*-use-default-none) ++ for (size_t j = j0; j < j1; j++) { ++ T dis = dis_tab_i[j]; ++ if (C::cmp(thresh, dis)) { ++ idx_t group_id = id_grouper->get_group(j); ++ typename std::unordered_map::const_iterator ++ it_pos = group_id_to_index_in_heap_tab[i - i0].find( ++ group_id); ++ if (it_pos == group_id_to_index_in_heap_tab[i - i0].end()) { ++ group_heap_replace_top( ++ k, ++ heap_dis, ++ heap_ids, ++ heap_group_ids_tab + ((i - i0) * k), ++ dis, ++ j, ++ group_id, ++ &group_id_to_index_in_heap_tab[i - i0]); ++ thresh = heap_dis[0]; ++ } else { ++ size_t pos = it_pos->first; ++ if (C::cmp(heap_dis[pos], dis)) { ++ group_heap_replace_at( ++ pos, ++ k, ++ heap_dis, ++ heap_ids, ++ heap_group_ids_tab + ((i - i0) * k), ++ dis, ++ j, ++ group_id, ++ &group_id_to_index_in_heap_tab[i - i0]); ++ thresh = heap_dis[0]; ++ } ++ } ++ } ++ } ++ } + } + -+ ~DefaultCollector() override {} ++ /// series of results for queries i0..i1 is done ++ void end_multiple() final { ++ // maybe parallel for ++ for (size_t i = i0; i < i1; i++) { ++ heap_reorder(k, heap_dis_tab + i * k, heap_ids_tab + i * k); ++ } ++ delete group_id_to_index_in_heap_tab; ++ delete heap_group_ids_tab; ++ } +}; + -+} // namespace faiss -diff --git a/faiss/impl/ResultCollectorFactory.h b/faiss/impl/ResultCollectorFactory.h + /***************************************************************** + * Reservoir result handler + * +diff --git a/faiss/utils/GroupHeap.h b/faiss/utils/GroupHeap.h new file mode 100644 -index 00000000..b460b20b +index 00000000..3b7078da --- /dev/null -+++ b/faiss/impl/ResultCollectorFactory.h -@@ -0,0 +1,33 @@ ++++ b/faiss/utils/GroupHeap.h +@@ -0,0 +1,182 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * @@ -271,31 +558,493 @@ index 00000000..b460b20b + */ + +#pragma once -+#include ++ ++#include ++#include ++#include ++ ++#include ++#include ++#include ++ ++#include ++#include ++ ++#include ++#include ++ +namespace faiss { + -+/** ResultCollectorFactory to create a ResultCollector object */ -+struct ResultCollectorFactory { -+ DefaultCollector default_collector; -+ const std::vector* id_map = nullptr; ++/** ++ * From start_index, it compare its value with parent node's and swap if needed. ++ * Continue until either there is no swap or it reaches the top node. ++ */ ++template ++static inline void group_up_heap( ++ typename C::T* heap_dis, ++ typename C::TI* heap_ids, ++ typename C::TI* heap_group_ids, ++ std::unordered_map* group_id_to_index_in_heap, ++ size_t start_index) { ++ heap_dis--; /* Use 1-based indexing for easier node->child translation */ ++ heap_ids--; ++ heap_group_ids--; ++ size_t i = start_index + 1, i_father; ++ typename C::T target_dis = heap_dis[i]; ++ typename C::TI target_id = heap_ids[i]; ++ typename C::TI target_group_id = heap_group_ids[i]; + -+ // Create a new ResultCollector object -+ virtual ResultCollector* new_collector() { -+ return &default_collector; ++ while (i > 1) { ++ i_father = i >> 1; ++ if (!C::cmp2( ++ target_dis, ++ heap_dis[i_father], ++ target_id, ++ heap_ids[i_father])) { ++ /* the heap structure is ok */ ++ break; ++ } ++ heap_dis[i] = heap_dis[i_father]; ++ heap_ids[i] = heap_ids[i_father]; ++ heap_group_ids[i] = heap_group_ids[i_father]; ++ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; ++ i = i_father; + } ++ heap_dis[i] = target_dis; ++ heap_ids[i] = target_id; ++ heap_group_ids[i] = target_group_id; ++ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; ++} + -+ // For default case, the factory share single object and no need to delete -+ // the object. For other case, the factory can create a new object which -+ // need to be deleted later. We have deleteCollector method to handle both -+ // case as factory class knows how to release resource that it created -+ virtual void delete_collector(ResultCollector* collector) { -+ // Do nothing ++/** ++ * From start_index, it compare its value with child node's and swap if needed. ++ * Continue until either there is no swap or it reaches the leaf node. ++ */ ++template ++static inline void group_down_heap( ++ size_t k, ++ typename C::T* heap_dis, ++ typename C::TI* heap_ids, ++ typename C::TI* heap_group_ids, ++ std::unordered_map* group_id_to_index_in_heap, ++ size_t start_index) { ++ heap_dis--; /* Use 1-based indexing for easier node->child translation */ ++ heap_ids--; ++ heap_group_ids--; ++ size_t i = start_index + 1, i1, i2; ++ typename C::T target_dis = heap_dis[i]; ++ typename C::TI target_id = heap_ids[i]; ++ typename C::TI target_group_id = heap_group_ids[i]; ++ ++ while (1) { ++ i1 = i << 1; ++ i2 = i1 + 1; ++ if (i1 > k) { ++ break; ++ } ++ ++ // Note that C::cmp2() is a bool function answering ++ // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max ++ // heap and same with the `<` sign for min heap. ++ if ((i2 == k + 1) || ++ C::cmp2(heap_dis[i1], heap_dis[i2], heap_ids[i1], heap_ids[i2])) { ++ if (C::cmp2(target_dis, heap_dis[i1], target_id, heap_ids[i1])) { ++ break; ++ } ++ heap_dis[i] = heap_dis[i1]; ++ heap_ids[i] = heap_ids[i1]; ++ heap_group_ids[i] = heap_group_ids[i1]; ++ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; ++ i = i1; ++ } else { ++ if (C::cmp2(target_dis, heap_dis[i2], target_id, heap_ids[i2])) { ++ break; ++ } ++ heap_dis[i] = heap_dis[i2]; ++ heap_ids[i] = heap_ids[i2]; ++ heap_group_ids[i] = heap_group_ids[i2]; ++ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; ++ i = i2; ++ } + } ++ heap_dis[i] = target_dis; ++ heap_ids[i] = target_id; ++ heap_group_ids[i] = target_group_id; ++ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; ++} + -+ virtual ~ResultCollectorFactory() {} -+}; ++template ++static inline void group_heap_replace_top( ++ size_t k, ++ typename C::T* heap_dis, ++ typename C::TI* heap_ids, ++ typename C::TI* heap_group_ids, ++ typename C::T dis, ++ typename C::TI id, ++ typename C::TI group_id, ++ std::unordered_map* group_id_to_index_in_heap) { ++ assert(group_id_to_index_in_heap->find(group_id) == ++ group_id_to_index_in_heap->end() && ++ "group id should not exist in the binary heap"); ++ ++ group_id_to_index_in_heap->erase(heap_group_ids[0]); ++ heap_group_ids[0] = group_id; ++ heap_dis[0] = dis; ++ heap_ids[0] = id; ++ (*group_id_to_index_in_heap)[group_id] = 0; ++ group_down_heap( ++ k, ++ heap_dis, ++ heap_ids, ++ heap_group_ids, ++ group_id_to_index_in_heap, ++ 0); ++} ++ ++template ++static inline void group_heap_replace_at( ++ size_t pos, ++ size_t k, ++ typename C::T* heap_dis, ++ typename C::TI* heap_ids, ++ typename C::TI* heap_group_ids, ++ typename C::T dis, ++ typename C::TI id, ++ typename C::TI group_id, ++ std::unordered_map* group_id_to_index_in_heap) { ++ assert(group_id_to_index_in_heap->find(group_id) != ++ group_id_to_index_in_heap->end() && ++ "group id should exist in the binary heap"); ++ assert(group_id_to_index_in_heap->find(group_id)->second == pos && ++ "index of group id in the heap should be same as pos"); ++ ++ heap_dis[pos] = dis; ++ heap_ids[pos] = id; ++ group_up_heap( ++ heap_dis, heap_ids, heap_group_ids, group_id_to_index_in_heap, pos); ++ group_down_heap( ++ k, ++ heap_dis, ++ heap_ids, ++ heap_group_ids, ++ group_id_to_index_in_heap, ++ pos); ++} + +} // namespace faiss +\ No newline at end of file +diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt +index cc0a4f4c..96e19328 100644 +--- a/tests/CMakeLists.txt ++++ b/tests/CMakeLists.txt +@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC + test_approx_topk.cpp + test_RCQ_cropping.cpp + test_distances_simd.cpp ++ test_id_grouper.cpp ++ test_group_heap.cpp + test_heap.cpp + test_code_distance.cpp + test_hnsw.cpp +diff --git a/tests/test_group_heap.cpp b/tests/test_group_heap.cpp +new file mode 100644 +index 00000000..0e8fe7a7 +--- /dev/null ++++ b/tests/test_group_heap.cpp +@@ -0,0 +1,98 @@ ++/** ++ * Copyright (c) Facebook, Inc. and its affiliates. ++ * ++ * This source code is licensed under the MIT license found in the ++ * LICENSE file in the root directory of this source tree. ++ */ ++#include ++#include ++#include ++#include ++ ++using namespace faiss; ++ ++TEST(GroupHeap, group_heap_replace_top) { ++ using C = CMax; ++ const int k = 100; ++ float binary_heap_values[k]; ++ int64_t binary_heap_ids[k]; ++ heap_heapify(k, binary_heap_values, binary_heap_ids); ++ int64_t binary_heap_group_ids[k]; ++ for (size_t i = 0; i < k; i++) { ++ binary_heap_group_ids[i] = -1; ++ } ++ std::unordered_map group_id_to_index_in_heap; ++ for (int i = 1000; i > 0; i--) { ++ group_heap_replace_top( ++ k, ++ binary_heap_values, ++ binary_heap_ids, ++ binary_heap_group_ids, ++ i * 10.0, ++ i, ++ i, ++ &group_id_to_index_in_heap); ++ } ++ ++ heap_reorder(k, binary_heap_values, binary_heap_ids); ++ ++ for (int i = 0; i < k; i++) { ++ ASSERT_EQ((i + 1) * 10.0, binary_heap_values[i]); ++ ASSERT_EQ(i + 1, binary_heap_ids[i]); ++ } ++} ++ ++TEST(GroupHeap, group_heap_replace_at) { ++ using C = CMax; ++ const int k = 10; ++ float binary_heap_values[k]; ++ int64_t binary_heap_ids[k]; ++ heap_heapify(k, binary_heap_values, binary_heap_ids); ++ int64_t binary_heap_group_ids[k]; ++ for (size_t i = 0; i < k; i++) { ++ binary_heap_group_ids[i] = -1; ++ } ++ std::unordered_map group_id_to_index_in_heap; ++ ++ std::unordered_map group_id_to_id; ++ for (int i = 1000; i > 0; i--) { ++ int64_t group_id = rand() % 100; ++ group_id_to_id[group_id] = i; ++ if (group_id_to_index_in_heap.find(group_id) == ++ group_id_to_index_in_heap.end()) { ++ group_heap_replace_top( ++ k, ++ binary_heap_values, ++ binary_heap_ids, ++ binary_heap_group_ids, ++ i * 10.0, ++ i, ++ group_id, ++ &group_id_to_index_in_heap); ++ } else { ++ group_heap_replace_at( ++ group_id_to_index_in_heap.at(group_id), ++ k, ++ binary_heap_values, ++ binary_heap_ids, ++ binary_heap_group_ids, ++ i * 10.0, ++ i, ++ group_id, ++ &group_id_to_index_in_heap); ++ } ++ } ++ ++ heap_reorder(k, binary_heap_values, binary_heap_ids); ++ ++ std::vector sorted_ids; ++ for (const auto& pair : group_id_to_id) { ++ sorted_ids.push_back(pair.second); ++ } ++ std::sort(sorted_ids.begin(), sorted_ids.end()); ++ ++ for (int i = 0; i < k && binary_heap_ids[i] != -1; i++) { ++ ASSERT_EQ(sorted_ids[i] * 10.0, binary_heap_values[i]); ++ ASSERT_EQ(sorted_ids[i], binary_heap_ids[i]); ++ } ++} +diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp +new file mode 100644 +index 00000000..2aed5500 +--- /dev/null ++++ b/tests/test_id_grouper.cpp +@@ -0,0 +1,189 @@ ++/** ++ * Copyright (c) Facebook, Inc. and its affiliates. ++ * ++ * This source code is licensed under the MIT license found in the ++ * LICENSE file in the root directory of this source tree. ++ */ ++#include ++#include ++#include ++#include ++ ++#include ++#include ++#include ++#include ++#include ++ ++// 64-bit int ++using idx_t = faiss::idx_t; ++ ++using namespace faiss; ++ ++TEST(IdGrouper, get_group) { ++ uint64_t ids1[1] = {0b1000100010001000}; ++ IDGrouperBitmap bitmap(1, ids1); ++ ++ ASSERT_EQ(3, bitmap.get_group(0)); ++ ASSERT_EQ(3, bitmap.get_group(1)); ++ ASSERT_EQ(3, bitmap.get_group(2)); ++ ASSERT_EQ(3, bitmap.get_group(3)); ++ ASSERT_EQ(7, bitmap.get_group(4)); ++ ASSERT_EQ(7, bitmap.get_group(5)); ++ ASSERT_EQ(7, bitmap.get_group(6)); ++ ASSERT_EQ(7, bitmap.get_group(7)); ++ ASSERT_EQ(11, bitmap.get_group(8)); ++ ASSERT_EQ(11, bitmap.get_group(9)); ++ ASSERT_EQ(11, bitmap.get_group(10)); ++ ASSERT_EQ(11, bitmap.get_group(11)); ++ ASSERT_EQ(15, bitmap.get_group(12)); ++ ASSERT_EQ(15, bitmap.get_group(13)); ++ ASSERT_EQ(15, bitmap.get_group(14)); ++ ASSERT_EQ(15, bitmap.get_group(15)); ++ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(16)); ++} ++ ++TEST(IdGrouper, set_group) { ++ idx_t group_ids[] = {64, 127, 128, 1022}; ++ uint64_t ids[16] = {}; // 1023 / 64 + 1 ++ IDGrouperBitmap bitmap(16, ids); ++ ++ for (int i = 0; i < 4; i++) { ++ bitmap.set_group(group_ids[i]); ++ } ++ ++ int group_id_index = 0; ++ for (int i = 0; i <= group_ids[3]; i++) { ++ ASSERT_EQ(group_ids[group_id_index], bitmap.get_group(i)); ++ if (group_ids[group_id_index] == i) { ++ group_id_index++; ++ } ++ } ++ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1)); ++} ++ ++TEST(IdGrouper, bitmap_with_hnsw) { ++ int d = 1; // dimension ++ int nb = 10; // database size ++ ++ std::mt19937 rng; ++ std::uniform_real_distribution<> distrib; ++ ++ float* xb = new float[d * nb]; ++ ++ for (int i = 0; i < nb; i++) { ++ for (int j = 0; j < d; j++) ++ xb[d * i + j] = distrib(rng); ++ xb[d * i] += i / 1000.; ++ } ++ ++ uint64_t bitmap[1] = {}; ++ faiss::IDGrouperBitmap id_grouper(1, bitmap); ++ for (int i = 0; i < nb; i++) { ++ if (i % 2 == 1) { ++ id_grouper.set_group(i); ++ } ++ } ++ ++ int k = 10; ++ int m = 8; ++ faiss::Index* index = ++ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); ++ index->add(nb, xb); // add vectors to the index ++ ++ // search ++ idx_t* I = new idx_t[k]; ++ float* D = new float[k]; ++ ++ auto pSearchParameters = new faiss::SearchParametersHNSW(); ++ pSearchParameters->grp = &id_grouper; ++ ++ index->search(1, xb, k, D, I, pSearchParameters); ++ ++ std::unordered_set group_ids; ++ ASSERT_EQ(0, I[0]); ++ ASSERT_EQ(0, D[0]); ++ group_ids.insert(id_grouper.get_group(I[0])); ++ for (int j = 1; j < 5; j++) { ++ ASSERT_NE(-1, I[j]); ++ ASSERT_NE(std::numeric_limits::max(), D[j]); ++ group_ids.insert(id_grouper.get_group(I[j])); ++ } ++ for (int j = 5; j < k; j++) { ++ ASSERT_EQ(-1, I[j]); ++ ASSERT_EQ(std::numeric_limits::max(), D[j]); ++ } ++ ASSERT_EQ(5, group_ids.size()); ++ ++ delete[] I; ++ delete[] D; ++ delete[] xb; ++} ++ ++TEST(IdGrouper, bitmap_with_hnswn_idmap) { ++ int d = 1; // dimension ++ int nb = 10; // database size ++ ++ std::mt19937 rng; ++ std::uniform_real_distribution<> distrib; ++ ++ float* xb = new float[d * nb]; ++ idx_t* xids = new idx_t[d * nb]; ++ ++ for (int i = 0; i < nb; i++) { ++ for (int j = 0; j < d; j++) ++ xb[d * i + j] = distrib(rng); ++ xb[d * i] += i / 1000.; ++ } ++ ++ uint64_t bitmap[1] = {}; ++ faiss::IDGrouperBitmap id_grouper(1, bitmap); ++ int num_grp = 0; ++ int grp_size = 2; ++ int id_in_grp = 0; ++ for (int i = 0; i < nb; i++) { ++ xids[i] = i + num_grp; ++ id_in_grp++; ++ if (id_in_grp == grp_size) { ++ id_grouper.set_group(i + num_grp + 1); ++ num_grp++; ++ id_in_grp = 0; ++ } ++ } ++ ++ int k = 10; ++ int m = 8; ++ faiss::Index* index = ++ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); ++ faiss::IndexIDMap id_map = ++ faiss::IndexIDMap(index); // add vectors to the index ++ id_map.add_with_ids(nb, xb, xids); ++ ++ // search ++ idx_t* I = new idx_t[k]; ++ float* D = new float[k]; ++ ++ auto pSearchParameters = new faiss::SearchParametersHNSW(); ++ pSearchParameters->grp = &id_grouper; ++ ++ id_map.search(1, xb, k, D, I, pSearchParameters); ++ ++ std::unordered_set group_ids; ++ ASSERT_EQ(0, I[0]); ++ ASSERT_EQ(0, D[0]); ++ group_ids.insert(id_grouper.get_group(I[0])); ++ for (int j = 1; j < 5; j++) { ++ ASSERT_NE(-1, I[j]); ++ ASSERT_NE(std::numeric_limits::max(), D[j]); ++ group_ids.insert(id_grouper.get_group(I[j])); ++ } ++ for (int j = 5; j < k; j++) { ++ ASSERT_EQ(-1, I[j]); ++ ASSERT_EQ(std::numeric_limits::max(), D[j]); ++ } ++ ASSERT_EQ(5, group_ids.size()); ++ ++ delete[] I; ++ delete[] D; ++ delete[] xb; ++} -- 2.39.3 (Apple Git-145) diff --git a/jni/src/faiss_util.cpp b/jni/src/faiss_util.cpp new file mode 100644 index 000000000..c2abe7f26 --- /dev/null +++ b/jni/src/faiss_util.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +#include "faiss_util.h" +#include + +std::unique_ptr faiss_util::buildIDGrouperBitmap(int *parentIdsArray, int parentIdsLength, std::vector* bitmap) { + const int* maxValue = std::max_element(parentIdsArray, parentIdsArray + parentIdsLength); + int num_bits = *maxValue + 1; + int num_blocks = (num_bits >> 6) + 1; // div by 64 + bitmap->resize(num_blocks, 0); + std::unique_ptr idGrouper(new faiss::IDGrouperBitmap(num_blocks, bitmap->data())); + for (int i = 0; i < parentIdsLength; i++) { + idGrouper->set_group(parentIdsArray[i]); + } + return idGrouper; +} diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 4609f3144..e88254b86 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -11,7 +11,7 @@ #include "jni_util.h" #include "faiss_wrapper.h" -#include "knn_extension/faiss/MultiVectorResultCollectorFactory.h" +#include "faiss_util.h" #include "faiss/impl/io.h" #include "faiss/index_factory.h" @@ -51,9 +51,7 @@ void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, fa // Concerts the FilterIds to BitMap void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector); -os_faiss::MultiVectorResultCollectorFactory* buildResultCollectorFactory(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ); - -void releaseResultCollectorFactory(os_faiss::MultiVectorResultCollectorFactory* collectorFactory); +std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap); void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) { @@ -253,13 +251,18 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter faiss::SearchParameters *searchParameters; faiss::SearchParametersHNSW hnswParams; faiss::SearchParametersIVF ivfParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); if(hnswReader) { // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default // value of ef_search = 16 which will then be used. hnswParams.efSearch = hnswReader->hnsw.efSearch; hnswParams.sel = idSelector.get(); - hnswParams.col = buildResultCollectorFactory(jniUtil, env, parentIdsJ); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } searchParameters = &hnswParams; } else { auto ivfReader = dynamic_cast(indexReader->index); @@ -274,30 +277,29 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter } catch (...) { jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); - releaseResultCollectorFactory(dynamic_cast(hnswParams.col)); throw; } jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); - releaseResultCollectorFactory(dynamic_cast(hnswParams.col)); } else { faiss::SearchParameters *searchParameters = nullptr; faiss::SearchParametersHNSW hnswParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); if(hnswReader!= nullptr && parentIdsJ != nullptr) { // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default // value of ef_search = 16 which will then be used. hnswParams.efSearch = hnswReader->hnsw.efSearch; - hnswParams.col = buildResultCollectorFactory(jniUtil, env, parentIdsJ); + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); searchParameters = &hnswParams; } try { indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters); } catch (...) { jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); - releaseResultCollectorFactory(dynamic_cast(hnswParams.col)); throw; } - releaseResultCollectorFactory(dynamic_cast(hnswParams.col)); } jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); @@ -509,21 +511,10 @@ void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bi } } -os_faiss::MultiVectorResultCollectorFactory* buildResultCollectorFactory(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ) { - if (parentIdsJ == nullptr) { - return nullptr; - } +std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); - auto* parent_id_filter = new FixedBitSet(parentIdsArray, parentIdsLength); + std::unique_ptr idGrouper = faiss_util::buildIDGrouperBitmap(parentIdsArray, parentIdsLength, bitmap); jniUtil->ReleaseIntArrayElements(env, parentIdsJ, parentIdsArray, JNI_ABORT); - return new os_faiss::MultiVectorResultCollectorFactory(parent_id_filter); -} - -void releaseResultCollectorFactory(os_faiss::MultiVectorResultCollectorFactory* collectorFactory) { - if (collectorFactory == nullptr) { - return; - } - delete collectorFactory->parent_bit_set; - delete collectorFactory; + return idGrouper; } diff --git a/jni/src/knn_extension/faiss/MultiVectorResultCollector.cpp b/jni/src/knn_extension/faiss/MultiVectorResultCollector.cpp deleted file mode 100644 index a7564d3aa..000000000 --- a/jni/src/knn_extension/faiss/MultiVectorResultCollector.cpp +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "MultiVectorResultCollector.h" -#include "knn_extension/faiss/utils/Heap.h" -#include "knn_extension/faiss/utils/BitSet.h" - -namespace os_faiss { - -using idx_t = faiss::idx_t; - -MultiVectorResultCollector::MultiVectorResultCollector(const BitSet* parent_bit_set, const std::vector* id_map) -: parent_bit_set(parent_bit_set), id_map(id_map) {} - -void MultiVectorResultCollector::collect( - int k, - int& nres, - float* bh_val, - int64_t* bh_ids, - float val, - int64_t ids) { - idx_t group_id = id_map ? parent_bit_set->next_set_bit(id_map->at(ids)) : parent_bit_set->next_set_bit(ids); - if (parent_id_to_index.find(group_id) == - parent_id_to_index.end()) { - if (nres < k) { - maxheap_push( - nres++, - bh_val, - bh_ids, - val, - ids, - &parent_id_to_id, - &parent_id_to_index, - group_id); - } else if (val < bh_val[0]) { - maxheap_replace_top( - nres, - bh_val, - bh_ids, - val, - ids, - &parent_id_to_id, - &parent_id_to_index, - group_id); - } - } else if (val < bh_val[parent_id_to_index.at(group_id)]) { - maxheap_update( - nres, - bh_val, - bh_ids, - val, - ids, - &parent_id_to_id, - &parent_id_to_index, - group_id); - } -} - -void MultiVectorResultCollector::post_process(int64_t nres, int64_t* bh_ids) { - for (size_t icnt = 0; icnt < nres; icnt++) { - bh_ids[icnt] = parent_id_to_id.at(bh_ids[icnt]); - } -} - -} // namespace os_faiss diff --git a/jni/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp b/jni/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp deleted file mode 100644 index f4c7c0656..000000000 --- a/jni/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "MultiVectorResultCollectorFactory.h" -#include "MultiVectorResultCollector.h" - -namespace os_faiss { - -MultiVectorResultCollectorFactory::MultiVectorResultCollectorFactory(BitSet* parent_bit_set) - : parent_bit_set(parent_bit_set) {} - -// id_map is set in IndexIDMap.cpp of faiss library with custom patch -// https://github.com/opensearch-project/k-NN/blob/feature/multi-vector/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch#L109 -faiss::ResultCollector* MultiVectorResultCollectorFactory::new_collector() { - return new MultiVectorResultCollector(parent_bit_set, id_map); -} - -void MultiVectorResultCollectorFactory::delete_collector(faiss::ResultCollector* resultCollector) { - delete resultCollector; -} - -} // namespace os_faiss diff --git a/jni/src/knn_extension/faiss/utils/BitSet.cpp b/jni/src/knn_extension/faiss/utils/BitSet.cpp deleted file mode 100644 index 33e9470e0..000000000 --- a/jni/src/knn_extension/faiss/utils/BitSet.cpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#include -#include -#include "BitSet.h" - -FixedBitSet::FixedBitSet(const int* int_array, const int length){ - assert(int_array && "int_array should not be null"); - const int* maxValue = std::max_element(int_array, int_array + length); - this->num_bits = *maxValue + 1; - this->num_words = (num_bits >> 6) + 1; // div by 64 - this->words = new uint64_t[this->num_words](); - for(int i = 0 ; i < length ; i ++) { - int value = int_array[i]; - int bitset_array_index = value >> 6; - this->words[bitset_array_index] |= 1ULL << (value & 63); // Equivalent of 1ULL << (value % 64) - } -} - -idx_t FixedBitSet::next_set_bit(idx_t index) const { - assert(index >= 0 && "index shouldn't be less than zero"); - assert(index < this->num_bits && "index should be less than total number of bits"); - - idx_t i = index >> 6; // div by 64 - uint64_t word = this->words[i] >> (index & 63); // Equivalent of words[i] >> (index % 64) - // word is non zero after right shift, it means, next set bit is in current word - // The index of set bit is "given index" + "trailing zero in the right shifted word" - if (word != 0) { - return index + __builtin_ctzll(word); - } - - while (++i < this->num_words) { - word = this->words[i]; - if (word != 0) { - return (i << 6) + __builtin_ctzll(word); - } - } - - return NO_MORE_DOCS; -} - -FixedBitSet::~FixedBitSet() { - delete this->words; -} diff --git a/jni/tests/faiss_util_test.cpp b/jni/tests/faiss_util_test.cpp new file mode 100644 index 000000000..d8b45d951 --- /dev/null +++ b/jni/tests/faiss_util_test.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +#include "faiss_util.h" + +#include + +#include "gtest/gtest.h" + +TEST(IDGrouperBitMapTest, BasicAssertions) { + int ids[] = {128, 1024}; + size_t length = sizeof(ids) / sizeof(ids[0]); + std::vector bitmap; + std::unique_ptr idGrouperBitmap = faiss_util::buildIDGrouperBitmap(ids, length, &bitmap); + int groupIndex = 0; + for (int i = 0; i <= ids[length - 1]; i++) { + if (i > ids[groupIndex]) { + groupIndex++; + } + ASSERT_EQ(ids[groupIndex], idGrouperBitmap->get_group(i)); + } +} diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 5afe09c22..58daaee2b 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -205,7 +205,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) { // Create the index std::unique_ptr createdIndex( - test_util::FaissCreateIndex(2, method, metricType)); + test_util::FaissCreateIndex(dim, method, metricType)); auto createdIndexWithData = test_util::FaissAddData(createdIndex.get(), ids, vectors); @@ -262,7 +262,7 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { // Create the index std::unique_ptr createdIndex( - test_util::FaissCreateIndex(2, method, metricType)); + test_util::FaissCreateIndex(dim, method, metricType)); auto createdIndexWithData = test_util::FaissAddData(createdIndex.get(), ids, vectors); @@ -335,7 +335,7 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) { // Create the index std::unique_ptr createdIndex( - test_util::FaissCreateIndex(2, method, metricType)); + test_util::FaissCreateIndex(dim, method, metricType)); auto createdIndexWithData = test_util::FaissAddData(createdIndex.get(), ids, vectors); @@ -474,7 +474,7 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Assert that Index is of type IndexHNSWSQ ASSERT_NE(indexIDMap, nullptr); ASSERT_NE(dynamic_cast(indexIDMap->index), nullptr); - + // Clean up std::remove(indexPath.c_str()); } diff --git a/jni/tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp b/jni/tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp deleted file mode 100644 index 3177360bb..000000000 --- a/jni/tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "knn_extension/faiss/MultiVectorResultCollectorFactory.h" -#include "knn_extension/faiss/MultiVectorResultCollector.h" - -#include - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "jni_util.h" - -using ::testing::NiceMock; -using ::testing::Return; -using idx_t = faiss::idx_t; - - -TEST(MultiVectorResultCollectorFactoryTest, BasicAssertions) { - int parent_ids[1] = {1}; - FixedBitSet parent_id_filter(parent_ids, 1); - - std::unordered_map distance1; - distance1[0] = 10; - distance1[1] = 11; - - std::unordered_map distance2; - distance2[0] = 11; - distance2[1] = 10; - - os_faiss::MultiVectorResultCollectorFactory* rc_factory = new os_faiss::MultiVectorResultCollectorFactory(&parent_id_filter); - faiss::ResultCollector* rc1 = rc_factory->new_collector(); - faiss::ResultCollector* rc2 = rc_factory->new_collector(); - ASSERT_NE(rc1, rc2); - - int k = 1; - int nres1 = 0; - int nres2 = 0; - float* bh_val = new float[k * 2]; - int64_t* bh_ids = new int64_t[k * 2]; - // Verify two collector are thread safe each other. - // Simulate multi thread by interleaving collect methods of two ResultCollectors. - for (int i = 0; i < distance1.size(); i++) { - rc1->collect(k, nres1, bh_val, bh_ids, distance1.at(i), i); - rc2->collect(k, nres2, bh_val + k, bh_ids + k, distance2.at(i), i); - } - rc1->post_process(nres1, bh_ids); - rc2->post_process(nres2, bh_ids + k); - - ASSERT_EQ(0, bh_ids[0]); - ASSERT_EQ(1, bh_ids[1]); - - rc_factory->delete_collector(rc1); - rc_factory->delete_collector(rc2); - delete rc_factory; - delete[] bh_val; - delete[] bh_ids; -} - -// Verify that id_map is passed to collector -TEST(MultiVectorResultCollectorFactoryWithIdMapTest, BasicAssertions) { - int parent_ids[1] = {1}; - FixedBitSet parent_id_filter(parent_ids, 1); - std::vector id_map; - - os_faiss::MultiVectorResultCollectorFactory* rc_factory = new os_faiss::MultiVectorResultCollectorFactory(&parent_id_filter); - os_faiss::MultiVectorResultCollector* rc1 = dynamic_cast(rc_factory->new_collector()); - ASSERT_EQ(nullptr, rc1->id_map); - - rc_factory->id_map = &id_map; - os_faiss::MultiVectorResultCollector* rc2 = dynamic_cast(rc_factory->new_collector()); - ASSERT_EQ(&id_map, rc2->id_map); - - rc_factory->delete_collector(rc1); - rc_factory->delete_collector(rc2); - delete rc_factory; -} diff --git a/jni/tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp b/jni/tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp deleted file mode 100644 index 934d708dd..000000000 --- a/jni/tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "knn_extension/faiss/MultiVectorResultCollector.h" - -#include - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using ::testing::NiceMock; -using ::testing::Return; -using idx_t = faiss::idx_t; - - -TEST(MultiVectorResultCollectorTest, BasicAssertions) { - // Data - // Parent ID: 2, ID: 0, Distance: 10 - // Parent ID: 2, ID: 1, Distance: 11 - // Parent ID: 5, ID: 3, Distance: 12 - // Parent ID: 5, ID: 4, Distance: 13 - // After collector handing the data with k = 3, it should return data with id 0 and 2, one from each group. - // Parent bit set representation: 100100 - int parent_ids[2] = {2, 5}; - FixedBitSet parent_id_filter(parent_ids, 2); - - idx_t ids[] = {0, 1, 2, 3}; - float distances[] = {10, 11, 12, 13}; - - os_faiss::MultiVectorResultCollector* rc = new os_faiss::MultiVectorResultCollector(&parent_id_filter, nullptr); - int k = 3; - int nres = 0; - float* bh_val = new float[k]; - int64_t* bh_ids = new int64_t[k]; - for (int i = 0; i < 4; i++) { - rc->collect(k, nres, bh_val, bh_ids, distances[i], ids[i]); - } - - // Parent ID is stored before finalize - ASSERT_EQ(5, bh_ids[0]); - ASSERT_EQ(2, bh_ids[1]); - - rc->post_process(nres, bh_ids); - - // Parent ID is converted to ID after finalize - ASSERT_EQ(3, bh_ids[0]); - ASSERT_EQ(0, bh_ids[1]); - - delete rc; - delete[] bh_val; - delete[] bh_ids; -} - -TEST(MultiVectorResultCollectorWithIDMapTest, BasicAssertions) { - // Data - // Parent ID: 2, Lucene ID: 0, Faiss ID: 0, Distance: 10 - // Parent ID: 2, Lucene ID: 1, Faiss ID: 1, Distance: 11 - // Parent ID: 5, Lucene ID: 3, Faiss ID: 2, Distance: 12 - // Parent ID: 5, Lucene ID: 4, Faiss ID: 3, Distance: 13 - // After collector handing the data with k = 3, it should return data with id 0 and 2, one from each group. - - // Parent bit set representation with Lucene ID: 100100 - int parent_ids[2] = {2, 5}; - FixedBitSet parent_id_filter(parent_ids, 2); - - idx_t faiss_ids[] = {0, 1, 2, 3}; - float distances[] = {10, 11, 12, 13}; - - // Faiss IDs to Lucene ID mapping - std::vector id_map = {0, 1, 3, 4}; - - os_faiss::MultiVectorResultCollector* rc = new os_faiss::MultiVectorResultCollector(&parent_id_filter, &id_map); - int k = 3; - int nres = 0; - float* bh_val = new float[k]; - int64_t* bh_ids = new int64_t[k]; - for (int i = 0; i < 4; i++) { - rc->collect(k, nres, bh_val, bh_ids, distances[i], faiss_ids[i]); - } - - // Parent ID is stored before finalize - ASSERT_EQ(5, bh_ids[0]); - ASSERT_EQ(2, bh_ids[1]); - - rc->post_process(nres, bh_ids); - - // Parent ID is converted to Faiss ID after finalize - ASSERT_EQ(2, bh_ids[0]); - ASSERT_EQ(0, bh_ids[1]); - - delete rc; - delete[] bh_val; - delete[] bh_ids; -} diff --git a/jni/tests/knn_extension/faiss/utils/BitSetTest.cpp b/jni/tests/knn_extension/faiss/utils/BitSetTest.cpp deleted file mode 100644 index 96ad6b3c2..000000000 --- a/jni/tests/knn_extension/faiss/utils/BitSetTest.cpp +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "knn_extension/faiss/utils/BitSet.h" - -#include - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using ::testing::NiceMock; -using ::testing::Return; -using idx_t = faiss::idx_t; - -TEST(FixedBitSetTest, BasicAssertions) { - int ids1[4] = {3, 7, 11, 15}; - FixedBitSet single_block(ids1, 4); - - ASSERT_EQ(3, single_block.next_set_bit(0)); - ASSERT_EQ(3, single_block.next_set_bit(1)); - ASSERT_EQ(3, single_block.next_set_bit(2)); - ASSERT_EQ(3, single_block.next_set_bit(3)); - ASSERT_EQ(7, single_block.next_set_bit(4)); - ASSERT_EQ(7, single_block.next_set_bit(5)); - ASSERT_EQ(7, single_block.next_set_bit(6)); - ASSERT_EQ(7, single_block.next_set_bit(7)); - ASSERT_EQ(11, single_block.next_set_bit(8)); - ASSERT_EQ(11, single_block.next_set_bit(9)); - ASSERT_EQ(11, single_block.next_set_bit(10)); - ASSERT_EQ(11, single_block.next_set_bit(11)); - ASSERT_EQ(15, single_block.next_set_bit(12)); - ASSERT_EQ(15, single_block.next_set_bit(13)); - ASSERT_EQ(15, single_block.next_set_bit(14)); - ASSERT_EQ(15, single_block.next_set_bit(15)); - ASSERT_EQ(single_block.NO_MORE_DOCS, single_block.next_set_bit(16)); - - int ids2[5] = {64, 128, 127, 1024, 34565}; - int ids2_sorted[5]; - std::copy(ids2, ids2 + 5, ids2_sorted); - std::sort(ids2_sorted, ids2_sorted + 5); - FixedBitSet multi_blocks(ids2, 5); - int parent_index = 0; - for (int i = 0; i < ids2[4] + 1; i++) { - ASSERT_EQ(ids2_sorted[parent_index], multi_blocks.next_set_bit(i)); - if (ids2_sorted[parent_index] == i) { - parent_index++; - } - } - ASSERT_EQ(multi_blocks.NO_MORE_DOCS, multi_blocks.next_set_bit(ids2[4] + 1)); -} diff --git a/jni/tests/knn_extension/faiss/utils/HeapTest.cpp b/jni/tests/knn_extension/faiss/utils/HeapTest.cpp deleted file mode 100644 index 97a30babd..000000000 --- a/jni/tests/knn_extension/faiss/utils/HeapTest.cpp +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "knn_extension/faiss/utils/Heap.h" -#include "faiss/utils/Heap.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using ::testing::NiceMock; -using ::testing::Return; -using ::testing::ElementsAreArray; - -TEST(MaxHeapUpdateTest, BasicAssertions) { - const int k = 5; - int nres = 0; - float binary_heap_values[k]; - int64_t binary_heap_ids[k]; - float input_values[] = {1.1f, 2.1f, 3.1f, 4.1f, 5.1f}; - int64_t input_ids[] = {1, 2, 3, 4, 5}; - int64_t group_ids[] = {11, 22, 33, 44, 55}; - std::unordered_map group_id_to_id; - std::unordered_map group_id_to_index; - - // Push - for (int i = 0; i < k; i++) { - os_faiss::maxheap_push( - nres++, - binary_heap_values, - binary_heap_ids, - input_values[i], - input_ids[i], - &group_id_to_id, - &group_id_to_index, - group_ids[i]); - } - - // Verify heap data - // The top node in the max heap should be the one with max value(5.1f) - ASSERT_EQ(5.1f, binary_heap_values[0]); - ASSERT_EQ(55, binary_heap_ids[0]); - ASSERT_EQ(5, group_id_to_id.at(binary_heap_ids[0])); - - // Replace top - os_faiss::maxheap_replace_top( - nres, - binary_heap_values, - binary_heap_ids, - 0.1f, - 6, - &group_id_to_id, - &group_id_to_index, - 66); - - // Verify heap data - // Previous top value(5.1f) should have been removed and the next max value(4.1f) should be in the top node. - ASSERT_EQ(4.1f, binary_heap_values[0]); - ASSERT_EQ(44, binary_heap_ids[0]); - ASSERT_EQ(4, group_id_to_id.at(binary_heap_ids[0])); - - // Update - os_faiss::maxheap_update( - nres, - binary_heap_values, - binary_heap_ids, - 0.2f, - 7, - &group_id_to_id, - &group_id_to_index, - 33); - - // Verify heap data - // node id 3 with group id 33 should have been replaced by node id 7 with new value - ASSERT_EQ(7, group_id_to_id.at(33)); - - // Verify heap is in order - float expectedValues[] = {4.1f, 2.1f, 1.1f, 0.2f, 0.1f}; - int64_t expectedIds[] = {4, 2, 1, 7, 6}; - for (int i = 0; i < k; i++) { - ASSERT_EQ(expectedValues[i], binary_heap_values[0]); - ASSERT_EQ(expectedIds[i], group_id_to_id.at(binary_heap_ids[0])); - faiss::maxheap_pop(nres--, binary_heap_values, binary_heap_ids); - } -}