Skip to content

Commit

Permalink
Handle multi-vector in exact search scenario
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jan 9, 2024
1 parent 0abed23 commit 9f82400
Show file tree
Hide file tree
Showing 12 changed files with 504 additions and 32 deletions.
47 changes: 27 additions & 20 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.query.filtered.FilteredKNNIterator;
import org.opensearch.knn.index.query.filtered.KNNFloatQueryVector;
import org.opensearch.knn.index.query.filtered.NestedFilteredKNNIterator;
import org.opensearch.knn.index.query.filtered.PlainFilteredKNNIterator;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
Expand All @@ -44,7 +45,6 @@
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
Expand Down Expand Up @@ -306,33 +306,23 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final i
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) throws IOException {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = getSpaceType(fieldInfo);
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(this.knnQuery.getK(), true);
ScoreDoc topDoc = queue.top();
final Map<Integer, Float> docToScore = new HashMap<>();
for (int filterId : filterIdsArray) {
int docId = values.advance(filterId);
final BytesRef value = values.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
if (score > topDoc.score) {
topDoc.score = score;
FilteredKNNIterator scorer = getFilteredKNNIterator(leafReaderContext, filterIdsArray);
int docId;
while ((docId = scorer.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (scorer.score() > topDoc.score) {
topDoc.score = scorer.score();
topDoc.doc = docId;
// As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we
// have seen till now on top.
topDoc = queue.updateTop();
}
}

// If scores are negative we will remove them.
// This is done, because there can be negative values in the Heap as we init the heap with Score as -INF.
// If filterIds < k, the some values in heap can have a negative score.
Expand All @@ -352,6 +342,23 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont
return Collections.emptyMap();
}

private FilteredKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final int[] filterIdsArray)
throws IOException {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = getSpaceType(fieldInfo);
return knnQuery.getParentsFilter() == null
? new PlainFilteredKNNIterator(filterIdsArray, new KNNFloatQueryVector(knnQuery.getQueryVector()), values, spaceType)
: new NestedFilteredKNNIterator(
filterIdsArray,
new KNNFloatQueryVector(knnQuery.getQueryVector()),
values,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}

private Scorer convertSearchResponseToScorer(final Map<Integer, Float> docsToScore) throws IOException {
final int maxDoc = Collections.max(docsToScore.keySet()) + 1;
final DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.opensearch.knn.index.SpaceType;

import java.io.IOException;

/**
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162
*
* The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray.
*/
public abstract class FilteredKNNIterator<T extends KNNQueryVector> {
// Array of doc ids to iterate
protected final int[] filterIdsArray;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected final T queryVector;
protected final BinaryDocValues values;
protected final SpaceType spaceType;
protected int currentPos = 0;

public FilteredKNNIterator(final int[] filterIdsArray,
final T queryVector,
final BinaryDocValues values,
final SpaceType spaceType) {
this.filterIdsArray = filterIdsArray;
this.queryVector = queryVector;
this.values = values;
this.spaceType = spaceType;
}

/**
* Advance to the next doc and update score
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next doc id
*/
abstract public int nextDoc() throws IOException;

/**
* Return a score of current doc
*
* @return current score
*/
public float score() {
return currentScore;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;

public class KNNFloatQueryVector implements KNNQueryVector {
private final float[] queryVector;

public KNNFloatQueryVector(final float[] queryVector) {
this.queryVector = queryVector;
}

public float score(final BinaryDocValues values, final SpaceType spaceType) throws IOException {
final BytesRef value = values.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.opensearch.knn.index.SpaceType;

import java.io.IOException;

public interface KNNQueryVector {
float score(final BinaryDocValues values, final SpaceType spaceType) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.SpaceType;

import java.io.IOException;

/**
* This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
*/
public class NestedFilteredKNNIterator<T extends KNNQueryVector> extends FilteredKNNIterator<T> {
private final BitSet parentBitSet;
private int currentParent = -1;
private int bestChild = -1;

public NestedFilteredKNNIterator(
final int[] filterIdsArray,
final T queryVector,
final BinaryDocValues values,
final SpaceType spaceType,
final BitSet parentBitSet
) {
super(filterIdsArray, queryVector, values, spaceType);
this.parentBitSet = parentBitSet;
}

@Override
public int nextDoc() throws IOException {
if (currentPos >= filterIdsArray.length) {
return DocIdSetIterator.NO_MORE_DOCS;
}
currentScore = Float.NEGATIVE_INFINITY;
currentParent = parentBitSet.nextSetBit(filterIdsArray[currentPos]);
do {
int currentChild = filterIdsArray[currentPos];
values.advance(currentChild);
final float score = queryVector.score(values, spaceType);
if (score > currentScore) {
bestChild = currentChild;
currentScore = score;
}
currentPos++;
} while (currentPos < filterIdsArray.length && filterIdsArray[currentPos] < currentParent);

return bestChild;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.index.SpaceType;

import java.io.IOException;

/**
* Basic implementation of FilteredKNNIterator which iterate all doc IDs in filterIdsArray
*/
public class PlainFilteredKNNIterator<T extends KNNQueryVector> extends FilteredKNNIterator<T> {
private int currentDoc = -1;

public PlainFilteredKNNIterator(
final int[] filterIdsArray,
final T queryVector,
final BinaryDocValues values,
final SpaceType spaceType
) {
super(filterIdsArray, queryVector, values, spaceType);
}

@Override
public int nextDoc() throws IOException {
if (currentPos >= filterIdsArray.length) {
return DocIdSetIterator.NO_MORE_DOCS;
}
currentDoc = values.advance(filterIdsArray[currentPos++]);
currentScore = queryVector.score(values, spaceType);
return currentDoc;
}
}
11 changes: 11 additions & 0 deletions src/test/java/org/opensearch/knn/common/Constants.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

public class Constants {
public static final String FILTER_FIELD = "filter";
public static final String TERM_FIELD = "term";
}
Loading

0 comments on commit 9f82400

Please sign in to comment.