From 4b554d35449f923362c91625dada729327c8f677 Mon Sep 17 00:00:00 2001 From: VIKASH TIWARI Date: Thu, 3 Oct 2024 15:48:35 -0700 Subject: [PATCH] Score Fix for Binary Quantized Vector and Setting Default value in case of shard level rescoring is disabled for oversampling factor Signed-off-by: VIKASH TIWARI --- CHANGELOG.md | 1 + .../org/opensearch/knn/index/KNNSettings.java | 7 +- .../knn/index/mapper/CompressionLevel.java | 39 +++---- .../opensearch/knn/index/query/KNNWeight.java | 4 + .../nativelib/NativeEngineKnnVectorQuery.java | 6 +- .../index/query/rescore/RescoreContext.java | 63 ++++++++++- .../knn/index/KNNSettingsTests.java | 10 +- .../index/mapper/CompressionLevelTests.java | 76 +++++-------- .../knn/index/query/KNNQueryBuilderTests.java | 3 +- .../knn/index/query/KNNWeightTests.java | 106 ++++++++++++++++++ .../NativeEngineKNNVectorQueryTests.java | 6 +- .../query/rescore/RescoreContextTests.java | 79 ++++++++----- 12 files changed, 285 insertions(+), 115 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6871074f0..44e29a0ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor[#2172](https://github.com/opensearch-project/k-NN/pull/2172) ### Bug Fixes * KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147) +* Score Fix for Binary Quantized Vector and Setting Default value in case of shard level rescoring is disabled for oversampling factor[#2183](https://github.com/opensearch-project/k-NN/pull/2183) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 1753140e6..f5980879a 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -92,6 +92,7 @@ public class KNNSettings { /** * Default setting values + * */ public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false; public static final boolean KNN_DEFAULT_FAISS_AVX512_DISABLED_VALUE = false; @@ -113,7 +114,7 @@ public class KNNSettings { public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed // 10% of the JVM heap public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; - public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = true; + public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = false; /** * Settings Definition @@ -554,12 +555,12 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } - public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) { + public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) { return KNNSettings.state().clusterService.state() .getMetadata() .index(indexName) .getSettings() - .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, true); + .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, false); } public void initialize(Client client, ClusterService clusterService) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index c9a169efc..ab583a2e0 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -25,9 +25,9 @@ public enum CompressionLevel { x1(1, "1x", null, Collections.emptySet()), x2(2, "2x", null, Collections.emptySet()), x4(4, "4x", null, Collections.emptySet()), - x8(8, "8x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK)), - x16(16, "16x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK)), - x32(32, "32x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK)); + x8(8, "8x", new RescoreContext(2.0f, false), Set.of(Mode.ON_DISK)), + x16(16, "16x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)), + x32(32, "32x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)); // Internally, an empty string is easier to deal with them null. However, from the mapping, // we do not want users to pass in the empty string and instead want null. So we make the conversion here @@ -97,32 +97,33 @@ public static boolean isConfigured(CompressionLevel compressionLevel) { /** * Returns the appropriate {@link RescoreContext} based on the given {@code mode} and {@code dimension}. * - *

If the {@code mode} is present in the valid {@code modesForRescore} set, the method adjusts the oversample factor based on the - * {@code dimension} value: + *

If the {@code mode} is present in the valid {@code modesForRescore} set, the method checks the value of + * {@code dimension}: *

- * If the {@code mode} is not present in the {@code modesForRescore} set, the method returns {@code null}. + * If the {@code mode} is not valid, the method returns {@code null}. * * @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}. * @param dimension The dimensional value that determines the {@link RescoreContext} behavior. - * @return A {@link RescoreContext} with the appropriate oversample factor based on the dimension, or {@code null} if the mode - * is not valid. + * @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than + * or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode + * is invalid. */ public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) { if (modesForRescore.contains(mode)) { // Adjust RescoreContext based on dimension - if (dimension >= RescoreContext.DIMENSION_THRESHOLD_1000) { - // No oversampling for dimensions >= 1000 - return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_1000).build(); - } else if (dimension >= RescoreContext.DIMENSION_THRESHOLD_768) { - // 2x oversampling for dimensions >= 768 but < 1000 - return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_768).build(); + if (dimension <= RescoreContext.DIMENSION_THRESHOLD) { + // For dimensions <= 1000, return a RescoreContext with 5.0f oversample factor + return RescoreContext.builder() + .oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD) + .userProvided(false) + .build(); } else { - // 3x oversampling for dimensions < 768 - return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_768).build(); + return defaultRescoreContext; } } return null; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 1c31ed725..c75d400a7 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -375,6 +375,10 @@ private Map doANNSearch( return null; } + if (quantizedVector != null) { + return Arrays.stream(results) + .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), SpaceType.HAMMING))); + } return Arrays.stream(results) .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index adb2875d5..c97a0d061 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -61,9 +61,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo if (rescoreContext == null) { perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK); } else { - int firstPassK = rescoreContext.getFirstPassK(finalK); + boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); + int dimension = knnQuery.getQueryVector().length; + int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); - if (KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()) == false) { + if (isShardLevelRescoringEnabled == true) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index a2563b2a6..0f8c59499 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -39,21 +39,74 @@ public final class RescoreContext { @Builder.Default private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR; + /** + * Flag to track whether the oversample factor is user-provided or default. The Reason to introduce + * this is to set default when Shard Level rescoring is false, + * else we end up overriding user provided value in NativeEngineKnnVectorQuery + * + * + * This flag is crucial to differentiate between user-defined oversample factors and system-assigned + * default values. The behavior of oversampling logic, especially when shard-level rescoring is disabled, + * depends on whether the user explicitly provided an oversample factor or whether the system is using + * a default value. + * + * When shard-level rescoring is disabled, the system applies dimension-based oversampling logic, + * overriding any default values. However, if the user provides their own oversample factor, the system + * should respect the user’s input and avoid overriding it with the dimension-based logic. + * + * This flag is set to {@code true} when the oversample factor is provided by the user, ensuring + * that their value is not overridden. It is set to {@code false} when the oversample factor is + * determined by system defaults (e.g., through a compression level or automatic logic). The system + * then applies its own oversampling rules if necessary. + * + * Key scenarios: + * - If {@code userProvided} is {@code true} and shard-level rescoring is disabled, the user's + * oversample factor is used as is, without applying the dimension-based logic. + * - If {@code userProvided} is {@code false}, the system applies dimension-based oversampling + * when shard-level rescoring is disabled. + * + * This flag enables flexibility, allowing the system to handle both user-defined and default + * behaviors, ensuring the correct oversampling logic is applied based on the context. + */ + @Builder.Default + private boolean userProvided = true; + /** * * @return default RescoreContext */ public static RescoreContext getDefault() { - return RescoreContext.builder().build(); + return RescoreContext.builder().oversampleFactor(DEFAULT_OVERSAMPLE_FACTOR).userProvided(false).build(); } /** - * Gets the number of results to return for the first pass of rescoring. + * Calculates the number of results to return for the first pass of rescoring (firstPassK). + * This method considers whether shard-level rescoring is enabled and adjusts the oversample factor + * based on the vector dimension if shard-level rescoring is disabled. * - * @param finalK The final number of results to return for the entire shard - * @return The number of results to return for the first pass of rescoring + * @param finalK The final number of results to return for the entire shard. + * @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled. + * If true, the dimension-based oversampling logic is bypassed. + * @param dimension The dimension of the vector. This is used to determine the oversampling factor when + * shard-level rescoring is disabled. + * @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor. */ - public int getFirstPassK(int finalK) { + public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) { + // Only apply default dimension-based oversampling logic when: + // 1. Shard-level rescoring is disabled + // 2. The oversample factor was not provided by the user + if (!isShardLevelRescoringEnabled && !userProvided) { + // Apply new dimension-based oversampling logic when shard-level rescoring is disabled + if (dimension >= DIMENSION_THRESHOLD_1000) { + oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000 + } else if (dimension >= DIMENSION_THRESHOLD_768) { + oversampleFactor = OVERSAMPLE_FACTOR_768; // 2x oversampling for dimensions >= 768 and < 1000 + } else { + oversampleFactor = OVERSAMPLE_FACTOR_BELOW_768; // 3x oversampling for dimensions < 768 + } + } + // The calculation for firstPassK remains the same, applying the oversample factor return Math.min(MAX_FIRST_PASS_RESULTS, Math.max(MIN_FIRST_PASS_RESULTS, (int) Math.ceil(finalK * oversampleFactor))); } + } diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index fd25699cc..c7a8e7ed8 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -159,7 +159,7 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() { } @SneakyThrows - public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { + public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { Node mockNode = createMockNode(Collections.emptyMap()); mockNode.start(); ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); @@ -167,14 +167,14 @@ public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefau mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); KNNSettings.state().setClusterService(clusterService); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); mockNode.close(); - assertTrue(shardLevelRescoringDisabled); + assertFalse(shardLevelRescoringDisabled); } @SneakyThrows public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingApplied() { - boolean userDefinedRescoringDisabled = false; + boolean userDefinedRescoringDisabled = true; Node mockNode = createMockNode(Collections.emptyMap()); mockNode.start(); ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); @@ -188,7 +188,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet(); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java index 57372b11e..e882d6697 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java @@ -45,83 +45,57 @@ public void testGetDefaultRescoreContext() { // Test rescore context for ON_DISK mode Mode mode = Mode.ON_DISK; - // Test various dimensions based on the updated oversampling logic - int belowThresholdDimension = 500; // A dimension below 768 - int between768and1000Dimension = 800; // A dimension between 768 and 1000 - int above1000Dimension = 1500; // A dimension above 1000 + int belowThresholdDimension = 500; // A dimension below the threshold + int aboveThresholdDimension = 1500; // A dimension above the threshold - // Compression level x32 with dimension < 768 should have an oversample factor of 3.0f + // x32 with dimension <= 1000 should have an oversample factor of 5.0f RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - - // Compression level x32 with dimension between 768 and 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, between768and1000Dimension); - assertNotNull(rescoreContext); - assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x32 with dimension > 1000 should have no oversampling (1.0f) - rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, above1000Dimension); - assertNotNull(rescoreContext); - assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); - - // Compression level x16 with dimension < 768 should have an oversample factor of 3.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); + // x32 with dimension > 1000 should have an oversample factor of 3.0f + rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNotNull(rescoreContext); assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x16 with dimension between 768 and 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, between768and1000Dimension); + // x16 with dimension <= 1000 should have an oversample factor of 5.0f + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x16 with dimension > 1000 should have no oversampling (1.0f) - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, above1000Dimension); + // x16 with dimension > 1000 should have an oversample factor of 3.0f + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNotNull(rescoreContext); - assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x8 with dimension < 768 should have an oversample factor of 3.0f + // x8 with dimension <= 1000 should have an oversample factor of 5.0f rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - - // Compression level x8 with dimension between 768 and 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, between768and1000Dimension); + assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + // x8 with dimension > 1000 should have an oversample factor of 2.0f + rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNotNull(rescoreContext); assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x8 with dimension > 1000 should have no oversampling (1.0f) - rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, above1000Dimension); - assertNotNull(rescoreContext); - assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); - - // Compression level x4 with dimension < 768 should return null (no RescoreContext) + // x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext) rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - - // Compression level x4 with dimension > 1000 should return null (no RescoreContext) - rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, above1000Dimension); + // x4 with dimension > 1000 should return null (no RescoreContext is configured for x4) + rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNull(rescoreContext); - - // Compression level x2 with dimension < 768 should return null + // Other compression levels should behave similarly with respect to dimension rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - - // Compression level x2 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, above1000Dimension); + // x2 with dimension > 1000 should return null + rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNull(rescoreContext); - - // Compression level x1 with dimension < 768 should return null rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - - // Compression level x1 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, above1000Dimension); + // x1 with dimension > 1000 should return null + rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNull(rescoreContext); - - // NOT_CONFIGURED mode should return null for any dimension + // NOT_CONFIGURED with dimension <= 1000 should return a RescoreContext with an oversample factor of 5.0f rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); } - } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 3db03085b..b28b790d1 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -912,7 +912,8 @@ private void assertRescore(Version version, RescoreContext expectedRescoreContex } if (expectedRescoreContext != null) { - assertEquals(expectedRescoreContext, actualRescoreContext); + assertNotNull(actualRescoreContext); + assertEquals(expectedRescoreContext.getOversampleFactor(), actualRescoreContext.getOversampleFactor(), 0.0f); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index f92f32406..2a2c3ed4d 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -79,6 +79,7 @@ import static java.util.Collections.emptyMap; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyFloat; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; @@ -516,6 +517,111 @@ public void testANNWithFilterQuery_whenDoingANNBinary_thenSuccess() { validateANNWithFilterQuery_whenDoingANN_thenSuccess(true); } + @SneakyThrows + public void testScorerWithQuantizedVector() { + // Given + int k = 3; + byte[] quantizedVector = new byte[] { 1, 2, 3 }; // Mocked quantized vector + float[] queryVector = new float[] { 0.1f, 0.3f }; + + // Mock the JNI service to return KNNQueryResults + KNNQueryResult[] knnQueryResults = new KNNQueryResult[] { + new KNNQueryResult(1, 10.0f), // Mock result with id 1 and score 10 + new KNNQueryResult(2, 20.0f) // Mock result with id 2 and score 20 + }; + jniServiceMockedStatic.when( + () -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any()) + ).thenReturn(knnQueryResults); + + KNNEngine knnEngine = mock(KNNEngine.class); + when(knnEngine.score(anyFloat(), eq(SpaceType.HAMMING))).thenAnswer(invocation -> { + Float score = invocation.getArgument(0); + return 1 / (1 + score); + }); + + // Build the KNNQuery object + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .k(k) + .indexName(INDEX_NAME) + .vectorDataType(VectorDataType.BINARY) // Simulate binary vector type for quantization + .build(); + + final float boost = 1.0F; + final KNNWeight knnWeight = new KNNWeight(query, boost); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo); + + when(fieldInfo.attributes()).thenReturn(Map.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.HAMMING.getValue())); + + FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + when(path.toString()).thenReturn("/fake/directory"); + + SegmentInfo segmentInfo = new SegmentInfo( + directory, // The directory where the segment is stored + Version.LATEST, // Lucene version + Version.LATEST, // Version of the segment info + "0", // Segment name + 100, // Max document count for this segment + false, // Is this a compound file segment + false, // Is this a merged segment + KNNCodecVersion.current().getDefaultCodecDelegate(), // Codec delegate for KNN + Map.of(), // Diagnostics map + new byte[StringHelper.ID_LENGTH], // Segment ID + Map.of(), // Attributes + Sort.RELEVANCE // Default sort order + ); + + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + try (MockedStatic knnCodecUtilMockedStatic = mockStatic(KNNCodecUtil.class)) { + List engineFiles = List.of("_0_1_target_field.faiss"); + knnCodecUtilMockedStatic.when(() -> KNNCodecUtil.getEngineFiles(anyString(), anyString(), eq(segmentInfo))) + .thenReturn(engineFiles); + + try (MockedStatic quantizationUtilMockedStatic = mockStatic(SegmentLevelQuantizationUtil.class)) { + quantizationUtilMockedStatic.when(() -> SegmentLevelQuantizationUtil.quantizeVector(any(), any())) + .thenReturn(quantizedVector); + + // When: Call the scorer method + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then: Ensure scorer is not null + assertNotNull(knnScorer); + + // Verify that JNIService.queryBinaryIndex is called with the quantized vector + jniServiceMockedStatic.verify( + () -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any()), + times(1) + ); + + // Iterate over the results and ensure they are scored with SpaceType.HAMMING + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + while (docIdSetIterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + int docId = docIdSetIterator.docID(); + float expectedScore = knnEngine.score(knnQueryResults[docId - 1].getScore(), SpaceType.HAMMING); + float actualScore = knnScorer.score(); + // Check if the score is calculated using HAMMING + assertEquals(expectedScore, actualScore, 0.01f); // Tolerance for floating-point comparison + } + } + } + } + public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean isBinary) throws IOException { // Given int k = 3; diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 7fd96c6df..53873e15f 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -103,6 +103,8 @@ public void setUp() throws Exception { // Set ClusterService in KNNSettings KNNSettings.state().setClusterService(clusterService); + when(knnQuery.getQueryVector()).thenReturn(new float[] { 1.0f, 2.0f, 3.0f }); // Example vector + } @SneakyThrows @@ -166,7 +168,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); // Mock ResultUtil to return valid TopDocs mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt())) @@ -250,7 +252,7 @@ public void testRescore() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(true); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf1Results), anyInt())).thenAnswer(t -> topDocs1); diff --git a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java index fd94667db..2b309e4ab 100644 --- a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java +++ b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java @@ -14,47 +14,72 @@ public class RescoreContextTests extends KNNTestCase { public void testGetFirstPassK() { float oversample = 2.6f; - RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build(); + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); int finalK = 100; - assertEquals(260, rescoreContext.getFirstPassK(finalK)); - finalK = 1; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); - finalK = 0; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); - finalK = MAX_FIRST_PASS_RESULTS; - assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); - } - - public void testGetFirstPassKWithMinPassK() { - float oversample = 2.6f; - RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build(); + boolean isShardLevelRescoringEnabled = true; + int dimension = 500; - // Case 1: Test with a finalK that results in a value greater than MIN_FIRST_PASS_RESULTS - int finalK = 100; - assertEquals(260, rescoreContext.getFirstPassK(finalK)); + // Case 1: Test with standard oversample factor when shard-level rescoring is enabled + assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); // Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS finalK = 1; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); - // Case 3: Test with finalK = 0, should return 0 + // Case 3: Test with finalK = 0, should return MIN_FIRST_PASS_RESULTS finalK = 0; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); // Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS finalK = MAX_FIRST_PASS_RESULTS; - assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + } - // Case 5: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS + public void testGetFirstPassKWithDimensionBasedOversampling() { + int finalK = 100; + int dimension; + + // Case 1: Test no oversampling for dimensions >= 1000 when shard-level rescoring is disabled + dimension = 1000; + RescoreContext rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensuring dimension-based logic applies + assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // No oversampling + + // Case 2: Test 2x oversampling for dimensions >= 768 but < 1000 when shard-level rescoring is disabled + dimension = 800; + rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over + assertEquals(200, rescoreContext.getFirstPassK(finalK, false, dimension)); // 2x oversampling + + // Case 3: Test 3x oversampling for dimensions < 768 when shard-level rescoring is disabled + dimension = 700; + rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over + assertEquals(300, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling + + // Case 4: Shard-level rescoring enabled, oversample factor should be used as provided by the user (ignore dimension) + rescoreContext = RescoreContext.builder().oversampleFactor(5.0f).userProvided(true).build(); // Provided by user + dimension = 500; + assertEquals(500, rescoreContext.getFirstPassK(finalK, true, dimension)); // User-defined oversample factor should be used + + // Case 5: Test finalK where oversampling factor results in a value less than MIN_FIRST_PASS_RESULTS finalK = 10; - oversample = 0.5f; // This will result in 5, which is less than MIN_FIRST_PASS_RESULTS - rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build(); - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + dimension = 700; + rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure dimension-based logic applies + assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling results in 30 + } + + public void testGetFirstPassKWithMinPassK() { + float oversample = 0.5f; + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided + boolean isShardLevelRescoringEnabled = false; + + // Case 1: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS + int finalK = 10; + int dimension = 700; + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); - // Case 6: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS + // Case 2: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS finalK = 100; oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS) - rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build(); - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); } }