diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 4a6f360b5..d8ef9e413 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -56,11 +56,11 @@ In addition to this, the plugin has been tested with JDK 17, and this JDK versio #### CMake -The plugin requires that cmake >= 3.17.2 is installed in order to build the JNI libraries. +The plugin requires that cmake >= 3.23.1 is installed in order to build the JNI libraries. One easy way to install on mac or linux is to use pip: ```bash -pip install cmake==3.17.2 +pip install cmake==3.23.1 ``` #### Faiss Dependencies diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 668ce684d..29a844ee0 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -cmake_minimum_required(VERSION 3.17) +cmake_minimum_required(VERSION 3.23.1) project(KNNPlugin_JNI) @@ -95,7 +95,7 @@ if (${CONFIG_NMSLIB} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES SUFFIX ${LIB_EXT}) set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES POSITION_INDEPENDENT_CODE ON) - if (WIN32) + if (NOT "${WIN32}" STREQUAL "") # Use RUNTIME_OUTPUT_DIRECTORY, to build the target library (opensearchknn_nmslib) in the specified directory at runtime. set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/release) else() diff --git a/jni/external/faiss b/jni/external/faiss index 88eabe97f..3219e3d12 160000 --- a/jni/external/faiss +++ b/jni/external/faiss @@ -1 +1 @@ -Subproject commit 88eabe97f96d0c0964dfa075f74373c64d46da80 +Subproject commit 3219e3d12e6fc36dfdfe17d4cf238ef70bf89568 diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 6c8a86143..284214631 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -40,6 +40,12 @@ namespace knn_jni { jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ); + // Execute a query against the index located in memory at indexPointerJ along with Filters + // + // Return an array of KNNQueryResults + jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ); + // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 1ab6c5681..a25264335 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -50,6 +50,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex (JNIEnv *, jclass, jlong, jfloatArray, jint); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: queryIndex_WithFilter + * Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult; + */ +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter + (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray); + /* * Class: org_opensearch_knn_jni_FaissService * Method: free diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index e0fcc822b..2e626f9c6 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -18,12 +18,18 @@ #include "faiss/IndexHNSW.h" #include "faiss/IndexIVFFlat.h" #include "faiss/MetaIndexes.h" +#include "faiss/Index.h" +#include "faiss/impl/IDSelector.h" #include #include #include #include +// Defines type of IDSelector +enum FilterIdsSelectorType{ + BITMAP, BATCH +}; // Translate space type to faiss metric faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType); @@ -33,7 +39,19 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, const std::unordered_map& parametersCpp, faiss::Index * index); // Train an index with data provided -void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x); +void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); + +// Create the SearchParams based on the Index Type +std::unique_ptr buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector); + +// Helps to choose the right FilterIdsSelectorType for Faiss +FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength); + +// Converts the int FilterIds to Faiss ids type array. +void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); + +// Concerts the FilterIds to BitMap +void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector); void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) { @@ -181,12 +199,17 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ) { + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr); +} + +jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); } - auto *indexReader = reinterpret_cast(indexPointerJ); + auto *indexReader = reinterpret_cast(indexPointerJ); if (indexReader == nullptr) { throw std::runtime_error("Invalid pointer to index"); @@ -195,14 +218,50 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU // The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from // the query point std::vector dis(kJ); - std::vector ids(kJ); + std::vector ids(kJ); float* rawQueryvector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr); - - try { - indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data()); - } catch (...) { - jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); - throw; + // create the filterSearch params if the filterIdsJ is not a null pointer + if(filterIdsJ != nullptr) { + int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr); + int filterIdsLength = env->GetArrayLength(filterIdsJ); + std::unique_ptr idSelector; + FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength); + // start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data + // during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local + // vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers. + // To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/ + std::vector convertedIds; + std::vector bitmap; + // Choose a selector which suits best + if(idSelectorType == BATCH) { + convertedIds.resize(filterIdsLength); + convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data()); + idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data())); + } else { + int maxIdValue = filteredIdsArray[filterIdsLength - 1]; + // >> 3 is equivalent to value / 8 + const int bitsetArraySize = (maxIdValue >> 3) + 1; + bitmap.resize(bitsetArraySize, 0); + buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data()); + idSelector.reset(new faiss::IDSelectorBitmap(filterIdsLength, bitmap.data())); + } + std::unique_ptr searchParameters = buildSearchParams(indexReader, idSelector.get()); + try { + indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters.get()); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + throw; + } + jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + } else { + try { + std::cout << "Doing query" << std::endl; + indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data()); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + throw; + } } jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); @@ -227,6 +286,33 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU return results; } +/** + * Based on the type of the index reader we need to return the SearchParameters. The way we do this by dynamically + * casting the IndexReader. + * @param indexReader + * @param idSelector + * @return SearchParameters + */ +std::unique_ptr buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector) { + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader) { + // we need to make this variable unique_ptr so that the scope can be shared with caller function. + std::unique_ptr hnswParams(new faiss::SearchParametersHNSW); + hnswParams->sel = idSelector; + return hnswParams; + } + + auto ivfReader = dynamic_cast(indexReader->index); + auto ivfFlatReader = dynamic_cast(indexReader->index); + if(ivfReader || ivfFlatReader) { + // we need to make this variable unique_ptr so that the scope can be shared with caller function. + std::unique_ptr ivfParams(new faiss::SearchParametersIVF); + ivfParams->sel = idSelector; + return ivfParams; + } + throw std::runtime_error("Invalid Index Type supported for Filtered Search on Faiss"); +} + void knn_jni::faiss_wrapper::Free(jlong indexPointer) { auto *indexWrapper = reinterpret_cast(indexPointer); delete indexWrapper; @@ -344,7 +430,7 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, } } -void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x) { +void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { if (auto * indexIvf = dynamic_cast(index)) { if (indexIvf->quantizer_trains_alone == 2) { InternalTrainIndex(indexIvf->quantizer, n, x); @@ -356,3 +442,60 @@ void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float index->train(n, x); } } + +/** + * This function takes a call on what ID Selector to use: + * https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap + * + * class storage lookup construction(Opensearch + Faiss) + * IDSelectorArray O(k) O(k) O(2k) + * IDSelectorBatch O(k) O(1) O(2k) + * IDSelectorBitmap O(n/8) O(1) O(k) -> n is the max value of id in the index + * + * TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts: + * an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to + * 7.5M vectors. + * Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation + * M = 16 + * Dimension = 128 + * (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB + * Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be + * 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease. + * + * With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids + * So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size + * as factor. We need to improve on this. + * + * TODO: Best way is to implement a SparseBitSet in C++. This can be done by extending the IDSelector Interface of Faiss. + * + * @param filterIds + * @param filterIdsLength + * @return std::string + */ +FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength) { + int maxIdValue = filterIds[filterIdsLength - 1]; + if(filterIdsLength * sizeof(faiss::idx_t) * 8 <= maxIdValue ) { + return BATCH; + } + return BITMAP; +} + +void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) { + for (int i = 0; i < filterIdsLength; i++) { + convertedFilterIds[i] = filterIds[i]; + } +} + +void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) { + /** + * Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected + * iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1. + */ + for(int i = 0 ; i < filterIdsLength ; i ++) { + int value = filterIds[i]; + // / , % are expensive operation. Hence, using BitShift operation as they are fast. + int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8 + // (value & 7) equivalent to value % 8 + bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7)); + } +} diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 543ce8ec4..1b79d9114 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -88,6 +88,18 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd return nullptr; } +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter + (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ) { + + try { + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; + +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ) { try { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 9bf38008b..5ac207c43 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -5,6 +5,11 @@ package org.opensearch.knn.index.query; +import lombok.Getter; +import lombok.Setter; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -25,6 +30,10 @@ public class KNNQuery extends Query { private final int k; private final String indexName; + @Getter + @Setter + private Query filterQuery; + public KNNQuery(String field, float[] queryVector, int k, String indexName) { this.field = field; this.queryVector = queryVector; @@ -32,6 +41,14 @@ public KNNQuery(String field, float[] queryVector, int k, String indexName) { this.indexName = indexName; } + public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) { + this.field = field; + this.queryVector = queryVector; + this.k = k; + this.indexName = indexName; + this.filterQuery = filterQuery; + } + public String getField() { return this.field; } @@ -61,9 +78,25 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo if (!KNNSettings.isKNNPluginEnabled()) { throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true"); } + final Weight filterWeight = getFilterWeight(searcher); + if (filterWeight != null) { + return new KNNWeight(this, boost, filterWeight); + } return new KNNWeight(this, boost); } + private Weight getFilterWeight(IndexSearcher searcher) throws IOException { + if (this.getFilterQuery() != null) { + // Run the filter query + final BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.getFilterQuery(), BooleanClause.Occur.FILTER) + .add(new FieldExistsQuery(this.getField()), BooleanClause.Occur.FILTER) + .build(); + final Query rewritten = searcher.rewrite(booleanQuery); + return searcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); + } + return null; + } + @Override public void visit(QueryVisitor visitor) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 3de0e69d4..5efa89fd0 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -283,7 +283,7 @@ protected Query doToQuery(QueryShardContext context) { ); } - if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) { + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null && knnEngine != KNNEngine.FAISS) { throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine)); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 188bbc150..20c456c4a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -59,27 +59,53 @@ public static Query create(CreateQueryRequest createQueryRequest) { final String fieldName = createQueryRequest.getFieldName(); final int k = createQueryRequest.getK(); final float[] vector = createQueryRequest.getVector(); + final Query filterQuery = getFilterQuery(createQueryRequest); if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { + log.debug( + String.format( + "Creating custom k-NN query with filters for index: %s \"\", field: %s \"\", " + "k: %d", + indexName, + fieldName, + k + ) + ); + return new KNNQuery(fieldName, vector, k, indexName, filterQuery); + } log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); return new KNNQuery(fieldName, vector, k, indexName); } + if (filterQuery != null) { + log.debug( + String.format("Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k) + ); + return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery); + } + log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + return new KnnFloatVectorQuery(fieldName, vector, k); + } + + private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { if (createQueryRequest.getFilter().isPresent()) { final QueryShardContext queryShardContext = createQueryRequest.getContext() .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); log.debug( - String.format("Creating Lucene k-NN query with filter for index [%s], field [%s] and k [%d]", indexName, fieldName, k) + String.format( + "Creating k-NN query with filter for index [%s], field [%s] and k [%d]", + createQueryRequest.getIndexName(), + createQueryRequest.fieldName, + createQueryRequest.k + ) ); try { - final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); - return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery); + return createQueryRequest.getFilter().get().toQuery(queryShardContext); } catch (IOException e) { throw new RuntimeException("Cannot create knn query with filter", e); } } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnFloatVectorQuery(fieldName, vector, k); + return null; } /** diff --git a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java index 0005212bf..3e5c8fff6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java @@ -56,4 +56,37 @@ public float score() { public int docID() { return docIdsIter.docID(); } + + /** + * Returns the Empty Scorer implementation. We use this scorer to short circuit the actual search when it is not + * required. + * @param knnWeight {@link KNNWeight} + * @return {@link KNNScorer} + */ + public static Scorer emptyScorer(KNNWeight knnWeight) { + return new Scorer(knnWeight) { + private final DocIdSetIterator docIdsIter = DocIdSetIterator.empty(); + + @Override + public DocIdSetIterator iterator() { + return docIdsIter; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return 0; + } + + @Override + public float score() throws IOException { + assert docID() != DocIdSetIterator.NO_MORE_DOCS; + return 0; + } + + @Override + public int docID() { + return docIdsIter.docID(); + } + }; + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 716aed412..050c36881 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -5,6 +5,12 @@ package org.opensearch.knn.index.query; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.FilteredDocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.jni.JNIService; @@ -13,8 +19,6 @@ import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.util.KNNEngine; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -49,20 +53,30 @@ /** * Calculate query weights and build query scorers. */ +@Log4j2 public class KNNWeight extends Weight { - private static Logger logger = LogManager.getLogger(KNNWeight.class); private static ModelDao modelDao; private final KNNQuery knnQuery; private final float boost; - private NativeMemoryCacheManager nativeMemoryCacheManager; + private final NativeMemoryCacheManager nativeMemoryCacheManager; + private final Weight filterWeight; public KNNWeight(KNNQuery query, float boost) { super(query); this.knnQuery = query; this.boost = boost; this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); + this.filterWeight = null; + } + + public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { + super(query); + this.knnQuery = query; + this.boost = boost; + this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); + this.filterWeight = filterWeight; } public static void initialize(ModelDao modelDao) { @@ -76,13 +90,20 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { + final int[] filterIdsArray = getFilterIdsArray(context); + // We don't need to go to JNI layer if no documents are found which satisfy the filters + // We should give this condition a deeper look that where it should be placed. For now I feel this is a good + // place, + if (filterWeight != null && filterIdsArray.length == 0) { + return KNNScorer.emptyScorer(this); + } SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); if (fieldInfo == null) { - logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + log.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); return null; } @@ -121,7 +142,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { .collect(Collectors.toList()); if (engineFiles.isEmpty()) { - logger.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); + log.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); return null; } @@ -148,7 +169,6 @@ public Scorer scorer(LeafReaderContext context) throws IOException { // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); - try { if (indexAllocation.isClosed()) { throw new RuntimeException("Index has already been closed"); @@ -158,8 +178,10 @@ public Scorer scorer(LeafReaderContext context) throws IOException { indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), - knnEngine.getName() + knnEngine.getName(), + filterIdsArray ); + } catch (Exception e) { GRAPH_QUERY_ERRORS.increment(); throw new RuntimeException(e); @@ -174,7 +196,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * neighbors we are inverting the scores. */ if (results.length == 0) { - logger.debug("[KNN] Query yielded 0 results"); + log.debug("[KNN] Query yielded 0 results"); return null; } @@ -191,6 +213,59 @@ public Scorer scorer(LeafReaderContext context) throws IOException { return new KNNScorer(this, docIdSetIter, scores, boost); } + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException { + final Bits liveDocs = ctx.reader().getLiveDocs(); + final int maxDoc = ctx.reader().maxDoc(); + + final Scorer scorer = filterWeight.scorer(ctx); + if (scorer == null) { + return new FixedBitSet(0); + } + + final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc); + // TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the + // distance for K vectors. This can avoid calls to native layer and save some latency. + final int cost = acceptDocs.cardinality(); + log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost); + return acceptDocs; + } + + private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException { + if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) { + // If we already have a BitSet and no deletions, reuse the BitSet + return ((BitSetIterator) filteredDocIdsIterator).getBitSet(); + } + // Create a new BitSet from matching and live docs + FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) { + @Override + protected boolean match(int doc) { + return liveDocs == null || liveDocs.get(doc); + } + }; + return BitSet.of(filterIterator, maxDoc); + } + + private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException { + if (filterWeight == null) { + return new int[0]; + } + final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight); + final int[] filteredIds = new int[filteredDocsBitSet.cardinality()]; + int filteredIdsIndex = 0; + int docId = 0; + while (true) { + docId = filteredDocsBitSet.nextSetBit(docId); + if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) { + break; + } + log.debug("Docs in filtered docs id set is : {}", docId); + filteredIds[filteredIdsIndex] = docId; + filteredIdsIndex++; + docId++; + } + return filteredIds; + } + @Override public boolean isCacheable(LeafReaderContext context) { return true; diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index fe28de43e..776ea5366 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -32,6 +32,7 @@ public enum KNNEngine implements KNNLibrary { public static final KNNEngine DEFAULT = NMSLIB; private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); + private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, @@ -105,6 +106,10 @@ public static Set getEnginesThatCreateCustomSegmentFiles() { return CUSTOM_SEGMENT_FILE_ENGINES; } + public static Set getEnginesThatSupportsFilters() { + return ENGINES_SUPPORTING_FILTERS; + } + /** * Return number of max allowed dimensions per single vector based on the knn engine * @param knnEngine knn engine to check max dimensions value diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index b990ce33b..ba1d3ac84 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -122,6 +122,6 @@ public interface KNNLibrary { * @return list of file extensions that will be read/write with mmap */ default List mmapFileExtensions() { - return Collections.EMPTY_LIST; + return Collections.emptyList(); } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index f1d869bd2..5dce15d6e 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -24,7 +24,7 @@ * * In order to compile C++ header file, run: * javac -h jni/include src/main/java/org/opensearch/knn/jni/FaissService.java - * src/main/java/org/opensearch/knn/index/KNNQueryResult.java + * src/main/java/org/opensearch/knn/index/query/KNNQueryResult.java * src/main/java/org/opensearch/knn/common/KNNConstants.java */ class FaissService { @@ -83,6 +83,8 @@ public static native void createIndexFromTemplate( */ public static native KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k); + public static native KNNQueryResult[] queryIndexWithFilter(long indexPointer, float[] queryVector, int k, int[] filterIds); + /** * Free native memory pointer */ diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index e32880fff..f45fb0c73 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -11,6 +11,7 @@ package org.opensearch.knn.jni; +import org.apache.commons.lang.ArrayUtils; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -94,20 +95,27 @@ public static long loadIndex(String indexPath, Map parameters, S * Query an index * * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param engineName name of engine to query index + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param engineName name of engine to query index + * @param filteredIds array of ints on which should be used for search. * @return KNNQueryResult array of k neighbors */ - public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName) { + public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds) { if (KNNEngine.NMSLIB.getName().equals(engineName)) { return NmslibService.queryIndex(indexPointer, queryVector, k); } if (KNNEngine.FAISS.getName().equals(engineName)) { + // This code assumes that if filteredIds == null / filteredIds.length == 0 if filter is specified then empty + // k-NN results are already returned. Otherwise, it's a filter case and we need to run search with + // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine + // normally. + if (ArrayUtils.isNotEmpty(filteredIds)) { + return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds); + } return FaissService.queryIndex(indexPointer, queryVector, k); } - throw new IllegalArgumentException("QueryIndex not supported for provided engine"); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java index 09f2daab2..8b1f0676b 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java @@ -12,8 +12,8 @@ package org.opensearch.knn.plugin.rest; import com.google.common.collect.ImmutableList; +import org.apache.commons.lang.StringUtils; import org.opensearch.client.node.NodeClient; -import org.opensearch.common.Strings; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.GetModelAction; import org.opensearch.knn.plugin.transport.GetModelRequest; @@ -50,7 +50,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String modelID = restRequest.param(MODEL_ID); - if (!Strings.hasText(modelID)) { + if (StringUtils.isBlank(modelID)) { throw new IllegalArgumentException("model ID cannot be empty"); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java index 3536b40fe..9049a83db 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java @@ -6,12 +6,12 @@ package org.opensearch.knn.plugin.rest; import lombok.AllArgsConstructor; +import org.apache.commons.lang.StringUtils; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.KNNStatsAction; import org.opensearch.knn.plugin.transport.KNNStatsRequest; import com.google.common.collect.ImmutableList; import org.opensearch.client.node.NodeClient; -import org.opensearch.common.Strings; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestActions; @@ -83,7 +83,7 @@ private KNNStatsRequest getRequest(RestRequest request) { // parse the nodes the user wants to query String[] nodeIdsArr = null; String nodesIdsStr = request.param("nodeId"); - if (!Strings.isEmpty(nodesIdsStr)) { + if (StringUtils.isNotEmpty(nodesIdsStr)) { nodeIdsArr = nodesIdsStr.split(","); } @@ -93,7 +93,7 @@ private KNNStatsRequest getRequest(RestRequest request) { // parse the stats the customer wants to see Set statsSet = null; String statsStr = request.param("stat"); - if (!Strings.isEmpty(statsStr)) { + if (StringUtils.isNotEmpty(statsStr)) { statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java index f457d6782..a31c2f297 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.rest; +import org.apache.commons.lang.StringUtils; import org.opensearch.knn.common.exception.KNNInvalidIndicesException; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.KNNWarmupAction; @@ -15,7 +16,6 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.index.Index; import org.opensearch.rest.BaseRestHandler; @@ -81,7 +81,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } private KNNWarmupRequest createKNNWarmupRequest(RestRequest request) { - String[] indexNames = Strings.splitStringByCommaToArray(request.param("index")); + String[] indexNames = StringUtils.split(request.param("index"), ","); Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), indexNames); List invalidIndexNames = new ArrayList<>(); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java index 792ccc543..fee82adb5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java @@ -11,9 +11,9 @@ package org.opensearch.knn.plugin.transport; +import org.apache.commons.lang.StringUtils; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -43,7 +43,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { - if (Strings.hasText(modelID)) { + if (StringUtils.isNotBlank(modelID)) { return null; } return addValidationError("Model id cannot be empty ", null); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 774029c58..b692e4a86 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -11,6 +11,7 @@ package org.opensearch.knn.plugin.transport; +import org.apache.commons.lang.StringUtils; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.search.SearchRequest; @@ -19,7 +20,6 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.Strings; import org.opensearch.common.ValidationException; import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.inject.Inject; @@ -107,7 +107,7 @@ protected DiscoveryNode selectNode(String preferredNode, TrainingJobRouteDecisio if (response.getTrainingJobCount() < 1) { selectedNode = currentNode; // Return right away if the user didnt pass a preferred node or this is the preferred node - if (Strings.isEmpty(preferredNode) || selectedNode.getId().equals(preferredNode)) { + if (StringUtils.isEmpty(preferredNode) || selectedNode.getId().equals(preferredNode)) { return selectedNode; } } diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index e01688ddc..27a1c6025 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -11,9 +11,9 @@ package org.opensearch.knn.training; +import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.common.Strings; import org.opensearch.common.UUIDs; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; @@ -68,7 +68,7 @@ public TrainingJob( String description ) { // Generate random base64 string if one is not provided - this.modelId = Strings.hasText(modelId) ? modelId : UUIDs.randomBase64UUID(); + this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); this.knnMethodContext = Objects.requireNonNull(knnMethodContext, "MethodContext cannot be null."); this.nativeMemoryCacheManager = Objects.requireNonNull(nativeMemoryCacheManager, "NativeMemoryCacheManager cannot be null."); this.trainingDataEntryContext = Objects.requireNonNull(trainingDataEntryContext, "TrainingDataEntryContext cannot be null."); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index 1e8c255f2..ad0cd37a0 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -333,7 +333,7 @@ public static void assertLoadableByEngine( ); int k = 2; float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName()); + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine.getName()); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 8d94b1afb..ce08e0350 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -74,7 +74,7 @@ public void testIndexLoadStrategy_load() throws IOException { // Confirm that the file was loaded by querying float[] query = new float[dimension]; Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName()); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null); assertTrue(results.length > 0); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 444f763a6..ec8675ab0 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -133,7 +133,8 @@ public void testQueryScoreForFaissWithModel() throws IOException { SpaceType spaceType = SpaceType.L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString())).thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); @@ -272,7 +273,8 @@ public void testShardWithoutFiles() { @SneakyThrows public void testEmptyQueryResults() { final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {}; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString())).thenReturn(knnQueryResults); + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + .thenReturn(knnQueryResults); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); @@ -316,7 +318,8 @@ private void testQueryScore( final Set segmentFiles, final Map fileAttributes ) throws IOException { - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString())).thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index f4971e6fd..39f7384ad 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -583,12 +583,12 @@ public void testLoadIndex_faiss_valid() throws IOException { } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid-engine")); + expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null)); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName())); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null)); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { @@ -611,7 +611,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName())); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null)); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -637,7 +637,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName()); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null); assertEquals(k, results.length); } } @@ -645,7 +645,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { @@ -664,7 +664,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null)); } public void testQueryIndex_faiss_valid() throws IOException { @@ -693,7 +693,7 @@ public void testQueryIndex_faiss_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null); assertEquals(k, results.length); } }