forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle multi-vector in exact search scenario
Signed-off-by: Heemin Kim <[email protected]>
- Loading branch information
Showing
12 changed files
with
504 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
src/main/java/org/opensearch/knn/index/query/filtered/FilteredKNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene | ||
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 | ||
* | ||
* The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray. | ||
*/ | ||
public abstract class FilteredKNNIterator<T extends KNNQueryVector> { | ||
// Array of doc ids to iterate | ||
protected final int[] filterIdsArray; | ||
protected float currentScore = Float.NEGATIVE_INFINITY; | ||
protected final T queryVector; | ||
protected final BinaryDocValues values; | ||
protected final SpaceType spaceType; | ||
protected int currentPos = 0; | ||
|
||
public FilteredKNNIterator(final int[] filterIdsArray, | ||
final T queryVector, | ||
final BinaryDocValues values, | ||
final SpaceType spaceType) { | ||
this.filterIdsArray = filterIdsArray; | ||
this.queryVector = queryVector; | ||
this.values = values; | ||
this.spaceType = spaceType; | ||
} | ||
|
||
/** | ||
* Advance to the next doc and update score | ||
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs | ||
* | ||
* @return next doc id | ||
*/ | ||
abstract public int nextDoc() throws IOException; | ||
|
||
/** | ||
* Return a score of current doc | ||
* | ||
* @return current score | ||
*/ | ||
public float score() { | ||
return currentScore; | ||
} | ||
} |
33 changes: 33 additions & 0 deletions
33
src/main/java/org/opensearch/knn/index/query/filtered/KNNFloatQueryVector.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.apache.lucene.util.BytesRef; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.codec.util.KNNVectorSerializer; | ||
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.IOException; | ||
|
||
public class KNNFloatQueryVector implements KNNQueryVector { | ||
private final float[] queryVector; | ||
|
||
public KNNFloatQueryVector(final float[] queryVector) { | ||
this.queryVector = queryVector; | ||
} | ||
|
||
public float score(final BinaryDocValues values, final SpaceType spaceType) throws IOException { | ||
final BytesRef value = values.binaryValue(); | ||
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); | ||
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); | ||
final float[] vector = vectorSerializer.byteToFloatArray(byteStream); | ||
// Calculates a similarity score between the two vectors with a specified function. Higher similarity | ||
// scores correspond to closer vectors. | ||
return spaceType.getVectorSimilarityFunction().compare(queryVector, vector); | ||
} | ||
} |
15 changes: 15 additions & 0 deletions
15
src/main/java/org/opensearch/knn/index/query/filtered/KNNQueryVector.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.io.IOException; | ||
|
||
public interface KNNQueryVector { | ||
float score(final BinaryDocValues values, final SpaceType spaceType) throws IOException; | ||
} |
55 changes: 55 additions & 0 deletions
55
src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredKNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.apache.lucene.util.BitSet; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc | ||
* of which ID is set in parentBitSet and only return best child doc with the highest score. | ||
*/ | ||
public class NestedFilteredKNNIterator<T extends KNNQueryVector> extends FilteredKNNIterator<T> { | ||
private final BitSet parentBitSet; | ||
private int currentParent = -1; | ||
private int bestChild = -1; | ||
|
||
public NestedFilteredKNNIterator( | ||
final int[] filterIdsArray, | ||
final T queryVector, | ||
final BinaryDocValues values, | ||
final SpaceType spaceType, | ||
final BitSet parentBitSet | ||
) { | ||
super(filterIdsArray, queryVector, values, spaceType); | ||
this.parentBitSet = parentBitSet; | ||
} | ||
|
||
@Override | ||
public int nextDoc() throws IOException { | ||
if (currentPos >= filterIdsArray.length) { | ||
return DocIdSetIterator.NO_MORE_DOCS; | ||
} | ||
currentScore = Float.NEGATIVE_INFINITY; | ||
currentParent = parentBitSet.nextSetBit(filterIdsArray[currentPos]); | ||
do { | ||
int currentChild = filterIdsArray[currentPos]; | ||
values.advance(currentChild); | ||
final float score = queryVector.score(values, spaceType); | ||
if (score > currentScore) { | ||
bestChild = currentChild; | ||
currentScore = score; | ||
} | ||
currentPos++; | ||
} while (currentPos < filterIdsArray.length && filterIdsArray[currentPos] < currentParent); | ||
|
||
return bestChild; | ||
} | ||
} |
38 changes: 38 additions & 0 deletions
38
src/main/java/org/opensearch/knn/index/query/filtered/PlainFilteredKNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* Basic implementation of FilteredKNNIterator which iterate all doc IDs in filterIdsArray | ||
*/ | ||
public class PlainFilteredKNNIterator<T extends KNNQueryVector> extends FilteredKNNIterator<T> { | ||
private int currentDoc = -1; | ||
|
||
public PlainFilteredKNNIterator( | ||
final int[] filterIdsArray, | ||
final T queryVector, | ||
final BinaryDocValues values, | ||
final SpaceType spaceType | ||
) { | ||
super(filterIdsArray, queryVector, values, spaceType); | ||
} | ||
|
||
@Override | ||
public int nextDoc() throws IOException { | ||
if (currentPos >= filterIdsArray.length) { | ||
return DocIdSetIterator.NO_MORE_DOCS; | ||
} | ||
currentDoc = values.advance(filterIdsArray[currentPos++]); | ||
currentScore = queryVector.score(values, spaceType); | ||
return currentDoc; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.common; | ||
|
||
public class Constants { | ||
public static final String FILTER_FIELD = "filter"; | ||
public static final String TERM_FIELD = "term"; | ||
} |
Oops, something went wrong.