From 932ae4d0486771e0e402ef63c13c832923da8fac Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Thu, 1 Feb 2024 14:53:42 -0800 Subject: [PATCH] Pass correct value on IDSelectorBitmap initialization Signed-off-by: Heemin Kim --- CHANGELOG.md | 1 + jni/src/faiss_wrapper.cpp | 4 +- jni/tests/faiss_wrapper_test.cpp | 68 ++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b3f172cbb..b0d366420 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367) * Fix KNNScorer to apply boost [#1403](https://github.com/opensearch-project/k-NN/pull/1403) * Fix equals and hashCode methods for KNNQuery and KNNQueryBuilder [#1397](https://github.com/opensearch-project/k-NN/pull/1397) +* Pass correct value on IDSelectorBitmap initialization [#1444](https://github.com/opensearch-project/k-NN/pull/1444) ### Infrastructure * Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289) * Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307) diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 8e9deb07b..4609f3144 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -228,7 +228,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter // create the filterSearch params if the filterIdsJ is not a null pointer if(filterIdsJ != nullptr) { int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr); - int filterIdsLength = env->GetArrayLength(filterIdsJ); + int filterIdsLength = jniUtil->GetJavaIntArrayLength(env, filterIdsJ); std::unique_ptr idSelector; FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength); // start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data @@ -248,7 +248,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter const int bitsetArraySize = (maxIdValue >> 3) + 1; bitmap.resize(bitsetArraySize, 0); buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data()); - idSelector.reset(new faiss::IDSelectorBitmap(filterIdsLength, bitmap.data())); + idSelector.reset(new faiss::IDSelectorBitmap(bitsetArraySize, bitmap.data())); } faiss::SearchParameters *searchParameters; faiss::SearchParametersHNSW hnswParams; diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 5fa5165bb..ed3ec880d 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -229,6 +229,74 @@ TEST(FaissQueryIndexTest, BasicAssertions) { } } +//Test for a bug reported in https://github.com/opensearch-project/k-NN/issues/1435 +TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 200; + std::vector ids; + std::vector vectors; + std::vector> queries; + + int dim = 16; + for (int64_t i = 1; i < numIds + 1; i++) { + std::vector query; + query.reserve(dim); + ids.push_back(i); + for (int j = 0; j < dim; j++) { + float vector = test_util::RandomFloat(-500.0, 500.0); + vectors.push_back(vector); + query.push_back(vector); + } + queries.push_back(query); + } + + std::vector filterIds; + for (int64_t i = 154; i < 163; i++) { + filterIds.push_back(i); + } + std::unordered_set filterIdSet(filterIds.begin(), filterIds.end()); + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(2, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(&filterIds))) + .WillRepeatedly(Return(filterIds.size())); + + int k = 20; + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + knn_jni::faiss_wrapper::QueryIndex_WithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), k, + reinterpret_cast(&filterIds), nullptr))); + + ASSERT_TRUE(results->size() <= filterIds.size()); + ASSERT_TRUE(results->size() > 0); + for (const auto& pairPtr : *results) { + auto it = filterIdSet.find(pairPtr->first); + ASSERT_NE(it, filterIdSet.end()); + } + + // Need to free up each result + for (auto it : *results.get()) { + delete it; + } + } +} + TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) { // Define the index data faiss::idx_t numIds = 100;