Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 7, 2023
1 parent 1fcba06 commit 4c24661
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,21 @@
package org.opensearch.knn.index;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final String fieldName;
@Getter
private final VectorDataType vectorDataType;
private boolean docExists;

public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName, VectorDataType vectorDataType) {
this.binaryDocValues = binaryDocValues;
this.fieldName = fieldName;
this.vectorDataType = vectorDataType;
}
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
Expand All @@ -55,31 +43,7 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
BytesRef value = binaryDocValues.binaryValue();
if (VectorDataType.BYTE.equals(vectorDataType)) {
float[] vector = new float[value.length];
int i = 0;
int j = value.offset;

while (i < value.length) {
vector[i++] = value.bytes[j++];
}
return vector;
} else if (VectorDataType.FLOAT.equals(vectorDataType)) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
return vector;
} else {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Invalid value provided for [%s] field. Supported values are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
);
}
return vectorDataType.getValue(binaryDocValues.binaryValue());
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
Expand All @@ -31,6 +35,18 @@ public enum VectorDataType {
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getValue(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
}
},
FLOAT("float") {

Expand All @@ -39,6 +55,13 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getValue(BytesRef binaryValue) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
return vectorSerializer.byteToFloatArray(byteStream);
}

};

public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values())
Expand All @@ -57,6 +80,14 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
*/
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

/**
* Deserializes float vector from doc values binary value.
*
* @param binaryValue Binary Value of DocValues
* @return float vector deserialized from binary value
*/
public abstract float[] getValue(BytesRef binaryValue);

/**
* Validates if given VectorDataType is in the list of supported data types.
* @param vectorDataType VectorDataType
Expand Down

0 comments on commit 4c24661

Please sign in to comment.