From 585a82aad6483c4c3e01cfa5352dc7d92e386a89 Mon Sep 17 00:00:00 2001 From: Vikasht34 Date: Mon, 9 Sep 2024 10:44:29 -0700 Subject: [PATCH] Re-Call Issue Fix with Binary Quantized Vectors (#2071) * Re-Call Issue Fix with Binary Quantized Vectors Signed-off-by: VIKASH TIWARI * Feedback Fix Signed-off-by: VIKASH TIWARI --------- Signed-off-by: VIKASH TIWARI Signed-off-by: Vikasht34 (cherry picked from commit ce735c4cf42b0dbc3d601f576592dad3cc16a19d) --- .../nativeindex/QuantizationIndexUtils.java | 16 ++- .../BinaryQuantizationOutput.java | 12 ++ .../QuantizationOutput.java | 34 ++++- .../quantizer/MultiBitScalarQuantizer.java | 2 +- .../quantizer/OneBitScalarQuantizer.java | 2 +- .../DefaultIndexBuildStrategyTests.java | 6 +- ...ptimizedNativeIndexBuildStrategyTests.java | 6 +- .../output/BinaryQuantizationOutputTests.java | 121 ++++++++++++++++++ 8 files changed, 184 insertions(+), 15 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/quantization/output/BinaryQuantizationOutputTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java index bebe9e8b0..c5994d66b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java @@ -18,12 +18,12 @@ class QuantizationIndexUtils { /** - * Processes and returns the vector based on whether quantization is applied or not. + * Processes the vector from {@link KNNVectorValues} and returns either a cloned quantized vector or a cloned original vector. * - * @param knnVectorValues the KNN vector values to be processed. - * @param indexBuildSetup the setup containing quantization state and output, along with other parameters. - * @return the processed vector, either quantized or original. - * @throws IOException if an I/O error occurs during processing. + * @param knnVectorValues The KNN vector values containing the original vector. + * @param indexBuildSetup The setup containing the quantization state and output details. + * @return The quantized vector (as a byte array) or the original/cloned vector. + * @throws IOException If an I/O error occurs while processing the vector. */ static Object processAndReturnVector(KNNVectorValues knnVectorValues, IndexBuildSetup indexBuildSetup) throws IOException { QuantizationService quantizationService = QuantizationService.getInstance(); @@ -33,7 +33,11 @@ static Object processAndReturnVector(KNNVectorValues knnVectorValues, IndexBu knnVectorValues.getVector(), indexBuildSetup.getQuantizationOutput() ); - return indexBuildSetup.getQuantizationOutput().getQuantizedVector(); + /** + * Returns a copy of the quantized vector. This is because of during transfer same vectors was getting + * added due to reference. + */ + return indexBuildSetup.getQuantizationOutput().getQuantizedVectorCopy(); } else { return knnVectorValues.conditionalCloneVector(); } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java index 388fd9e94..dc8634b9d 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java @@ -63,4 +63,16 @@ public boolean isPrepared(int vectorLength) { public byte[] getQuantizedVector() { return quantizedVector; } + + /** + * Returns a copy of the quantized vector. + * + * @return a copy of the quantized vector byte array. + */ + @Override + public byte[] getQuantizedVectorCopy() { + byte[] clonedByteArray = new byte[quantizedVector.length]; + System.arraycopy(quantizedVector, 0, clonedByteArray, 0, quantizedVector.length); + return clonedByteArray; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java index 4d088f91f..29124c268 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java @@ -14,7 +14,32 @@ public interface QuantizationOutput { /** * Returns the quantized vector. * - * @return the quantized data. + * This method provides access to the quantized data in its current state. + * It returns the same reference to the internal quantized vector on each call, meaning any modifications + * to the returned array will directly affect the internal state of the object. This design is intentional + * to avoid unnecessary copying of data and improve performance, especially in scenarios where frequent + * access to the quantized vector is required. + * + *

Important: As this method returns a direct reference to the internal array, care must be taken + * when modifying the returned array. If the returned vector is altered, the changes will reflect in the + * quantized vector managed by the object, which could lead to unintended side effects.

+ * + *

Usage Example:

+ *
+     * byte[] quantizedData = quantizationOutput.getQuantizedVector();
+     * // Use or modify quantizedData, but be cautious that changes affect the internal state.
+     * 
+ * + * This method does not create a deep copy of the vector to avoid performance overhead in real-time + * or high-frequency operations. If a separate copy of the vector is needed, the caller should manually + * clone or copy the returned array. + * + *

Example to clone the array:

+ *
+     * byte[] clonedData = Arrays.copyOf(quantizationOutput.getQuantizedVector(), quantizationOutput.getQuantizedVector().length);
+     * 
+ * + * @return the quantized vector (same reference on each invocation). */ T getQuantizedVector(); @@ -33,4 +58,11 @@ public interface QuantizationOutput { * @return true if the quantized vector is already prepared, false otherwise. */ boolean isPrepared(int vectorLength); + + /** + * Returns a copy of the quantized vector. + * + * @return a copy of the quantized data. + */ + T getQuantizedVectorCopy(); } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index a0e6ec402..0bcc252d1 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -139,7 +139,7 @@ public void quantize(final float[] vector, final QuantizationState state, final if (thresholds == null || thresholds[0].length != vector.length) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } - if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength); + output.prepareQuantizedVector(vectorLength); BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, output.getQuantizedVector()); } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index ac48a9523..3cba89c39 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -84,7 +84,7 @@ public void quantize(final float[] vector, final QuantizationState state, final if (thresholds == null || thresholds.length != vectorLength) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } - if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength); + output.prepareQuantizedVector(vectorLength); BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index 9c2e5a4b7..abb61ccd9 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -138,7 +138,7 @@ public void testBuildAndWrite_withQuantization() { ArgumentCaptor vectorCaptor = ArgumentCaptor.forClass(float[].class); // New: Create QuantizationOutput and mock the quantization process QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); - when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 }); + when(quantizationOutput.getQuantizedVectorCopy()).thenReturn(new byte[] { 1, 2 }); when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn( quantizationOutput ); @@ -146,8 +146,8 @@ public void testBuildAndWrite_withQuantization() { // Quantize the vector with the quantization output when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer( invocation -> { - quantizationOutput.getQuantizedVector(); - return quantizationOutput.getQuantizedVector(); + quantizationOutput.getQuantizedVectorCopy(); + return quantizationOutput.getQuantizedVectorCopy(); } ); when(quantizationState.getDimensions()).thenReturn(2); diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 62c3b7a71..77abe1cd2 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -164,7 +164,7 @@ public void testBuildAndWrite_withQuantization() { ArgumentCaptor vectorCaptor = ArgumentCaptor.forClass(float[].class); // New: Create QuantizationOutput and mock the quantization process QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); - when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 }); + when(quantizationOutput.getQuantizedVectorCopy()).thenReturn(new byte[] { 1, 2 }); when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn( quantizationOutput ); @@ -172,8 +172,8 @@ public void testBuildAndWrite_withQuantization() { // Quantize the vector with the quantization output when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer( invocation -> { - quantizationOutput.getQuantizedVector(); - return quantizationOutput.getQuantizedVector(); + quantizationOutput.getQuantizedVectorCopy(); + return quantizationOutput.getQuantizedVectorCopy(); } ); when(quantizationState.getDimensions()).thenReturn(2); diff --git a/src/test/java/org/opensearch/knn/quantization/output/BinaryQuantizationOutputTests.java b/src/test/java/org/opensearch/knn/quantization/output/BinaryQuantizationOutputTests.java new file mode 100644 index 000000000..8eab5d00c --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/output/BinaryQuantizationOutputTests.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.output; + +import org.junit.Before; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; + +public class BinaryQuantizationOutputTests extends KNNTestCase { + + private static final int BITS_PER_COORDINATE = 1; + private BinaryQuantizationOutput quantizationOutput; + + @Before + public void setUp() throws Exception { + super.setUp(); + quantizationOutput = new BinaryQuantizationOutput(BITS_PER_COORDINATE); + } + + public void testPrepareQuantizedVector_ShouldInitializeCorrectly_WhenVectorLengthIsValid() { + // Arrange + int vectorLength = 10; + + // Act + quantizationOutput.prepareQuantizedVector(vectorLength); + + // Assert + assertNotNull(quantizationOutput.getQuantizedVector()); + } + + public void testPrepareQuantizedVector_ShouldThrowException_WhenVectorLengthIsZeroOrNegative() { + // Act and Assert + expectThrows(IllegalArgumentException.class, () -> quantizationOutput.prepareQuantizedVector(0)); + expectThrows(IllegalArgumentException.class, () -> quantizationOutput.prepareQuantizedVector(-1)); + } + + public void testIsPrepared_ShouldReturnTrue_WhenCalledWithSameVectorLength() { + // Arrange + int vectorLength = 8; + quantizationOutput.prepareQuantizedVector(vectorLength); + // Act and Assert + assertTrue(quantizationOutput.isPrepared(vectorLength)); + } + + public void testIsPrepared_ShouldReturnFalse_WhenCalledWithDifferentVectorLength() { + // Arrange + int vectorLength = 8; + quantizationOutput.prepareQuantizedVector(vectorLength); + // Act and Assert + assertFalse(quantizationOutput.isPrepared(vectorLength + 1)); + } + + public void testGetQuantizedVector_ShouldReturnSameReference() { + // Arrange + int vectorLength = 5; + quantizationOutput.prepareQuantizedVector(vectorLength); + // Act + byte[] vector = quantizationOutput.getQuantizedVector(); + // Assert + assertEquals(vector, quantizationOutput.getQuantizedVector()); + } + + public void testGetQuantizedVectorCopy_ShouldReturnCopyOfVector() { + // Arrange + int vectorLength = 5; + quantizationOutput.prepareQuantizedVector(vectorLength); + + // Act + byte[] vectorCopy = quantizationOutput.getQuantizedVectorCopy(); + + // Assert + assertNotSame(vectorCopy, quantizationOutput.getQuantizedVector()); + assertArrayEquals(vectorCopy, quantizationOutput.getQuantizedVector()); + } + + public void testGetQuantizedVectorCopy_ShouldReturnNewCopyOnEachCall() { + // Arrange + int vectorLength = 5; + quantizationOutput.prepareQuantizedVector(vectorLength); + + // Act + byte[] vectorCopy1 = quantizationOutput.getQuantizedVectorCopy(); + byte[] vectorCopy2 = quantizationOutput.getQuantizedVectorCopy(); + + // Assert + assertNotSame(vectorCopy1, vectorCopy2); + } + + public void testPrepareQuantizedVector_ShouldResetQuantizedVector_WhenCalledWithDifferentLength() { + // Arrange + int initialLength = 5; + int newLength = 10; + quantizationOutput.prepareQuantizedVector(initialLength); + byte[] initialVector = quantizationOutput.getQuantizedVector(); + + // Act + quantizationOutput.prepareQuantizedVector(newLength); + byte[] newVector = quantizationOutput.getQuantizedVector(); + + // Assert + assertNotSame(initialVector, newVector); // The array reference should change + assertEquals(newVector.length, (BITS_PER_COORDINATE * newLength + 7) / 8); // Correct size for new vector + } + + public void testPrepareQuantizedVector_ShouldRetainSameArray_WhenCalledWithSameLength() { + // Arrange + int vectorLength = 5; + quantizationOutput.prepareQuantizedVector(vectorLength); + byte[] initialVector = quantizationOutput.getQuantizedVector(); + + // Act + quantizationOutput.prepareQuantizedVector(vectorLength); + byte[] newVector = quantizationOutput.getQuantizedVector(); + + // Assert + assertSame(newVector, initialVector); // The array reference should remain the same + } +}