diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index e68121a7db..476c95b8d5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -8,7 +8,6 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; @@ -56,8 +55,13 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept iterateVectorValuesOnce(knnVectorValues); IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); - try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + try ( + final OffHeapVectorTransfer vectorTransfer = getVectorTransfer( + indexInfo.getVectorDataType(), + indexBuildSetup.getBytesPerVector(), + indexInfo.getTotalLiveDocs() + ) + ) { final List transferredDocIds = new ArrayList<>(indexInfo.getTotalLiveDocs()); while (knnVectorValues.docId() != NO_MORE_DOCS) { diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index af3f4777f4..b7e337081d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -7,7 +7,6 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.engine.KNNEngine; @@ -70,10 +69,15 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept ) ); - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); - try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + try ( + final OffHeapVectorTransfer vectorTransfer = getVectorTransfer( + indexInfo.getVectorDataType(), + indexBuildSetup.getBytesPerVector(), + indexInfo.getTotalLiveDocs() + ) + ) { - final List transferredDocIds = new ArrayList<>(transferLimit); + final List transferredDocIds = new ArrayList<>(vectorTransfer.getTransferLimit()); while (knnVectorValues.docId() != NO_MORE_DOCS) { Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup); diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java index ffa12a2315..964007fc09 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java @@ -17,8 +17,8 @@ */ public final class OffHeapBinaryVectorTransfer extends OffHeapVectorTransfer { - public OffHeapBinaryVectorTransfer(int transferLimit) { - super(transferLimit); + public OffHeapBinaryVectorTransfer(int bytesPerVector, int totalVectorsToTransfer) { + super(bytesPerVector, totalVectorsToTransfer); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java index 83ebf2fa3e..16e3334789 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java @@ -17,8 +17,8 @@ */ public final class OffHeapByteVectorTransfer extends OffHeapVectorTransfer { - public OffHeapByteVectorTransfer(int transferLimit) { - super(transferLimit); + public OffHeapByteVectorTransfer(int bytesPerVector, int totalVectorsToTransfer) { + super(bytesPerVector, totalVectorsToTransfer); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java index 0eb28d791e..767f572718 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java @@ -15,8 +15,8 @@ */ public final class OffHeapFloatVectorTransfer extends OffHeapVectorTransfer { - public OffHeapFloatVectorTransfer(int transferLimit) { - super(transferLimit); + public OffHeapFloatVectorTransfer(int bytesPerVector, int totalVectorsToTransfer) { + super(bytesPerVector, totalVectorsToTransfer); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java index 43c27c8da9..8a248e06c4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.codec.transfer; import lombok.Getter; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import java.io.Closeable; @@ -27,16 +28,22 @@ public abstract class OffHeapVectorTransfer implements Closeable { @Getter private long vectorAddress; + @Getter protected final int transferLimit; - private final List vectorsToTransfer; + private List vectorsToTransfer; - public OffHeapVectorTransfer(final int transferLimit) { - this.transferLimit = transferLimit; - this.vectorsToTransfer = new ArrayList<>(transferLimit); + public OffHeapVectorTransfer(int bytesPerVector, int totalVectorsToTransfer) { + this.transferLimit = computeTransferLimit(bytesPerVector, totalVectorsToTransfer); + this.vectorsToTransfer = new ArrayList<>(this.transferLimit); this.vectorAddress = 0; } + private int computeTransferLimit(int bytesPerVector, int totalVectorsToTransfer) { + int limit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); + return Math.min(limit, totalVectorsToTransfer); + } + /** * Transfer vectors to off-heap * @param vector float[] or byte[] @@ -90,7 +97,7 @@ public void close() { */ public void reset() { vectorAddress = 0; - vectorsToTransfer.clear(); + vectorsToTransfer = null; } protected abstract void deallocate(); diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java index 446b6ae806..3bc55f7fa4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java @@ -18,18 +18,23 @@ public final class OffHeapVectorTransferFactory { /** * Gets the right vector transfer object based on vector data type * @param vectorDataType {@link VectorDataType} - * @param transferLimit max number of vectors that can be transferred to off heap in one transfer + * @param bytesPerVector Bytes used per vector + * @param totalVectorsToTransfer total number of vectors that will be transferred off heap * @return Correct implementation of {@link OffHeapVectorTransfer} * @param float[] or byte[] */ - public static OffHeapVectorTransfer getVectorTransfer(final VectorDataType vectorDataType, final int transferLimit) { + public static OffHeapVectorTransfer getVectorTransfer( + final VectorDataType vectorDataType, + int bytesPerVector, + int totalVectorsToTransfer + ) { switch (vectorDataType) { case FLOAT: - return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(transferLimit); + return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(bytesPerVector, totalVectorsToTransfer); case BINARY: - return (OffHeapVectorTransfer) new OffHeapBinaryVectorTransfer(transferLimit); + return (OffHeapVectorTransfer) new OffHeapBinaryVectorTransfer(bytesPerVector, totalVectorsToTransfer); case BYTE: - return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(transferLimit); + return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(bytesPerVector, totalVectorsToTransfer); default: throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 265876310e..f149fa1d28 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -782,6 +782,8 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() { .vectorDataType(vectorDataType) .versionCreated(indexCreatedVersion) .dimension(fieldType().getKnnMappingConfig().getDimension()) + .compressionLevel(fieldType().getKnnMappingConfig().getCompressionLevel()) + .mode(fieldType().getKnnMappingConfig().getMode()) .build(); } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index f70d9cce6c..abb61ccd93 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -57,14 +57,11 @@ public void testBuildAndWrite() { final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); try ( - MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class); MockedStatic mockedJNIService = mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = mockStatic(OffHeapVectorTransferFactory.class) ) { - - mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); @@ -131,7 +128,7 @@ public void testBuildAndWrite_withQuantization() { mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); QuantizationService quantizationService = mock(QuantizationService.class); @@ -237,14 +234,12 @@ public void testBuildAndWriteWithModel() { docs ); try ( - MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class); MockedStatic mockedJNIService = mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = mockStatic(OffHeapVectorTransferFactory.class) ) { - mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 37a738dae0..77abe1cd2c 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -9,8 +9,6 @@ import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.Mockito; -import org.opensearch.core.common.unit.ByteSizeValue; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; @@ -49,7 +47,6 @@ public void testBuildAndWrite() { final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); try ( - MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( OffHeapVectorTransferFactory.class @@ -57,13 +54,13 @@ public void testBuildAndWrite() { ) { // Limits transfer to 2 vectors - mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); + when(offHeapVectorTransfer.getTransferLimit()).thenReturn(2); when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) .thenReturn(true) .thenReturn(false); @@ -145,7 +142,6 @@ public void testBuildAndWrite_withQuantization() { final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); try ( - MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( OffHeapVectorTransferFactory.class @@ -154,11 +150,11 @@ public void testBuildAndWrite_withQuantization() { ) { // Limits transfer to 2 vectors - mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + when(offHeapVectorTransfer.getTransferLimit()).thenReturn(2); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); QuantizationService quantizationService = mock(QuantizationService.class); diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java index 39415d811b..09984ba46d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java @@ -5,22 +5,30 @@ package org.opensearch.knn.index.codec.transfer; +import org.mockito.MockedStatic; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.test.OpenSearchTestCase; +import static org.mockito.Mockito.mockStatic; + public class OffHeapVectorTransferFactoryTests extends OpenSearchTestCase { public void testOffHeapVectorTransferFactory() { - var floatVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10); - assertEquals(OffHeapFloatVectorTransfer.class, floatVectorTransfer.getClass()); - assertNotSame(floatVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10)); + try (MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class)) { + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + var floatVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10, 10); + assertEquals(OffHeapFloatVectorTransfer.class, floatVectorTransfer.getClass()); + assertNotSame(floatVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10, 10)); - var byteVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10); - assertEquals(OffHeapByteVectorTransfer.class, byteVectorTransfer.getClass()); - assertNotSame(byteVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10)); + var byteVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10, 10); + assertEquals(OffHeapByteVectorTransfer.class, byteVectorTransfer.getClass()); + assertNotSame(byteVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10, 10)); - var binaryVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10); - assertEquals(OffHeapBinaryVectorTransfer.class, binaryVectorTransfer.getClass()); - assertNotSame(binaryVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10)); + var binaryVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10, 10); + assertEquals(OffHeapBinaryVectorTransfer.class, binaryVectorTransfer.getClass()); + assertNotSame(binaryVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10, 10)); + } } } diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java index f1650db8f3..fb2ef274e6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java @@ -6,10 +6,15 @@ package org.opensearch.knn.index.codec.transfer; import lombok.SneakyThrows; +import org.mockito.MockedStatic; +import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; import java.util.List; +import static org.mockito.Mockito.mockStatic; + public class OffHeapVectorTransferTests extends KNNTestCase { @SneakyThrows @@ -22,21 +27,27 @@ public void testFloatTransfer() { new float[] { 0.3f, 0.4f } ); - OffHeapFloatVectorTransfer vectorTransfer = new OffHeapFloatVectorTransfer(2); - long vectorAddress = 0; - assertFalse(vectorTransfer.transfer(vectors.get(0), false)); - assertEquals(0, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(1), false)); - vectorAddress = vectorTransfer.getVectorAddress(); - assertFalse(vectorTransfer.transfer(vectors.get(2), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(3), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertFalse(vectorTransfer.transfer(vectors.get(4), false)); - assertTrue(vectorTransfer.flush(false)); - vectorTransfer.reset(); - assertEquals(0, vectorTransfer.getVectorAddress()); - vectorTransfer.close(); + try (MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class)) { + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + + OffHeapFloatVectorTransfer vectorTransfer = new OffHeapFloatVectorTransfer(8, 5); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.reset(); + assertEquals(0, vectorTransfer.getVectorAddress()); + vectorTransfer.close(); + + } + } @SneakyThrows @@ -49,20 +60,23 @@ public void testByteTransfer() { new byte[] { 8, 9 } ); - OffHeapByteVectorTransfer vectorTransfer = new OffHeapByteVectorTransfer(2); - long vectorAddress = 0; - assertFalse(vectorTransfer.transfer(vectors.get(0), false)); - assertEquals(0, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(1), false)); - vectorAddress = vectorTransfer.getVectorAddress(); - assertFalse(vectorTransfer.transfer(vectors.get(2), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(3), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertFalse(vectorTransfer.transfer(vectors.get(4), false)); - assertTrue(vectorTransfer.flush(false)); - vectorTransfer.close(); - assertEquals(0, vectorTransfer.getVectorAddress()); + try (MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class)) { + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(4)); + OffHeapByteVectorTransfer vectorTransfer = new OffHeapByteVectorTransfer(2, 5); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); + assertEquals(0, vectorTransfer.getVectorAddress()); + } } @SneakyThrows @@ -75,18 +89,21 @@ public void testBinaryTransfer() { new byte[] { 8, 9 } ); - OffHeapBinaryVectorTransfer vectorTransfer = new OffHeapBinaryVectorTransfer(2); - long vectorAddress = 0; - assertFalse(vectorTransfer.transfer(vectors.get(0), false)); - assertEquals(0, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(1), false)); - vectorAddress = vectorTransfer.getVectorAddress(); - assertFalse(vectorTransfer.transfer(vectors.get(2), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(3), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertFalse(vectorTransfer.transfer(vectors.get(4), false)); - assertTrue(vectorTransfer.flush(false)); - vectorTransfer.close(); + try (MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class)) { + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(4)); + OffHeapBinaryVectorTransfer vectorTransfer = new OffHeapBinaryVectorTransfer(2, 5); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); + } } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 4966413399..84cbf05dc3 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -880,6 +880,65 @@ public void testTypeParser_parse_fromLegacy() throws IOException { assertNull(builder.knnMethodContext.get()); } + public void testKNNVectorFieldMapperMerge_whenModeAndCompressionIsPresent_thenSuccess() throws IOException { + String fieldName = "test-field-name"; + String indexName = "test-index-name"; + + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + int dimension = 133; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x32.getName()) + .endObject(); + + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); + Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); + KNNVectorFieldMapper knnVectorFieldMapper1 = builder.build(builderContext); + + // merge with itself - should be successful + KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getKnnMethodContext().get() + ); + + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getCompressionLevel(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getCompressionLevel() + ); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getMode(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getMode() + ); + + // merge with another mapper of the same field with same context + KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext); + KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getKnnMethodContext().get() + ); + + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getCompressionLevel(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getCompressionLevel() + ); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getMode(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getMode() + ); + } + public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 4c8b44e5b1..13e6756d13 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -7,16 +7,18 @@ import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Assert; import org.junit.Ignore; +import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; -import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.mapper.ModeBasedResolver; @@ -50,7 +52,7 @@ public class ModeAndCompressionIT extends KNNRestTestCase { private static final int DIMENSION = 16; private static final int NUM_DOCS = 20; - private static final int K = 2; + private static final int K = NUM_DOCS; private final static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f, @@ -210,7 +212,7 @@ public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() { .endObject(); String mapping = builder.toString(); validateIndexWithDeletedDocs(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH); + validateGreenIndex(indexName); } @SneakyThrows @@ -352,6 +354,19 @@ private void validateIndexWithDeletedDocs(String indexName, String mapping) { refreshIndex(indexName); } + @SneakyThrows + private void validateGreenIndex(String indexName) { + Request request = new Request("GET", "/_cat/indices/" + indexName + "?format=csv"); + Response response = client().performRequest(request); + assertOK(response); + assertEquals( + "The status of index " + indexName + " is not green", + "green", + new String(response.getEntity().getContent().readAllBytes()).split("\n")[0].split(" ")[0] + ); + + } + @SneakyThrows private void setupTrainingIndex() { createBasicKnnIndex(TRAINING_INDEX_NAME, TRAINING_FIELD_NAME, DIMENSION); @@ -391,6 +406,38 @@ private void validateSearch(String indexName, String methodParameterName, int me List knnResults = parseSearchResponse(responseBody, FIELD_NAME); assertTrue(knnResults.size() > 0); + // Do exact search and gather right scores for the documents + Response exactSearchResponse = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("script_score") + .startObject("query") + .field("match_all") + .startObject() + .endObject() + .endObject() + .startObject("script") + .field("source", "knn_score") + .field("lang", "knn") + .startObject("params") + .field("field", FIELD_NAME) + .field("query_value", TEST_VECTOR) + .field("space_type", SpaceType.L2.getValue()) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(exactSearchResponse); + String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity()); + List exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME); + assertEquals(NUM_DOCS, exactSearchKnnResults.size()); + Assert.assertEquals(exactSearchKnnResults, knnResults); + // Search with rescore response = searchKNNIndex( indexName,