From a4f6c31e2c7080b8bf4774cca11ddd74f3d54492 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Fri, 23 Aug 2019 12:58:58 -0700 Subject: [PATCH] Switch to ByteBuffer for vector encoding. --- .../mapper/DenseVectorFieldMapper.java | 25 ++++--- .../vectors/mapper/VectorEncoderDecoder.java | 65 ++++++++----------- .../mapper/VectorEncoderDecoderTests.java | 41 ++++-------- 3 files changed, 52 insertions(+), 79 deletions(-) diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java index 9d5a4d676491c..2c56832b51f9b 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java @@ -26,10 +26,11 @@ import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.ParseContext; import org.elasticsearch.index.query.QueryShardContext; -import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData; import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData; import java.io.IOException; +import java.nio.ByteBuffer; import java.time.ZoneId; import java.util.List; import java.util.Map; @@ -181,9 +182,11 @@ public void parse(ParseContext context) throws IOException { // encode array of floats as array of integers and store into buf // this code is here and not int the VectorEncoderDecoder so not to create extra arrays - byte[] buf = indexCreatedVersion.onOrAfter(Version.V_7_4_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES]; - int offset = 0; + byte[] bytes = indexCreatedVersion.onOrAfter(Version.V_7_4_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES]; + + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); double dotProduct = 0f; + int dim = 0; for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { if (dim++ >= dims) { @@ -192,11 +195,8 @@ public void parse(ParseContext context) throws IOException { } ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation); float value = context.parser().floatValue(true); - int intValue = Float.floatToIntBits(value); - buf[offset++] = (byte) (intValue >> 24); - buf[offset++] = (byte) (intValue >> 16); - buf[offset++] = (byte) (intValue >> 8); - buf[offset++] = (byte) intValue; + + byteBuffer.putFloat(value); dotProduct += value * value; } if (dim != dims) { @@ -204,16 +204,13 @@ public void parse(ParseContext context) throws IOException { context.sourceToParse().id() + "] has number of dimensions [" + dim + "] less than defined in the mapping [" + dims +"]"); } + if (indexCreatedVersion.onOrAfter(Version.V_7_4_0)) { // encode vector magnitude at the end float vectorMagnitude = (float) Math.sqrt(dotProduct); - int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8); - buf[offset++] = (byte) vectorMagnitudeIntValue; + byteBuffer.putFloat(vectorMagnitude); } - BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf)); + BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(bytes)); if (context.doc().getByKey(fieldType().name()) != null) { throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't not support indexing multiple values for the same field in the same document"); diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java index 38a240663fe8c..686823c248bb9 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java @@ -11,6 +11,8 @@ import org.apache.lucene.util.InPlaceMergeSorter; import org.elasticsearch.Version; +import java.nio.ByteBuffer; + // static utility functions for encoding and decoding dense_vector and sparse_vector fields public final class VectorEncoderDecoder { static final byte INT_BYTES = 4; @@ -34,36 +36,31 @@ public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, floa // 2. Encode dimensions // as each dimension is a positive value that doesn't exceed 65535, 2 bytes is enough for encoding it - byte[] buf = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] : + byte[] bytes = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] : new byte[dimCount * (INT_BYTES + SHORT_BYTES)]; - int offset = 0; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + for (int dim = 0; dim < dimCount; dim++) { - buf[offset++] = (byte) (dims[dim] >> 8); - buf[offset++] = (byte) dims[dim]; + int dimValue = dims[dim]; + byteBuffer.put((byte) (dimValue >> 8)); + byteBuffer.put((byte) dimValue); } // 3. Encode values double dotProduct = 0.0f; for (int dim = 0; dim < dimCount; dim++) { - int intValue = Float.floatToIntBits(values[dim]); - buf[offset++] = (byte) (intValue >> 24); - buf[offset++] = (byte) (intValue >> 16); - buf[offset++] = (byte) (intValue >> 8); - buf[offset++] = (byte) intValue; - dotProduct += values[dim] * values[dim]; + float value = values[dim]; + byteBuffer.putFloat(value); + dotProduct += value * value; } // 4. Encode vector magnitude at the end if (indexVersion.onOrAfter(Version.V_7_4_0)) { float vectorMagnitude = (float) Math.sqrt(dotProduct); - int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8); - buf[offset++] = (byte) vectorMagnitudeIntValue; + byteBuffer.putFloat(vectorMagnitude); } - return new BytesRef(buf); + return new BytesRef(bytes); } /** @@ -75,12 +72,14 @@ public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vector if (vectorBR == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } + int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) : vectorBR.length / (INT_BYTES + SHORT_BYTES); - int offset = vectorBR.offset; + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, dimCount * SHORT_BYTES); + int[] dims = new int[dimCount]; for (int dim = 0; dim < dimCount; dim++) { - dims[dim] = ((vectorBR.bytes[offset++] & 0xFF) << 8) | (vectorBR.bytes[offset++] & 0xFF); + dims[dim] = ((byteBuffer.get() & 0xFF) << 8) | (byteBuffer.get() & 0xFF); } return dims; } @@ -94,21 +93,19 @@ public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR if (vectorBR == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } + int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) : vectorBR.length / (INT_BYTES + SHORT_BYTES); int offset = vectorBR.offset + SHORT_BYTES * dimCount; float[] vector = new float[dimCount]; + + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, offset, dimCount * INT_BYTES); for (int dim = 0; dim < dimCount; dim++) { - int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) | - ((vectorBR.bytes[offset++] & 0xFF) << 16) | - ((vectorBR.bytes[offset++] & 0xFF) << 8) | - (vectorBR.bytes[offset++] & 0xFF); - vector[dim] = Float.intBitsToFloat(intValue); + vector[dim] = byteBuffer.getFloat(); } return vector; } - /** * Sorts dimensions in the ascending order and * sorts values in the same order as their corresponding dimensions @@ -174,15 +171,13 @@ public static float[] decodeDenseVector(Version indexVersion, BytesRef vectorBR) if (vectorBR == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } + int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / INT_BYTES : vectorBR.length/ INT_BYTES; - int offset = vectorBR.offset; float[] vector = new float[dimCount]; + + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); for (int dim = 0; dim < dimCount; dim++) { - int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) | - ((vectorBR.bytes[offset++] & 0xFF) << 16) | - ((vectorBR.bytes[offset++] & 0xFF) << 8) | - (vectorBR.bytes[offset++] & 0xFF); - vector[dim] = Float.intBitsToFloat(intValue); + vector[dim] = byteBuffer.getFloat(); } return vector; } @@ -198,14 +193,10 @@ public static float getVectorMagnitude(Version indexVersion, BytesRef vectorBR, if (vectorBR == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } + if (indexVersion.onOrAfter(Version.V_7_4_0)) { // decode vector magnitude - int offset = vectorBR.offset + vectorBR.length - 4; - int vectorMagnitudeIntValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) | - ((vectorBR.bytes[offset++] & 0xFF) << 16) | - ((vectorBR.bytes[offset++] & 0xFF) << 8) | - (vectorBR.bytes[offset++] & 0xFF); - float vectorMagnitude = Float.intBitsToFloat(vectorMagnitudeIntValue); - return vectorMagnitude; + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4); } else { // calculate vector magnitude double dotProduct = 0f; for (int dim = 0; dim < vector.length; dim++) { diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java index 51426755aca7a..ba7de2ad74528 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.Version; import org.elasticsearch.test.ESTestCase; +import java.nio.ByteBuffer; import java.util.HashSet; import java.util.Set; import java.util.Arrays; @@ -137,44 +138,28 @@ public void testSparseVectorEncodingDecodingBefore7_4() { // imitates the code in DenseVectorFieldMapper::parse public static BytesRef mockEncodeDenseVector(float[] values) { - final short INT_BYTES = VectorEncoderDecoder.INT_BYTES; - byte[] buf = new byte[INT_BYTES * values.length + INT_BYTES]; - int offset = 0; + byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES]; double dotProduct = 0f; - int intValue; - for (float value: values) { + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + for (float value : values) { + byteBuffer.putFloat(value); dotProduct += value * value; - intValue = Float.floatToIntBits(value); - buf[offset++] = (byte) (intValue >> 24); - buf[offset++] = (byte) (intValue >> 16); - buf[offset++] = (byte) (intValue >> 8); - buf[offset++] = (byte) intValue; } // encode vector magnitude at the end float vectorMagnitude = (float) Math.sqrt(dotProduct); - int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16); - buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8); - buf[offset++] = (byte) vectorMagnitudeIntValue; - - return new BytesRef(buf); + byteBuffer.putFloat(vectorMagnitude); + return new BytesRef(bytes); } // imitates the code in DenseVectorFieldMapper::parse before version 7.4 public static BytesRef mockEncodeDenseVectorBefore7_4(float[] values) { - final short INT_BYTES = VectorEncoderDecoder.INT_BYTES; - byte[] buf = new byte[INT_BYTES * values.length]; - int offset = 0; - int intValue; - for (float value: values) { - intValue = Float.floatToIntBits(value); - buf[offset++] = (byte) (intValue >> 24); - buf[offset++] = (byte) (intValue >> 16); - buf[offset++] = (byte) (intValue >> 8); - buf[offset++] = (byte) intValue; + byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + for (float value : values) { + byteBuffer.putFloat(value); } - return new BytesRef(buf, 0, offset); + return new BytesRef(bytes); + } // generate unique random dims