Skip to content

Commit

Permalink
Enabled the efficient filtering support for Faiss Engine (#907)
Browse files Browse the repository at this point in the history
  • Loading branch information
navneet1v authored Jun 2, 2023
1 parent f11f1f1 commit 119b8d6
Show file tree
Hide file tree
Showing 26 changed files with 416 additions and 62 deletions.
4 changes: 2 additions & 2 deletions DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ In addition to this, the plugin has been tested with JDK 17, and this JDK versio

#### CMake

The plugin requires that cmake >= 3.17.2 is installed in order to build the JNI libraries.
The plugin requires that cmake >= 3.23.1 is installed in order to build the JNI libraries.

One easy way to install on mac or linux is to use pip:
```bash
pip install cmake==3.17.2
pip install cmake==3.23.1
```

#### Faiss Dependencies
Expand Down
4 changes: 2 additions & 2 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

cmake_minimum_required(VERSION 3.17)
cmake_minimum_required(VERSION 3.23.1)

project(KNNPlugin_JNI)

Expand Down Expand Up @@ -95,7 +95,7 @@ if (${CONFIG_NMSLIB} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST}
set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES POSITION_INDEPENDENT_CODE ON)

if (WIN32)
if (NOT "${WIN32}" STREQUAL "")
# Use RUNTIME_OUTPUT_DIRECTORY, to build the target library (opensearchknn_nmslib) in the specified directory at runtime.
set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/release)
else()
Expand Down
2 changes: 1 addition & 1 deletion jni/external/faiss
Submodule faiss updated 665 files
6 changes: 6 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ namespace knn_jni {
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ);

// Execute a query against the index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);

Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndex_WithFilter
* Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
Expand Down
163 changes: 153 additions & 10 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
#include "faiss/IndexHNSW.h"
#include "faiss/IndexIVFFlat.h"
#include "faiss/MetaIndexes.h"
#include "faiss/Index.h"
#include "faiss/impl/IDSelector.h"

#include <algorithm>
#include <jni.h>
#include <string>
#include <vector>

// Defines type of IDSelector
enum FilterIdsSelectorType{
BITMAP, BATCH
};

// Translate space type to faiss metric
faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType);
Expand All @@ -33,7 +39,19 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
const std::unordered_map<std::string, jobject>& parametersCpp, faiss::Index * index);

// Train an index with data provided
void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x);
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Create the SearchParams based on the Index Type
std::unique_ptr<faiss::SearchParameters> buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector);

// Helps to choose the right FilterIdsSelectorType for Faiss
FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

// Concerts the FilterIds to BitMap
void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector);

void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) {
Expand Down Expand Up @@ -181,12 +199,17 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ) {

if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}

auto *indexReader = reinterpret_cast<faiss::Index*>(indexPointerJ);
auto *indexReader = reinterpret_cast<faiss::IndexIDMap *>(indexPointerJ);

if (indexReader == nullptr) {
throw std::runtime_error("Invalid pointer to index");
Expand All @@ -195,14 +218,50 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU
// The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from
// the query point
std::vector<float> dis(kJ);
std::vector<faiss::Index::idx_t> ids(kJ);
std::vector<faiss::idx_t> ids(kJ);
float* rawQueryvector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr);

try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data());
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
throw;
// create the filterSearch params if the filterIdsJ is not a null pointer
if(filterIdsJ != nullptr) {
int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = env->GetArrayLength(filterIdsJ);
std::unique_ptr<faiss::IDSelector> idSelector;
FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength);
// start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data
// during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local
// vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers.
// To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/
std::vector<faiss::idx_t> convertedIds;
std::vector<uint8_t> bitmap;
// Choose a selector which suits best
if(idSelectorType == BATCH) {
convertedIds.resize(filterIdsLength);
convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data());
idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data()));
} else {
int maxIdValue = filteredIdsArray[filterIdsLength - 1];
// >> 3 is equivalent to value / 8
const int bitsetArraySize = (maxIdValue >> 3) + 1;
bitmap.resize(bitsetArraySize, 0);
buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data());
idSelector.reset(new faiss::IDSelectorBitmap(filterIdsLength, bitmap.data()));
}
std::unique_ptr<faiss::SearchParameters> searchParameters = buildSearchParams(indexReader, idSelector.get());
try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters.get());
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
throw;
}
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
} else {
try {
std::cout << "Doing query" << std::endl;
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data());
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
throw;
}
}
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);

