Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled the efficient filtering support for Faiss Engine #907

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ In addition to this, the plugin has been tested with JDK 17, and this JDK versio

#### CMake

The plugin requires that cmake >= 3.17.2 is installed in order to build the JNI libraries.
The plugin requires that cmake >= 3.23.1 is installed in order to build the JNI libraries.

One easy way to install on mac or linux is to use pip:
```bash
pip install cmake==3.17.2
pip install cmake==3.23.1
```

#### Faiss Dependencies
Expand Down
4 changes: 2 additions & 2 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

cmake_minimum_required(VERSION 3.17)
cmake_minimum_required(VERSION 3.23.1)

project(KNNPlugin_JNI)

Expand Down Expand Up @@ -95,7 +95,7 @@ if (${CONFIG_NMSLIB} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST}
set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES POSITION_INDEPENDENT_CODE ON)

if (WIN32)
if (NOT "${WIN32}" STREQUAL "")
# Use RUNTIME_OUTPUT_DIRECTORY, to build the target library (opensearchknn_nmslib) in the specified directory at runtime.
set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/release)
else()
Expand Down
2 changes: 1 addition & 1 deletion jni/external/faiss
Submodule faiss updated 665 files
6 changes: 6 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ namespace knn_jni {
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ);

// Execute a query against the index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);

Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndex_WithFilter
* Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
Expand Down
163 changes: 153 additions & 10 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
#include "faiss/IndexHNSW.h"
#include "faiss/IndexIVFFlat.h"
#include "faiss/MetaIndexes.h"
#include "faiss/Index.h"
#include "faiss/impl/IDSelector.h"

#include <algorithm>
#include <jni.h>
#include <string>
#include <vector>

// Defines type of IDSelector
enum FilterIdsSelectorType{
BITMAP, BATCH
};

// Translate space type to faiss metric
faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType);
Expand All @@ -33,7 +39,19 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
const std::unordered_map<std::string, jobject>& parametersCpp, faiss::Index * index);

// Train an index with data provided
void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x);
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Create the SearchParams based on the Index Type
std::unique_ptr<faiss::SearchParameters> buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector);

// Helps to choose the right FilterIdsSelectorType for Faiss
FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

// Concerts the FilterIds to BitMap
void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector);

void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) {
Expand Down Expand Up @@ -181,12 +199,17 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ) {
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ) {

if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}

auto *indexReader = reinterpret_cast<faiss::Index*>(indexPointerJ);
auto *indexReader = reinterpret_cast<faiss::IndexIDMap *>(indexPointerJ);

if (indexReader == nullptr) {
throw std::runtime_error("Invalid pointer to index");
Expand All @@ -195,14 +218,50 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU
// The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from
// the query point
std::vector<float> dis(kJ);
std::vector<faiss::Index::idx_t> ids(kJ);
std::vector<faiss::idx_t> ids(kJ);
float* rawQueryvector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr);

try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data());
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
throw;
// create the filterSearch params if the filterIdsJ is not a null pointer
if(filterIdsJ != nullptr) {
int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = env->GetArrayLength(filterIdsJ);
std::unique_ptr<faiss::IDSelector> idSelector;
FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength);
// start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data
// during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local
// vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers.
// To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/
std::vector<faiss::idx_t> convertedIds;
std::vector<uint8_t> bitmap;
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
// Choose a selector which suits best
if(idSelectorType == BATCH) {
convertedIds.resize(filterIdsLength);
convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data());
idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data()));
} else {
int maxIdValue = filteredIdsArray[filterIdsLength - 1];
// >> 3 is equivalent to value / 8
const int bitsetArraySize = (maxIdValue >> 3) + 1;
bitmap.resize(bitsetArraySize, 0);
buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data());
idSelector.reset(new faiss::IDSelectorBitmap(filterIdsLength, bitmap.data()));
}
std::unique_ptr<faiss::SearchParameters> searchParameters = buildSearchParams(indexReader, idSelector.get());
try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters.get());
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
throw;
}
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
} else {
try {
std::cout << "Doing query" << std::endl;
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data());
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
throw;
}
}
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);

Expand All @@ -227,6 +286,33 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU
return results;
}

