Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to ByteBuffer for vector encoding. #45936

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.ParseContext;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -181,9 +182,11 @@ public void parse(ParseContext context) throws IOException {

// encode array of floats as array of integers and store into buf
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
byte[] buf = indexCreatedVersion.onOrAfter(Version.V_7_4_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES];
int offset = 0;
byte[] bytes = indexCreatedVersion.onOrAfter(Version.V_7_4_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES];

ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
double dotProduct = 0f;

int dim = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
if (dim++ >= dims) {
Expand All @@ -192,28 +195,22 @@ public void parse(ParseContext context) throws IOException {
}
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
float value = context.parser().floatValue(true);
int intValue = Float.floatToIntBits(value);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;

byteBuffer.putFloat(value);
dotProduct += value * value;
}
if (dim != dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has number of dimensions [" + dim +
"] less than defined in the mapping [" + dims +"]");
}

if (indexCreatedVersion.onOrAfter(Version.V_7_4_0)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8);
buf[offset++] = (byte) vectorMagnitudeIntValue;
byteBuffer.putFloat(vectorMagnitude);
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf));
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(bytes));
if (context.doc().getByKey(fieldType().name()) != null) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
"] doesn't not support indexing multiple values for the same field in the same document");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.apache.lucene.util.InPlaceMergeSorter;
import org.elasticsearch.Version;

import java.nio.ByteBuffer;

// static utility functions for encoding and decoding dense_vector and sparse_vector fields
public final class VectorEncoderDecoder {
static final byte INT_BYTES = 4;
Expand All @@ -34,36 +36,31 @@ public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, floa

// 2. Encode dimensions
// as each dimension is a positive value that doesn't exceed 65535, 2 bytes is enough for encoding it
byte[] buf = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] :
byte[] bytes = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] :
new byte[dimCount * (INT_BYTES + SHORT_BYTES)];
int offset = 0;
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);

for (int dim = 0; dim < dimCount; dim++) {
buf[offset++] = (byte) (dims[dim] >> 8);
buf[offset++] = (byte) dims[dim];
int dimValue = dims[dim];
byteBuffer.put((byte) (dimValue >> 8));
byteBuffer.put((byte) dimValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if byteBuffer.putShort((short) dimValue)) would be better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it would be nice to keep this explicit shifting to match the logic for decoding. We aren't able to use getShort in decoding, because we want to interpret the two bytes as an unsigned value.

}

// 3. Encode values
double dotProduct = 0.0f;
for (int dim = 0; dim < dimCount; dim++) {
int intValue = Float.floatToIntBits(values[dim]);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;
dotProduct += values[dim] * values[dim];
float value = values[dim];
byteBuffer.putFloat(value);
dotProduct += value * value;
}

// 4. Encode vector magnitude at the end
if (indexVersion.onOrAfter(Version.V_7_4_0)) {
float vectorMagnitude = (float) Math.sqrt(dotProduct);
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8);
buf[offset++] = (byte) vectorMagnitudeIntValue;
byteBuffer.putFloat(vectorMagnitude);
}

return new BytesRef(buf);
return new BytesRef(bytes);
}

/**
Expand All @@ -75,12 +72,14 @@ public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vector
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) :
vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset;
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, dimCount * SHORT_BYTES);

int[] dims = new int[dimCount];
for (int dim = 0; dim < dimCount; dim++) {
dims[dim] = ((vectorBR.bytes[offset++] & 0xFF) << 8) | (vectorBR.bytes[offset++] & 0xFF);
dims[dim] = ((byteBuffer.get() & 0xFF) << 8) | (byteBuffer.get() & 0xFF);
}
return dims;
}
Expand All @@ -94,21 +93,19 @@ public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) :
vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset + SHORT_BYTES * dimCount;
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, offset, dimCount * INT_BYTES);
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
vector[dim] = byteBuffer.getFloat();
}
return vector;
}


/**
* Sorts dimensions in the ascending order and
* sorts values in the same order as their corresponding dimensions
Expand Down Expand Up @@ -174,15 +171,13 @@ public static float[] decodeDenseVector(Version indexVersion, BytesRef vectorBR)
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / INT_BYTES : vectorBR.length/ INT_BYTES;
int offset = vectorBR.offset;
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jtibshirani I was wondering if we plan to eventually switch to decodeAndDotProduct? If not how about using FloatBuffer instead? With FloatBuffer we get all values in bulk instead of loop.

From benchmarks:

    static float[] decodeWithFloatBuffer(BytesRef vectorBR) {
        if (vectorBR == null) {
            throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
        }
        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
        FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
        float[] vector = new float[vectorBR.length / INT_BYTES];
        floatBuffer.get(vector);
        return vector;
    }

VectorFunctionBenchmark.decodeNoop avgt 30 42.340 ± 0.706 ns/op
VectorFunctionBenchmark.decodeWithByteBuffer avgt 30 91.239 ± 1.644 ns/op
VectorFunctionBenchmark.decodeWithFloatBuffer avgt 30 72.101 ± 1.872 ns/op

A similar comment for decodeSparseVector.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I'm planning to switch to combined decoding and dot product in an upcoming PR, as I've measured that it really helps performance.

As a side note, I also tried FloatBuffer while benchmarking. Although it helped in microbenchmarks, it actually hurt in macrobenchmarks, so I decided not to pursue it. I haven't yet done a root cause analysis of why FloatBuffer hurt overall search performance.

for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
vector[dim] = byteBuffer.getFloat();
}
return vector;
}
Expand All @@ -198,14 +193,10 @@ public static float getVectorMagnitude(Version indexVersion, BytesRef vectorBR,
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

if (indexVersion.onOrAfter(Version.V_7_4_0)) { // decode vector magnitude
int offset = vectorBR.offset + vectorBR.length - 4;
int vectorMagnitudeIntValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
float vectorMagnitude = Float.intBitsToFloat(vectorMagnitudeIntValue);
return vectorMagnitude;
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
} else { // calculate vector magnitude
double dotProduct = 0f;
for (int dim = 0; dim < vector.length; dim++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.Version;
import org.elasticsearch.test.ESTestCase;

import java.nio.ByteBuffer;
import java.util.HashSet;
import java.util.Set;
import java.util.Arrays;
Expand Down Expand Up @@ -137,44 +138,28 @@ public void testSparseVectorEncodingDecodingBefore7_4() {

// imitates the code in DenseVectorFieldMapper::parse
public static BytesRef mockEncodeDenseVector(float[] values) {
final short INT_BYTES = VectorEncoderDecoder.INT_BYTES;
byte[] buf = new byte[INT_BYTES * values.length + INT_BYTES];
int offset = 0;
byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES];
double dotProduct = 0f;
int intValue;
for (float value: values) {
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
for (float value : values) {
byteBuffer.putFloat(value);
dotProduct += value * value;
intValue = Float.floatToIntBits(value);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;
}
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8);
buf[offset++] = (byte) vectorMagnitudeIntValue;

return new BytesRef(buf);
byteBuffer.putFloat(vectorMagnitude);
return new BytesRef(bytes);
}

// imitates the code in DenseVectorFieldMapper::parse before version 7.4
public static BytesRef mockEncodeDenseVectorBefore7_4(float[] values) {
final short INT_BYTES = VectorEncoderDecoder.INT_BYTES;
byte[] buf = new byte[INT_BYTES * values.length];
int offset = 0;
int intValue;
for (float value: values) {
intValue = Float.floatToIntBits(value);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;
byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
for (float value : values) {
byteBuffer.putFloat(value);
}
return new BytesRef(buf, 0, offset);
return new BytesRef(bytes);

}

// generate unique random dims
Expand Down