From 32762a5c42c30cce0c9687f99b3e501890ae32ce Mon Sep 17 00:00:00 2001 From: panguixin Date: Fri, 22 Mar 2024 00:22:23 +0800 Subject: [PATCH] support script score when doc value is disabled Signed-off-by: panguixin --- .../knn/index/KNNVectorDVLeafFieldData.java | 21 ++++- .../knn/index/KNNVectorScriptDocValues.java | 82 ++++++++++++++++--- .../index/KNNVectorScriptDocValuesTests.java | 2 +- .../knn/index/VectorDataTypeTests.java | 4 +- .../plugin/script/KNNScoringUtilTests.java | 2 +- 5 files changed, 92 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index f4caa4f20..8dc8d1f85 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -5,9 +5,11 @@ package org.opensearch.knn.index; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; @@ -39,10 +41,21 @@ public long ramBytesUsed() { @Override public ScriptDocValues getScriptValues() { try { - BinaryDocValues values = DocValues.getBinary(reader, fieldName); - return new KNNVectorScriptDocValues(values, fieldName, vectorDataType); + DocIdSetIterator values = null; + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName); + System.out.println(fieldInfo); + if (fieldInfo.hasVectorValues()) { + values = fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32 + ? reader.getFloatVectorValues(fieldName) + : reader.getByteVectorValues(fieldName); + System.out.println("use vector values"); + } else { + values = DocValues.getBinary(reader, fieldName); + System.out.println("use binary values"); + } + return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); + throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e); } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 9f7d52205..496997bc6 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -5,18 +5,21 @@ package org.opensearch.knn.index; +import java.io.IOException; +import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import java.io.IOException; - -@RequiredArgsConstructor -public final class KNNVectorScriptDocValues extends ScriptDocValues { +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public abstract class KNNVectorScriptDocValues extends ScriptDocValues { - private final BinaryDocValues binaryDocValues; + private final DocIdSetIterator vectorValues; private final String fieldName; @Getter private final VectorDataType vectorDataType; @@ -24,11 +27,7 @@ public final class KNNVectorScriptDocValues extends ScriptDocValues { @Override public void setNextDocId(int docId) throws IOException { - if (binaryDocValues.advanceExact(docId)) { - docExists = true; - return; - } - docExists = false; + docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId; } public float[] getValue() { @@ -43,12 +42,14 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue()); + return doGetValue(); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } } + protected abstract float[] doGetValue() throws IOException; + @Override public int size() { return docExists ? 1 : 0; @@ -58,4 +59,63 @@ public int size() { public float[] get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } + + public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + if (values instanceof ByteVectorValues) { + return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof FloatVectorValues) { + return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof BinaryDocValues) { + return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); + } else { + throw new IllegalArgumentException("Unsupported values type: " + values.getClass()); + } + } + + private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { + private final ByteVectorValues values; + + KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + byte[] bytes = values.vectorValue(); + float[] value = new float[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + value[i] = (float) bytes[i]; + } + return value; + } + } + + private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { + private final FloatVectorValues values; + + KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return values.vectorValue(); + } + } + + private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { + private final BinaryDocValues values; + + KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return getVectorDataType().getVectorFromDocValues(values.binaryValue()); + } + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index a0df3ce64..2e3531b18 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -35,7 +35,7 @@ public void setUp() throws Exception { createKNNVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( + scriptDocValues = KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 4423c85d8..19270717d 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { createKNNFloatVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, VectorDataType.FLOAT @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { createKNNByteVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, VectorDataType.BYTE diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 8c43a4acf..22110accd 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( + scriptDocValues = KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName, VectorDataType.FLOAT