diff --git a/CHANGELOG.md b/CHANGELOG.md index 666c6df07..875e3ecc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Added efficient filtering support for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936)) * Add Indexing Support for Lucene Byte Sized Vector ([#937](https://github.com/opensearch-project/k-NN/pull/937)) * Add Querying Support for Lucene Byte Sized Vector ([#956](https://github.com/opensearch-project/k-NN/pull/956)) +* Add DocValues Support for Lucene Byte Sized Vector ([#953](https://github.com/opensearch-project/k-NN/pull/953)) + ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 5f522e3de..f4caa4f20 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -18,10 +18,12 @@ public class KNNVectorDVLeafFieldData implements LeafFieldData { private final LeafReader reader; private final String fieldName; + private final VectorDataType vectorDataType; - public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName) { + public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName, VectorDataType vectorDataType) { this.reader = reader; this.fieldName = fieldName; + this.vectorDataType = vectorDataType; } @Override @@ -38,7 +40,7 @@ public long ramBytesUsed() { public ScriptDocValues getScriptValues() { try { BinaryDocValues values = DocValues.getBinary(reader, fieldName); - return new KNNVectorScriptDocValues(values, fieldName); + return new KNNVectorScriptDocValues(values, fieldName, vectorDataType); } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java index 367cfae53..deef8bae1 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java @@ -21,10 +21,12 @@ public class KNNVectorIndexFieldData implements IndexFieldData build(IndexFieldDataCache cache, CircuitBreakerService breakerService) { - return new KNNVectorIndexFieldData(name, valuesSourceType); + return new KNNVectorIndexFieldData(name, valuesSourceType, vectorDataType); } } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 0c8240dd4..9f7d52205 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -5,26 +5,22 @@ package org.opensearch.knn.index; +import lombok.Getter; +import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import java.io.ByteArrayInputStream; import java.io.IOException; +@RequiredArgsConstructor public final class KNNVectorScriptDocValues extends ScriptDocValues { private final BinaryDocValues binaryDocValues; private final String fieldName; - private boolean docExists; - - public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName) { - this.binaryDocValues = binaryDocValues; - this.fieldName = fieldName; - } + @Getter + private final VectorDataType vectorDataType; + private boolean docExists = false; @Override public void setNextDocId(int docId) throws IOException { @@ -47,11 +43,7 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - BytesRef value = binaryDocValues.binaryValue(); - ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); - final float[] vector = vectorSerializer.byteToFloatArray(byteStream); - return vector; + return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue()); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index be4d7110c..23b374e9d 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -11,7 +11,11 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import java.io.ByteArrayInputStream; import java.util.Arrays; import java.util.Locale; import java.util.Objects; @@ -31,6 +35,18 @@ public enum VectorDataType { public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction); } + + @Override + public float[] getVectorFromDocValues(BytesRef binaryValue) { + float[] vector = new float[binaryValue.length]; + int i = 0; + int j = binaryValue.offset; + + while (i < binaryValue.length) { + vector[i++] = binaryValue.bytes[j++]; + } + return vector; + } }, FLOAT("float") { @@ -39,6 +55,13 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); } + @Override + public float[] getVectorFromDocValues(BytesRef binaryValue) { + ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length); + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + return vectorSerializer.byteToFloatArray(byteStream); + } + }; public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values()) @@ -57,6 +80,14 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio */ public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); + /** + * Deserializes float vector from doc values binary value. + * + * @param binaryValue Binary Value of DocValues + * @return float vector deserialized from binary value + */ + public abstract float[] getVectorFromDocValues(BytesRef binaryValue); + /** * Validates if given VectorDataType is in the list of supported data types. * @param vectorDataType VectorDataType diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java index d56e09a3f..9e210fcd9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -18,7 +18,7 @@ public class KNNCodecService extends CodecService { private final MapperService mapperService; public KNNCodecService(CodecServiceConfig codecServiceConfig) { - super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger()); + super(codecServiceConfig.getMapperService(), codecServiceConfig.getIndexSettings(), codecServiceConfig.getLogger()); mapperService = codecServiceConfig.getMapperService(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 4b9980e27..346d4c238 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -45,6 +45,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.function.Supplier; @@ -52,10 +53,12 @@ import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension; /** @@ -241,6 +244,12 @@ public KNNVectorFieldMapper build(BuilderContext context) { .build(); return new LuceneFieldMapper(createLuceneFieldMapperInput); } + + // Validates and throws exception if data_type field is set in the index mapping + // using any VectorDataType (other than float, which is default) because other + // VectorDataTypes are only supported for lucene engine. + validateVectorDataTypeWithEngine(vectorDataType); + return new MethodFieldMapper( name, mappedFieldType, @@ -286,9 +295,14 @@ public KNNVectorFieldMapper build(BuilderContext context) { this.efConstruction = LegacyFieldMapper.getEfConstruction(context.indexSettings()); } + // Validates and throws exception if index.knn is set to true in the index settings + // using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper + // and it only supports float VectorDataType + validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType); + return new LegacyFieldMapper( name, - new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue()), + new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue(), vectorDataType.getValue()), multiFieldsBuilder, copyToBuilder, ignoreMalformed, @@ -348,10 +362,6 @@ public Mapper.Builder parse(String name, Map node, ParserCont throw new IllegalArgumentException(String.format("Dimension value missing for vector: %s", name)); } - // Validates and throws exception if data_type field is set in the index mapping - // using any VectorDataType (other than float, which is default) with any engine (except lucene). - validateVectorDataTypeWithEngine(builder.knnMethodContext, builder.vectorDataType); - return builder; } } @@ -363,8 +373,8 @@ public static class KNNVectorFieldType extends MappedFieldType { KNNMethodContext knnMethodContext; VectorDataType vectorDataType; - public KNNVectorFieldType(String name, Map meta, int dimension) { - this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE_FIELD); + public KNNVectorFieldType(String name, Map meta, int dimension, VectorDataType vectorDataType) { + this(name, meta, dimension, null, null, vectorDataType); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { @@ -426,7 +436,7 @@ public Query termQuery(Object value, QueryShardContext context) { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier searchLookup) { failIfNoDocValues(); - return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES); + return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType); } } @@ -480,16 +490,34 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - Optional arrayOptional = getFloatsFromContext(context, dimension); + if (VectorDataType.BYTE == vectorDataType) { + Optional bytesArrayOptional = getBytesFromContext(context, dimension); - if (!arrayOptional.isPresent()) { - return; + if (!bytesArrayOptional.isPresent()) { + return; + } + final byte[] array = bytesArrayOptional.get(); + VectorField point = new VectorField(name(), array, fieldType); + + context.doc().add(point); + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + } else if (VectorDataType.FLOAT == vectorDataType) { + Optional floatsArrayOptional = getFloatsFromContext(context, dimension); + + if (!floatsArrayOptional.isPresent()) { + return; + } + final float[] array = floatsArrayOptional.get(); + VectorField point = new VectorField(name(), array, fieldType); + + context.doc().add(point); + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) + ); } - final float[] array = arrayOptional.get(); - VectorField point = new VectorField(name(), array, fieldType); - context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point.toString()); context.path().remove(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 2784d2a33..bf331eeb3 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -16,13 +16,11 @@ import org.apache.lucene.index.DocValuesType; import org.opensearch.index.mapper.ParametrizedFieldMapper; import org.opensearch.index.mapper.ParseContext; -import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.util.Locale; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; @@ -92,20 +90,43 @@ public static void validateVectorDimension(int dimension, int vectorSize) { /** * Validates and throws exception if data_type field is set in the index mapping - * using any VectorDataType (other than float, which is default) with any engine (except lucene). + * using any VectorDataType (other than float, which is default) because other + * VectorDataTypes are only supported for lucene engine. * - * @param knnMethodContext KNNMethodContext Parameter * @param vectorDataType VectorDataType Parameter */ - public static void validateVectorDataTypeWithEngine( - ParametrizedFieldMapper.Parameter knnMethodContext, + public static void validateVectorDataTypeWithEngine(ParametrizedFieldMapper.Parameter vectorDataType) { + if (VectorDataType.FLOAT == vectorDataType.getValue()) { + return; + } + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue().getValue(), + LUCENE_NAME + ) + ); + } + + /** + * Validates and throws exception if index.knn is set to true in the index settings + * using any VectorDataType (other than float, which is default) because we are using NMSLIB engine + * for LegacyFieldMapper, and it only supports float VectorDataType + * + * @param knnIndexSetting index.knn setting in the index settings + * @param vectorDataType VectorDataType Parameter + */ + public static void validateVectorDataTypeWithKnnIndexSetting( + boolean knnIndexSetting, ParametrizedFieldMapper.Parameter vectorDataType ) { - if (vectorDataType.getValue() == DEFAULT_VECTOR_DATA_TYPE_FIELD) { + + if (VectorDataType.FLOAT == vectorDataType.getValue()) { return; } - if ((knnMethodContext.getValue() == null && KNNEngine.DEFAULT != KNNEngine.LUCENE) - || knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE) { + if (knnIndexSetting) { throw new IllegalArgumentException( String.format( Locale.ROOT, diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 4b5a73d9a..94e42ee7c 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -82,7 +82,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - if (VectorDataType.BYTE.equals(vectorDataType)) { + if (VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension); if (bytesArrayOptional.isEmpty()) { return; @@ -96,7 +96,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); } - } else if (VectorDataType.FLOAT.equals(vectorDataType)) { + } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); if (floatsArrayOptional.isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index ca7526dcb..16bf6e204 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -54,7 +54,11 @@ public L2(Object query, MappedFieldType fieldType) { throw new IllegalArgumentException("Incompatible field_type for l2 space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); } @@ -81,7 +85,11 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { throw new IllegalArgumentException("Incompatible field_type for cosine space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery); this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } @@ -159,7 +167,11 @@ public L1(Object query, MappedFieldType fieldType) { throw new IllegalArgumentException("Incompatible field_type for l1 space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); } @@ -185,7 +197,11 @@ public LInf(Object query, MappedFieldType fieldType) { throw new IllegalArgumentException("Incompatible field_type for l-inf space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); } @@ -213,7 +229,11 @@ public InnerProd(Object query, MappedFieldType fieldType) { ); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 6f68d16b6..3ec1a9941 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.index.mapper.BinaryFieldMapper; @@ -16,6 +17,7 @@ import java.util.Base64; import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; public class KNNScoringSpaceUtil { @@ -85,8 +87,8 @@ public static BigInteger parseToBigInteger(Object object) { * @param expectedDimensions int representing the expected dimension of this array. * @return float[] of the object */ - public static float[] parseToFloatArray(Object object, int expectedDimensions) { - float[] floatArray = convertVectorToPrimitive(object); + public static float[] parseToFloatArray(Object object, int expectedDimensions, VectorDataType vectorDataType) { + float[] floatArray = convertVectorToPrimitive(object, vectorDataType); if (expectedDimensions != floatArray.length) { KNNCounter.SCRIPT_QUERY_ERRORS.increment(); throw new IllegalStateException( @@ -103,13 +105,17 @@ public static float[] parseToFloatArray(Object object, int expectedDimensions) { * @return Float array representing the vector */ @SuppressWarnings("unchecked") - public static float[] convertVectorToPrimitive(Object vector) { + public static float[] convertVectorToPrimitive(Object vector, VectorDataType vectorDataType) { float[] primitiveVector = null; if (vector != null) { - final ArrayList tmp = (ArrayList) vector; + final ArrayList tmp = (ArrayList) vector; primitiveVector = new float[tmp.size()]; for (int i = 0; i < primitiveVector.length; i++) { - primitiveVector[i] = tmp.get(i).floatValue(); + float value = tmp.get(i).floatValue(); + if (VectorDataType.BYTE == vectorDataType) { + validateByteVectorValue(value); + } + primitiveVector[i] = value; } } return primitiveVector; diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 90468c2e7..130c4d8e0 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -8,11 +8,14 @@ import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.knn.index.VectorDataType; import java.math.BigInteger; import java.util.List; import java.util.Objects; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; + public class KNNScoringUtil { private static Logger logger = LogManager.getLogger(KNNScoringUtil.class); @@ -54,12 +57,16 @@ public static float l2Squared(float[] queryVector, float[] inputVector) { return squaredDistance; } - private static float[] toFloat(List inputVector) { + private static float[] toFloat(List inputVector, VectorDataType vectorDataType) { Objects.requireNonNull(inputVector); float[] value = new float[inputVector.size()]; int index = 0; for (final Number val : inputVector) { - value[index++] = val.floatValue(); + float floatValue = val.floatValue(); + if (VectorDataType.BYTE == vectorDataType) { + validateByteVectorValue(floatValue); + } + value[index++] = floatValue; } return value; } @@ -81,7 +88,7 @@ private static float[] toFloat(List inputVector) { * @return L2 score */ public static float l2Squared(List queryVector, KNNVectorScriptDocValues docValues) { - return l2Squared(toFloat(queryVector), docValues.getValue()); + return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -127,7 +134,11 @@ public static float cosinesimilOptimized(float[] queryVector, float[] inputVecto * @return cosine score */ public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { - return cosinesimilOptimized(toFloat(queryVector), docValues.getValue(), queryVectorMagnitude.floatValue()); + return cosinesimilOptimized( + toFloat(queryVector, docValues.getVectorDataType()), + docValues.getValue(), + queryVectorMagnitude.floatValue() + ); } /** @@ -172,7 +183,7 @@ public static float cosinesimil(float[] queryVector, float[] inputVector) { * @return cosine score */ public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { - return cosinesimil(toFloat(queryVector), docValues.getValue()); + return cosinesimil(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -232,7 +243,7 @@ public static float l1Norm(float[] queryVector, float[] inputVector) { * @return L1 score */ public static float l1Norm(List queryVector, KNNVectorScriptDocValues docValues) { - return l1Norm(toFloat(queryVector), docValues.getValue()); + return l1Norm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -270,7 +281,7 @@ public static float lInfNorm(float[] queryVector, float[] inputVector) { * @return L-inf score */ public static float lInfNorm(List queryVector, KNNVectorScriptDocValues docValues) { - return lInfNorm(toFloat(queryVector), docValues.getValue()); + return lInfNorm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -307,6 +318,6 @@ public static float innerProduct(float[] queryVector, float[] inputVector) { * @return inner product score */ public static float innerProduct(List queryVector, KNNVectorScriptDocValues docValues) { - return innerProduct(toFloat(queryVector), docValues.getValue()); + return innerProduct(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index 8bda1aefc..cbe11dd6b 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -62,30 +62,38 @@ public void tearDown() throws Exception { } public void testGetScriptValues() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), MOCK_INDEX_FIELD_NAME); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( + leafReaderContext.reader(), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); ScriptDocValues scriptValues = leafFieldData.getScriptValues(); assertNotNull(scriptValues); assertTrue(scriptValues instanceof KNNVectorScriptDocValues); } public void testGetScriptValuesWrongFieldName() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid"); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid", VectorDataType.FLOAT); ScriptDocValues scriptValues = leafFieldData.getScriptValues(); assertNotNull(scriptValues); } public void testGetScriptValuesWrongFieldType() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), MOCK_NUMERIC_INDEX_FIELD_NAME); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( + leafReaderContext.reader(), + MOCK_NUMERIC_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); expectThrows(IllegalStateException.class, () -> leafFieldData.getScriptValues()); } public void testRamBytesUsed() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), ""); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "", VectorDataType.FLOAT); assertEquals(0, leafFieldData.ramBytesUsed()); } public void testGetBytesValues() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), ""); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "", VectorDataType.FLOAT); expectThrows(UnsupportedOperationException.class, () -> leafFieldData.getBytesValues()); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java index 8523c4146..ee57cb190 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java @@ -27,7 +27,7 @@ public class KNNVectorIndexFieldDataTests extends KNNTestCase { @Before public void setUp() throws Exception { super.setUp(); - indexFieldData = new KNNVectorIndexFieldData(MOCK_INDEX_FIELD_NAME, CoreValuesSourceType.BYTES); + indexFieldData = new KNNVectorIndexFieldData(MOCK_INDEX_FIELD_NAME, CoreValuesSourceType.BYTES, VectorDataType.FLOAT); directory = newDirectory(); createEmptyDocument(directory); } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 876117940..a0df3ce64 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -37,7 +37,8 @@ public void setUp() throws Exception { LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); scriptDocValues = new KNNVectorScriptDocValues( leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), - MOCK_INDEX_FIELD_NAME + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT ); } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 711160cf9..43976b901 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -8,19 +8,28 @@ import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.junit.After; +import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.rest.RestStatus; +import org.opensearch.script.Script; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; @@ -37,6 +46,7 @@ public class VectorDataTypeIT extends KNNRestTestCase { private static final String KNN_VECTOR_TYPE = "knn_vector"; private static final int EF_CONSTRUCTION = 128; private static final int M = 16; + private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder(); @After @SneakyThrows @@ -248,6 +258,172 @@ public void testByteVectorDataTypeWithNmslibEngine() { ); } + @SneakyThrows + public void testByteVectorDataTypeWithLegacyFieldMapperKnnIndexSetting() { + // Create an index with byte vector data_type and index.knn as true without setting KnnMethodContext, + // which should throw an exception because the LegacyFieldMapper will use NMSLIB engine and byte data_type + // is not supported for NMSLIB engine. + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 2) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()) + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + + ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + LUCENE_NAME + ) + ) + ); + + } + + public void testDocValuesWithByteVectorDataTypeLuceneEngine() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testDocValuesWithFloatVectorDataTypeLuceneEngine() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2ScriptScoreWithByteVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2ScriptScoreWithFloatVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + Float[] queryVector = { 1.0f, 1.0f }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2PainlessScriptingWithByteVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + String source = String.format("1/(1 + l2Squared([1, 1], doc['%s']))", FIELD_NAME); + Request request = constructScriptScoreContextSearchRequest( + INDEX_NAME, + MATCH_ALL_QUERY_BUILDER, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + 4 + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2PainlessScriptingWithFloatVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); + Request request = constructScriptScoreContextSearchRequest( + INDEX_NAME, + MATCH_ALL_QUERY_BUILDER, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + 4 + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testKNNScriptScoreWithInvalidVectorDataType() { + // Set an invalid value for data_type field while creating the index for script scoring which should throw an exception + ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndexMappingForScripting(2, "invalid_data_type")); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES + ) + ) + ); + } + + public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception { + // Create an index with byte vector data_type, add docs and run a scoring script query with decimal values + // which should throw exception + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + + Byte[] f1 = { 6, 6 }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); + + Byte[] f2 = { 2, 2 }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); + + // Construct Search Request with query vector having decimal values + Float[] queryVector = { 10.67f, 19.78f }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ) + ); + } + @SneakyThrows private void ingestL2ByteTestData() { Byte[] b1 = { 6, 6 }; @@ -312,6 +488,40 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac createKnnIndex(INDEX_NAME, mapping); } + private void createKnnIndexMappingForScripting(int dimension, String vectorDataType) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, dimension) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createKnnIndex(INDEX_NAME, Settings.EMPTY, mapping); + } + + @SneakyThrows + private Request createScriptQueryRequest(Byte[] queryVector, String spaceType, QueryBuilder qb) { + Map params = new HashMap<>(); + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType); + return constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + } + + @SneakyThrows + private Request createScriptQueryRequest(Float[] queryVector, String spaceType, QueryBuilder qb) { + Map params = new HashMap<>(); + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType); + return constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + } + @SneakyThrows private void validateL2SearchResults(Response response) { diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java new file mode 100644 index 000000000..4423c85d8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.SneakyThrows; +import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.junit.Assert; +import org.opensearch.knn.KNNTestCase; + +import java.io.IOException; + +public class VectorDataTypeTests extends KNNTestCase { + + private static final String MOCK_FLOAT_INDEX_FIELD_NAME = "test-float-index-field-name"; + private static final String MOCK_BYTE_INDEX_FIELD_NAME = "test-byte-index-field-name"; + private static final float[] SAMPLE_FLOAT_VECTOR_DATA = new float[] { 10.0f, 25.0f }; + private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 10, 25 }; + private Directory directory; + private DirectoryReader reader; + + @SneakyThrows + public void testGetDocValuesWithFloatVectorDataType() { + KNNVectorScriptDocValues scriptDocValues = getKNNFloatVectorScriptDocValues(); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testGetDocValuesWithByteVectorDataType() { + KNNVectorScriptDocValues scriptDocValues = getKNNByteVectorScriptDocValues(); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + + reader.close(); + directory.close(); + } + + @SneakyThrows + private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { + directory = newDirectory(); + createKNNFloatVectorDocument(directory); + reader = DirectoryReader.open(directory); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + return new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), + VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + } + + @SneakyThrows + private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { + directory = newDirectory(); + createKNNByteVectorDocument(directory); + reader = DirectoryReader.open(directory); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + return new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), + VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, + VectorDataType.BYTE + ); + } + + private void createKNNFloatVectorDocument(Directory directory) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + Document knnDocument = new Document(); + knnDocument.add( + new BinaryDocValuesField( + MOCK_FLOAT_INDEX_FIELD_NAME, + new VectorField(MOCK_FLOAT_INDEX_FIELD_NAME, SAMPLE_FLOAT_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); + } + + private void createKNNByteVectorDocument(Directory directory) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + Document knnDocument = new Document(); + knnDocument.add( + new BinaryDocValuesField( + MOCK_BYTE_INDEX_FIELD_NAME, + new VectorField(MOCK_BYTE_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 92fd56e45..b5bc4b95f 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.plugin.script; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; @@ -64,11 +65,14 @@ public void testParseKNNVectorQuery() { KNNVectorFieldMapper.KNNVectorFieldType fieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(fieldType.getDimension()).thenReturn(3); - assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3), 0.1f); + assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3, VectorDataType.FLOAT), 0.1f); - expectThrows(IllegalStateException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 4)); + expectThrows( + IllegalStateException.class, + () -> KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 4, VectorDataType.FLOAT) + ); String invalidObject = "invalidObject"; - expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3)); + expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 49add790e..4a2bb7254 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -7,6 +7,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; @@ -81,7 +82,7 @@ public void testGetInvalidVectorMagnitudeSquared() { public void testConvertInvalidVectorToPrimitive() { float[] primitiveVector = null; - assertEquals(primitiveVector, KNNScoringSpaceUtil.convertVectorToPrimitive(primitiveVector)); + assertEquals(primitiveVector, KNNScoringSpaceUtil.convertVectorToPrimitive(primitiveVector, VectorDataType.FLOAT)); } public void testCosineSimilQueryVectorZeroMagnitude() { @@ -243,7 +244,11 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues(leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName); + scriptDocValues = new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(fieldName), + fieldName, + VectorDataType.FLOAT + ); } return scriptDocValues; }