Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Faiss Query With Filters: Reduce iteration and memory for id filter #1402

Merged
merged 22 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9015089
Optimize Faiss Query With Filters. Reduce iteration copy for docid se…
luyuncheng Jan 19, 2024
103307c
Optimize Faiss Query With Filters. Reduce iteration copy for docid se…
luyuncheng Jan 22, 2024
c2cd334
Using int64_t instead of long type for GetLongArrayElements
luyuncheng Jan 24, 2024
bfbfb55
Add IDSelectorJlongBitmap
luyuncheng Jan 29, 2024
a82e59f
1. Add IDSelectorJlongBitmap and UT for it
luyuncheng Jan 31, 2024
965621f
1. Add IDSelectorJlongBitmap and UT for it
luyuncheng Jan 31, 2024
da85df8
Rebase remote-tracking branch 'origin/main' into Filter
luyuncheng Feb 6, 2024
f5a7f95
Merge remote-tracking branch 'origin/main' into Filter
luyuncheng Feb 6, 2024
263a575
Rebase remote-tracking branch 'origin/main' into Filter
luyuncheng Feb 6, 2024
dcef6c2
tidy
luyuncheng Feb 6, 2024
9ca9c98
Add Changelog
luyuncheng Feb 6, 2024
568972f
fix javadoc tasks
luyuncheng Feb 7, 2024
a2b27ee
fix bwc javadoc
luyuncheng Feb 7, 2024
a48a928
UpdatedFilterIdsSelector
luyuncheng Feb 7, 2024
5d303e7
Merge branch 'main' into Filter
luyuncheng Feb 8, 2024
1aed422
UpdatedFilterIdsSelector
luyuncheng Feb 16, 2024
b8e961c
Merge remote-tracking branch 'origin/main' into Filter
luyuncheng Feb 16, 2024
3970f98
Rebase faiss_wrapper.cpp
luyuncheng Feb 16, 2024
f5ebc1a
UpdatedFilterIdsSelector For description Select different FilterIdsSe…
luyuncheng Feb 26, 2024
3e1aaee
UpdatedFilterIdsSelector For description Select different FilterIdsSe…
luyuncheng Feb 26, 2024
f747ec7
UpdatedFilterIdsSelector as Byte.SIZE
luyuncheng Mar 4, 2024
b042b62
UpdatedFilterIdsSelector For comments
luyuncheng Mar 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.12...2.x)
### Features
### Enhancements
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
3 changes: 2 additions & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ namespace knn_jni {
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ);
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ,
jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
Expand Down
9 changes: 9 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ namespace knn_jni {

virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0;

virtual int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) = 0;

virtual int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) = 0;

virtual int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ) = 0;
Expand All @@ -94,6 +96,8 @@ namespace knn_jni {

virtual jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy) = 0;

virtual jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) = 0;

virtual jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) = 0;

virtual jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance) = 0;
Expand All @@ -108,6 +112,8 @@ namespace knn_jni {

virtual void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode) = 0;

virtual void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) = 0;

virtual void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) = 0;

virtual void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) = 0;
Expand Down Expand Up @@ -139,20 +145,23 @@ namespace knn_jni {
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ);
int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ);

void DeleteLocalRef(JNIEnv *env, jobject obj);
jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy);
jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy);
jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy);
jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy);
jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index);
jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance);
jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init);
jbyteArray NewByteArray(JNIEnv *env, jsize len);
void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode);
void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode);
void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode);
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode);
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);

Expand Down
2 changes: 1 addition & 1 deletion jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
* Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray, jintArray);
(JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
120 changes: 35 additions & 85 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,32 @@

// Defines type of IDSelector
enum FilterIdsSelectorType{
BITMAP, BATCH
BITMAP = 0, BATCH = 1,
};
namespace faiss {

// Using jlong to do Bitmap selector, jlong[] equals to lucene FixedBitSet#bits
struct IDSelectorJlongBitmap : IDSelector {
luyuncheng marked this conversation as resolved.
Show resolved Hide resolved
size_t n;
const jlong* bitmap;

/** Construct with a binary mask like Lucene FixedBitSet
*
* @param n size of the bitmap array
* @param bitmap id like Lucene FixedBitSet bits
*/
IDSelectorJlongBitmap(size_t n, const jlong* bitmap) : n(n), bitmap(bitmap) {};
bool is_member(idx_t id) const final {
uint64_t index = id;
uint64_t i = index >> 6; // div 64
if (i >= n ) {
return false;
}
return (bitmap[i] >> ( index & 63)) & 1L;
}
~IDSelectorJlongBitmap() override {}
};
}
// Translate space type to faiss metric
faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType);

Expand All @@ -42,9 +65,6 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
// Train an index with data provided
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Helps to choose the right FilterIdsSelectorType for Faiss
FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

