From df6d1fab8951501136edcd72517721c8881281d3 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Wed, 3 Jan 2024 13:24:43 -0800 Subject: [PATCH] Add support of multi vector in jni (#1364) Signed-off-by: Heemin Kim --- jni/CMakeLists.txt | 23 +- jni/include/faiss_wrapper.h | 5 +- .../faiss/MultiVectorResultCollector.h | 69 +++++ .../faiss/MultiVectorResultCollectorFactory.h | 26 ++ .../knn_extension/faiss/utils/BitSet.h | 51 ++++ jni/include/knn_extension/faiss/utils/Heap.h | 253 ++++++++++++++++++ .../org_opensearch_knn_jni_FaissService.h | 4 +- jni/src/faiss_wrapper.cpp | 48 +++- .../faiss/MultiVectorResultCollector.cpp | 67 +++++ .../MultiVectorResultCollectorFactory.cpp | 24 ++ jni/src/knn_extension/faiss/utils/BitSet.cpp | 42 +++ .../org_opensearch_knn_jni_FaissService.cpp | 8 +- jni/tests/faiss_wrapper_test.cpp | 85 +++++- .../MultiVectorResultCollectorFactoryTest.cpp | 78 ++++++ .../faiss/MultiVectorResultCollectorTest.cpp | 96 +++++++ .../knn_extension/faiss/utils/BitSetTest.cpp | 52 ++++ .../knn_extension/faiss/utils/HeapTest.cpp | 86 ++++++ .../org/opensearch/knn/jni/FaissService.java | 29 +- .../org/opensearch/knn/jni/JNIService.java | 4 +- 19 files changed, 1022 insertions(+), 28 deletions(-) create mode 100644 jni/include/knn_extension/faiss/MultiVectorResultCollector.h create mode 100644 jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h create mode 100644 jni/include/knn_extension/faiss/utils/BitSet.h create mode 100644 jni/include/knn_extension/faiss/utils/Heap.h create mode 100644 jni/src/knn_extension/faiss/MultiVectorResultCollector.cpp create mode 100644 jni/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp create mode 100644 jni/src/knn_extension/faiss/utils/BitSet.cpp create mode 100644 jni/tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp create mode 100644 jni/tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp create mode 100644 jni/tests/knn_extension/faiss/utils/BitSetTest.cpp create mode 100644 jni/tests/knn_extension/faiss/utils/HeapTest.cpp diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 6e66e17ac7..04dca217c0 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -158,9 +158,21 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S set(FAISS_ENABLE_PYTHON OFF) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/faiss EXCLUDE_FROM_ALL) - add_library(${TARGET_LIB_FAISS} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp) + add_library( + ${TARGET_LIB_FAISS} SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/utils/BitSet.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollector.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp) target_link_libraries(${TARGET_LIB_FAISS} faiss ${TARGET_LIB_COMMON} OpenMP::OpenMP_CXX) - target_include_directories(${TARGET_LIB_FAISS} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE} ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss) + target_include_directories(${TARGET_LIB_FAISS} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss + ${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss/utils + $ENV{JAVA_HOME}/include + $ENV{JAVA_HOME}/include/${JVM_OS_TYPE} + ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss) set_target_properties(${TARGET_LIB_FAISS} PROPERTIES SUFFIX ${LIB_EXT}) set_target_properties(${TARGET_LIB_FAISS} PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -198,7 +210,12 @@ if ("${WIN32}" STREQUAL "") jni_test tests/faiss_wrapper_test.cpp tests/nmslib_wrapper_test.cpp - tests/test_util.cpp) + tests/test_util.cpp + tests/knn_extension/faiss/utils/BitSetTest.cpp + tests/knn_extension/faiss/utils/HeapTest.cpp + tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp + tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp + ) target_link_libraries( jni_test diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 284214631f..0785260001 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -13,7 +13,6 @@ #define OPENSEARCH_KNN_FAISS_WRAPPER_H #include "jni_util.h" - #include namespace knn_jni { @@ -38,13 +37,13 @@ namespace knn_jni { // // Return an array of KNNQueryResults jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ); + jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ); // Execute a query against the index located in memory at indexPointerJ along with Filters // // Return an array of KNNQueryResults jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ); + jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ); // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/include/knn_extension/faiss/MultiVectorResultCollector.h b/jni/include/knn_extension/faiss/MultiVectorResultCollector.h new file mode 100644 index 0000000000..a11a278d9e --- /dev/null +++ b/jni/include/knn_extension/faiss/MultiVectorResultCollector.h @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include "knn_extension/faiss/utils/BitSet.h" +#include + +namespace os_faiss { + +using idx_t = faiss::idx_t; +/** + * Implementation of ResultCollector to support multi vector + * + * Only supports HNSW algorithm + * + * Example: + * When there is two lucene document with two nested fields, the parent_bit_set value of 100100 is provided where + * parent doc ids are 2, and 5. Doc id for nested fields of parent document 2 are 0, and 1. Doc id for nested fields + * of parent document 5 are 3, and 4. For faiss, only nested fields are stored. Therefore corresponding doc ids for + * nested fields 0, 1, 3, 4 is 0, 1, 2, 3 in faiss. This mapping data is stored in id_map parameter. + * + * When collect method is called + * 1. It switches from faiss id to lucene id and look for its parent id. + * 2. See if the parent id already exist in heap using either parent_id_to_id or parent_id_to_index. + * 3. If it does not exist, add the parent id and distance value in the heap(bh_ids, bh_val) and update parent_id_to_id, and parent_id_to_index. + * 4. If it does exist, update the distance value(bh_val), parent_id_to_id, and parent_id_to_index. + * + * When post_process method is called + * 1. Convert lucene parent ID to faiss doc ID using parent_id_to_id + */ +struct MultiVectorResultCollector:faiss::ResultCollector { + // BitSet of lucene parent doc ID + const BitSet* parent_bit_set; + + // Mapping data from Faiss doc ID to Lucene doc ID + const std::vector* id_map; + + // Lucene parent doc ID to to Faiss doc ID + // Lucene parent doc ID to index in heap(bh_val, bh_ids) + std::unordered_map parent_id_to_id; + std::unordered_map parent_id_to_index; + MultiVectorResultCollector(const BitSet* parent_bit_set, const std::vector* id_map); + + /** + * + * @param k max size of bh_val, and bh_ids + * @param nres number of results in bh_val, and bh_ids + * @param bh_val binary heap storing values (For this case distance from query to result) + * @param bh_ids binary heap storing document IDs + * @param val a new value to add in bh_val + * @param ids a new doc id to add in bh_ids + */ + void collect( + int k, + int& nres, + float* bh_val, + int64_t* bh_ids, + float val, + int64_t ids) override; + void post_process(int64_t nres, int64_t* bh_ids) override; +}; + +} // namespace os_faiss + diff --git a/jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h b/jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h new file mode 100644 index 0000000000..45c0338b33 --- /dev/null +++ b/jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include "knn_extension/faiss/utils/BitSet.h" + +namespace os_faiss { +/** + * Create MultiVectorResultCollector for single query request + * + * Creating new collector is required because MultiVectorResultCollector has instance variables + * which should be isolated for each query. + */ +struct MultiVectorResultCollectorFactory:faiss::ResultCollectorFactory { + BitSet* parent_bit_set; + + MultiVectorResultCollectorFactory(BitSet* parent_bit_set); + faiss::ResultCollector* new_collector() override; + void delete_collector(faiss::ResultCollector* resultCollector) override; +}; + +} // namespace os_faiss diff --git a/jni/include/knn_extension/faiss/utils/BitSet.h b/jni/include/knn_extension/faiss/utils/BitSet.h new file mode 100644 index 0000000000..0b481d578d --- /dev/null +++ b/jni/include/knn_extension/faiss/utils/BitSet.h @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +using idx_t = faiss::idx_t; + +struct BitSet { + const int NO_MORE_DOCS = std::numeric_limits::max(); + /** + * Returns the index of the first set bit starting at the index specified. + * NO_MORE_DOCS is returned if there are no more set bits. + */ + virtual idx_t next_set_bit(idx_t index) const = 0; + virtual ~BitSet() = default; +}; + + +/** + * BitSet of fixed length (numBits), implemented using an array of unit64. + * See https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java + * + * Here a block is 64 bit. However, for simplicity let's assume its size is 8 bits. + * Then, if have an array of 3, 7, and 10, it will be represented in bitmap as follow. + * [0] [1] + * bitmap: 10001000 00000100 + * + * for next_set_bit call with 4 + * 1. it looks for bitmap[0] + * 2. bitmap[0] >> 4 + * 3. count trailing zero of the result from step 2 which is 3 + * 4. return 4(current index) + 3(result from step 3) + */ +struct FixedBitSet : public BitSet { + // Length of bitmap + size_t numBits; + + // Pointer to an array of uint64_t + // Using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h + uint64_t* bitmap; + + FixedBitSet(const int* int_array, const int length); + idx_t next_set_bit(idx_t index) const; + ~FixedBitSet(); +}; diff --git a/jni/include/knn_extension/faiss/utils/Heap.h b/jni/include/knn_extension/faiss/utils/Heap.h new file mode 100644 index 0000000000..08d9823119 --- /dev/null +++ b/jni/include/knn_extension/faiss/utils/Heap.h @@ -0,0 +1,253 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +// Collection of heap operations with parent id to dedupe +namespace os_faiss { + +/** + * From start_index, it compare its value with parent node's and swap if needed. + * Continue until either there is no swap or it reaches the top node. + * + * @param bh_val binary heap storing values + * @param bh_ids binary heap storing parent ids + * @param val new value to add + * @param id new id to add + * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h + * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h + * @parent_id parent id of given id + * @start_index an index to start up-heap from in the binary heap(bh_val, and bh_ids) + */ +template +static inline void up_heap( + typename C::T* bh_val, + typename C::TI* bh_ids, + typename C::T val, + typename C::TI id, + std::unordered_map* parent_id_to_id, + std::unordered_map* parent_id_to_index, + typename C::TI parent_id, + size_t start_index) { + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = start_index + 1, i_father; + + while (i > 1) { + i_father = i >> 1; + if (!C::cmp2(val, bh_val[i_father], parent_id, bh_ids[i_father])) { + /* the heap structure is ok */ + break; + } + bh_val[i] = bh_val[i_father]; + bh_ids[i] = bh_ids[i_father]; + (*parent_id_to_index)[bh_ids[i]] = i - 1; + i = i_father; + } + bh_val[i] = val; + bh_ids[i] = parent_id; + (*parent_id_to_id)[parent_id] = id; + (*parent_id_to_index)[parent_id] = i - 1; +} + +/** + * From start_index, it compare its value with child node's and swap if needed. + * Continue until either there is no swap or it reaches the leaf node. + * + * @param nres number of values in the binary heap(bh_val, and bh_ids) + * @param bh_val binary heap storing values + * @param bh_ids binary heap storing parent ids + * @param val new value to add + * @param id new id to add + * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h + * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h + * @parent_id parent id of given id + * @start_index an index to start up-heap from in the binary heap(bh_val, and bh_ids) + */ +template +static inline void down_heap( + int nres, + typename C::T* bh_val, + typename C::TI* bh_ids, + typename C::T val, + typename C::TI id, + std::unordered_map* parent_id_to_id, + std::unordered_map* parent_id_to_index, + typename C::TI parent_id, + size_t start_index) { + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = start_index + 1, i1, i2; + + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > nres) { + break; + } + + // Note that C::cmp2() is a bool function answering + // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max + // heap and same with the `<` sign for min heap. + if ((i2 == nres + 1) || + C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) { + if (C::cmp2(val, bh_val[i1], parent_id, bh_ids[i1])) { + break; + } + bh_val[i] = bh_val[i1]; + bh_ids[i] = bh_ids[i1]; + (*parent_id_to_index)[bh_ids[i]] = i - 1; + i = i1; + } else { + if (C::cmp2(val, bh_val[i2], parent_id, bh_ids[i2])) { + break; + } + bh_val[i] = bh_val[i2]; + bh_ids[i] = bh_ids[i2]; + (*parent_id_to_index)[bh_ids[i]] = i - 1; + i = i2; + } + } + bh_val[i] = val; + bh_ids[i] = parent_id; + (*parent_id_to_id)[parent_id] = id; + (*parent_id_to_index)[parent_id] = i - 1; +} + +/** + * Push the value to the max heap + * The parent_id should not exist in in bh_ids, parent_id_to_id, and parent_id_to_index. + * + * @param nres number of values in the binary heap(bh_val, and bh_ids) + * @param bh_val binary heap storing values + * @param bh_ids binary heap storing parent ids + * @param val new value to add + * @param id new id to add + * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h + * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h + * @parent_id parent id of given id + */ +template +inline void maxheap_push( + int nres, + T* bh_val, + int64_t* bh_ids, + T val, + int64_t id, + std::unordered_map* parent_id_to_id, + std::unordered_map* parent_id_to_index, + int64_t parent_id) { + + assert(parent_id_to_index->find(parent_id) != parent_id_to_index->end() && "parent id should not exist in the binary heap"); + + up_heap>( + bh_val, + bh_ids, + val, + id, + parent_id_to_id, + parent_id_to_index, + parent_id, + nres); +} + +/** + * Update the top node with given value + * The parent_id should not exist in in bh_ids, parent_id_to_id, and parent_id_to_index. + * + * @param nres number of values in the binary heap(bh_val, and bh_ids) + * @param bh_val binary heap storing values + * @param bh_ids binary heap storing parent ids + * @param val new value to add + * @param id new id to add + * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h + * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h + * @parent_id parent id of given id + */ +template +inline void maxheap_replace_top( + int nres, + T* bh_val, + int64_t* bh_ids, + T val, + int64_t id, + std::unordered_map* parent_id_to_id, + std::unordered_map* parent_id_to_index, + int64_t parent_id) { + + assert(parent_id_to_index->find(parent_id) != parent_id_to_index->end() && "parent id should not exist in the binary heap"); + + parent_id_to_id->erase(bh_ids[0]); + parent_id_to_index->erase(bh_ids[0]); + down_heap>( + nres, + bh_val, + bh_ids, + val, + id, + parent_id_to_id, + parent_id_to_index, + parent_id, + 0); +} + +/** + * Update value of the parent_id in the binary heap and id of the parent_id in parent_id_to_id + * The parent_id should exist in bh_ids, parent_id_to_id, and parent_id_to_index. + * + * @param nres number of values in the binary heap(bh_val, and bh_ids) + * @param bh_val binary heap storing values + * @param bh_ids binary heap storing parent ids + * @param val new value to update + * @param id new id to update + * @parent_id_to_id parent doc id to id mapping data, see MultiVectorResultCollector.h + * @parent_id_to_index parent doc id to index mapping data, see MultiVectorResultCollector.h + * @parent_id parent id of given id + */ +template +inline void maxheap_update( + int nres, + T* bh_val, + int64_t* bh_ids, + T val, + int64_t id, + std::unordered_map* parent_id_to_id, + std::unordered_map* parent_id_to_index, + int64_t parent_id) { + size_t target_index = parent_id_to_index->at(parent_id); + up_heap>( + bh_val, + bh_ids, + val, + id, + parent_id_to_id, + parent_id_to_index, + parent_id, + target_index); + down_heap>( + nres, + bh_val, + bh_ids, + val, + id, + parent_id_to_id, + parent_id_to_index, + parent_id, + target_index); +} + +} // namespace os_faiss diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index a252643355..aefadcee46 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -48,7 +48,7 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex * Signature: (J[FI)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex - (JNIEnv *, jclass, jlong, jfloatArray, jint); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray); /* * Class: org_opensearch_knn_jni_FaissService @@ -56,7 +56,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd * Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray, jintArray); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index e8fb4de201..8e9deb07b4 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -11,6 +11,7 @@ #include "jni_util.h" #include "faiss_wrapper.h" +#include "knn_extension/faiss/MultiVectorResultCollectorFactory.h" #include "faiss/impl/io.h" #include "faiss/index_factory.h" @@ -50,6 +51,10 @@ void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, fa // Concerts the FilterIds to BitMap void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector); +os_faiss::MultiVectorResultCollectorFactory* buildResultCollectorFactory(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ); + +void releaseResultCollectorFactory(os_faiss::MultiVectorResultCollectorFactory* collectorFactory); + void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) { @@ -195,13 +200,12 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI } jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ) { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr); + jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) { + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, parentIdsJ); } jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ) { - + jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); } @@ -255,6 +259,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter // value of ef_search = 16 which will then be used. hnswParams.efSearch = hnswReader->hnsw.efSearch; hnswParams.sel = idSelector.get(); + hnswParams.col = buildResultCollectorFactory(jniUtil, env, parentIdsJ); searchParameters = &hnswParams; } else { auto ivfReader = dynamic_cast(indexReader->index); @@ -269,16 +274,30 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter } catch (...) { jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + releaseResultCollectorFactory(dynamic_cast(hnswParams.col)); throw; } jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + releaseResultCollectorFactory(dynamic_cast(hnswParams.col)); } else { + faiss::SearchParameters *searchParameters = nullptr; + faiss::SearchParametersHNSW hnswParams; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader!= nullptr && parentIdsJ != nullptr) { + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams.efSearch = hnswReader->hnsw.efSearch; + hnswParams.col = buildResultCollectorFactory(jniUtil, env, parentIdsJ); + searchParameters = &hnswParams; + } try { - indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data()); + indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters); } catch (...) { jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + releaseResultCollectorFactory(dynamic_cast(hnswParams.col)); throw; } + releaseResultCollectorFactory(dynamic_cast(hnswParams.col)); } jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); @@ -489,3 +508,22 @@ void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bi bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7)); } } + +os_faiss::MultiVectorResultCollectorFactory* buildResultCollectorFactory(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ) { + if (parentIdsJ == nullptr) { + return nullptr; + } + int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); + int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); + auto* parent_id_filter = new FixedBitSet(parentIdsArray, parentIdsLength); + jniUtil->ReleaseIntArrayElements(env, parentIdsJ, parentIdsArray, JNI_ABORT); + return new os_faiss::MultiVectorResultCollectorFactory(parent_id_filter); +} + +void releaseResultCollectorFactory(os_faiss::MultiVectorResultCollectorFactory* collectorFactory) { + if (collectorFactory == nullptr) { + return; + } + delete collectorFactory->parent_bit_set; + delete collectorFactory; +} diff --git a/jni/src/knn_extension/faiss/MultiVectorResultCollector.cpp b/jni/src/knn_extension/faiss/MultiVectorResultCollector.cpp new file mode 100644 index 0000000000..a7564d3aab --- /dev/null +++ b/jni/src/knn_extension/faiss/MultiVectorResultCollector.cpp @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "MultiVectorResultCollector.h" +#include "knn_extension/faiss/utils/Heap.h" +#include "knn_extension/faiss/utils/BitSet.h" + +namespace os_faiss { + +using idx_t = faiss::idx_t; + +MultiVectorResultCollector::MultiVectorResultCollector(const BitSet* parent_bit_set, const std::vector* id_map) +: parent_bit_set(parent_bit_set), id_map(id_map) {} + +void MultiVectorResultCollector::collect( + int k, + int& nres, + float* bh_val, + int64_t* bh_ids, + float val, + int64_t ids) { + idx_t group_id = id_map ? parent_bit_set->next_set_bit(id_map->at(ids)) : parent_bit_set->next_set_bit(ids); + if (parent_id_to_index.find(group_id) == + parent_id_to_index.end()) { + if (nres < k) { + maxheap_push( + nres++, + bh_val, + bh_ids, + val, + ids, + &parent_id_to_id, + &parent_id_to_index, + group_id); + } else if (val < bh_val[0]) { + maxheap_replace_top( + nres, + bh_val, + bh_ids, + val, + ids, + &parent_id_to_id, + &parent_id_to_index, + group_id); + } + } else if (val < bh_val[parent_id_to_index.at(group_id)]) { + maxheap_update( + nres, + bh_val, + bh_ids, + val, + ids, + &parent_id_to_id, + &parent_id_to_index, + group_id); + } +} + +void MultiVectorResultCollector::post_process(int64_t nres, int64_t* bh_ids) { + for (size_t icnt = 0; icnt < nres; icnt++) { + bh_ids[icnt] = parent_id_to_id.at(bh_ids[icnt]); + } +} + +} // namespace os_faiss diff --git a/jni/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp b/jni/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp new file mode 100644 index 0000000000..f4c7c0656d --- /dev/null +++ b/jni/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "MultiVectorResultCollectorFactory.h" +#include "MultiVectorResultCollector.h" + +namespace os_faiss { + +MultiVectorResultCollectorFactory::MultiVectorResultCollectorFactory(BitSet* parent_bit_set) + : parent_bit_set(parent_bit_set) {} + +// id_map is set in IndexIDMap.cpp of faiss library with custom patch +// https://github.com/opensearch-project/k-NN/blob/feature/multi-vector/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch#L109 +faiss::ResultCollector* MultiVectorResultCollectorFactory::new_collector() { + return new MultiVectorResultCollector(parent_bit_set, id_map); +} + +void MultiVectorResultCollectorFactory::delete_collector(faiss::ResultCollector* resultCollector) { + delete resultCollector; +} + +} // namespace os_faiss diff --git a/jni/src/knn_extension/faiss/utils/BitSet.cpp b/jni/src/knn_extension/faiss/utils/BitSet.cpp new file mode 100644 index 0000000000..90cd7d1f0d --- /dev/null +++ b/jni/src/knn_extension/faiss/utils/BitSet.cpp @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "BitSet.h" + +FixedBitSet::FixedBitSet(const int* int_array, const int length){ + assert(int_array && "int_array should not be null"); + const int* maxValue = std::max_element(int_array, int_array + length); + this->numBits = (*maxValue >> 6) + 1; // div by 64 + this->bitmap = new uint64_t[this->numBits](); + for(int i = 0 ; i < length ; i ++) { + int value = int_array[i]; + int bitsetArrayIndex = value >> 6; + this->bitmap[bitsetArrayIndex] |= 1ULL << (value & 63); // Equivalent of 1ULL << (value % 64) + } +} + +idx_t FixedBitSet::next_set_bit(idx_t index) const { + idx_t i = index >> 6; // div by 64 + uint64_t word = this->bitmap[i] >> (index & 63); // Equivalent of bitmap[i] >> (index % 64) + + if (word != 0) { + return index + __builtin_ctzll(word); + } + + while (++i < this->numBits) { + word = this->bitmap[i]; + if (word != 0) { + return (i << 6) + __builtin_ctzll(word); + } + } + + return NO_MORE_DOCS; +} + +FixedBitSet::~FixedBitSet() { + delete this->bitmap; +} diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 1b79d91143..a7b24fcab7 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -77,10 +77,10 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEn JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex(JNIEnv * env, jclass cls, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ) + jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ); + return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); @@ -89,10 +89,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd } JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ) { + (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ); + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index abe4ecb205..5fa5165bb0 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -40,7 +40,7 @@ TEST(FaissCreateIndexTest, BasicAssertions) { std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); std::string spaceType = knn_jni::L2; - std::string index_description = "Flat"; // TODO: Revert bach to HNSW32,Flat + std::string index_description = "HNSW32,Flat"; std::unordered_map parametersMap; parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; @@ -87,7 +87,7 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); faiss::MetricType metricType = faiss::METRIC_L2; - std::string method = "Flat"; // TODO: Revert bach to HNSW32,Flat + std::string method = "HNSW32,Flat"; std::unique_ptr createdIndex( test_util::FaissCreateIndex(dim, method, metricType)); @@ -135,7 +135,7 @@ TEST(FaissLoadIndexTest, BasicAssertions) { std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); faiss::MetricType metricType = faiss::METRIC_L2; - std::string method = "Flat"; // TODO: Revert bach to HNSW32,Flat + std::string method = "HNSW32,Flat"; // Create the index std::unique_ptr createdIndex( @@ -186,7 +186,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) { } faiss::MetricType metricType = faiss::METRIC_L2; - std::string method = "Flat"; // TODO: Revert bach to HNSW32,Flat + std::string method = "HNSW32,Flat"; // Define query data int k = 10; @@ -218,7 +218,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::QueryIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), k))); + reinterpret_cast(&query), k, nullptr))); ASSERT_EQ(k, results->size()); @@ -229,11 +229,84 @@ TEST(FaissQueryIndexTest, BasicAssertions) { } } +TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 100; + std::vector ids; + std::vector vectors; + std::vector parentIds; + int dim = 16; + for (int64_t i = 1; i < numIds + 1; i++) { + if (i % 10 == 0) { + parentIds.push_back(i); + continue; + } + ids.push_back(i); + for (int j = 0; j < dim; j++) { + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + int k = 20; + int numQueries = 100; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(2, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(&parentIds))) + .WillRepeatedly(Return(parentIds.size())); + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + knn_jni::faiss_wrapper::QueryIndex( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), k, + reinterpret_cast(&parentIds)))); + + // Even with k 20, result should have only 10 which is total number of groups + ASSERT_EQ(10, results->size()); + // Result should be one for each group + std::set idSet; + for (const auto& pairPtr : *results) { + idSet.insert(pairPtr->first / 10); + } + ASSERT_EQ(10, idSet.size()); + + // Need to free up each result + for (auto it : *results.get()) { + delete it; + } + } +} + TEST(FaissFreeTest, BasicAssertions) { // Define the data int dim = 2; faiss::MetricType metricType = faiss::METRIC_L2; - std::string method = "Flat"; // TODO: Revert bach to HNSW32,Flat + std::string method = "HNSW32,Flat"; // Create the index faiss::Index *createdIndex( diff --git a/jni/tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp b/jni/tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp new file mode 100644 index 0000000000..3177360bbd --- /dev/null +++ b/jni/tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "knn_extension/faiss/MultiVectorResultCollectorFactory.h" +#include "knn_extension/faiss/MultiVectorResultCollector.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "jni_util.h" + +using ::testing::NiceMock; +using ::testing::Return; +using idx_t = faiss::idx_t; + + +TEST(MultiVectorResultCollectorFactoryTest, BasicAssertions) { + int parent_ids[1] = {1}; + FixedBitSet parent_id_filter(parent_ids, 1); + + std::unordered_map distance1; + distance1[0] = 10; + distance1[1] = 11; + + std::unordered_map distance2; + distance2[0] = 11; + distance2[1] = 10; + + os_faiss::MultiVectorResultCollectorFactory* rc_factory = new os_faiss::MultiVectorResultCollectorFactory(&parent_id_filter); + faiss::ResultCollector* rc1 = rc_factory->new_collector(); + faiss::ResultCollector* rc2 = rc_factory->new_collector(); + ASSERT_NE(rc1, rc2); + + int k = 1; + int nres1 = 0; + int nres2 = 0; + float* bh_val = new float[k * 2]; + int64_t* bh_ids = new int64_t[k * 2]; + // Verify two collector are thread safe each other. + // Simulate multi thread by interleaving collect methods of two ResultCollectors. + for (int i = 0; i < distance1.size(); i++) { + rc1->collect(k, nres1, bh_val, bh_ids, distance1.at(i), i); + rc2->collect(k, nres2, bh_val + k, bh_ids + k, distance2.at(i), i); + } + rc1->post_process(nres1, bh_ids); + rc2->post_process(nres2, bh_ids + k); + + ASSERT_EQ(0, bh_ids[0]); + ASSERT_EQ(1, bh_ids[1]); + + rc_factory->delete_collector(rc1); + rc_factory->delete_collector(rc2); + delete rc_factory; + delete[] bh_val; + delete[] bh_ids; +} + +// Verify that id_map is passed to collector +TEST(MultiVectorResultCollectorFactoryWithIdMapTest, BasicAssertions) { + int parent_ids[1] = {1}; + FixedBitSet parent_id_filter(parent_ids, 1); + std::vector id_map; + + os_faiss::MultiVectorResultCollectorFactory* rc_factory = new os_faiss::MultiVectorResultCollectorFactory(&parent_id_filter); + os_faiss::MultiVectorResultCollector* rc1 = dynamic_cast(rc_factory->new_collector()); + ASSERT_EQ(nullptr, rc1->id_map); + + rc_factory->id_map = &id_map; + os_faiss::MultiVectorResultCollector* rc2 = dynamic_cast(rc_factory->new_collector()); + ASSERT_EQ(&id_map, rc2->id_map); + + rc_factory->delete_collector(rc1); + rc_factory->delete_collector(rc2); + delete rc_factory; +} diff --git a/jni/tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp b/jni/tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp new file mode 100644 index 0000000000..934d708dd1 --- /dev/null +++ b/jni/tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "knn_extension/faiss/MultiVectorResultCollector.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::NiceMock; +using ::testing::Return; +using idx_t = faiss::idx_t; + + +TEST(MultiVectorResultCollectorTest, BasicAssertions) { + // Data + // Parent ID: 2, ID: 0, Distance: 10 + // Parent ID: 2, ID: 1, Distance: 11 + // Parent ID: 5, ID: 3, Distance: 12 + // Parent ID: 5, ID: 4, Distance: 13 + // After collector handing the data with k = 3, it should return data with id 0 and 2, one from each group. + // Parent bit set representation: 100100 + int parent_ids[2] = {2, 5}; + FixedBitSet parent_id_filter(parent_ids, 2); + + idx_t ids[] = {0, 1, 2, 3}; + float distances[] = {10, 11, 12, 13}; + + os_faiss::MultiVectorResultCollector* rc = new os_faiss::MultiVectorResultCollector(&parent_id_filter, nullptr); + int k = 3; + int nres = 0; + float* bh_val = new float[k]; + int64_t* bh_ids = new int64_t[k]; + for (int i = 0; i < 4; i++) { + rc->collect(k, nres, bh_val, bh_ids, distances[i], ids[i]); + } + + // Parent ID is stored before finalize + ASSERT_EQ(5, bh_ids[0]); + ASSERT_EQ(2, bh_ids[1]); + + rc->post_process(nres, bh_ids); + + // Parent ID is converted to ID after finalize + ASSERT_EQ(3, bh_ids[0]); + ASSERT_EQ(0, bh_ids[1]); + + delete rc; + delete[] bh_val; + delete[] bh_ids; +} + +TEST(MultiVectorResultCollectorWithIDMapTest, BasicAssertions) { + // Data + // Parent ID: 2, Lucene ID: 0, Faiss ID: 0, Distance: 10 + // Parent ID: 2, Lucene ID: 1, Faiss ID: 1, Distance: 11 + // Parent ID: 5, Lucene ID: 3, Faiss ID: 2, Distance: 12 + // Parent ID: 5, Lucene ID: 4, Faiss ID: 3, Distance: 13 + // After collector handing the data with k = 3, it should return data with id 0 and 2, one from each group. + + // Parent bit set representation with Lucene ID: 100100 + int parent_ids[2] = {2, 5}; + FixedBitSet parent_id_filter(parent_ids, 2); + + idx_t faiss_ids[] = {0, 1, 2, 3}; + float distances[] = {10, 11, 12, 13}; + + // Faiss IDs to Lucene ID mapping + std::vector id_map = {0, 1, 3, 4}; + + os_faiss::MultiVectorResultCollector* rc = new os_faiss::MultiVectorResultCollector(&parent_id_filter, &id_map); + int k = 3; + int nres = 0; + float* bh_val = new float[k]; + int64_t* bh_ids = new int64_t[k]; + for (int i = 0; i < 4; i++) { + rc->collect(k, nres, bh_val, bh_ids, distances[i], faiss_ids[i]); + } + + // Parent ID is stored before finalize + ASSERT_EQ(5, bh_ids[0]); + ASSERT_EQ(2, bh_ids[1]); + + rc->post_process(nres, bh_ids); + + // Parent ID is converted to Faiss ID after finalize + ASSERT_EQ(2, bh_ids[0]); + ASSERT_EQ(0, bh_ids[1]); + + delete rc; + delete[] bh_val; + delete[] bh_ids; +} diff --git a/jni/tests/knn_extension/faiss/utils/BitSetTest.cpp b/jni/tests/knn_extension/faiss/utils/BitSetTest.cpp new file mode 100644 index 0000000000..96ad6b3c21 --- /dev/null +++ b/jni/tests/knn_extension/faiss/utils/BitSetTest.cpp @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "knn_extension/faiss/utils/BitSet.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::NiceMock; +using ::testing::Return; +using idx_t = faiss::idx_t; + +TEST(FixedBitSetTest, BasicAssertions) { + int ids1[4] = {3, 7, 11, 15}; + FixedBitSet single_block(ids1, 4); + + ASSERT_EQ(3, single_block.next_set_bit(0)); + ASSERT_EQ(3, single_block.next_set_bit(1)); + ASSERT_EQ(3, single_block.next_set_bit(2)); + ASSERT_EQ(3, single_block.next_set_bit(3)); + ASSERT_EQ(7, single_block.next_set_bit(4)); + ASSERT_EQ(7, single_block.next_set_bit(5)); + ASSERT_EQ(7, single_block.next_set_bit(6)); + ASSERT_EQ(7, single_block.next_set_bit(7)); + ASSERT_EQ(11, single_block.next_set_bit(8)); + ASSERT_EQ(11, single_block.next_set_bit(9)); + ASSERT_EQ(11, single_block.next_set_bit(10)); + ASSERT_EQ(11, single_block.next_set_bit(11)); + ASSERT_EQ(15, single_block.next_set_bit(12)); + ASSERT_EQ(15, single_block.next_set_bit(13)); + ASSERT_EQ(15, single_block.next_set_bit(14)); + ASSERT_EQ(15, single_block.next_set_bit(15)); + ASSERT_EQ(single_block.NO_MORE_DOCS, single_block.next_set_bit(16)); + + int ids2[5] = {64, 128, 127, 1024, 34565}; + int ids2_sorted[5]; + std::copy(ids2, ids2 + 5, ids2_sorted); + std::sort(ids2_sorted, ids2_sorted + 5); + FixedBitSet multi_blocks(ids2, 5); + int parent_index = 0; + for (int i = 0; i < ids2[4] + 1; i++) { + ASSERT_EQ(ids2_sorted[parent_index], multi_blocks.next_set_bit(i)); + if (ids2_sorted[parent_index] == i) { + parent_index++; + } + } + ASSERT_EQ(multi_blocks.NO_MORE_DOCS, multi_blocks.next_set_bit(ids2[4] + 1)); +} diff --git a/jni/tests/knn_extension/faiss/utils/HeapTest.cpp b/jni/tests/knn_extension/faiss/utils/HeapTest.cpp new file mode 100644 index 0000000000..97a30babd1 --- /dev/null +++ b/jni/tests/knn_extension/faiss/utils/HeapTest.cpp @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "knn_extension/faiss/utils/Heap.h" +#include "faiss/utils/Heap.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::ElementsAreArray; + +TEST(MaxHeapUpdateTest, BasicAssertions) { + const int k = 5; + int nres = 0; + float binary_heap_values[k]; + int64_t binary_heap_ids[k]; + float input_values[] = {1.1f, 2.1f, 3.1f, 4.1f, 5.1f}; + int64_t input_ids[] = {1, 2, 3, 4, 5}; + int64_t group_ids[] = {11, 22, 33, 44, 55}; + std::unordered_map group_id_to_id; + std::unordered_map group_id_to_index; + + // Push + for (int i = 0; i < k; i++) { + os_faiss::maxheap_push( + nres++, + binary_heap_values, + binary_heap_ids, + input_values[i], + input_ids[i], + &group_id_to_id, + &group_id_to_index, + group_ids[i]); + } + + // Verify heap data + // The top node in the max heap should be the one with max value(5.1f) + ASSERT_EQ(5.1f, binary_heap_values[0]); + ASSERT_EQ(55, binary_heap_ids[0]); + ASSERT_EQ(5, group_id_to_id.at(binary_heap_ids[0])); + + // Replace top + os_faiss::maxheap_replace_top( + nres, + binary_heap_values, + binary_heap_ids, + 0.1f, + 6, + &group_id_to_id, + &group_id_to_index, + 66); + + // Verify heap data + // Previous top value(5.1f) should have been removed and the next max value(4.1f) should be in the top node. + ASSERT_EQ(4.1f, binary_heap_values[0]); + ASSERT_EQ(44, binary_heap_ids[0]); + ASSERT_EQ(4, group_id_to_id.at(binary_heap_ids[0])); + + // Update + os_faiss::maxheap_update( + nres, + binary_heap_values, + binary_heap_ids, + 0.2f, + 7, + &group_id_to_id, + &group_id_to_index, + 33); + + // Verify heap data + // node id 3 with group id 33 should have been replaced by node id 7 with new value + ASSERT_EQ(7, group_id_to_id.at(33)); + + // Verify heap is in order + float expectedValues[] = {4.1f, 2.1f, 1.1f, 0.2f, 0.1f}; + int64_t expectedIds[] = {4, 2, 1, 7, 6}; + for (int i = 0; i < k; i++) { + ASSERT_EQ(expectedValues[i], binary_heap_values[0]); + ASSERT_EQ(expectedIds[i], group_id_to_id.at(binary_heap_ids[0])); + faiss::maxheap_pop(nres--, binary_heap_values, binary_heap_ids); + } +} diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 5dce15d6e0..abf3e052a9 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -74,16 +74,39 @@ public static native void createIndexFromTemplate( public static native long loadIndex(String indexPath); /** - * Query an index + * Query an index without filter + * + * If the "knn" field is a nested field, each vector value within that nested field will be assigned its + * own document ID. In this situation, the term "parent ID" corresponds to the original document ID. + * The arrangement of parent IDs and nested field IDs is assured to have all nested field IDs appearing first, + * followed by the parent ID, in consecutive order without any gaps. Because of this ID pattern, + * we can determine the parent ID of a specific nested field ID using only an array of parent IDs. * * @param indexPointer pointer to index in memory * @param queryVector vector to be used for query * @param k neighbors to be returned + * @param parentIds list of parent doc ids when the knn field is a nested field * @return KNNQueryResult array of k neighbors */ - public static native KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k); + public static native KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, int[] parentIds); - public static native KNNQueryResult[] queryIndexWithFilter(long indexPointer, float[] queryVector, int k, int[] filterIds); + /** + * Query an index with filter + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param filterIds list of doc ids to include in the query result + * @param parentIds list of parent doc ids when the knn field is a nested field + * @return KNNQueryResult array of k neighbors + */ + public static native KNNQueryResult[] queryIndexWithFilter( + long indexPointer, + float[] queryVector, + int k, + int[] filterIds, + int[] parentIds + ); /** * Free native memory pointer diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index f45fb0c736..42d12e984c 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -112,9 +112,9 @@ public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine // normally. if (ArrayUtils.isNotEmpty(filteredIds)) { - return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds); + return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, null); } - return FaissService.queryIndex(indexPointer, queryVector, k); + return FaissService.queryIndex(indexPointer, queryVector, k, null); } throw new IllegalArgumentException("QueryIndex not supported for provided engine"); }