Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-Call Issue Fix with Binary Quantized Vectors #2071

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static Object processAndReturnVector(KNNVectorValues<?> knnVectorValues, IndexBu
knnVectorValues.getVector(),
indexBuildSetup.getQuantizationOutput()
);
return indexBuildSetup.getQuantizationOutput().getQuantizedVector();
return indexBuildSetup.getQuantizationOutput().getQuantizedVectorCopy();
Vikasht34 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the unit test for this class to make sure we always return a copy? we need a assertNotSame check in the unit test of this class

} else {
return knnVectorValues.conditionalCloneVector();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,16 @@ public boolean isPrepared(int vectorLength) {
public byte[] getQuantizedVector() {
return quantizedVector;
}

/**
* Returns a copy of the quantized vector.
*
* @return a copy of the quantized vector byte array.
*/
@Override
public byte[] getQuantizedVectorCopy() {
Vikasht34 marked this conversation as resolved.
Show resolved Hide resolved
Vikasht34 marked this conversation as resolved.
Show resolved Hide resolved
byte[] clonedByteArray = new byte[quantizedVector.length];
System.arraycopy(quantizedVector, 0, clonedByteArray, 0, quantizedVector.length);
return clonedByteArray;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,11 @@ public interface QuantizationOutput<T> {
* @return true if the quantized vector is already prepared, false otherwise.
*/
boolean isPrepared(int vectorLength);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We should be removing this as its a dead code at this point, we can add it back if we ever need it in the future


/**
* Returns a copy of the quantized vector.
*
* @return a copy of the quantized data.
*/
T getQuantizedVectorCopy();
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void quantize(final float[] vector, final QuantizationState state, final
if (thresholds == null || thresholds[0].length != vector.length) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength);
output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, output.getQuantizedVector());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void quantize(final float[] vector, final QuantizationState state, final
if (thresholds == null || thresholds.length != vectorLength) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength);
output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,16 @@ public void testBuildAndWrite_withQuantization() {
ArgumentCaptor<float[]> vectorCaptor = ArgumentCaptor.forClass(float[].class);
// New: Create QuantizationOutput and mock the quantization process
QuantizationOutput<byte[]> quantizationOutput = mock(QuantizationOutput.class);
when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 });
when(quantizationOutput.getQuantizedVectorCopy()).thenReturn(new byte[] { 1, 2 });
when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn(
quantizationOutput
);

// Quantize the vector with the quantization output
when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer(
invocation -> {
quantizationOutput.getQuantizedVector();
return quantizationOutput.getQuantizedVector();
quantizationOutput.getQuantizedVectorCopy();
return quantizationOutput.getQuantizedVectorCopy();
}
);
when(quantizationState.getDimensions()).thenReturn(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,16 @@ public void testBuildAndWrite_withQuantization() {
ArgumentCaptor<float[]> vectorCaptor = ArgumentCaptor.forClass(float[].class);
// New: Create QuantizationOutput and mock the quantization process
QuantizationOutput<byte[]> quantizationOutput = mock(QuantizationOutput.class);
when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 });
when(quantizationOutput.getQuantizedVectorCopy()).thenReturn(new byte[] { 1, 2 });
when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn(
quantizationOutput
);

// Quantize the vector with the quantization output
when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer(
invocation -> {
quantizationOutput.getQuantizedVector();
return quantizationOutput.getQuantizedVector();
quantizationOutput.getQuantizedVectorCopy();
return quantizationOutput.getQuantizedVectorCopy();
}
);
when(quantizationState.getDimensions()).thenReturn(2);
Expand Down
Loading