/**
* Based on the type of the index reader we need to return the SearchParameters. The way we do this by dynamically
* casting the IndexReader.
* @param indexReader
* @param idSelector
* @return SearchParameters
*/
std::unique_ptr<faiss::SearchParameters> buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector) {
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader) {
// we need to make this variable unique_ptr so that the scope can be shared with caller function.
std::unique_ptr<faiss::SearchParametersHNSW> hnswParams(new faiss::SearchParametersHNSW);
hnswParams->sel = idSelector;
return hnswParams;
}

auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
auto ivfFlatReader = dynamic_cast<const faiss::IndexIVFFlat*>(indexReader->index);
if(ivfReader || ivfFlatReader) {
// we need to make this variable unique_ptr so that the scope can be shared with caller function.
std::unique_ptr<faiss::SearchParametersIVF> ivfParams(new faiss::SearchParametersIVF);
ivfParams->sel = idSelector;
return ivfParams;
}
throw std::runtime_error("Invalid Index Type supported for Filtered Search on Faiss");
}

void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
Expand Down Expand Up @@ -344,7 +430,7 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
}
}

void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x) {
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexIVF*>(index)) {
if (indexIvf->quantizer_trains_alone == 2) {
InternalTrainIndex(indexIvf->quantizer, n, x);
Expand All @@ -356,3 +442,60 @@ void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float
index->train(n, x);
}
}

/**
* This function takes a call on what ID Selector to use:
* https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap
*
* class storage lookup construction(Opensearch + Faiss)
* IDSelectorArray O(k) O(k) O(2k)
* IDSelectorBatch O(k) O(1) O(2k)
* IDSelectorBitmap O(n/8) O(1) O(k) -> n is the max value of id in the index
*
* TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts:
* an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to
* 7.5M vectors.
* Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation
* M = 16
* Dimension = 128
* (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB
* Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be
* 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease.
*
* With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids
* So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size
* as factor. We need to improve on this.
*
* TODO: Best way is to implement a SparseBitSet in C++. This can be done by extending the IDSelector Interface of Faiss.
*
* @param filterIds
* @param filterIdsLength
* @return std::string
*/
FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength) {
int maxIdValue = filterIds[filterIdsLength - 1];
if(filterIdsLength * sizeof(faiss::idx_t) * 8 <= maxIdValue ) {
return BATCH;
}
return BITMAP;
}

void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) {
for (int i = 0; i < filterIdsLength; i++) {
convertedFilterIds[i] = filterIds[i];
}
}

void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) {
/**
* Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected
* iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1.
*/
for(int i = 0 ; i < filterIdsLength ; i ++) {
int value = filterIds[i];
// / , % are expensive operation. Hence, using BitShift operation as they are fast.
int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8
// (value & 7) equivalent to value % 8
bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7));
}
}
12 changes: 12 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
return nullptr;
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ) {

try {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;

}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ)
{
try {
Expand Down
33 changes: 33 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

package org.opensearch.knn.index.query;

import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
Expand All @@ -25,13 +30,25 @@ public class KNNQuery extends Query {
private final int k;
private final String indexName;

@Getter
@Setter
private Query filterQuery;

public KNNQuery(String field, float[] queryVector, int k, String indexName) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
}

public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
this.filterQuery = filterQuery;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rewrite above constructor to call this new one, only passing filter as null?


public String getField() {
return this.field;
}
Expand Down Expand Up @@ -61,9 +78,25 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true");
}
final Weight filterWeight = getFilterWeight(searcher);
if (filterWeight != null) {
return new KNNWeight(this, boost, filterWeight);
}
return new KNNWeight(this, boost);
}

private Weight getFilterWeight(IndexSearcher searcher) throws IOException {
if (this.getFilterQuery() != null) {
// Run the filter query
final BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.getFilterQuery(), BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(this.getField()), BooleanClause.Occur.FILTER)
.build();
final Query rewritten = searcher.rewrite(booleanQuery);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this call taking care of recursive part of rewrite calls?

return searcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}
return null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are aware that null is possible, can we return Optional here and check if it's empty in the client?

}

@Override
public void visit(QueryVisitor visitor) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ protected Query doToQuery(QueryShardContext context) {
);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) {
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null && knnEngine != KNNEngine.FAISS) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if condition grew up to the size of it's own method

throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
}

Expand Down
Loading