Skip to content

Commit

Permalink
Handle multi-vector in exact search scenario (#1399)
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 authored Jan 19, 2024
1 parent 709b448 commit 8c98265
Show file tree
Hide file tree
Showing 9 changed files with 449 additions and 45 deletions.
71 changes: 38 additions & 33 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,40 @@
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.io.PathUtils;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.query.filtered.FilteredIdsKNNIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator;
import org.opensearch.knn.index.util.KNNEngine;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.common.io.PathUtils;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
Expand Down Expand Up @@ -306,33 +304,23 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final i
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) throws IOException {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = getSpaceType(fieldInfo);
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(this.knnQuery.getK(), true);
ScoreDoc topDoc = queue.top();
final Map<Integer, Float> docToScore = new HashMap<>();
for (int filterId : filterIdsArray) {
int docId = values.advance(filterId);
final BytesRef value = values.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
if (score > topDoc.score) {
topDoc.score = score;
FilteredIdsKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsArray);
int docId;
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (iterator.score() > topDoc.score) {
topDoc.score = iterator.score();
topDoc.doc = docId;
// As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we
// have seen till now on top.
topDoc = queue.updateTop();
}
}

// If scores are negative we will remove them.
// This is done, because there can be negative values in the Heap as we init the heap with Score as -INF.
// If filterIds < k, the some values in heap can have a negative score.
Expand All @@ -352,6 +340,23 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont
return Collections.emptyMap();
}

private FilteredIdsKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final int[] filterIdsArray)
throws IOException {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = getSpaceType(fieldInfo);
return knnQuery.getParentsFilter() == null
? new FilteredIdsKNNIterator(filterIdsArray, knnQuery.getQueryVector(), values, spaceType)
: new NestedFilteredIdsKNNIterator(
filterIdsArray,
knnQuery.getQueryVector(),
values,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}

private Scorer convertSearchResponseToScorer(final Map<Integer, Float> docsToScore) throws IOException {
final int maxDoc = Collections.max(docsToScore.keySet()) + 1;
final DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;

/**
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162
*
* The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray.
*/
public class FilteredIdsKNNIterator {
// Array of doc ids to iterate
protected final int[] filterIdsArray;
protected final float[] queryVector;
protected final BinaryDocValues binaryDocValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int currentPos = 0;

public FilteredIdsKNNIterator(
final int[] filterIdsArray,
final float[] queryVector,
final BinaryDocValues binaryDocValues,
final SpaceType spaceType
) {
this.filterIdsArray = filterIdsArray;
this.queryVector = queryVector;
this.binaryDocValues = binaryDocValues;
this.spaceType = spaceType;
}

/**
* Advance to the next doc and update score value with score of the next doc.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next doc id
*/
public int nextDoc() throws IOException {
if (currentPos >= filterIdsArray.length) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int docId = binaryDocValues.advance(filterIdsArray[currentPos]);
currentScore = computeScore();
currentPos++;
return docId;
}

public float score() {
return currentScore;
}

protected float computeScore() throws IOException {
final BytesRef value = binaryDocValues.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.SpaceType;

import java.io.IOException;

/**
* This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
*/
public class NestedFilteredIdsKNNIterator extends FilteredIdsKNNIterator {
private final BitSet parentBitSet;

public NestedFilteredIdsKNNIterator(
final int[] filterIdsArray,
final float[] queryVector,
final BinaryDocValues values,
final SpaceType spaceType,
final BitSet parentBitSet
) {
super(filterIdsArray, queryVector, values, spaceType);
this.parentBitSet = parentBitSet;
}

/**
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next best child doc id
*/
@Override
public int nextDoc() throws IOException {
if (currentPos >= filterIdsArray.length) {
return DocIdSetIterator.NO_MORE_DOCS;
}
currentScore = Float.NEGATIVE_INFINITY;
int currentParent = parentBitSet.nextSetBit(filterIdsArray[currentPos]);
int bestChild = -1;
while (currentPos < filterIdsArray.length && filterIdsArray[currentPos] < currentParent) {
binaryDocValues.advance(filterIdsArray[currentPos]);
float score = computeScore();
if (score > currentScore) {
bestChild = filterIdsArray[currentPos];
currentScore = score;
}
currentPos++;
}

return bestChild;
}
}
11 changes: 11 additions & 0 deletions src/test/java/org/opensearch/knn/common/Constants.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

public class Constants {
public static final String FIELD_FILTER = "filter";
public static final String FIELD_TERM = "term";
}
Loading

0 comments on commit 8c98265

Please sign in to comment.