Skip to content

Commit

Permalink
Pass correct value on IDSelectorBitmap initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
heemin32 committed Feb 1, 2024
1 parent fe592f5 commit 950a15a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367)
* Fix KNNScorer to apply boost [#1403](https://github.com/opensearch-project/k-NN/pull/1403)
* Fix equals and hashCode methods for KNNQuery and KNNQueryBuilder [#1397](https://github.com/opensearch-project/k-NN/pull/1397)
* Pass correct value on IDSelectorBitmap initialization [#1444](https://github.com/opensearch-project/k-NN/pull/1444)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)
Expand Down
4 changes: 2 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
// create the filterSearch params if the filterIdsJ is not a null pointer
if(filterIdsJ != nullptr) {
int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = env->GetArrayLength(filterIdsJ);
int filterIdsLength = jniUtil->GetJavaIntArrayLength(env, filterIdsJ);
std::unique_ptr<faiss::IDSelector> idSelector;
FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength);
// start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data
Expand All @@ -248,7 +248,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
const int bitsetArraySize = (maxIdValue >> 3) + 1;
bitmap.resize(bitsetArraySize, 0);
buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data());
idSelector.reset(new faiss::IDSelectorBitmap(filterIdsLength, bitmap.data()));
idSelector.reset(new faiss::IDSelectorBitmap(bitsetArraySize, bitmap.data()));
}
faiss::SearchParameters *searchParameters;
faiss::SearchParametersHNSW hnswParams;
Expand Down
68 changes: 68 additions & 0 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,74 @@ TEST(FaissQueryIndexTest, BasicAssertions) {
}
}

//Test for a bug reported in https://github.com/opensearch-project/k-NN/issues/1435
TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) {
// Define the index data
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
std::vector<std::vector<float>> queries;

int dim = 16;
for (int64_t i = 1; i < numIds + 1; i++) {
std::vector<float> query;
query.reserve(dim);
ids.push_back(i);
for (int j = 0; j < dim; j++) {
float vector = test_util::RandomFloat(-500.0, 500.0);
vectors.push_back(vector);
query.push_back(vector);
}
queries.push_back(query);
}

std::vector<int> filterIds;
for (int64_t i = 154; i < 163; i++) {
filterIds.push_back(i);
}
std::unordered_set<int> filterIdSet(filterIds.begin(), filterIds.end());

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Create the index
std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(2, method, metricType));
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
EXPECT_CALL(mockJNIUtil,
GetJavaIntArrayLength(
jniEnv, reinterpret_cast<jintArray>(&filterIds)))
.WillRepeatedly(Return(filterIds.size()));

int k = 20;
for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(
knn_jni::faiss_wrapper::QueryIndex_WithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), k,
reinterpret_cast<jintArray>(&filterIds), nullptr)));

ASSERT_TRUE(results->size() <= filterIds.size());
ASSERT_TRUE(results->size() > 0);
for (const auto& pairPtr : *results) {
auto it = filterIdSet.find(pairPtr->first);
ASSERT_NE(it, filterIdSet.end());
}

// Need to free up each result
for (auto it : *results.get()) {
delete it;
}
}
}

TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) {
// Define the index data
faiss::idx_t numIds = 100;
Expand Down

0 comments on commit 950a15a

Please sign in to comment.