-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch to ByteBuffer for vector encoding. #45936
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,8 @@ | |
import org.apache.lucene.util.InPlaceMergeSorter; | ||
import org.elasticsearch.Version; | ||
|
||
import java.nio.ByteBuffer; | ||
|
||
// static utility functions for encoding and decoding dense_vector and sparse_vector fields | ||
public final class VectorEncoderDecoder { | ||
static final byte INT_BYTES = 4; | ||
|
@@ -34,36 +36,31 @@ public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, floa | |
|
||
// 2. Encode dimensions | ||
// as each dimension is a positive value that doesn't exceed 65535, 2 bytes is enough for encoding it | ||
byte[] buf = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] : | ||
byte[] bytes = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] : | ||
new byte[dimCount * (INT_BYTES + SHORT_BYTES)]; | ||
int offset = 0; | ||
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); | ||
|
||
for (int dim = 0; dim < dimCount; dim++) { | ||
buf[offset++] = (byte) (dims[dim] >> 8); | ||
buf[offset++] = (byte) dims[dim]; | ||
int dimValue = dims[dim]; | ||
byteBuffer.put((byte) (dimValue >> 8)); | ||
byteBuffer.put((byte) dimValue); | ||
} | ||
|
||
// 3. Encode values | ||
double dotProduct = 0.0f; | ||
for (int dim = 0; dim < dimCount; dim++) { | ||
int intValue = Float.floatToIntBits(values[dim]); | ||
buf[offset++] = (byte) (intValue >> 24); | ||
buf[offset++] = (byte) (intValue >> 16); | ||
buf[offset++] = (byte) (intValue >> 8); | ||
buf[offset++] = (byte) intValue; | ||
dotProduct += values[dim] * values[dim]; | ||
float value = values[dim]; | ||
byteBuffer.putFloat(value); | ||
dotProduct += value * value; | ||
} | ||
|
||
// 4. Encode vector magnitude at the end | ||
if (indexVersion.onOrAfter(Version.V_7_4_0)) { | ||
float vectorMagnitude = (float) Math.sqrt(dotProduct); | ||
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude); | ||
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24); | ||
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16); | ||
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8); | ||
buf[offset++] = (byte) vectorMagnitudeIntValue; | ||
byteBuffer.putFloat(vectorMagnitude); | ||
} | ||
|
||
return new BytesRef(buf); | ||
return new BytesRef(bytes); | ||
} | ||
|
||
/** | ||
|
@@ -75,12 +72,14 @@ public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vector | |
if (vectorBR == null) { | ||
throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); | ||
} | ||
|
||
int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) : | ||
vectorBR.length / (INT_BYTES + SHORT_BYTES); | ||
int offset = vectorBR.offset; | ||
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, dimCount * SHORT_BYTES); | ||
|
||
int[] dims = new int[dimCount]; | ||
for (int dim = 0; dim < dimCount; dim++) { | ||
dims[dim] = ((vectorBR.bytes[offset++] & 0xFF) << 8) | (vectorBR.bytes[offset++] & 0xFF); | ||
dims[dim] = ((byteBuffer.get() & 0xFF) << 8) | (byteBuffer.get() & 0xFF); | ||
} | ||
return dims; | ||
} | ||
|
@@ -94,21 +93,19 @@ public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR | |
if (vectorBR == null) { | ||
throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); | ||
} | ||
|
||
int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) : | ||
vectorBR.length / (INT_BYTES + SHORT_BYTES); | ||
int offset = vectorBR.offset + SHORT_BYTES * dimCount; | ||
float[] vector = new float[dimCount]; | ||
|
||
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, offset, dimCount * INT_BYTES); | ||
for (int dim = 0; dim < dimCount; dim++) { | ||
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) | | ||
((vectorBR.bytes[offset++] & 0xFF) << 16) | | ||
((vectorBR.bytes[offset++] & 0xFF) << 8) | | ||
(vectorBR.bytes[offset++] & 0xFF); | ||
vector[dim] = Float.intBitsToFloat(intValue); | ||
vector[dim] = byteBuffer.getFloat(); | ||
} | ||
return vector; | ||
} | ||
|
||
|
||
/** | ||
* Sorts dimensions in the ascending order and | ||
* sorts values in the same order as their corresponding dimensions | ||
|
@@ -174,15 +171,13 @@ public static float[] decodeDenseVector(Version indexVersion, BytesRef vectorBR) | |
if (vectorBR == null) { | ||
throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); | ||
} | ||
|
||
int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / INT_BYTES : vectorBR.length/ INT_BYTES; | ||
int offset = vectorBR.offset; | ||
float[] vector = new float[dimCount]; | ||
|
||
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jtibshirani I was wondering if we plan to eventually switch to From benchmarks: static float[] decodeWithFloatBuffer(BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
float[] vector = new float[vectorBR.length / INT_BYTES];
floatBuffer.get(vector);
return vector;
} VectorFunctionBenchmark.decodeNoop avgt 30 42.340 ± 0.706 ns/op A similar comment for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, I'm planning to switch to combined decoding and dot product in an upcoming PR, as I've measured that it really helps performance. As a side note, I also tried |
||
for (int dim = 0; dim < dimCount; dim++) { | ||
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) | | ||
((vectorBR.bytes[offset++] & 0xFF) << 16) | | ||
((vectorBR.bytes[offset++] & 0xFF) << 8) | | ||
(vectorBR.bytes[offset++] & 0xFF); | ||
vector[dim] = Float.intBitsToFloat(intValue); | ||
vector[dim] = byteBuffer.getFloat(); | ||
} | ||
return vector; | ||
} | ||
|
@@ -198,14 +193,10 @@ public static float getVectorMagnitude(Version indexVersion, BytesRef vectorBR, | |
if (vectorBR == null) { | ||
throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); | ||
} | ||
|
||
if (indexVersion.onOrAfter(Version.V_7_4_0)) { // decode vector magnitude | ||
int offset = vectorBR.offset + vectorBR.length - 4; | ||
int vectorMagnitudeIntValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) | | ||
((vectorBR.bytes[offset++] & 0xFF) << 16) | | ||
((vectorBR.bytes[offset++] & 0xFF) << 8) | | ||
(vectorBR.bytes[offset++] & 0xFF); | ||
float vectorMagnitude = Float.intBitsToFloat(vectorMagnitudeIntValue); | ||
return vectorMagnitude; | ||
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); | ||
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4); | ||
} else { // calculate vector magnitude | ||
double dotProduct = 0f; | ||
for (int dim = 0; dim < vector.length; dim++) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if
byteBuffer.putShort((short) dimValue))
would be better?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it would be nice to keep this explicit shifting to match the logic for decoding. We aren't able to use
getShort
in decoding, because we want to interpret the two bytes as an unsigned value.