From e7a5a42ebcceacbb2745c18a933bdde5de028ffd Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Tue, 2 Jan 2024 15:17:39 -0800 Subject: [PATCH] Multi vector support for Faiss HNSW Pass parentId filter to faiss HNSW search method so that documents will be deduped on its parentId and k results will be returned for documents with nested fields. Signed-off-by: Heemin Kim --- .../opensearch/knn/index/query/KNNQuery.java | 9 +++-- .../knn/index/query/KNNQueryFactory.java | 6 ++-- .../opensearch/knn/index/query/KNNWeight.java | 36 ++++++++++++------- .../org/opensearch/knn/jni/JNIService.java | 7 ++-- .../knn/index/codec/KNNCodecTestCase.java | 11 +++--- .../knn/index/codec/KNNCodecTestUtil.java | 2 +- .../memory/NativeMemoryLoadStrategyTests.java | 2 +- .../knn/index/query/KNNWeightTests.java | 31 ++++++++-------- .../opensearch/knn/jni/JNIServiceTests.java | 16 ++++----- 9 files changed, 70 insertions(+), 50 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 5ac207c43..33a570284 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -15,6 +15,7 @@ import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; import org.opensearch.knn.index.KNNSettings; import java.io.IOException; @@ -33,20 +34,24 @@ public class KNNQuery extends Query { @Getter @Setter private Query filterQuery; + @Getter + private BitSetProducer bitSetProducer; - public KNNQuery(String field, float[] queryVector, int k, String indexName) { + public KNNQuery(String field, float[] queryVector, int k, String indexName, final BitSetProducer bitSetProducer) { this.field = field; this.queryVector = queryVector; this.k = k; this.indexName = indexName; + this.bitSetProducer = bitSetProducer; } - public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) { + public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery, BitSetProducer bitSetProducer) { this.field = field; this.queryVector = queryVector; this.k = k; this.indexName = indexName; this.filterQuery = filterQuery; + this.bitSetProducer = bitSetProducer; } public String getField() { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index c073450af..741f917a7 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -80,17 +80,17 @@ public static Query create(CreateQueryRequest createQueryRequest) { final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); + BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter(); if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k); - return new KNNQuery(fieldName, vector, k, indexName, filterQuery); + return new KNNQuery(fieldName, vector, k, indexName, filterQuery, parentFilter); } log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KNNQuery(fieldName, vector, k, indexName); + return new KNNQuery(fieldName, vector, k, indexName, parentFilter); } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter(); if (VectorDataType.BYTE == vectorDataType) { return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter); } else if (VectorDataType.FLOAT == vectorDataType) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index c166be3c2..46540481a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.FilteredDocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; @@ -119,7 +120,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { if (filterWeight != null && canDoExactSearch(filterIdsArray.length)) { docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray)); } else { - Map annResults = doANNSearch(context, filterIdsArray); + Map annResults = doANNSearch(context, filterIdsArray, knnQuery.getBitSetProducer()); if (annResults == null) { return null; } @@ -172,23 +173,33 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept if (filterWeight == null) { return new int[0]; } - final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight); - final int[] filteredIds = new int[filteredDocsBitSet.cardinality()]; - int filteredIdsIndex = 0; + return bitSetToIntArray(getFilteredDocsBitSet(context, this.filterWeight)); + } + + private int[] getParentIdsArray(final LeafReaderContext context, final BitSetProducer parentFilter) throws IOException { + if (parentFilter == null) { + return null; + } + return bitSetToIntArray(parentFilter.getBitSet(context)); + } + + private int[] bitSetToIntArray(final BitSet bitSet) { + final int[] intArray = new int[bitSet.cardinality()]; + int index = 0; int docId = 0; - while (docId < filteredDocsBitSet.length()) { - docId = filteredDocsBitSet.nextSetBit(docId); + while (docId < bitSet.length()) { + docId = bitSet.nextSetBit(docId); if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) { break; } - filteredIds[filteredIdsIndex] = docId; - filteredIdsIndex++; + intArray[index] = docId; + index++; docId++; } - return filteredIds; + return intArray; } - private Map doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException { + private Map doANNSearch(final LeafReaderContext context, final int[] filterIdsArray, final BitSetProducer parentFilter) throws IOException { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); @@ -265,13 +276,14 @@ private Map doANNSearch(final LeafReaderContext context, final i if (indexAllocation.isClosed()) { throw new RuntimeException("Index has already been closed"); } - + int[] parentIds = getParentIdsArray(context, parentFilter); results = JNIService.queryIndex( indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName(), - filterIdsArray + filterIdsArray, + parentIds ); } catch (Exception e) { diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 42d12e984..8af568dbf 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -12,6 +12,7 @@ package org.opensearch.knn.jni; import org.apache.commons.lang.ArrayUtils; +import org.apache.lucene.util.BitSet; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -101,7 +102,7 @@ public static long loadIndex(String indexPath, Map parameters, S * @param filteredIds array of ints on which should be used for search. * @return KNNQueryResult array of k neighbors */ - public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds) { + public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds, int[] parentIds) { if (KNNEngine.NMSLIB.getName().equals(engineName)) { return NmslibService.queryIndex(indexPointer, queryVector, k); } @@ -112,9 +113,9 @@ public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine // normally. if (ArrayUtils.isNotEmpty(filteredIds)) { - return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, null); + return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, parentIds); } - return FaissService.queryIndex(indexPointer, queryVector, k, null); + return FaissService.queryIndex(indexPointer, queryVector, k, parentIds); } throw new IllegalArgumentException("QueryIndex not supported for provided engine"); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 40309027d..5096a3dce 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.index.mapper.MapperService; @@ -162,14 +163,14 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { // query to verify distance for each of the field IndexSearcher searcher = new IndexSearcher(reader); - float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"), 10).scoreDocs[0].score; - float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy"), 10).scoreDocs[0].score; + float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null), 10).scoreDocs[0].score; + float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy", (BitSetProducer) null), 10).scoreDocs[0].score; assertEquals(1.0f / (1 + 25), score, 0.01f); assertEquals(1.0f / (1 + 169), score1, 0.01f); // query to determine the hits - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"))); - assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null))); + assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy", (BitSetProducer) null))); reader.close(); dir.close(); @@ -254,7 +255,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); float[] query = { 10.0f, 10.0f, 10.0f }; IndexSearcher searcher = new IndexSearcher(reader); - TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy"), 10); + TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy", (BitSetProducer) null), 10); assertEquals(3, topDocs.scoreDocs[0].doc); assertEquals(2, topDocs.scoreDocs[1].doc); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index 1a9507e6a..23a0643ad 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -333,7 +333,7 @@ public static void assertLoadableByEngine( ); int k = 2; float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null); + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null, null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine.getName()); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index fb91266ab..798c64a17 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -74,7 +74,7 @@ public void testIndexLoadStrategy_load() throws IOException { // Confirm that the file was loaded by querying float[] query = new float[dimension]; Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null, null); assertTrue(results.length > 0); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index ba6675d3d..5bf8db242 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -22,6 +22,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; @@ -154,10 +155,10 @@ public void testQueryScoreForFaissWithModel() throws IOException { SpaceType spaceType = SpaceType.L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) .thenReturn(getKNNQueryResults()); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); ModelDao modelDao = mock(ModelDao.class); ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -221,7 +222,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { SpaceType spaceType = SpaceType.L2; final String modelId = "modelId"; - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); ModelDao modelDao = mock(ModelDao.class); ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -253,7 +254,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { @SneakyThrows public void testShardWithoutFiles() { - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -294,10 +295,10 @@ public void testShardWithoutFiles() { @SneakyThrows public void testEmptyQueryResults() { final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {}; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) .thenReturn(knnQueryResults); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -338,7 +339,7 @@ public void testEmptyQueryResults() { public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { int k = 3; final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any())) .thenReturn(getFilteredKNNQueryResults()); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -351,7 +352,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { when(liveDocsBits.length()).thenReturn(1000); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -395,7 +396,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); - jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))); + jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any())); final List actualDocIds = new ArrayList<>(); final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); @@ -415,7 +416,7 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -465,7 +466,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -534,7 +535,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length)); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); @@ -577,7 +578,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); when(filterScorer.iterator()).thenReturn(DocIdSetIterator.empty()); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); final FieldInfos fieldInfos = mock(FieldInfos.class); @@ -598,10 +599,10 @@ private void testQueryScore( final Set segmentFiles, final Map fileAttributes ) throws IOException { - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) .thenReturn(getKNNQueryResults()); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 185f2953d..04e810cf6 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -590,12 +590,12 @@ public void testLoadIndex_faiss_valid() throws IOException { } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null)); + expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null, null)); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null, null)); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { @@ -618,7 +618,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null, null)); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -644,7 +644,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null, null); assertEquals(k, results.length); } } @@ -652,7 +652,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null, null)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { @@ -671,7 +671,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null, null)); } public void testQueryIndex_faiss_valid() throws IOException { @@ -700,13 +700,13 @@ public void testQueryIndex_faiss_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, null); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }, null); assertEquals(0, results.length); } }