diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 538be9815f..9f7d522053 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -6,33 +6,21 @@ package org.opensearch.knn.index; import lombok.Getter; +import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; +@RequiredArgsConstructor public final class KNNVectorScriptDocValues extends ScriptDocValues { private final BinaryDocValues binaryDocValues; private final String fieldName; @Getter private final VectorDataType vectorDataType; - private boolean docExists; - - public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName, VectorDataType vectorDataType) { - this.binaryDocValues = binaryDocValues; - this.fieldName = fieldName; - this.vectorDataType = vectorDataType; - } + private boolean docExists = false; @Override public void setNextDocId(int docId) throws IOException { @@ -55,31 +43,7 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - BytesRef value = binaryDocValues.binaryValue(); - if (VectorDataType.BYTE.equals(vectorDataType)) { - float[] vector = new float[value.length]; - int i = 0; - int j = value.offset; - - while (i < value.length) { - vector[i++] = value.bytes[j++]; - } - return vector; - } else if (VectorDataType.FLOAT.equals(vectorDataType)) { - ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); - final float[] vector = vectorSerializer.byteToFloatArray(byteStream); - return vector; - } else { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "Invalid value provided for [%s] field. Supported values are [%s]", - VECTOR_DATA_TYPE_FIELD, - SUPPORTED_VECTOR_DATA_TYPES - ) - ); - } + return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue()); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index be4d7110c4..23b374e9d0 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -11,7 +11,11 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import java.io.ByteArrayInputStream; import java.util.Arrays; import java.util.Locale; import java.util.Objects; @@ -31,6 +35,18 @@ public enum VectorDataType { public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction); } + + @Override + public float[] getVectorFromDocValues(BytesRef binaryValue) { + float[] vector = new float[binaryValue.length]; + int i = 0; + int j = binaryValue.offset; + + while (i < binaryValue.length) { + vector[i++] = binaryValue.bytes[j++]; + } + return vector; + } }, FLOAT("float") { @@ -39,6 +55,13 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); } + @Override + public float[] getVectorFromDocValues(BytesRef binaryValue) { + ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length); + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + return vectorSerializer.byteToFloatArray(byteStream); + } + }; public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values()) @@ -57,6 +80,14 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio */ public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); + /** + * Deserializes float vector from doc values binary value. + * + * @param binaryValue Binary Value of DocValues + * @return float vector deserialized from binary value + */ + public abstract float[] getVectorFromDocValues(BytesRef binaryValue); + /** * Validates if given VectorDataType is in the list of supported data types. * @param vectorDataType VectorDataType