From 90f98731a5a7c5420cfa8f20ad9a96dd2ccdb226 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Thu, 12 Dec 2024 11:36:55 -0800 Subject: [PATCH] Support expand_nested_docs parameter for nmslib engine Signed-off-by: Heemin Kim --- CHANGELOG.md | 1 + .../knn/index/engine/KNNEngine.java | 1 - .../knn/index/query/KNNQueryFactory.java | 12 +------ .../GroupedNestedDocIdSetIterator.java | 31 +++++++++++++------ .../GroupedNestedDocIdSetIteratorTests.java | 29 +++++++++++++++++ 5 files changed, 52 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cbd0ef2f..4ae86dba1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x) ### Features - Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] +- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index f75c7f1d9..1e560a11b 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -34,7 +34,6 @@ public enum KNNEngine implements KNNLibrary { private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); public static final Set ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); - public static final Set ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index d01a9aff6..0c1efef88 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -26,7 +26,6 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; -import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS; /** * Creates the Lucene k-NN queries @@ -50,7 +49,6 @@ public static Query create(CreateQueryRequest createQueryRequest) { final Query filterQuery = getFilterQuery(createQueryRequest); final Map methodParameters = createQueryRequest.getMethodParameters(); final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null); - final KNNEngine knnEngine = createQueryRequest.getKnnEngine(); final boolean expandNested = createQueryRequest.isExpandNested(); BitSetProducer parentFilter = null; if (createQueryRequest.getContext().isPresent()) { @@ -110,15 +108,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { .build(); } - if (createQueryRequest.getRescoreContext().isPresent()) { - return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); - } - - if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && expandNested) { - return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); - } - - return knnQuery; + return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); } Integer requestEfSearch = null; diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java index 19842a67a..727c508fb 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java @@ -19,9 +19,8 @@ * A `DocIdSetIterator` that iterates over all nested document IDs belongs to the same parent document for a given * set of nested document IDs. * - * The {@link #docIds} should include only a single nested document ID per parent document. Otherwise, the nested documents - * of that parent document will be iterated multiple times. - * + * It is permissible for {@link #docIds} to contain multiple nested document IDs linked to a single parent document. + * In such cases, this iterator will still iterate over each nested document ID only once. */ public class GroupedNestedDocIdSetIterator extends DocIdSetIterator { private final BitSet parentBitSet; @@ -99,9 +98,14 @@ public long cost() { private long calculateCost() { long numDocs = 0; + int lastDocId = -1; for (int docId : docIds) { - for (int i = parentBitSet.prevSetBit(docId) + 1; i < parentBitSet.nextSetBit(docId); i++) { - if (filterBits.get(i)) { + if (docId < lastDocId) { + continue; + } + + for (lastDocId = parentBitSet.prevSetBit(docId) + 1; lastDocId < parentBitSet.nextSetBit(docId); lastDocId++) { + if (filterBits.get(lastDocId)) { numDocs++; } } @@ -111,12 +115,19 @@ private long calculateCost() { private void moveToNextIndex() { currentIndex++; - if (currentIndex >= docIds.size()) { - currentDocId = NO_MORE_DOCS; + while (currentIndex < docIds.size()) { + // Advance currentIndex until the docId at the currentIndex is greater than currentDocId. + // This ensures proper handling when docIds contain multiple entries under the same parent ID + // that have already been iterated. + if (docIds.get(currentIndex) <= currentDocId) { + currentIndex++; + continue; + } + currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1; + currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex)); + assert currentParentId != NO_MORE_DOCS; return; } - currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1; - currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex)); - assert currentParentId != NO_MORE_DOCS; + currentDocId = NO_MORE_DOCS; } } diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java index 55f3d91d9..976b50ea6 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java @@ -70,4 +70,33 @@ public void testGroupedNestedDocIdSetIterator_whenAdvanceIsCalled_thenBehaveAsEx assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID()); assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost()); } + + public void testGroupedNestedDocIdSetIterator_whenGivenMultipleDocsUnderSameParent_thenBehaveAsExpected() throws Exception { + // 0, 1, 2(parent), 3, 4, 5, 6, 7(parent), 8, 9, 10(parent) + BitSet parentBitSet = new FixedBitSet(new long[1], 11); + parentBitSet.set(2); + parentBitSet.set(7); + parentBitSet.set(10); + + BitSet filterBits = new FixedBitSet(new long[1], 11); + filterBits.set(1); + filterBits.set(8); + filterBits.set(9); + + // Run + Set docIds = Set.of(0, 1, 3, 4, 5, 8, 9); + GroupedNestedDocIdSetIterator groupedNestedDocIdSetIterator = new GroupedNestedDocIdSetIterator(parentBitSet, docIds, filterBits); + + // Verify + Set expectedDocIds = Set.of(1, 8, 9); + groupedNestedDocIdSetIterator.advance(1); + assertEquals(1, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.nextDoc(); + assertEquals(8, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.advance(9); + assertEquals(9, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.nextDoc(); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID()); + assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost()); + } }