From 6624246b95c24bfb9c777c665816d450c461abf6 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 28 Aug 2023 14:53:24 -0700 Subject: [PATCH] Added strong check on number of weights equals number of sub-queries Signed-off-by: Martin Gaievski --- ...ithmeticMeanScoreCombinationTechnique.java | 1 + ...eometricMeanScoreCombinationTechnique.java | 1 + ...HarmonicMeanScoreCombinationTechnique.java | 1 + .../combination/ScoreCombinationUtil.java | 67 ++++++++++++++++++- .../processor/ScoreCombinationIT.java | 41 +++++------- .../processor/ScoreNormalizationIT.java | 12 ++-- ...ticMeanScoreCombinationTechniqueTests.java | 16 ++--- ...ricMeanScoreCombinationTechniqueTests.java | 20 ++---- ...nicMeanScoreCombinationTechniqueTests.java | 18 ++--- .../NormalizationProcessorFactoryTests.java | 46 +++++++++++-- 10 files changed, 151 insertions(+), 72 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index cfafeb3e5..e656beca3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -38,6 +38,7 @@ public ArithmeticMeanScoreCombinationTechnique(final Map params, */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float combinedScore = 0.0f; float sumOfWeights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index 4e7a8ca9e..2a78d5ac6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -40,6 +40,7 @@ public GeometricMeanScoreCombinationTechnique(final Map params, */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float weightedLnSum = 0; float sumOfWeights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 9f913b2ef..0b45fb616 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -38,6 +38,7 @@ public HarmonicMeanScoreCombinationTechnique(final Map params, f */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float sumOfWeights = 0; float sumOfHarmonics = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index 35e097f7f..319c1fcbd 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor.combination; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; @@ -13,11 +14,19 @@ import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.Range; + +import com.google.common.math.DoubleMath; + /** * Collection of utility methods for score combination technique classes */ +@Log4j2 class ScoreCombinationUtil { private static final String PARAM_NAME_WEIGHTS = "weights"; + private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; /** * Get collection of weights based on user provided config @@ -29,9 +38,11 @@ public List getWeights(final Map params) { return List.of(); } // get weights, we don't need to check for instance as it's done during validation - return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + List weightsList = ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() .map(Double::floatValue) .collect(Collectors.toUnmodifiableList()); + validateWeights(weightsList); + return weightsList; } /** @@ -77,4 +88,58 @@ public void validateParams(final Map actualParams, final Set weights, final int indexOfSubQuery) { return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; } + + /** + * Check if number of weights matches number of queries. This deos no apply for case when + * weights were not provided. This is default setting and in such case every sub-query has equal weight + * @param scores + * @param weights + */ + protected void validateIfWeightsMatchScores(final float[] scores, final List weights) { + if (weights.isEmpty()) { + return; + } + if (scores.length != weights.size()) { + log.error( + String.format( + Locale.ROOT, + "number of weights [%d] must match number of sub-queries [%d] in hybrid query", + weights.size(), + scores.length + ) + ); + throw new IllegalArgumentException("number of weights must match number of sub-queries in hybrid query"); + } + } + + /** + * Check if provided weights are valid for combination. Following is checked: + * - every weight is between 0.0 and 1.0 + * - sum of all weights must be equal 1.0 + * @param weightsList + */ + private void validateWeights(final List weightsList) { + boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight)); + if (isOutOfRange) { + log.error( + String.format( + Locale.ROOT, + "all weights must be in range [0.0 ... 1.0], submitted weights: %s", + Arrays.toString(weightsList.toArray(new Float[0])) + ) + ); + throw new IllegalArgumentException("all weights must be in range [0.0 ... 1.0]"); + } + float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum); + if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) { + log.error( + String.format( + Locale.ROOT, + "sum of weights for combination must be equal to 1.0, submitted weights: %s", + Arrays.toString(weightsList.toArray(new Float[0])) + ) + ); + throw new IllegalArgumentException("sum of weights for combination must be equal to 1.0"); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index e56532b52..473d80b7e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -18,6 +18,7 @@ import org.junit.After; import org.junit.Before; +import org.opensearch.client.ResponseException; import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; @@ -96,7 +97,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.4f, 0.3f, 0.3f })) ); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -120,7 +121,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.233f, 0.666f, 0.1f })) ); Map searchResponseWithWeights2AsMap = search( @@ -140,18 +141,14 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 1.0f })) ); - Map searchResponseWithWeights3AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) + ResponseException exception1 = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE)) ); - - assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001); + assertTrue(exception1.getMessage().contains("number of weights must match number of sub-queries in hybrid query")); // check case when number of weights is more than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -160,18 +157,14 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.25f, 0.25f, 0.2f })) ); - Map searchResponseWithWeights4AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) + ResponseException exception2 = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE)) ); - - assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); + assertTrue(exception2.getMessage().contains("number of weights must match number of sub-queries in hybrid query")); } /** @@ -199,7 +192,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -223,7 +216,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); @@ -265,7 +258,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -289,7 +282,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 64b3fe07f..7b05b86ee 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -91,7 +91,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); @@ -138,7 +138,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); @@ -180,7 +180,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -204,7 +204,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); @@ -227,7 +227,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 842df736d..125930007 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -33,9 +31,7 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { } public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil @@ -44,20 +40,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.825f; + float expectedScore = 0.69f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java index fe0d962ca..3f70c229f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -34,19 +32,17 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List scores = List.of(1.0f, 0.5f, 0.3f); - List weights = List.of(0.9, 0.2, 0.7); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.5797f; + float expectedScore = 0.5567f; testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil @@ -55,20 +51,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.7997f; + float expectedScore = 0.7863f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 8187123a1..7b1b07f64 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -34,30 +32,28 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List scores = List.of(1.0f, 0.5f, 0.3f); - List weights = List.of(0.9, 0.2, 0.7); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expecteScore = 0.4954f; - testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expecteScore); + float expectedScore = 0.48f; + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.7741f; + float expectedScore = 0.7611f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java index 83bb0e7bb..a1ddefe16 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -132,17 +132,14 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful( String tag = "tag"; String description = "description"; boolean ignoreFailure = false; + double weight1 = RandomizedTest.randomDouble(); + double weight2 = 1.0f - weight1; Map config = new HashMap<>(); config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); config.put( COMBINATION_CLAUSE, new HashMap<>( - Map.of( - TECHNIQUE, - "arithmetic_mean", - PARAMETERS, - new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble()))) - ) + Map.of(TECHNIQUE, "arithmetic_mean", PARAMETERS, new HashMap<>(Map.of("weights", Arrays.asList(weight1, weight2)))) ) ); Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); @@ -160,6 +157,43 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful( assertEquals("normalization-processor", normalizationProcessor.getType()); } + @SneakyThrows + public void testWeightsParams_whenInvalidValues_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); + config.put( + COMBINATION_CLAUSE, + new HashMap<>( + Map.of( + TECHNIQUE, + "arithmetic_mean", + PARAMETERS, + new HashMap<>( + Map.of( + "weights", + Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble(), RandomizedTest.randomDouble()) + ) + ) + ) + ) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("sum of weights for combination must be equal to 1.0")); + } + public void testInputValidation_whenInvalidNormalizationClause_thenFail() { NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),