From 7ddea22957862603af7392e7076ed13885255538 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Wed, 6 Nov 2024 16:27:58 -0800 Subject: [PATCH] POC for multi innerhit with lucene Signed-off-by: Heemin Kim --- .../knn/index/query/KNNQueryFactory.java | 6 +- .../nativelib/LuceneEngineKNNVectorQuery.java | 196 ++++++++++++++++++ .../nativelib/NativeEngineKnnVectorQuery.java | 3 + 3 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/query/nativelib/LuceneEngineKNNVectorQuery.java diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index dab2e08c8..b1b5f314b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -15,6 +15,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.nativelib.LuceneEngineKNNVectorQuery; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -95,7 +96,8 @@ public static Query create(CreateQueryRequest createQueryRequest) { .rescoreContext(rescoreContext) .build(); } - return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery; +// return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery; + return new NativeEngineKnnVectorQuery(knnQuery); } Integer requestEfSearch = null; @@ -162,7 +164,7 @@ private static Query getKnnFloatVectorQuery( if (parentFilter == null) { return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery); } else { - return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter); + return new LuceneEngineKNNVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter); } } } diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/LuceneEngineKNNVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/LuceneEngineKNNVectorQuery.java new file mode 100644 index 000000000..f3290d991 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/LuceneEngineKNNVectorQuery.java @@ -0,0 +1,196 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.DocIdSetBuilder; +import org.opensearch.knn.index.query.ResultUtil; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; + +public class LuceneEngineKNNVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery { + private final BitSetProducer parentsFilter; + private final Query childFilter; + private final int k; + private final float[] query; + private final LuceneKnnFloatVectorQuery knnFloatVectorQuery; + private boolean doReWrite; + + public LuceneEngineKNNVectorQuery(String field, float[] query, Query childFilter, int k, BitSetProducer parentsFilter) { + super(field, query, childFilter, k, parentsFilter); + this.childFilter = childFilter; + this.parentsFilter = parentsFilter; + this.k = k; + this.query = query; + this.knnFloatVectorQuery = new LuceneKnnFloatVectorQuery(field, query, 1000); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return doReWrite ? super.rewrite(indexSearcher) : this; + } + + private Query doReWrite(IndexSearcher indexSearcher) throws IOException { + try { + doReWrite = true; + return indexSearcher.rewrite(this); + } finally { + doReWrite = false; + } + } + + @Override + public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException { + final IndexReader reader = indexSearcher.getIndexReader(); + Query docAndScoreQuery = doReWrite(indexSearcher); + Weight weight = docAndScoreQuery.createWeight(indexSearcher, scoreMode, boost); + List leafReaderContexts = reader.leaves(); + List> perLeafResults; + perLeafResults = doSearch(indexSearcher, leafReaderContexts, weight, k); + ResultUtil.reduceToTopK(perLeafResults, k); + TopDocs[] topDocs = retrieveAll(indexSearcher, leafReaderContexts, weight, perLeafResults); + long sum = 0; + for (TopDocs topDoc : topDocs) { + sum += topDoc.totalHits.value; + } + TopDocs topK = TopDocs.merge((int)sum, topDocs); + if (topK.scoreDocs.length == 0) { + return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost); + } + + return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); + } + + private TopDocs[] retrieveAll(final IndexSearcher indexSearcher, final List leafReaderContexts, final Weight weight, final List> perLeafResults) throws IOException { + // Construct query + List> rescoreTasks = new ArrayList<>(leafReaderContexts.size()); + for (int i = 0; i < perLeafResults.size(); i++) { + LeafReaderContext leafReaderContext = leafReaderContexts.get(i); + int finalI = i; + rescoreTasks.add(() -> { + // Here, it should be bitset of parent doc id. + DocIdSetIterator iterator = getAllSiblings(leafReaderContext, perLeafResults.get(finalI)); + return knnFloatVectorQuery.exactSearch(leafReaderContext, iterator, null); + }); + } + return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks).toArray(TopDocs[]::new); + } + + private DocIdSetIterator getAllSiblings(final LeafReaderContext leafReaderContext, final Map integerFloatMap) throws IOException { + if (integerFloatMap.isEmpty()) { + return DocIdSetIterator.empty(); + } + final int maxDoc = Collections.max(integerFloatMap.keySet()); + BitSet parentBitSet = parentsFilter.getBitSet(leafReaderContext); + final int maxParentDoc = parentBitSet.nextSetBit(maxDoc) + 1; + return resultMapToDocIds(leafReaderContext, integerFloatMap, maxParentDoc); + } + + public DocIdSetIterator resultMapToDocIds(final LeafReaderContext leafReaderContext, Map resultMap, final int maxDoc) throws IOException { + if (resultMap.isEmpty()) { + return DocIdSetIterator.empty(); + } + final DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); + final DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc - resultMap.size()); + BitSet parentBitSet = parentsFilter.getBitSet(leafReaderContext); + resultMap.keySet().forEach(key -> { + for (int i = parentBitSet.prevSetBit(key) + 1; i < parentBitSet.nextSetBit(key); i++) { + setAdder.add(i); + } + }); + return docIdSetBuilder.build().iterator(); + } + + private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) { + int len = topK.scoreDocs.length; + Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); + int[] docs = new int[len]; + float[] scores = new float[len]; + for (int i = 0; i < len; i++) { + docs[i] = topK.scoreDocs[i].doc; + scores[i] = topK.scoreDocs[i].score; + } + int[] segmentStarts = findSegmentStarts(reader, docs); + return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id()); + } + + private List> doSearch( + final IndexSearcher indexSearcher, + List leafReaderContexts, + Weight weight, + int k + ) throws IOException { + List>> tasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + tasks.add(() -> searchLeaf(leafReaderContext, weight, k)); + } + return indexSearcher.getTaskExecutor().invokeAll(tasks); + } + + static int[] findSegmentStarts(IndexReader reader, int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; + } + + private Map searchLeaf(LeafReaderContext ctx, Weight weight, int k) throws IOException { + final Map leafDocScores = new HashMap<>(); + Scorer scorer = weight.scorer(ctx); + DocIdSetIterator iterator = scorer.iterator(); + iterator.nextDoc(); + while (iterator.docID() != DocIdSetIterator.NO_MORE_DOCS) { + leafDocScores.put(scorer.docID(), scorer.score()); + iterator.nextDoc(); + } + return leafDocScores; + } + + public class LuceneKnnFloatVectorQuery extends KnnFloatVectorQuery { + + public LuceneKnnFloatVectorQuery(final String field, final float[] target, final int k) { + super(field, target, k); + } + + @Override + protected TopDocs exactSearch(final LeafReaderContext context, final DocIdSetIterator acceptIterator, final QueryTimeout queryTimeout) throws IOException { + return super.exactSearch(context, acceptIterator, queryTimeout); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index a34a0f1ee..b596543d2 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -75,6 +75,9 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); } ResultUtil.reduceToTopK(perLeafResults, finalK); + if (knnQuery.getParentsFilter() != null) { + + } TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; for (int i = 0; i < perLeafResults.size(); i++) { topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase);