From 009f5be93e4be7799ce6bceec229ba34aff7e8ed Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Tue, 2 Jan 2024 15:17:39 -0800 Subject: [PATCH] Multi vector support for Faiss HNSW Apply the parentId filter to the Faiss HNSW search method. This ensures that documents are deduplicated based on their parentId, and the method returns k results for documents with nested fields. Signed-off-by: Heemin Kim --- .../knn_extension/faiss/utils/BitSet.h | 12 +- jni/include/knn_extension/faiss/utils/Heap.h | 4 +- jni/src/knn_extension/faiss/utils/BitSet.cpp | 18 +-- .../opensearch/knn/index/query/KNNQuery.java | 9 +- .../knn/index/query/KNNQueryFactory.java | 9 +- .../opensearch/knn/index/query/KNNWeight.java | 55 +++++---- .../org/opensearch/knn/jni/JNIService.java | 13 ++- .../opensearch/knn/index/NestedSearchIT.java | 26 ++++- .../knn/index/codec/KNNCodecTestCase.java | 17 ++- .../knn/index/codec/KNNCodecTestUtil.java | 2 +- .../memory/NativeMemoryLoadStrategyTests.java | 2 +- .../knn/index/query/KNNQueryFactoryTests.java | 24 ++++ .../knn/index/query/KNNWeightTests.java | 110 +++++++++++++++--- .../opensearch/knn/jni/JNIServiceTests.java | 16 +-- 14 files changed, 239 insertions(+), 78 deletions(-) diff --git a/jni/include/knn_extension/faiss/utils/BitSet.h b/jni/include/knn_extension/faiss/utils/BitSet.h index 0b481d578d..deb97ecdfa 100644 --- a/jni/include/knn_extension/faiss/utils/BitSet.h +++ b/jni/include/knn_extension/faiss/utils/BitSet.h @@ -32,18 +32,18 @@ struct BitSet { * bitmap: 10001000 00000100 * * for next_set_bit call with 4 - * 1. it looks for bitmap[0] - * 2. bitmap[0] >> 4 + * 1. it looks for words[0] + * 2. words[0] >> 4 * 3. count trailing zero of the result from step 2 which is 3 * 4. return 4(current index) + 3(result from step 3) */ struct FixedBitSet : public BitSet { - // Length of bitmap - size_t numBits; + // The exact number of longs needed to hold numBits (<= bits.length) + size_t num_words; - // Pointer to an array of uint64_t + // Array of uint64_t holding the bits // Using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h - uint64_t* bitmap; + uint64_t* words; FixedBitSet(const int* int_array, const int length); idx_t next_set_bit(idx_t index) const; diff --git a/jni/include/knn_extension/faiss/utils/Heap.h b/jni/include/knn_extension/faiss/utils/Heap.h index 08d9823119..274626eb2f 100644 --- a/jni/include/knn_extension/faiss/utils/Heap.h +++ b/jni/include/knn_extension/faiss/utils/Heap.h @@ -152,7 +152,7 @@ inline void maxheap_push( std::unordered_map* parent_id_to_index, int64_t parent_id) { - assert(parent_id_to_index->find(parent_id) != parent_id_to_index->end() && "parent id should not exist in the binary heap"); + assert(parent_id_to_index->find(parent_id) == parent_id_to_index->end() && "parent id should not exist in the binary heap"); up_heap>( bh_val, @@ -189,7 +189,7 @@ inline void maxheap_replace_top( std::unordered_map* parent_id_to_index, int64_t parent_id) { - assert(parent_id_to_index->find(parent_id) != parent_id_to_index->end() && "parent id should not exist in the binary heap"); + assert(parent_id_to_index->find(parent_id) == parent_id_to_index->end() && "parent id should not exist in the binary heap"); parent_id_to_id->erase(bh_ids[0]); parent_id_to_index->erase(bh_ids[0]); diff --git a/jni/src/knn_extension/faiss/utils/BitSet.cpp b/jni/src/knn_extension/faiss/utils/BitSet.cpp index 90cd7d1f0d..19b774180a 100644 --- a/jni/src/knn_extension/faiss/utils/BitSet.cpp +++ b/jni/src/knn_extension/faiss/utils/BitSet.cpp @@ -10,25 +10,27 @@ FixedBitSet::FixedBitSet(const int* int_array, const int length){ assert(int_array && "int_array should not be null"); const int* maxValue = std::max_element(int_array, int_array + length); - this->numBits = (*maxValue >> 6) + 1; // div by 64 - this->bitmap = new uint64_t[this->numBits](); + this->num_words = (*maxValue >> 6) + 1; // div by 64 + this->words = new uint64_t[this->num_words](); for(int i = 0 ; i < length ; i ++) { int value = int_array[i]; - int bitsetArrayIndex = value >> 6; - this->bitmap[bitsetArrayIndex] |= 1ULL << (value & 63); // Equivalent of 1ULL << (value % 64) + int bitset_array_index = value >> 6; + this->words[bitset_array_index] |= 1ULL << (value & 63); // Equivalent of 1ULL << (value % 64) } } idx_t FixedBitSet::next_set_bit(idx_t index) const { idx_t i = index >> 6; // div by 64 - uint64_t word = this->bitmap[i] >> (index & 63); // Equivalent of bitmap[i] >> (index % 64) + uint64_t word = this->words[i] >> (index & 63); // Equivalent of words[i] >> (index % 64) + // word is non zero after right shift, it means, next set bit is in current word + // The index of set bit is "given index" + "trailing zero in the right shifted word" if (word != 0) { return index + __builtin_ctzll(word); } - while (++i < this->numBits) { - word = this->bitmap[i]; + while (++i < this->num_words) { + word = this->words[i]; if (word != 0) { return (i << 6) + __builtin_ctzll(word); } @@ -38,5 +40,5 @@ idx_t FixedBitSet::next_set_bit(idx_t index) const { } FixedBitSet::~FixedBitSet() { - delete this->bitmap; + delete this->words; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 5ac207c431..20db72a5e1 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -15,6 +15,7 @@ import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; import org.opensearch.knn.index.KNNSettings; import java.io.IOException; @@ -33,20 +34,24 @@ public class KNNQuery extends Query { @Getter @Setter private Query filterQuery; + @Getter + private BitSetProducer parentsFilter; - public KNNQuery(String field, float[] queryVector, int k, String indexName) { + public KNNQuery(String field, float[] queryVector, int k, String indexName, final BitSetProducer parentsFilter) { this.field = field; this.queryVector = queryVector; this.k = k; this.indexName = indexName; + this.parentsFilter = parentsFilter; } - public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) { + public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery, BitSetProducer parentsFilter) { this.field = field; this.queryVector = queryVector; this.k = k; this.indexName = indexName; this.filterQuery = filterQuery; + this.parentsFilter = parentsFilter; } public String getField() { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index c073450af4..65187dcd2e 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -80,17 +80,17 @@ public static Query create(CreateQueryRequest createQueryRequest) { final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); + BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter(); if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k); - return new KNNQuery(fieldName, vector, k, indexName, filterQuery); + return new KNNQuery(fieldName, vector, k, indexName, filterQuery, parentFilter); } log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KNNQuery(fieldName, vector, k, indexName); + return new KNNQuery(fieldName, vector, k, indexName, parentFilter); } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter(); if (VectorDataType.BYTE == vectorDataType) { return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter); } else if (VectorDataType.FLOAT == vectorDataType) { @@ -187,9 +187,6 @@ static class CreateQueryRequest { private VectorDataType vectorDataType; @Getter private int k; - // can be null in cases filter not passed with the knn query - @Getter - private BitSetProducer parentFilter; private QueryBuilder filter; // can be null in cases filter not passed with the knn query private QueryShardContext context; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index c166be3c27..c20d1efc76 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.FilteredDocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; @@ -117,9 +118,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * This improves the recall. */ if (filterWeight != null && canDoExactSearch(filterIdsArray.length)) { - docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray)); + docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray, knnQuery.getParentsFilter())); } else { - Map annResults = doANNSearch(context, filterIdsArray); + Map annResults = doANNSearch(context, filterIdsArray, knnQuery.getParentsFilter()); if (annResults == null) { return null; } @@ -131,7 +132,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { annResults.size(), filterIdsArray.length ); - annResults = doExactSearch(context, filterIdsArray); + annResults = doExactSearch(context, filterIdsArray, knnQuery.getParentsFilter()); } docIdsToScoreMap.putAll(annResults); } @@ -172,23 +173,31 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept if (filterWeight == null) { return new int[0]; } - final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight); - final int[] filteredIds = new int[filteredDocsBitSet.cardinality()]; - int filteredIdsIndex = 0; - int docId = 0; - while (docId < filteredDocsBitSet.length()) { - docId = filteredDocsBitSet.nextSetBit(docId); - if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) { - break; - } - filteredIds[filteredIdsIndex] = docId; - filteredIdsIndex++; - docId++; + return bitSetToIntArray(getFilteredDocsBitSet(context, this.filterWeight)); + } + + private int[] getParentIdsArray(final LeafReaderContext context, final BitSetProducer parentFilter) throws IOException { + if (parentFilter == null) { + return null; } - return filteredIds; + return bitSetToIntArray(parentFilter.getBitSet(context)); } - private Map doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException { + private int[] bitSetToIntArray(final BitSet bitSet) { + final int[] intArray = new int[bitSet.cardinality()]; + final BitSetIterator bitSetIterator = new BitSetIterator(bitSet, bitSet.cardinality()); + int index = 0; + int docId = bitSetIterator.nextDoc(); + while (docId != DocIdSetIterator.NO_MORE_DOCS) { + assert index < intArray.length; + intArray[index++] = docId; + docId = bitSetIterator.nextDoc(); + } + return intArray; + } + + private Map doANNSearch(final LeafReaderContext context, final int[] filterIdsArray, final BitSetProducer parentFilter) + throws IOException { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); @@ -265,13 +274,14 @@ private Map doANNSearch(final LeafReaderContext context, final i if (indexAllocation.isClosed()) { throw new RuntimeException("Index has already been closed"); } - + int[] parentIds = getParentIdsArray(context, parentFilter); results = JNIService.queryIndex( indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName(), - filterIdsArray + filterIdsArray, + parentIds ); } catch (Exception e) { @@ -296,9 +306,14 @@ private Map doANNSearch(final LeafReaderContext context, final i .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } - private Map doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) { + private Map doExactSearch( + final LeafReaderContext leafReaderContext, + final int[] filterIdsArray, + final BitSetProducer parentFilter + ) throws IOException { final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final BitSet parentIds = parentFilter.getBitSet(leafReaderContext); float[] queryVector = this.knnQuery.getQueryVector(); try { final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 42d12e984c..beef9f927c 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -101,7 +101,14 @@ public static long loadIndex(String indexPath, Map parameters, S * @param filteredIds array of ints on which should be used for search. * @return KNNQueryResult array of k neighbors */ - public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds) { + public static KNNQueryResult[] queryIndex( + long indexPointer, + float[] queryVector, + int k, + String engineName, + int[] filteredIds, + int[] parentIds + ) { if (KNNEngine.NMSLIB.getName().equals(engineName)) { return NmslibService.queryIndex(indexPointer, queryVector, k); } @@ -112,9 +119,9 @@ public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine // normally. if (ArrayUtils.isNotEmpty(filteredIds)) { - return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, null); + return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, parentIds); } - return FaissService.queryIndex(indexPointer, queryVector, k, null); + return FaissService.queryIndex(indexPointer, queryVector, k, parentIds); } throw new IllegalArgumentException("QueryIndex not supported for provided engine"); } diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java index 4f39c8bba7..107d45cfd3 100644 --- a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -55,7 +55,7 @@ public final void cleanUp() { } @SneakyThrows - public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() { + public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() { createKnnIndex(2, KNNEngine.LUCENE.getName()); String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) @@ -78,6 +78,30 @@ public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() { assertEquals(2, hits.size()); } + @SneakyThrows + public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() { + createKnnIndex(2, KNNEngine.FAISS.getName()); + + String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f }) + .build(); + addNestedKnnDoc(INDEX_NAME, "1", doc1); + + String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f }) + .build(); + addNestedKnnDoc(INDEX_NAME, "2", doc2); + + Float[] queryVector = { 1f, 1f }; + Response response = queryNestedField(INDEX_NAME, 2, queryVector); + + List hits = (List) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + EntityUtils.toString(response.getEntity()) + ).map().get("hits")).get("hits"); + assertEquals(2, hits.size()); + } + /** * { * "properties": { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 40309027d8..42eb817594 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.index.mapper.MapperService; @@ -162,14 +163,20 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { // query to verify distance for each of the field IndexSearcher searcher = new IndexSearcher(reader); - float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"), 10).scoreDocs[0].score; - float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy"), 10).scoreDocs[0].score; + float score = searcher.search( + new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null), + 10 + ).scoreDocs[0].score; + float score1 = searcher.search( + new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy", (BitSetProducer) null), + 10 + ).scoreDocs[0].score; assertEquals(1.0f / (1 + 25), score, 0.01f); assertEquals(1.0f / (1 + 169), score1, 0.01f); // query to determine the hits - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"))); - assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null))); + assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy", (BitSetProducer) null))); reader.close(); dir.close(); @@ -254,7 +261,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); float[] query = { 10.0f, 10.0f, 10.0f }; IndexSearcher searcher = new IndexSearcher(reader); - TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy"), 10); + TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy", (BitSetProducer) null), 10); assertEquals(3, topDocs.scoreDocs[0].doc); assertEquals(2, topDocs.scoreDocs[1].doc); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index 1a9507e6a1..23a0643ad2 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -333,7 +333,7 @@ public static void assertLoadableByEngine( ); int k = 2; float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null); + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null, null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine.getName()); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index fb91266ab9..798c64a17e 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -74,7 +74,7 @@ public void testIndexLoadStrategy_load() throws IOException { // Confirm that the file was loaded by querying float[] query = new float[dimension]; Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null, null); assertTrue(results.length > 0); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index a6b915a852..3b91a0c3c2 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -131,6 +131,30 @@ public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class); } + public void testCreate_whenFaissWithParentFilter_thenSuccess() { + final KNNEngine knnEngine = KNNEngine.FAISS; + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .build(); + final Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof KNNQuery); + assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); + assertEquals(testFieldName, ((KNNQuery) query).getField()); + assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); + assertEquals(testK, ((KNNQuery) query).getK()); + assertEquals(parentFilter, ((KNNQuery) query).getParentsFilter()); + } + private void validateDiversifyingQueryWithParentFilter(final VectorDataType type, final Class expectedQueryClass) { List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index ba6675d3df..267e48cbe2 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -22,9 +22,11 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; import org.junit.Before; @@ -48,6 +50,7 @@ import java.io.IOException; import java.nio.file.Path; import java.util.ArrayList; +import java.util.Arrays; import java.util.Comparator; import java.util.List; import java.util.Map; @@ -154,10 +157,10 @@ public void testQueryScoreForFaissWithModel() throws IOException { SpaceType spaceType = SpaceType.L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) .thenReturn(getKNNQueryResults()); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); ModelDao modelDao = mock(ModelDao.class); ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -221,7 +224,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { SpaceType spaceType = SpaceType.L2; final String modelId = "modelId"; - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); ModelDao modelDao = mock(ModelDao.class); ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -253,7 +256,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { @SneakyThrows public void testShardWithoutFiles() { - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -294,10 +297,10 @@ public void testShardWithoutFiles() { @SneakyThrows public void testEmptyQueryResults() { final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {}; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) .thenReturn(knnQueryResults); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -338,7 +341,7 @@ public void testEmptyQueryResults() { public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { int k = 3; final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any())) .thenReturn(getFilteredKNNQueryResults()); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -351,7 +354,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { when(liveDocsBits.length()).thenReturn(1000); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -395,7 +398,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); - jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))); + jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any())); final List actualDocIds = new ArrayList<>(); final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); @@ -415,7 +418,7 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -465,7 +468,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -534,7 +537,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length)); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); @@ -577,7 +580,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); when(filterScorer.iterator()).thenReturn(DocIdSetIterator.empty()); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); final FieldInfos fieldInfos = mock(FieldInfos.class); @@ -593,15 +596,92 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { assertEquals(0, docIdSetIterator.cost()); } + @SneakyThrows + public void testANNWithParentsFilter_whenSet_thenBitSetIsPassedToJNI() { + SegmentReader reader = getMockedSegmentReader(); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + // Prepare parentFilter + final int[] parentsFilter = { 10, 64 }; + final FixedBitSet bitset = new FixedBitSet(65); + Arrays.stream(parentsFilter).forEach(i -> bitset.set(i)); + final BitSetProducer bitSetProducer = mock(BitSetProducer.class); + + // Prepare query and weight + when(bitSetProducer.getBitSet(leafReaderContext)).thenReturn(bitset); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, 1, INDEX_NAME, null, bitSetProducer); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, null); + + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), eq(parentsFilter))) + .thenReturn(getKNNQueryResults()); + + // Execute + Scorer knnScorer = knnWeight.scorer(leafReaderContext); + + // Verify + jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), eq(parentsFilter))); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + } + + private SegmentReader getMockedSegmentReader() { + final SegmentReader reader = mock(SegmentReader.class); + when(reader.maxDoc()).thenReturn(1); + + // Prepare live docs + int liveDocId = 0; + final Bits liveDocsBits = mock(Bits.class); + when(liveDocsBits.get(liveDocId)).thenReturn(true); + when(liveDocsBits.length()).thenReturn(1); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + + // Prepare directory + final Path path = mock(Path.class); + final FSDirectory directory = mock(FSDirectory.class); + when(directory.getDirectory()).thenReturn(path); + when(reader.directory()).thenReturn(directory); + + // Prepare segment + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + true, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(SEGMENT_FILES_FAISS); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + // Prepare fieldInfos + final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName()); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.attributes()).thenReturn(attributesMap); + final FieldInfos fieldInfos = mock(FieldInfos.class); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + + return reader; + } + private void testQueryScore( final Function scoreTranslator, final Set segmentFiles, final Map fileAttributes ) throws IOException { - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) .thenReturn(getKNNQueryResults()); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 185f2953d5..04e810cf6c 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -590,12 +590,12 @@ public void testLoadIndex_faiss_valid() throws IOException { } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null)); + expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null, null)); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null, null)); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { @@ -618,7 +618,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null, null)); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -644,7 +644,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null, null); assertEquals(k, results.length); } } @@ -652,7 +652,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null, null)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { @@ -671,7 +671,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null, null)); } public void testQueryIndex_faiss_valid() throws IOException { @@ -700,13 +700,13 @@ public void testQueryIndex_faiss_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, null); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }, null); assertEquals(0, results.length); } }