Skip to content

Commit

Permalink
Improve halfbyte transposition performance
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Nov 22, 2024
1 parent 35116c3 commit 374a8e7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,38 @@
public class BQSpaceUtils {

public static final short B_QUERY = 4;
// the first four bits masked
private static final int B_QUERY_MASK = 15;

/**
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
* Transpose the query vector into a byte array allowing for efficient bitwise operations with the
* index bit vectors. The idea here is to organize the query vector bits such that the first bit
* of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,
* third, and fourth bits are in the second, third, and fourth set of dimensions bits,
* respectively. This allows for direct bitwise comparisons with the stored index vectors through
* summing the bitwise results with the relative required bit shifts.
*
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
* @param dimensions the number of dimensions in the query vector
* @param quantQueryByte the byte array to store the transposed query vector
*/
public static void transposeBin(byte[] q, int dimensions, byte[] quantQueryByte) {
// TODO: rewrite this in Panama Vector API
int qOffset = 0;
final byte[] v1 = new byte[4];
final byte[] v = new byte[32];
for (int i = 0; i < dimensions; i += 32) {
// for every four bytes we shift left (with remainder across those bytes)
for (int j = 0; j < v.length; j += 4) {
v[j] = (byte) (q[qOffset + j] << B_QUERY | ((q[qOffset + j] >>> B_QUERY) & B_QUERY_MASK));
v[j + 1] = (byte) (q[qOffset + j + 1] << B_QUERY | ((q[qOffset + j + 1] >>> B_QUERY) & B_QUERY_MASK));
v[j + 2] = (byte) (q[qOffset + j + 2] << B_QUERY | ((q[qOffset + j + 2] >>> B_QUERY) & B_QUERY_MASK));
v[j + 3] = (byte) (q[qOffset + j + 3] << B_QUERY | ((q[qOffset + j + 3] >>> B_QUERY) & B_QUERY_MASK));
}
for (int j = 0; j < B_QUERY; j++) {
moveMaskEpi8Byte(v, v1);
for (int k = 0; k < 4; k++) {
quantQueryByte[(B_QUERY - j - 1) * (dimensions / 8) + i / 8 + k] = v1[k];
v1[k] = 0;
}
for (int k = 0; k < v.length; k += 4) {
v[k] = (byte) (v[k] + v[k]);
v[k + 1] = (byte) (v[k + 1] + v[k + 1]);
v[k + 2] = (byte) (v[k + 2] + v[k + 2]);
v[k + 3] = (byte) (v[k + 3] + v[k + 3]);
}
}
qOffset += 32;
}
}

private static void moveMaskEpi8Byte(byte[] v, byte[] v1b) {
int m = 0;
for (int k = 0; k < v.length; k++) {
if ((v[k] & 0b10000000) == 0b10000000) {
v1b[m] |= 0b00000001;
}
if (k % 8 == 7) {
m++;
} else {
v1b[m] <<= 1;
public static void transposeHalfByte(byte[] q, byte[] quantQueryByte) {
for (int i = 0; i < q.length;) {
assert q[i] >= 0 && q[i] <= 15;
int lowerByte = 0;
int lowerMiddleByte = 0;
int upperMiddleByte = 0;
int upperByte = 0;
for (int j = 7; j >= 0 && i < q.length; j--) {
lowerByte |= (q[i] & 1) << j;
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
upperByte |= ((q[i] >> 3) & 1) << j;
i++;
}
int index = ((i + 7) / 8) - 1;
quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ public QueryAndIndexResults quantizeQueryAndIndex(float[] vector, byte[] indexDe

// q¯ = Δ · q¯𝑢 + 𝑣𝑙 · 1𝐷
// q¯ is an approximation of q′ (scalar quantized approximation)
// FIXME: vectors need to be padded but that's expensive; update transponseBin to deal
byteQuery = BQVectorUtils.pad(byteQuery, discretizedDimensions);
BQSpaceUtils.transposeBin(byteQuery, discretizedDimensions, queryDestination);
BQSpaceUtils.transposeHalfByte(byteQuery, queryDestination);
QueryFactors factors = new QueryFactors(quantResult.quantizedSum, distToC, lower, width, normVmC, vDotC);
final float[] indexCorrections;
if (similarityFunction == EUCLIDEAN) {
Expand Down Expand Up @@ -366,9 +364,7 @@ public QueryFactors quantizeForQuery(float[] vector, byte[] destination, float[]

// q¯ = Δ · q¯𝑢 + 𝑣𝑙 · 1𝐷
// q¯ is an approximation of q′ (scalar quantized approximation)
// FIXME: vectors need to be padded but that's expensive; update transponseBin to deal
byteQuery = BQVectorUtils.pad(byteQuery, discretizedDimensions);
BQSpaceUtils.transposeBin(byteQuery, discretizedDimensions, destination);
BQSpaceUtils.transposeHalfByte(byteQuery, destination);

QueryFactors factors;
if (similarityFunction != EUCLIDEAN) {
Expand Down

0 comments on commit 374a8e7

Please sign in to comment.