Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DocValues Support for Lucene Byte Sized Vector #953

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Added efficient filtering support for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936))
* Add Indexing Support for Lucene Byte Sized Vector ([#937](https://github.com/opensearch-project/k-NN/pull/937))
* Add Querying Support for Lucene Byte Sized Vector ([#956](https://github.com/opensearch-project/k-NN/pull/956))
* Add DocValues Support for Lucene Byte Sized Vector ([#953](https://github.com/opensearch-project/k-NN/pull/953))

### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ public class KNNVectorDVLeafFieldData implements LeafFieldData {

private final LeafReader reader;
private final String fieldName;
private final VectorDataType vectorDataType;

public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName) {
public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName, VectorDataType vectorDataType) {
this.reader = reader;
this.fieldName = fieldName;
this.vectorDataType = vectorDataType;
}

@Override
Expand All @@ -38,7 +40,7 @@ public long ramBytesUsed() {
public ScriptDocValues<float[]> getScriptValues() {
try {
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ public class KNNVectorIndexFieldData implements IndexFieldData<KNNVectorDVLeafFi

private final String fieldName;
private final ValuesSourceType valuesSourceType;
private final VectorDataType vectorDataType;

public KNNVectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType) {
public KNNVectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType, VectorDataType vectorDataType) {
this.fieldName = fieldName;
this.valuesSourceType = valuesSourceType;
this.vectorDataType = vectorDataType;
}

@Override
Expand All @@ -39,7 +41,7 @@ public ValuesSourceType getValuesSourceType() {

@Override
public KNNVectorDVLeafFieldData load(LeafReaderContext context) {
return new KNNVectorDVLeafFieldData(context.reader(), fieldName);
return new KNNVectorDVLeafFieldData(context.reader(), fieldName, vectorDataType);
}

@Override
Expand Down Expand Up @@ -70,15 +72,17 @@ public static class Builder implements IndexFieldData.Builder {

private final String name;
private final ValuesSourceType valuesSourceType;
private final VectorDataType vectorDataType;

public Builder(String name, ValuesSourceType valuesSourceType) {
public Builder(String name, ValuesSourceType valuesSourceType, VectorDataType vectorDataType) {
this.name = name;
this.valuesSourceType = valuesSourceType;
this.vectorDataType = vectorDataType;
}

@Override
public IndexFieldData<?> build(IndexFieldDataCache cache, CircuitBreakerService breakerService) {
return new KNNVectorIndexFieldData(name, valuesSourceType);
return new KNNVectorIndexFieldData(name, valuesSourceType, vectorDataType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,22 @@

package org.opensearch.knn.index;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final String fieldName;
private boolean docExists;

public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName) {
this.binaryDocValues = binaryDocValues;
this.fieldName = fieldName;
}
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
Expand All @@ -47,11 +43,7 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
BytesRef value = binaryDocValues.binaryValue();
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
return vector;
return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
Expand All @@ -31,6 +35,18 @@ public enum VectorDataType {
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a float[] for byte type? Shouldnt it be an int[]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using this method for scripting functions to retrieve the vector from docValues to calculate the score. SO, it doesn't make any difference if we return it as int[] or float[]. If we return it as int[] then again we need to do some method overloading and add methods for the spacetype functions in ScoringUtils to accept int[].

int i = 0;
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
}
},
FLOAT("float") {

Expand All @@ -39,6 +55,13 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
return vectorSerializer.byteToFloatArray(byteStream);
}

};

public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values())
Expand All @@ -57,6 +80,14 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
*/
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

/**
* Deserializes float vector from doc values binary value.
*
* @param binaryValue Binary Value of DocValues
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromDocValues(BytesRef binaryValue);

/**
* Validates if given VectorDataType is in the list of supported data types.
* @param vectorDataType VectorDataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class KNNCodecService extends CodecService {
private final MapperService mapperService;

public KNNCodecService(CodecServiceConfig codecServiceConfig) {
super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger());
super(codecServiceConfig.getMapperService(), codecServiceConfig.getIndexSettings(), codecServiceConfig.getLogger());
mapperService = codecServiceConfig.getMapperService();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,20 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do changes in this file related to adding docvalues support for lucene byte sized vectors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, here we are trying to add some extra validation checks wrt knnIndex setting and knn engine. Also, to ingest doc values for byte vectors using script scoring

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension;

/**
Expand Down Expand Up @@ -241,6 +244,12 @@ public KNNVectorFieldMapper build(BuilderContext context) {
.build();
return new LuceneFieldMapper(createLuceneFieldMapperInput);
}

// Validates and throws exception if data_type field is set in the index mapping
// using any VectorDataType (other than float, which is default) because other
// VectorDataTypes are only supported for lucene engine.
validateVectorDataTypeWithEngine(vectorDataType);

return new MethodFieldMapper(
name,
mappedFieldType,
Expand Down Expand Up @@ -286,9 +295,14 @@ public KNNVectorFieldMapper build(BuilderContext context) {
this.efConstruction = LegacyFieldMapper.getEfConstruction(context.indexSettings());
}

// Validates and throws exception if index.knn is set to true in the index settings
// using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper
// and it only supports float VectorDataType
validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType);

return new LegacyFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue()),
new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue(), vectorDataType.getValue()),
multiFieldsBuilder,
copyToBuilder,
ignoreMalformed,
Expand Down Expand Up @@ -348,10 +362,6 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
throw new IllegalArgumentException(String.format("Dimension value missing for vector: %s", name));
}

// Validates and throws exception if data_type field is set in the index mapping
// using any VectorDataType (other than float, which is default) with any engine (except lucene).
validateVectorDataTypeWithEngine(builder.knnMethodContext, builder.vectorDataType);

return builder;
}
}
Expand All @@ -363,8 +373,8 @@ public static class KNNVectorFieldType extends MappedFieldType {
KNNMethodContext knnMethodContext;
VectorDataType vectorDataType;

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension) {
this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE_FIELD);
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, VectorDataType vectorDataType) {
this(name, meta, dimension, null, null, vectorDataType);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext) {
Expand Down Expand Up @@ -426,7 +436,7 @@ public Query termQuery(Object value, QueryShardContext context) {
@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
failIfNoDocValues();
return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES);
return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType);
}
}

Expand Down Expand Up @@ -480,16 +490,34 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

Optional<float[]> arrayOptional = getFloatsFromContext(context, dimension);
if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);

if (!arrayOptional.isPresent()) {
return;
if (!bytesArrayOptional.isPresent()) {
return;
}
final byte[] array = bytesArrayOptional.get();
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (!floatsArrayOptional.isPresent()) {
return;
}
final float[] array = floatsArrayOptional.get();
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
);
}
final float[] array = arrayOptional.get();
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
context.path().remove();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
import org.apache.lucene.index.DocValuesType;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
Expand Down Expand Up @@ -92,20 +90,43 @@ public static void validateVectorDimension(int dimension, int vectorSize) {

/**
* Validates and throws exception if data_type field is set in the index mapping
* using any VectorDataType (other than float, which is default) with any engine (except lucene).
* using any VectorDataType (other than float, which is default) because other
* VectorDataTypes are only supported for lucene engine.
*
* @param knnMethodContext KNNMethodContext Parameter
* @param vectorDataType VectorDataType Parameter
*/
public static void validateVectorDataTypeWithEngine(
ParametrizedFieldMapper.Parameter<KNNMethodContext> knnMethodContext,
public static void validateVectorDataTypeWithEngine(ParametrizedFieldMapper.Parameter<VectorDataType> vectorDataType) {
if (VectorDataType.FLOAT == vectorDataType.getValue()) {
return;
}
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field with value [%s] is only supported for [%s] engine",
VECTOR_DATA_TYPE_FIELD,
vectorDataType.getValue().getValue(),
LUCENE_NAME
)
);
}

/**
* Validates and throws exception if index.knn is set to true in the index settings
* using any VectorDataType (other than float, which is default) because we are using NMSLIB engine
* for LegacyFieldMapper, and it only supports float VectorDataType
*
* @param knnIndexSetting index.knn setting in the index settings
* @param vectorDataType VectorDataType Parameter
*/
public static void validateVectorDataTypeWithKnnIndexSetting(
boolean knnIndexSetting,
ParametrizedFieldMapper.Parameter<VectorDataType> vectorDataType
) {
if (vectorDataType.getValue() == DEFAULT_VECTOR_DATA_TYPE_FIELD) {

if (VectorDataType.FLOAT == vectorDataType.getValue()) {
return;
}
if ((knnMethodContext.getValue() == null && KNNEngine.DEFAULT != KNNEngine.LUCENE)
|| knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE) {
if (knnIndexSetting) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

if (VectorDataType.BYTE.equals(vectorDataType)) {
if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);
if (bytesArrayOptional.isEmpty()) {
return;
Expand All @@ -96,7 +96,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else if (VectorDataType.FLOAT.equals(vectorDataType)) {
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (floatsArrayOptional.isEmpty()) {
Expand Down
Loading