diff --git a/CHANGELOG.md b/CHANGELOG.md index 17f47690d..7c504d813 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/)) ### Enhancements * Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259)) +* Added validations for score combination weights in Hybrid Search ([#265](https://github.com/opensearch-project/neural-search/pull/265)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index cfafeb3e5..e656beca3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -38,6 +38,7 @@ public ArithmeticMeanScoreCombinationTechnique(final Map params, */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float combinedScore = 0.0f; float sumOfWeights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index 4e7a8ca9e..2a78d5ac6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -40,6 +40,7 @@ public GeometricMeanScoreCombinationTechnique(final Map params, */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float weightedLnSum = 0; float sumOfWeights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 9f913b2ef..0b45fb616 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -38,6 +38,7 @@ public HarmonicMeanScoreCombinationTechnique(final Map params, f */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float sumOfWeights = 0; float sumOfHarmonics = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index 35e097f7f..ed82a62ea 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor.combination; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; @@ -13,11 +14,19 @@ import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.Range; + +import com.google.common.math.DoubleMath; + /** * Collection of utility methods for score combination technique classes */ +@Log4j2 class ScoreCombinationUtil { private static final String PARAM_NAME_WEIGHTS = "weights"; + private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; /** * Get collection of weights based on user provided config @@ -29,9 +38,11 @@ public List getWeights(final Map params) { return List.of(); } // get weights, we don't need to check for instance as it's done during validation - return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + List weightsList = ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() .map(Double::floatValue) .collect(Collectors.toUnmodifiableList()); + validateWeights(weightsList); + return weightsList; } /** @@ -77,4 +88,55 @@ public void validateParams(final Map actualParams, final Set weights, final int indexOfSubQuery) { return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; } + + /** + * Check if number of weights matches number of queries. This does not apply for case when + * weights were not provided, as this is valid default value + * @param scores collection of scores from all sub-queries of a single hybrid search query + * @param weights score combination weights that are defined as part of search result processor + */ + protected void validateIfWeightsMatchScores(final float[] scores, final List weights) { + if (weights.isEmpty()) { + return; + } + if (scores.length != weights.size()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "number of weights [%d] must match number of sub-queries [%d] in hybrid query", + weights.size(), + scores.length + ) + ); + } + } + + /** + * Check if provided weights are valid for combination. Following conditions are checked: + * - every weight is between 0.0 and 1.0 + * - sum of all weights must be equal 1.0 + * @param weightsList + */ + private void validateWeights(final List weightsList) { + boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight)); + if (isOutOfRange) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "all weights must be in range [0.0 ... 1.0], submitted weights: %s", + Arrays.toString(weightsList.toArray(new Float[0])) + ) + ); + } + float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum); + if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "sum of weights for combination must be equal to 1.0, submitted weights: %s", + Arrays.toString(weightsList.toArray(new Float[0])) + ) + ); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index df15f40fd..03b77549a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.processor; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; @@ -18,6 +20,7 @@ import org.junit.After; import org.junit.Before; +import org.opensearch.client.ResponseException; import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; @@ -96,7 +99,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.4f, 0.3f, 0.3f })) ); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -112,7 +115,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights1AsMap, 0.375, 0.3125, 0.001); + assertWeightedScores(searchResponseWithWeights1AsMap, 0.4, 0.3, 0.001); // delete existing pipeline and create a new one with another set of weights deleteSearchPipeline(SEARCH_PIPELINE); @@ -120,7 +123,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.233f, 0.666f, 0.1f })) ); Map searchResponseWithWeights2AsMap = search( @@ -131,7 +134,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights2AsMap, 0.606, 0.242, 0.001); + assertWeightedScores(searchResponseWithWeights2AsMap, 0.6666, 0.2332, 0.001); // check case when number of weights is less than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -140,18 +143,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 1.0f })) ); - Map searchResponseWithWeights3AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) + ResponseException exception1 = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE)) + ); + org.hamcrest.MatcherAssert.assertThat( + exception1.getMessage(), + allOf( + containsString("number of weights"), + containsString("must match number of sub-queries"), + containsString("in hybrid query") + ) ); - - assertWeightedScores(searchResponseWithWeights3AsMap, 0.357, 0.285, 0.001); // check case when number of weights is more than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -160,18 +166,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.25f, 0.25f, 0.2f })) ); - Map searchResponseWithWeights4AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) + ResponseException exception2 = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE)) + ); + org.hamcrest.MatcherAssert.assertThat( + exception2.getMessage(), + allOf( + containsString("number of weights"), + containsString("must match number of sub-queries"), + containsString("in hybrid query") + ) ); - - assertWeightedScores(searchResponseWithWeights4AsMap, 0.375, 0.3125, 0.001); } /** @@ -199,7 +208,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -223,7 +232,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); @@ -265,7 +274,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -289,7 +298,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 36af1a712..6f98e8d5e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -91,7 +91,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); @@ -138,7 +138,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); @@ -180,7 +180,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -204,7 +204,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); @@ -227,7 +227,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 842df736d..125930007 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -33,9 +31,7 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { } public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil @@ -44,20 +40,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.825f; + float expectedScore = 0.69f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java index fe0d962ca..3f70c229f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -34,19 +32,17 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List scores = List.of(1.0f, 0.5f, 0.3f); - List weights = List.of(0.9, 0.2, 0.7); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.5797f; + float expectedScore = 0.5567f; testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil @@ -55,20 +51,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.7997f; + float expectedScore = 0.7863f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 8187123a1..7b1b07f64 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -34,30 +32,28 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List scores = List.of(1.0f, 0.5f, 0.3f); - List weights = List.of(0.9, 0.2, 0.7); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expecteScore = 0.4954f; - testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expecteScore); + float expectedScore = 0.48f; + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.7741f; + float expectedScore = 0.7611f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java new file mode 100644 index 000000000..ca13cb2eb --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.util.List; +import java.util.Map; + +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase { + + public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() { + ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + List weights = scoreCombinationUtil.getWeights(Map.of()); + assertNotNull(weights); + assertTrue(weights.isEmpty()); + } + + public void testCombinationWeights_whenWeightsArePassed_thenSuccessful() { + ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + List weights = scoreCombinationUtil.getWeights(Map.of("weights", List.of(0.4, 0.6))); + assertNotNull(weights); + assertEquals(2, weights.size()); + assertTrue(weights.containsAll(List.of(0.4f, 0.6f))); + } + + public void testCombinationWeights_whenInvalidWeightsArePassed_thenFail() { + ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationUtil.getWeights(Map.of("weights", List.of(2.4))) + ); + assertTrue(exception1.getMessage().contains("all weights must be in range")); + + IllegalArgumentException exception2 = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationUtil.getWeights(Map.of("weights", List.of(0.4, 0.5, 0.6))) + ); + assertTrue(exception2.getMessage().contains("sum of weights for combination must be equal to 1.0")); + } + + public void testWeightsValidation_whenNumberOfScoresDifferentFromNumberOfWeights_thenFail() { + ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationUtil.validateIfWeightsMatchScores(new float[] { 0.6f, 0.5f }, List.of(0.4f, 0.2f, 0.4f)) + ); + org.hamcrest.MatcherAssert.assertThat( + exception1.getMessage(), + allOf( + containsString("number of weights"), + containsString("must match number of sub-queries"), + containsString("in hybrid query") + ) + ); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java index 83bb0e7bb..a1ddefe16 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -132,17 +132,14 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful( String tag = "tag"; String description = "description"; boolean ignoreFailure = false; + double weight1 = RandomizedTest.randomDouble(); + double weight2 = 1.0f - weight1; Map config = new HashMap<>(); config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); config.put( COMBINATION_CLAUSE, new HashMap<>( - Map.of( - TECHNIQUE, - "arithmetic_mean", - PARAMETERS, - new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble()))) - ) + Map.of(TECHNIQUE, "arithmetic_mean", PARAMETERS, new HashMap<>(Map.of("weights", Arrays.asList(weight1, weight2)))) ) ); Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); @@ -160,6 +157,43 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful( assertEquals("normalization-processor", normalizationProcessor.getType()); } + @SneakyThrows + public void testWeightsParams_whenInvalidValues_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); + config.put( + COMBINATION_CLAUSE, + new HashMap<>( + Map.of( + TECHNIQUE, + "arithmetic_mean", + PARAMETERS, + new HashMap<>( + Map.of( + "weights", + Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble(), RandomizedTest.randomDouble()) + ) + ) + ) + ) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("sum of weights for combination must be equal to 1.0")); + } + public void testInputValidation_whenInvalidNormalizationClause_thenFail() { NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),