Skip to content

Commit

Permalink
Correct bit * byte and bit * float script comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Nov 23, 2024
1 parent 2b8e4e7 commit 87c1f93
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
4 changes: 4 additions & 0 deletions docs/reference/vectors/vector-functions.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ When using `bit` vectors, not all the vector functions are available. The suppor
this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is
the sum of the floating point values using the stored `bit` vector as a mask.

NOTE: When comparing `floats` and `bytes` with `bit` vectors, the `bit` vector is treated as a mask in big-endian order.
For example, if the `bit` vector is `10100001` (e.g. the single byte value `161`) and its compared
with array of values `[1, 2, 3, 4, 5, 6, 7, 8]` the `dotProduct` will be `1 + 3 + 8 = 16`.

Here is an example of using dot-product with bit vectors.

[source,console]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public static long ipByteBinByte(byte[] q, byte[] d) {
/**
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* When comparing the bits with the bytes, they are done in "big endian" order. For example, if the byte vector
* is [1, 2, 3, 4, 5, 6, 7, 8] and the bit vector is [0b10000000], the inner product will be 1.0.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
Expand All @@ -63,9 +65,9 @@ public static int ipByteBit(byte[] q, byte[] d) {
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
for (int j = Byte.SIZE - 1; j >= 0; j--) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
}
}
}
Expand All @@ -75,6 +77,8 @@ public static int ipByteBit(byte[] q, byte[] d) {
/**
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* When comparing the bits with the floats, they are done in "big endian" order. For example, if the float vector
* is [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] and the bit vector is [0b10000000], the inner product will be 1.0.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
Expand All @@ -86,9 +90,9 @@ public static float ipFloatBit(float[] q, byte[] d) {
float result = 0;
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
for (int j = Byte.SIZE - 1; j >= 0; j--) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();

public void testIpByteBit() {
byte[] q = new byte[16];
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
random().nextBytes(q);
int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
assertEquals(expected, ESVectorUtil.ipByteBit(q, d));
}

public void testIpFloatBit() {
float[] q = new float[16];
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
random().nextFloat();
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
}

public void testBitAndCount() {
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
}
Expand Down

0 comments on commit 87c1f93

Please sign in to comment.