Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Re-Call Issue Fix with Binary Quantized Vectors #2078

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
class QuantizationIndexUtils {

/**
* Processes and returns the vector based on whether quantization is applied or not.
* Processes the vector from {@link KNNVectorValues} and returns either a cloned quantized vector or a cloned original vector.
*
* @param knnVectorValues the KNN vector values to be processed.
* @param indexBuildSetup the setup containing quantization state and output, along with other parameters.
* @return the processed vector, either quantized or original.
* @throws IOException if an I/O error occurs during processing.
* @param knnVectorValues The KNN vector values containing the original vector.
* @param indexBuildSetup The setup containing the quantization state and output details.
* @return The quantized vector (as a byte array) or the original/cloned vector.
* @throws IOException If an I/O error occurs while processing the vector.
*/
static Object processAndReturnVector(KNNVectorValues<?> knnVectorValues, IndexBuildSetup indexBuildSetup) throws IOException {
QuantizationService quantizationService = QuantizationService.getInstance();
Expand All @@ -33,7 +33,11 @@ static Object processAndReturnVector(KNNVectorValues<?> knnVectorValues, IndexBu
knnVectorValues.getVector(),
indexBuildSetup.getQuantizationOutput()
);
return indexBuildSetup.getQuantizationOutput().getQuantizedVector();
/**
* Returns a copy of the quantized vector. This is because of during transfer same vectors was getting
* added due to reference.
*/
return indexBuildSetup.getQuantizationOutput().getQuantizedVectorCopy();
} else {
return knnVectorValues.conditionalCloneVector();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,16 @@ public boolean isPrepared(int vectorLength) {
public byte[] getQuantizedVector() {
return quantizedVector;
}

/**
* Returns a copy of the quantized vector.
*
* @return a copy of the quantized vector byte array.
*/
@Override
public byte[] getQuantizedVectorCopy() {
byte[] clonedByteArray = new byte[quantizedVector.length];
System.arraycopy(quantizedVector, 0, clonedByteArray, 0, quantizedVector.length);
return clonedByteArray;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,32 @@ public interface QuantizationOutput<T> {
/**
* Returns the quantized vector.
*
* @return the quantized data.
* This method provides access to the quantized data in its current state.
* It returns the same reference to the internal quantized vector on each call, meaning any modifications
* to the returned array will directly affect the internal state of the object. This design is intentional
* to avoid unnecessary copying of data and improve performance, especially in scenarios where frequent
* access to the quantized vector is required.
*
* <p><b>Important:</b> As this method returns a direct reference to the internal array, care must be taken
* when modifying the returned array. If the returned vector is altered, the changes will reflect in the
* quantized vector managed by the object, which could lead to unintended side effects.</p>
*
* <p><b>Usage Example:</b></p>
* <pre>
* byte[] quantizedData = quantizationOutput.getQuantizedVector();
* // Use or modify quantizedData, but be cautious that changes affect the internal state.
* </pre>
*
* This method does not create a deep copy of the vector to avoid performance overhead in real-time
* or high-frequency operations. If a separate copy of the vector is needed, the caller should manually
* clone or copy the returned array.
*
* <p><b>Example to clone the array:</b></p>
* <pre>
* byte[] clonedData = Arrays.copyOf(quantizationOutput.getQuantizedVector(), quantizationOutput.getQuantizedVector().length);
* </pre>
*
* @return the quantized vector (same reference on each invocation).
*/
T getQuantizedVector();

Expand All @@ -33,4 +58,11 @@ public interface QuantizationOutput<T> {
* @return true if the quantized vector is already prepared, false otherwise.
*/
boolean isPrepared(int vectorLength);

/**
* Returns a copy of the quantized vector.
*
* @return a copy of the quantized data.
*/
T getQuantizedVectorCopy();
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void quantize(final float[] vector, final QuantizationState state, final
if (thresholds == null || thresholds[0].length != vector.length) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength);
output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, output.getQuantizedVector());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void quantize(final float[] vector, final QuantizationState state, final
if (thresholds == null || thresholds.length != vectorLength) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength);
output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,16 @@ public void testBuildAndWrite_withQuantization() {
ArgumentCaptor<float[]> vectorCaptor = ArgumentCaptor.forClass(float[].class);
// New: Create QuantizationOutput and mock the quantization process
QuantizationOutput<byte[]> quantizationOutput = mock(QuantizationOutput.class);
when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 });
when(quantizationOutput.getQuantizedVectorCopy()).thenReturn(new byte[] { 1, 2 });
when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn(
quantizationOutput
);

// Quantize the vector with the quantization output
when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer(
invocation -> {
quantizationOutput.getQuantizedVector();
return quantizationOutput.getQuantizedVector();
quantizationOutput.getQuantizedVectorCopy();
return quantizationOutput.getQuantizedVectorCopy();
}
);
when(quantizationState.getDimensions()).thenReturn(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,16 @@ public void testBuildAndWrite_withQuantization() {
ArgumentCaptor<float[]> vectorCaptor = ArgumentCaptor.forClass(float[].class);
// New: Create QuantizationOutput and mock the quantization process
QuantizationOutput<byte[]> quantizationOutput = mock(QuantizationOutput.class);
when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 });
when(quantizationOutput.getQuantizedVectorCopy()).thenReturn(new byte[] { 1, 2 });
when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn(
quantizationOutput
);

// Quantize the vector with the quantization output
when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer(
invocation -> {
quantizationOutput.getQuantizedVector();
return quantizationOutput.getQuantizedVector();
quantizationOutput.getQuantizedVectorCopy();
return quantizationOutput.getQuantizedVectorCopy();
}
);
when(quantizationState.getDimensions()).thenReturn(2);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.quantization.output;

import org.junit.Before;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;

public class BinaryQuantizationOutputTests extends KNNTestCase {

private static final int BITS_PER_COORDINATE = 1;
private BinaryQuantizationOutput quantizationOutput;

@Before
public void setUp() throws Exception {
super.setUp();
quantizationOutput = new BinaryQuantizationOutput(BITS_PER_COORDINATE);
}

public void testPrepareQuantizedVector_ShouldInitializeCorrectly_WhenVectorLengthIsValid() {
// Arrange
int vectorLength = 10;

// Act
quantizationOutput.prepareQuantizedVector(vectorLength);

// Assert
assertNotNull(quantizationOutput.getQuantizedVector());
}

public void testPrepareQuantizedVector_ShouldThrowException_WhenVectorLengthIsZeroOrNegative() {
// Act and Assert
expectThrows(IllegalArgumentException.class, () -> quantizationOutput.prepareQuantizedVector(0));
expectThrows(IllegalArgumentException.class, () -> quantizationOutput.prepareQuantizedVector(-1));
}

public void testIsPrepared_ShouldReturnTrue_WhenCalledWithSameVectorLength() {
// Arrange
int vectorLength = 8;
quantizationOutput.prepareQuantizedVector(vectorLength);
// Act and Assert
assertTrue(quantizationOutput.isPrepared(vectorLength));
}

public void testIsPrepared_ShouldReturnFalse_WhenCalledWithDifferentVectorLength() {
// Arrange
int vectorLength = 8;
quantizationOutput.prepareQuantizedVector(vectorLength);
// Act and Assert
assertFalse(quantizationOutput.isPrepared(vectorLength + 1));
}

public void testGetQuantizedVector_ShouldReturnSameReference() {
// Arrange
int vectorLength = 5;
quantizationOutput.prepareQuantizedVector(vectorLength);
// Act
byte[] vector = quantizationOutput.getQuantizedVector();
// Assert
assertEquals(vector, quantizationOutput.getQuantizedVector());
}

public void testGetQuantizedVectorCopy_ShouldReturnCopyOfVector() {
// Arrange
int vectorLength = 5;
quantizationOutput.prepareQuantizedVector(vectorLength);

// Act
byte[] vectorCopy = quantizationOutput.getQuantizedVectorCopy();

// Assert
assertNotSame(vectorCopy, quantizationOutput.getQuantizedVector());
assertArrayEquals(vectorCopy, quantizationOutput.getQuantizedVector());
}

public void testGetQuantizedVectorCopy_ShouldReturnNewCopyOnEachCall() {
// Arrange
int vectorLength = 5;
quantizationOutput.prepareQuantizedVector(vectorLength);

// Act
byte[] vectorCopy1 = quantizationOutput.getQuantizedVectorCopy();
byte[] vectorCopy2 = quantizationOutput.getQuantizedVectorCopy();

// Assert
assertNotSame(vectorCopy1, vectorCopy2);
}

public void testPrepareQuantizedVector_ShouldResetQuantizedVector_WhenCalledWithDifferentLength() {
// Arrange
int initialLength = 5;
int newLength = 10;
quantizationOutput.prepareQuantizedVector(initialLength);
byte[] initialVector = quantizationOutput.getQuantizedVector();

// Act
quantizationOutput.prepareQuantizedVector(newLength);
byte[] newVector = quantizationOutput.getQuantizedVector();

// Assert
assertNotSame(initialVector, newVector); // The array reference should change
assertEquals(newVector.length, (BITS_PER_COORDINATE * newLength + 7) / 8); // Correct size for new vector
}

public void testPrepareQuantizedVector_ShouldRetainSameArray_WhenCalledWithSameLength() {
// Arrange
int vectorLength = 5;
quantizationOutput.prepareQuantizedVector(vectorLength);
byte[] initialVector = quantizationOutput.getQuantizedVector();

// Act
quantizationOutput.prepareQuantizedVector(vectorLength);
byte[] newVector = quantizationOutput.getQuantizedVector();

// Assert
assertSame(newVector, initialVector); // The array reference should remain the same
}
}
Loading