Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 11, 2023
1 parent 91c01e0 commit 598ab43
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,21 @@
package org.opensearch.knn.index;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final String fieldName;
@Getter
private final VectorDataType vectorDataType;
private boolean docExists;

public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName, VectorDataType vectorDataType) {
this.binaryDocValues = binaryDocValues;
this.fieldName = fieldName;
this.vectorDataType = vectorDataType;
}
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
Expand All @@ -55,31 +43,7 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
BytesRef value = binaryDocValues.binaryValue();
if (VectorDataType.BYTE.equals(vectorDataType)) {
float[] vector = new float[value.length];
int i = 0;
int j = value.offset;

while (i < value.length) {
vector[i++] = value.bytes[j++];
}
return vector;
} else if (VectorDataType.FLOAT.equals(vectorDataType)) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
return vector;
} else {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Invalid value provided for [%s] field. Supported values are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
);
}
return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
Expand All @@ -31,6 +35,18 @@ public enum VectorDataType {
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
}
},
FLOAT("float") {

Expand All @@ -39,6 +55,13 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
return vectorSerializer.byteToFloatArray(byteStream);
}

};

public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values())
Expand All @@ -57,6 +80,14 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
*/
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

/**
* Deserializes float vector from doc values binary value.
*
* @param binaryValue Binary Value of DocValues
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromDocValues(BytesRef binaryValue);

/**
* Validates if given VectorDataType is in the list of supported data types.
* @param vectorDataType VectorDataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

if (VectorDataType.BYTE.equals(vectorDataType)) {
if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);

if (!bytesArrayOptional.isPresent()) {
Expand All @@ -501,7 +501,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT.equals(vectorDataType)) {
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (!floatsArrayOptional.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public static void validateVectorDimension(int dimension, int vectorSize) {
* @param vectorDataType VectorDataType Parameter
*/
public static void validateVectorDataTypeWithEngine(ParametrizedFieldMapper.Parameter<VectorDataType> vectorDataType) {
if (VectorDataType.FLOAT.equals(vectorDataType.getValue())) {
if (VectorDataType.FLOAT == vectorDataType.getValue()) {
return;
}
throw new IllegalArgumentException(
Expand All @@ -123,7 +123,7 @@ public static void validateVectorDataTypeWithKnnIndexSetting(
ParametrizedFieldMapper.Parameter<VectorDataType> vectorDataType
) {

if (VectorDataType.FLOAT.equals(vectorDataType.getValue())) {
if (VectorDataType.FLOAT == vectorDataType.getValue()) {
return;
}
if (knnIndexSetting) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

if (VectorDataType.BYTE.equals(vectorDataType)) {
if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);
if (bytesArrayOptional.isEmpty()) {
return;
Expand All @@ -96,7 +96,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else if (VectorDataType.FLOAT.equals(vectorDataType)) {
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (floatsArrayOptional.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec
primitiveVector = new float[tmp.size()];
for (int i = 0; i < primitiveVector.length; i++) {
float value = tmp.get(i).floatValue();
if (VectorDataType.BYTE.equals(vectorDataType)) {
if (VectorDataType.BYTE == vectorDataType) {
validateByteVectorValue(value);
}
primitiveVector[i] = value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private static float[] toFloat(List<Number> inputVector, VectorDataType vectorDa
int index = 0;
for (final Number val : inputVector) {
float floatValue = val.floatValue();
if (VectorDataType.BYTE.equals(vectorDataType)) {
if (VectorDataType.BYTE == vectorDataType) {
validateByteVectorValue(floatValue);
}
value[index++] = floatValue;
Expand Down

0 comments on commit 598ab43

Please sign in to comment.