Skip to content

Commit

Permalink
Add DocValues Support for Lucene Byte Sized Vector (#953)
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 12, 2023
1 parent 77db3ab commit c7565ac
Show file tree
Hide file tree
Showing 18 changed files with 532 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ public class KNNVectorDVLeafFieldData implements LeafFieldData {

private final LeafReader reader;
private final String fieldName;
private final VectorDataType vectorDataType;

public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName) {
public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName, VectorDataType vectorDataType) {
this.reader = reader;
this.fieldName = fieldName;
this.vectorDataType = vectorDataType;
}

@Override
Expand All @@ -38,7 +40,7 @@ public long ramBytesUsed() {
public ScriptDocValues<float[]> getScriptValues() {
try {
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ public class KNNVectorIndexFieldData implements IndexFieldData<KNNVectorDVLeafFi

private final String fieldName;
private final ValuesSourceType valuesSourceType;
private final VectorDataType vectorDataType;

public KNNVectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType) {
public KNNVectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType, VectorDataType vectorDataType) {
this.fieldName = fieldName;
this.valuesSourceType = valuesSourceType;
this.vectorDataType = vectorDataType;
}

@Override
Expand All @@ -39,7 +41,7 @@ public ValuesSourceType getValuesSourceType() {

@Override
public KNNVectorDVLeafFieldData load(LeafReaderContext context) {
return new KNNVectorDVLeafFieldData(context.reader(), fieldName);
return new KNNVectorDVLeafFieldData(context.reader(), fieldName, vectorDataType);
}

@Override
Expand Down Expand Up @@ -70,15 +72,17 @@ public static class Builder implements IndexFieldData.Builder {

private final String name;
private final ValuesSourceType valuesSourceType;
private final VectorDataType vectorDataType;

public Builder(String name, ValuesSourceType valuesSourceType) {
public Builder(String name, ValuesSourceType valuesSourceType, VectorDataType vectorDataType) {
this.name = name;
this.valuesSourceType = valuesSourceType;
this.vectorDataType = vectorDataType;
}

@Override
public IndexFieldData<?> build(IndexFieldDataCache cache, CircuitBreakerService breakerService) {
return new KNNVectorIndexFieldData(name, valuesSourceType);
return new KNNVectorIndexFieldData(name, valuesSourceType, vectorDataType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,22 @@

package org.opensearch.knn.index;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final String fieldName;
private boolean docExists;

public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName) {
this.binaryDocValues = binaryDocValues;
this.fieldName = fieldName;
}
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
Expand All @@ -47,11 +43,7 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
BytesRef value = binaryDocValues.binaryValue();
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
return vector;
return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
Expand All @@ -31,6 +35,18 @@ public enum VectorDataType {
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
}
},
FLOAT("float") {

Expand All @@ -39,6 +55,13 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
return vectorSerializer.byteToFloatArray(byteStream);
}

};

public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values())
Expand All @@ -57,6 +80,14 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
*/
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

/**
* Deserializes float vector from doc values binary value.
*
* @param binaryValue Binary Value of DocValues
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromDocValues(BytesRef binaryValue);

/**
* Validates if given VectorDataType is in the list of supported data types.
* @param vectorDataType VectorDataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,20 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension;

/**
Expand Down Expand Up @@ -241,6 +244,12 @@ public KNNVectorFieldMapper build(BuilderContext context) {
.build();
return new LuceneFieldMapper(createLuceneFieldMapperInput);
}

// Validates and throws exception if data_type field is set in the index mapping
// using any VectorDataType (other than float, which is default) because other
// VectorDataTypes are only supported for lucene engine.
validateVectorDataTypeWithEngine(vectorDataType);

return new MethodFieldMapper(
name,
mappedFieldType,
Expand Down Expand Up @@ -286,9 +295,14 @@ public KNNVectorFieldMapper build(BuilderContext context) {
this.efConstruction = LegacyFieldMapper.getEfConstruction(context.indexSettings());
}

// Validates and throws exception if index.knn is set to true in the index settings
// using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper
// and it only supports float VectorDataType
validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType);

return new LegacyFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue()),
new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue(), vectorDataType.getValue()),
multiFieldsBuilder,
copyToBuilder,
ignoreMalformed,
Expand Down Expand Up @@ -348,10 +362,6 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
throw new IllegalArgumentException(String.format("Dimension value missing for vector: %s", name));
}

// Validates and throws exception if data_type field is set in the index mapping
// using any VectorDataType (other than float, which is default) with any engine (except lucene).
validateVectorDataTypeWithEngine(builder.knnMethodContext, builder.vectorDataType);

return builder;
}
}
Expand All @@ -363,8 +373,8 @@ public static class KNNVectorFieldType extends MappedFieldType {
KNNMethodContext knnMethodContext;
VectorDataType vectorDataType;

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension) {
this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE_FIELD);
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, VectorDataType vectorDataType) {
this(name, meta, dimension, null, null, vectorDataType);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext) {
Expand Down Expand Up @@ -426,7 +436,7 @@ public Query termQuery(Object value, QueryShardContext context) {
@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
failIfNoDocValues();
return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES);
return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType);
}
}

Expand Down Expand Up @@ -480,16 +490,34 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

Optional<float[]> arrayOptional = getFloatsFromContext(context, dimension);
if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);

if (!arrayOptional.isPresent()) {
return;
if (!bytesArrayOptional.isPresent()) {
return;
}
final byte[] array = bytesArrayOptional.get();
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (!floatsArrayOptional.isPresent()) {
return;
}
final float[] array = floatsArrayOptional.get();
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
);
}
final float[] array = arrayOptional.get();
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
context.path().remove();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
import org.apache.lucene.index.DocValuesType;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
Expand Down Expand Up @@ -92,20 +90,43 @@ public static void validateVectorDimension(int dimension, int vectorSize) {

/**
* Validates and throws exception if data_type field is set in the index mapping
* using any VectorDataType (other than float, which is default) with any engine (except lucene).
* using any VectorDataType (other than float, which is default) because other
* VectorDataTypes are only supported for lucene engine.
*
* @param knnMethodContext KNNMethodContext Parameter
* @param vectorDataType VectorDataType Parameter
*/
public static void validateVectorDataTypeWithEngine(
ParametrizedFieldMapper.Parameter<KNNMethodContext> knnMethodContext,
public static void validateVectorDataTypeWithEngine(ParametrizedFieldMapper.Parameter<VectorDataType> vectorDataType) {
if (VectorDataType.FLOAT == vectorDataType.getValue()) {
return;
}
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field with value [%s] is only supported for [%s] engine",
VECTOR_DATA_TYPE_FIELD,
vectorDataType.getValue().getValue(),
LUCENE_NAME
)
);
}

/**
* Validates and throws exception if index.knn is set to true in the index settings
* using any VectorDataType (other than float, which is default) because we are using NMSLIB engine
* for LegacyFieldMapper, and it only supports float VectorDataType
*
* @param knnIndexSetting index.knn setting in the index settings
* @param vectorDataType VectorDataType Parameter
*/
public static void validateVectorDataTypeWithKnnIndexSetting(
boolean knnIndexSetting,
ParametrizedFieldMapper.Parameter<VectorDataType> vectorDataType
) {
if (vectorDataType.getValue() == DEFAULT_VECTOR_DATA_TYPE_FIELD) {

if (VectorDataType.FLOAT == vectorDataType.getValue()) {
return;
}
if ((knnMethodContext.getValue() == null && KNNEngine.DEFAULT != KNNEngine.LUCENE)
|| knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE) {
if (knnIndexSetting) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

if (VectorDataType.BYTE.equals(vectorDataType)) {
if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);
if (bytesArrayOptional.isEmpty()) {
return;
Expand All @@ -96,7 +96,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else if (VectorDataType.FLOAT.equals(vectorDataType)) {
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (floatsArrayOptional.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ protected Query doToQuery(QueryShardContext context) {
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
&& filter != null
&& !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
&& filter != null
&& !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
}

Expand Down
Loading

0 comments on commit c7565ac

Please sign in to comment.