Expand All @@ -227,6 +286,33 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU
return results;
}

/**
* Based on the type of the index reader we need to return the SearchParameters. The way we do this by dynamically
* casting the IndexReader.
* @param indexReader
* @param idSelector
* @return SearchParameters
*/
std::unique_ptr<faiss::SearchParameters> buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector) {
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader) {
// we need to make this variable unique_ptr so that the scope can be shared with caller function.
std::unique_ptr<faiss::SearchParametersHNSW> hnswParams(new faiss::SearchParametersHNSW);
hnswParams->sel = idSelector;
return hnswParams;
}

auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
auto ivfFlatReader = dynamic_cast<const faiss::IndexIVFFlat*>(indexReader->index);
if(ivfReader || ivfFlatReader) {
// we need to make this variable unique_ptr so that the scope can be shared with caller function.
std::unique_ptr<faiss::SearchParametersIVF> ivfParams(new faiss::SearchParametersIVF);
ivfParams->sel = idSelector;
return ivfParams;
}
throw std::runtime_error("Invalid Index Type supported for Filtered Search on Faiss");
}

void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
Expand Down Expand Up @@ -344,7 +430,7 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
}
}

void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x) {
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexIVF*>(index)) {
if (indexIvf->quantizer_trains_alone == 2) {
InternalTrainIndex(indexIvf->quantizer, n, x);
Expand All @@ -356,3 +442,60 @@ void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float
index->train(n, x);
}
}

/**
* This function takes a call on what ID Selector to use:
* https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap
*
* class storage lookup construction(Opensearch + Faiss)
* IDSelectorArray O(k) O(k) O(2k)
* IDSelectorBatch O(k) O(1) O(2k)
* IDSelectorBitmap O(n/8) O(1) O(k) -> n is the max value of id in the index
*
* TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts:
* an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to
* 7.5M vectors.
* Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation
* M = 16
* Dimension = 128
* (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB
* Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be
* 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease.
*
* With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids
* So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size
* as factor. We need to improve on this.
*
* TODO: Best way is to implement a SparseBitSet in C++. This can be done by extending the IDSelector Interface of Faiss.
*
* @param filterIds
* @param filterIdsLength
* @return std::string
*/
FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength) {
int maxIdValue = filterIds[filterIdsLength - 1];
if(filterIdsLength * sizeof(faiss::idx_t) * 8 <= maxIdValue ) {
return BATCH;
}
return BITMAP;
}

void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) {
for (int i = 0; i < filterIdsLength; i++) {
convertedFilterIds[i] = filterIds[i];
}
}

void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) {
/**
* Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected
* iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1.
*/
for(int i = 0 ; i < filterIdsLength ; i ++) {
int value = filterIds[i];
// / , % are expensive operation. Hence, using BitShift operation as they are fast.
int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8
// (value & 7) equivalent to value % 8
bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7));
}
}
12 changes: 12 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
return nullptr;
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ) {

try {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;

}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ)
{
try {
Expand Down
33 changes: 33 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

package org.opensearch.knn.index.query;

import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
Expand All @@ -25,13 +30,25 @@ public class KNNQuery extends Query {
private final int k;
private final String indexName;

@Getter
@Setter
private Query filterQuery;

public KNNQuery(String field, float[] queryVector, int k, String indexName) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
}

public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
this.filterQuery = filterQuery;
}

public String getField() {
return this.field;
}
Expand Down Expand Up @@ -61,9 +78,25 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true");
}
final Weight filterWeight = getFilterWeight(searcher);
if (filterWeight != null) {
return new KNNWeight(this, boost, filterWeight);
}
return new KNNWeight(this, boost);
}

private Weight getFilterWeight(IndexSearcher searcher) throws IOException {
if (this.getFilterQuery() != null) {
// Run the filter query
final BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.getFilterQuery(), BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(this.getField()), BooleanClause.Occur.FILTER)
.build();
final Query rewritten = searcher.rewrite(booleanQuery);
return searcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}
return null;
}

@Override
public void visit(QueryVisitor visitor) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ protected Query doToQuery(QueryShardContext context) {
);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) {
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null && knnEngine != KNNEngine.FAISS) {
throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
}

Expand Down
Loading

0 comments on commit 119b8d6

Please sign in to comment.