Expand Down Expand Up @@ -199,11 +219,12 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, parentIdsJ);
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0, parentIdsJ);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ) {
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}
Expand All @@ -225,28 +246,14 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
omp_set_num_threads(1);
// create the filterSearch params if the filterIdsJ is not a null pointer
if(filterIdsJ != nullptr) {
int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = jniUtil->GetJavaIntArrayLength(env, filterIdsJ);
jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ);
std::unique_ptr<faiss::IDSelector> idSelector;
FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength);
// start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data
// during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local
// vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers.
// To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/
std::vector<faiss::idx_t> convertedIds;
std::vector<uint8_t> bitmap;
// Choose a selector which suits best
if(idSelectorType == BATCH) {
convertedIds.resize(filterIdsLength);
convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data());
idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data()));
if(filterIdsTypeJ == BITMAP) {
idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray));
luyuncheng marked this conversation as resolved.
Show resolved Hide resolved
} else {
int maxIdValue = filteredIdsArray[filterIdsLength - 1];
// >> 3 is equivalent to value / 8
const int bitsetArraySize = (maxIdValue >> 3) + 1;
bitmap.resize(bitsetArraySize, 0);
buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data());
idSelector.reset(new faiss::IDSelectorBitmap(bitsetArraySize, bitmap.data()));
faiss::idx_t* batchIndices = reinterpret_cast<faiss::idx_t*>(filteredIdsArray);
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices));
}
faiss::SearchParameters *searchParameters;
faiss::SearchParametersHNSW hnswParams;
Expand Down Expand Up @@ -276,10 +283,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters);
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
throw;
}
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
Expand Down Expand Up @@ -454,63 +461,6 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
}
}

/**
* This function takes a call on what ID Selector to use:
* https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap
*
* class storage lookup construction(Opensearch + Faiss)
* IDSelectorArray O(k) O(k) O(2k)
* IDSelectorBatch O(k) O(1) O(2k)
* IDSelectorBitmap O(n/8) O(1) O(k) -> n is the max value of id in the index
*
* TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts:
* an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to
* 7.5M vectors.
* Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation
* M = 16
* Dimension = 128
* (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB
* Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be
* 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease.
*
* With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids
* So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size
* as factor. We need to improve on this.
*
* TODO: Best way is to implement a SparseBitSet in C++. This can be done by extending the IDSelector Interface of Faiss.
*
* @param filterIds
* @param filterIdsLength
* @return std::string
*/
FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength) {
int maxIdValue = filterIds[filterIdsLength - 1];
if(filterIdsLength * sizeof(faiss::idx_t) * 8 <= maxIdValue ) {
return BATCH;
}
return BITMAP;
}

void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) {
for (int i = 0; i < filterIdsLength; i++) {
convertedFilterIds[i] = filterIds[i];
}
}

void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) {
/**
* Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected
* iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1.
*/
for(int i = 0 ; i < filterIdsLength ; i ++) {
int value = filterIds[i];
// / , % are expensive operation. Hence, using BitShift operation as they are fast.
int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8
// (value & 7) equivalent to value % 8
bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7));
}
}

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap) {
int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr);
int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ);
Expand Down
26 changes: 26 additions & 0 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ int knn_jni::JNIUtil::GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) {
return length;
}

int knn_jni::JNIUtil::GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) {

if (arrayJ == nullptr) {
throw std::runtime_error("Array cannot be null");
}

int length = env->GetArrayLength(arrayJ);
this->HasExceptionInStack(env, "Unable to get array length");
return length;
}

int knn_jni::JNIUtil::GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) {

if (arrayJ == nullptr) {
Expand Down Expand Up @@ -376,6 +387,17 @@ jint * knn_jni::JNIUtil::GetIntArrayElements(JNIEnv *env, jintArray array, jbool
return intArray;
}

jlong * knn_jni::JNIUtil::GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) {
// Lets check for error here
jlong * longArray = env->GetLongArrayElements(array, isCopy);
if (longArray == nullptr) {
this->HasExceptionInStack(env, "Unable to get long array");
throw std::runtime_error("Unable to get long array");
}

return longArray;
}

jobject knn_jni::JNIUtil::GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) {
jobject object = env->GetObjectArrayElement(array, index);
this->HasExceptionInStack(env, "Unable to get object");
Expand Down Expand Up @@ -424,6 +446,10 @@ void knn_jni::JNIUtil::ReleaseIntArrayElements(JNIEnv *env, jintArray array, jin
env->ReleaseIntArrayElements(array, elems, mode);
}

void knn_jni::JNIUtil::ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) {
env->ReleaseLongArrayElements(array, elems, mode);
}

void knn_jni::JNIUtil::SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) {
env->SetObjectArrayElement(array, index, val);
this->HasExceptionInStack(env, "Unable to set object array element");
Expand Down
4 changes: 2 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ, jintArray parentIdsJ) {
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

try {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, parentIdsJ);
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
14 changes: 9 additions & 5 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,13 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) {
queries.push_back(query);
}

std::vector<int> filterIds;
int num_bits = test_util::bits2words(164);
std::vector<jlong> bitmap(num_bits,0);
std::vector<int64_t> filterIds;

for (int64_t i = 154; i < 163; i++) {
filterIds.push_back(i);
test_util::setBitSet(i, bitmap.data(), bitmap.size());
}
std::unordered_set<int> filterIdSet(filterIds.begin(), filterIds.end());

Expand All @@ -270,9 +274,9 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) {
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
EXPECT_CALL(mockJNIUtil,
GetJavaIntArrayLength(
jniEnv, reinterpret_cast<jintArray>(&filterIds)))
.WillRepeatedly(Return(filterIds.size()));
GetJavaLongArrayLength(
jniEnv, reinterpret_cast<jlongArray>(&bitmap)))
.WillRepeatedly(Return(bitmap.size()));

int k = 20;
for (auto query : queries) {
Expand All @@ -282,7 +286,7 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) {
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), k,
reinterpret_cast<jintArray>(&filterIds), nullptr)));
reinterpret_cast<jlongArray>(&bitmap), 0, nullptr)));

ASSERT_TRUE(results->size() <= filterIds.size());
ASSERT_TRUE(results->size() > 0);
Expand Down
Loading
Loading