diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 538be9815f..9f7d522053 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -6,33 +6,21 @@ package org.opensearch.knn.index; import lombok.Getter; +import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; +@RequiredArgsConstructor public final class KNNVectorScriptDocValues extends ScriptDocValues { private final BinaryDocValues binaryDocValues; private final String fieldName; @Getter private final VectorDataType vectorDataType; - private boolean docExists; - - public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName, VectorDataType vectorDataType) { - this.binaryDocValues = binaryDocValues; - this.fieldName = fieldName; - this.vectorDataType = vectorDataType; - } + private boolean docExists = false; @Override public void setNextDocId(int docId) throws IOException { @@ -55,31 +43,7 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - BytesRef value = binaryDocValues.binaryValue(); - if (VectorDataType.BYTE.equals(vectorDataType)) { - float[] vector = new float[value.length]; - int i = 0; - int j = value.offset; - - while (i < value.length) { - vector[i++] = value.bytes[j++]; - } - return vector; - } else if (VectorDataType.FLOAT.equals(vectorDataType)) { - ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); - final float[] vector = vectorSerializer.byteToFloatArray(byteStream); - return vector; - } else { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "Invalid value provided for [%s] field. Supported values are [%s]", - VECTOR_DATA_TYPE_FIELD, - SUPPORTED_VECTOR_DATA_TYPES - ) - ); - } + return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue()); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index be4d7110c4..23b374e9d0 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -11,7 +11,11 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import java.io.ByteArrayInputStream; import java.util.Arrays; import java.util.Locale; import java.util.Objects; @@ -31,6 +35,18 @@ public enum VectorDataType { public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction); } + + @Override + public float[] getVectorFromDocValues(BytesRef binaryValue) { + float[] vector = new float[binaryValue.length]; + int i = 0; + int j = binaryValue.offset; + + while (i < binaryValue.length) { + vector[i++] = binaryValue.bytes[j++]; + } + return vector; + } }, FLOAT("float") { @@ -39,6 +55,13 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); } + @Override + public float[] getVectorFromDocValues(BytesRef binaryValue) { + ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length); + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + return vectorSerializer.byteToFloatArray(byteStream); + } + }; public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values()) @@ -57,6 +80,14 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio */ public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); + /** + * Deserializes float vector from doc values binary value. + * + * @param binaryValue Binary Value of DocValues + * @return float vector deserialized from binary value + */ + public abstract float[] getVectorFromDocValues(BytesRef binaryValue); + /** * Validates if given VectorDataType is in the list of supported data types. * @param vectorDataType VectorDataType diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 5e5fdae2ce..346d4c2388 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -490,7 +490,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - if (VectorDataType.BYTE.equals(vectorDataType)) { + if (VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension); if (!bytesArrayOptional.isPresent()) { @@ -501,7 +501,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx context.doc().add(point); addStoredFieldForVectorField(context, fieldType, name(), point.toString()); - } else if (VectorDataType.FLOAT.equals(vectorDataType)) { + } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); if (!floatsArrayOptional.isPresent()) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index a8dc2d1104..bf331eeb3b 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -96,7 +96,7 @@ public static void validateVectorDimension(int dimension, int vectorSize) { * @param vectorDataType VectorDataType Parameter */ public static void validateVectorDataTypeWithEngine(ParametrizedFieldMapper.Parameter vectorDataType) { - if (VectorDataType.FLOAT.equals(vectorDataType.getValue())) { + if (VectorDataType.FLOAT == vectorDataType.getValue()) { return; } throw new IllegalArgumentException( @@ -123,7 +123,7 @@ public static void validateVectorDataTypeWithKnnIndexSetting( ParametrizedFieldMapper.Parameter vectorDataType ) { - if (VectorDataType.FLOAT.equals(vectorDataType.getValue())) { + if (VectorDataType.FLOAT == vectorDataType.getValue()) { return; } if (knnIndexSetting) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 4b5a73d9ac..94e42ee7c2 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -82,7 +82,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - if (VectorDataType.BYTE.equals(vectorDataType)) { + if (VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension); if (bytesArrayOptional.isEmpty()) { return; @@ -96,7 +96,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); } - } else if (VectorDataType.FLOAT.equals(vectorDataType)) { + } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); if (floatsArrayOptional.isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 397b6f497e..3ec1a99418 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -112,7 +112,7 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec primitiveVector = new float[tmp.size()]; for (int i = 0; i < primitiveVector.length; i++) { float value = tmp.get(i).floatValue(); - if (VectorDataType.BYTE.equals(vectorDataType)) { + if (VectorDataType.BYTE == vectorDataType) { validateByteVectorValue(value); } primitiveVector[i] = value; diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 0a62f5e57e..130c4d8e0c 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -63,7 +63,7 @@ private static float[] toFloat(List inputVector, VectorDataType vectorDa int index = 0; for (final Number val : inputVector) { float floatValue = val.floatValue(); - if (VectorDataType.BYTE.equals(vectorDataType)) { + if (VectorDataType.BYTE == vectorDataType) { validateByteVectorValue(floatValue); } value[index++] = floatValue;