Skip to content

Commit

Permalink
Support script score when doc value is disabled (opensearch-project#1573
Browse files Browse the repository at this point in the history
)

* support script score when doc value is disabled

Signed-off-by: panguixin <[email protected]>

* add test

Signed-off-by: panguixin <[email protected]>

* apply review comments

Signed-off-by: panguixin <[email protected]>

* fix test

Signed-off-by: panguixin <[email protected]>

---------

Signed-off-by: panguixin <[email protected]>
(cherry picked from commit 771c4b5)
  • Loading branch information
bugmakerrrrrr committed Apr 2, 2024
1 parent 3ec60af commit e2506c3
Show file tree
Hide file tree
Showing 14 changed files with 385 additions and 326 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x)
### Features
### Enhancements
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

package org.opensearch.knn.index;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
Expand Down Expand Up @@ -39,10 +40,29 @@ public long ramBytesUsed() {
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}

DocIdSetIterator values;
if (fieldInfo.hasVectorValues()) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
values = reader.getFloatVectorValues(fieldName);
break;
case BYTE:
values = reader.getByteVectorValues(fieldName);
break;
default:
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());
}
} else {
values = DocValues.getBinary(reader, fieldName);
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e);
}
}

Expand Down
109 changes: 98 additions & 11 deletions src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,30 @@

package org.opensearch.knn.index;

import java.io.IOException;
import java.util.Objects;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;

import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final DocIdSetIterator vectorValues;
private final String fieldName;
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
if (binaryDocValues.advanceExact(docId)) {
docExists = true;
return;
}
docExists = false;
docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId;
}

public float[] getValue() {
Expand All @@ -43,12 +43,14 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
return doGetValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected abstract float[] doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
Expand All @@ -58,4 +60,89 @@ public int size() {
public float[] get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

/**
* Creates a KNNVectorScriptDocValues object based on the provided parameters.
*
* @param values The DocIdSetIterator representing the vector values.
* @param fieldName The name of the field.
* @param vectorDataType The data type of the vector.
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
}
return value;
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return values.vectorValue();
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromDocValues(values.binaryValue());
}
}

/**
* Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type.
*
* @param fieldName The name of the field.
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
}
}
6 changes: 3 additions & 3 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down Expand Up @@ -257,7 +257,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down Expand Up @@ -827,7 +827,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed(

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

package org.opensearch.knn.index;

import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.KNNTestCase;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.document.BinaryDocValuesField;
Expand All @@ -13,7 +22,6 @@
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.store.Directory;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -24,6 +32,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {

private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name";
private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f };
private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 };
private KNNVectorScriptDocValues scriptDocValues;
private Directory directory;
private DirectoryReader reader;
Expand All @@ -32,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {
public void setUp() throws Exception {
super.setUp();
directory = newDirectory();
createKNNVectorDocument(directory);
Class<? extends DocIdSetIterator> valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class);
createKNNVectorDocument(directory, valuesClass);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME),
MOCK_INDEX_FIELD_NAME,
VectorDataType.FLOAT
);
LeafReader leafReader = reader.getContext().leaves().get(0).reader();
DocIdSetIterator vectorValues;
if (BinaryDocValues.class.equals(valuesClass)) {
vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME);
} else if (ByteVectorValues.class.equals(valuesClass)) {
vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME);
} else {
vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME);
}

scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
}

private void createKNNVectorDocument(Directory directory) throws IOException {
private void createKNNVectorDocument(Directory directory, Class<? extends DocIdSetIterator> valuesClass) throws IOException {
IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random()));
IndexWriter writer = new IndexWriter(directory, conf);
Document knnDocument = new Document();
knnDocument.add(
new BinaryDocValuesField(
Field field;
if (BinaryDocValues.class.equals(valuesClass)) {
field = new BinaryDocValuesField(
MOCK_INDEX_FIELD_NAME,
new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue()
)
);
);
} else if (ByteVectorValues.class.equals(valuesClass)) {
field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA);
} else {
field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA);
}

knnDocument.add(field);
writer.addDocument(knnDocument);
writer.commit();
writer.close();
Expand Down Expand Up @@ -83,4 +105,18 @@ public void testSize() throws IOException {
public void testGet() throws IOException {
expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
}

public void testUnsupportedValues() throws IOException {
expectThrows(
IllegalArgumentException.class,
() -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT)
);
}

public void testEmptyValues() throws IOException {
KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
assertEquals(0, values.size());
scriptDocValues.setNextDocId(0);
assertEquals(0, values.size());
}
}
9 changes: 4 additions & 5 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Floats;
import org.apache.http.util.EntityUtils;
import lombok.SneakyThrows;
import org.apache.commons.lang.math.RandomUtils;
Expand Down Expand Up @@ -307,14 +306,14 @@ public void testIndexReopening() throws Exception {
final float[] searchVector = TEST_QUERY_VECTORS[0];
final int k = 1 + RandomUtils.nextInt(TEST_INDEX_VECTORS.length);

final List<Float[]> knnResultsBeforeIndexClosure = queryResults(searchVector, k);
final List<float[]> knnResultsBeforeIndexClosure = queryResults(searchVector, k);

closeIndex(INDEX_NAME);
openIndex(INDEX_NAME);

ensureGreen(INDEX_NAME);

final List<Float[]> knnResultsAfterIndexClosure = queryResults(searchVector, k);
final List<float[]> knnResultsAfterIndexClosure = queryResults(searchVector, k);

assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray());
}
Expand Down Expand Up @@ -365,15 +364,15 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws IOExc

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, queryVector);
float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance);
assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(j), 0.0001);
}
}
}

private List<Float[]> queryResults(final float[] searchVector, final int k) throws Exception {
private List<float[]> queryResults(final float[] searchVector, final int k) throws Exception {
final String responseBody = EntityUtils.toString(
searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity()
);
Expand Down
4 changes: 1 addition & 3 deletions src/test/java/org/opensearch/knn/index/NmslibIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@

import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;

Expand Down Expand Up @@ -115,7 +113,7 @@ public void testEndToEnd() throws IOException, InterruptedException {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down
Loading

0 comments on commit e2506c3

Please sign in to comment.