Skip to content

Commit

Permalink
Switch to ByteBuffer for vector encoding. (#45936)
Browse files Browse the repository at this point in the history
This commit updates the vector encoding and decoding logic to use
`java.nio.ByteBuffer`. Using `ByteBuffer` shows an improvement in
[microbenchmarks](jtibshirani#3) and I
think it helps code readability. The performance gain might be due to the fact
`ByteBuffer` uses hotspot intrinsic candidates like `Unsafe#getIntUnaligned`
under the hood.
  • Loading branch information
jtibshirani committed Aug 30, 2019
1 parent 8e2cb70 commit da5213f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.ParseContext;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -181,9 +182,11 @@ public void parse(ParseContext context) throws IOException {

// encode array of floats as array of integers and store into buf
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
byte[] buf = indexCreatedVersion.onOrAfter(Version.V_7_4_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES];
int offset = 0;
byte[] bytes = indexCreatedVersion.onOrAfter(Version.V_7_4_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES];

ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
double dotProduct = 0f;

int dim = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
if (dim++ >= dims) {
Expand All @@ -192,28 +195,22 @@ public void parse(ParseContext context) throws IOException {
}
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
float value = context.parser().floatValue(true);
int intValue = Float.floatToIntBits(value);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;

byteBuffer.putFloat(value);
dotProduct += value * value;
}
if (dim != dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has number of dimensions [" + dim +
"] less than defined in the mapping [" + dims +"]");
}

if (indexCreatedVersion.onOrAfter(Version.V_7_4_0)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8);
buf[offset++] = (byte) vectorMagnitudeIntValue;
byteBuffer.putFloat(vectorMagnitude);
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf));
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(bytes));
if (context.doc().getByKey(fieldType().name()) != null) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
"] doesn't not support indexing multiple values for the same field in the same document");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.apache.lucene.util.InPlaceMergeSorter;
import org.elasticsearch.Version;

import java.nio.ByteBuffer;

// static utility functions for encoding and decoding dense_vector and sparse_vector fields
public final class VectorEncoderDecoder {
static final byte INT_BYTES = 4;
Expand All @@ -34,36 +36,31 @@ public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, floa

// 2. Encode dimensions
// as each dimension is a positive value that doesn't exceed 65535, 2 bytes is enough for encoding it
byte[] buf = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] :
byte[] bytes = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] :
new byte[dimCount * (INT_BYTES + SHORT_BYTES)];
int offset = 0;
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);

for (int dim = 0; dim < dimCount; dim++) {
buf[offset++] = (byte) (dims[dim] >> 8);
buf[offset++] = (byte) dims[dim];
int dimValue = dims[dim];
byteBuffer.put((byte) (dimValue >> 8));
byteBuffer.put((byte) dimValue);
}

// 3. Encode values
double dotProduct = 0.0f;
for (int dim = 0; dim < dimCount; dim++) {
int intValue = Float.floatToIntBits(values[dim]);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;
dotProduct += values[dim] * values[dim];
float value = values[dim];
byteBuffer.putFloat(value);
dotProduct += value * value;
}

// 4. Encode vector magnitude at the end
if (indexVersion.onOrAfter(Version.V_7_4_0)) {
float vectorMagnitude = (float) Math.sqrt(dotProduct);
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8);
buf[offset++] = (byte) vectorMagnitudeIntValue;
byteBuffer.putFloat(vectorMagnitude);
}

return new BytesRef(buf);
return new BytesRef(bytes);
}

/**
Expand All @@ -75,12 +72,14 @@ public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vector
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) :
vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset;
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, dimCount * SHORT_BYTES);

int[] dims = new int[dimCount];
for (int dim = 0; dim < dimCount; dim++) {
dims[dim] = ((vectorBR.bytes[offset++] & 0xFF) << 8) | (vectorBR.bytes[offset++] & 0xFF);
dims[dim] = ((byteBuffer.get() & 0xFF) << 8) | (byteBuffer.get() & 0xFF);
}
return dims;
}
Expand All @@ -94,21 +93,19 @@ public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) :
vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset + SHORT_BYTES * dimCount;
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, offset, dimCount * INT_BYTES);
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
vector[dim] = byteBuffer.getFloat();
}
return vector;
}


/**
* Sorts dimensions in the ascending order and
* sorts values in the same order as their corresponding dimensions
Expand Down Expand Up @@ -174,15 +171,13 @@ public static float[] decodeDenseVector(Version indexVersion, BytesRef vectorBR)
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / INT_BYTES : vectorBR.length/ INT_BYTES;
int offset = vectorBR.offset;
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
vector[dim] = byteBuffer.getFloat();
}
return vector;
}
Expand All @@ -198,14 +193,10 @@ public static float getVectorMagnitude(Version indexVersion, BytesRef vectorBR,
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

if (indexVersion.onOrAfter(Version.V_7_4_0)) { // decode vector magnitude
int offset = vectorBR.offset + vectorBR.length - 4;
int vectorMagnitudeIntValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
float vectorMagnitude = Float.intBitsToFloat(vectorMagnitudeIntValue);
return vectorMagnitude;
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
} else { // calculate vector magnitude
double dotProduct = 0f;
for (int dim = 0; dim < vector.length; dim++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.Version;
import org.elasticsearch.test.ESTestCase;

import java.nio.ByteBuffer;
import java.util.HashSet;
import java.util.Set;
import java.util.Arrays;
Expand Down Expand Up @@ -137,44 +138,28 @@ public void testSparseVectorEncodingDecodingBefore7_4() {

// imitates the code in DenseVectorFieldMapper::parse
public static BytesRef mockEncodeDenseVector(float[] values) {
final short INT_BYTES = VectorEncoderDecoder.INT_BYTES;
byte[] buf = new byte[INT_BYTES * values.length + INT_BYTES];
int offset = 0;
byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES];
double dotProduct = 0f;
int intValue;
for (float value: values) {
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
for (float value : values) {
byteBuffer.putFloat(value);
dotProduct += value * value;
intValue = Float.floatToIntBits(value);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;
}
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8);
buf[offset++] = (byte) vectorMagnitudeIntValue;

return new BytesRef(buf);
byteBuffer.putFloat(vectorMagnitude);
return new BytesRef(bytes);
}

// imitates the code in DenseVectorFieldMapper::parse before version 7.4
public static BytesRef mockEncodeDenseVectorBefore7_4(float[] values) {
final short INT_BYTES = VectorEncoderDecoder.INT_BYTES;
byte[] buf = new byte[INT_BYTES * values.length];
int offset = 0;
int intValue;
for (float value: values) {
intValue = Float.floatToIntBits(value);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;
byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
for (float value : values) {
byteBuffer.putFloat(value);
}
return new BytesRef(buf, 0, offset);
return new BytesRef(bytes);

}

// generate unique random dims
Expand Down

0 comments on commit da5213f

Please sign in to comment.