diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index 10dca8084e28a..23419e8eb12b1 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -336,6 +336,10 @@ When using `bit` vectors, not all the vector functions are available. The suppor this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is the sum of the floating point values using the stored `bit` vector as a mask. +NOTE: When comparing `floats` and `bytes` with `bit` vectors, the `bit` vector is treated as a mask in big-endian order. +For example, if the `bit` vector is `10100001` (e.g. the single byte value `161`) and its compared +with array of values `[1, 2, 3, 4, 5, 6, 7, 8]` the `dotProduct` will be `1 + 3 + 8 = 16`. + Here is an example of using dot-product with bit vectors. [source,console] diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index de2cb9042610b..2f4743a47a14a 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -51,6 +51,8 @@ public static long ipByteBinByte(byte[] q, byte[] d) { /** * Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector. * This will return the sum of the query vector values using the document vector as a mask. + * When comparing the bits with the bytes, they are done in "big endian" order. For example, if the byte vector + * is [1, 2, 3, 4, 5, 6, 7, 8] and the bit vector is [0b10000000], the inner product will be 1.0. * @param q the query vector * @param d the document vector * @return the inner product of the two vectors @@ -63,9 +65,9 @@ public static int ipByteBit(byte[] q, byte[] d) { // now combine the two vectors, summing the byte dimensions where the bit in d is `1` for (int i = 0; i < d.length; i++) { byte mask = d[i]; - for (int j = 0; j < Byte.SIZE; j++) { + for (int j = Byte.SIZE - 1; j >= 0; j--) { if ((mask & (1 << j)) != 0) { - result += q[i * Byte.SIZE + j]; + result += q[i * Byte.SIZE + Byte.SIZE - 1 - j]; } } } @@ -75,6 +77,8 @@ public static int ipByteBit(byte[] q, byte[] d) { /** * Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector. * This will return the sum of the query vector values using the document vector as a mask. + * When comparing the bits with the floats, they are done in "big endian" order. For example, if the float vector + * is [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] and the bit vector is [0b10000000], the inner product will be 1.0. * @param q the query vector * @param d the document vector * @return the inner product of the two vectors @@ -86,9 +90,9 @@ public static float ipFloatBit(float[] q, byte[] d) { float result = 0; for (int i = 0; i < d.length; i++) { byte mask = d[i]; - for (int j = 0; j < Byte.SIZE; j++) { + for (int j = Byte.SIZE - 1; j >= 0; j--) { if ((mask & (1 << j)) != 0) { - result += q[i * Byte.SIZE + j]; + result += q[i * Byte.SIZE + Byte.SIZE - 1 - j]; } } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index e9e0fd58f7638..368898b934c87 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -21,6 +21,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests { static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider(); static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider(); + public void testIpByteBit() { + byte[] q = new byte[16]; + byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) }; + random().nextBytes(q); + int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15]; + assertEquals(expected, ESVectorUtil.ipByteBit(q, d)); + } + + public void testIpFloatBit() { + float[] q = new float[16]; + byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) }; + random().nextFloat(); + float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15]; + assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6); + } + public void testBitAndCount() { testBasicBitAndImpl(ESVectorUtil::andBitCountLong); }