Skip to content

Commit

Permalink
Added validations for score combination weights in Hybrid Search (#265)…
Browse files Browse the repository at this point in the history
… (#268)

* Added strong check on number of weights equals number of sub-queries

Signed-off-by: Martin Gaievski <[email protected]>
(cherry picked from commit 685d5d6)

Co-authored-by: Martin Gaievski <[email protected]>
opensearch-trigger-bot[bot] and martin-gaievski authored Aug 30, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent e564d4d commit 40479e9
Showing 12 changed files with 231 additions and 74 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/))
### Enhancements
* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259))
* Added validations for score combination weights in Hybrid Search ([#265](https://github.com/opensearch-project/neural-search/pull/265))
### Bug Fixes
### Infrastructure
### Documentation
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params,
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float combinedScore = 0.0f;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ public GeometricMeanScoreCombinationTechnique(final Map<String, Object> params,
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float weightedLnSum = 0;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params, f
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float sumOfWeights = 0;
float sumOfHarmonics = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor.combination;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -13,11 +14,19 @@
import java.util.Set;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.Range;

import com.google.common.math.DoubleMath;

/**
* Collection of utility methods for score combination technique classes
*/
@Log4j2
class ScoreCombinationUtil {
private static final String PARAM_NAME_WEIGHTS = "weights";
private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f;

/**
* Get collection of weights based on user provided config
@@ -29,9 +38,11 @@ public List<Float> getWeights(final Map<String, Object> params) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
List<Float> weightsList = ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
validateWeights(weightsList);
return weightsList;
}

/**
@@ -77,4 +88,55 @@ public void validateParams(final Map<String, Object> actualParams, final Set<Str
public float getWeightForSubQuery(final List<Float> weights, final int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}

/**
* Check if number of weights matches number of queries. This does not apply for case when
* weights were not provided, as this is valid default value
* @param scores collection of scores from all sub-queries of a single hybrid search query
* @param weights score combination weights that are defined as part of search result processor
*/
protected void validateIfWeightsMatchScores(final float[] scores, final List<Float> weights) {
if (weights.isEmpty()) {
return;
}
if (scores.length != weights.size()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"number of weights [%d] must match number of sub-queries [%d] in hybrid query",
weights.size(),
scores.length
)
);
}
}

/**
* Check if provided weights are valid for combination. Following conditions are checked:
* - every weight is between 0.0 and 1.0
* - sum of all weights must be equal 1.0
* @param weightsList
*/
private void validateWeights(final List<Float> weightsList) {
boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight));
if (isOutOfRange) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"all weights must be in range [0.0 ... 1.0], submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
}
float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum);
if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"sum of weights for combination must be equal to 1.0, submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
}
}
}
Original file line number Diff line number Diff line change
@@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.processor;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults;
import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores;
import static org.opensearch.neuralsearch.TestUtils.createRandomVector;
@@ -18,6 +20,7 @@

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.ResponseException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;
@@ -96,7 +99,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.4f, 0.3f, 0.3f }))
);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
@@ -112,15 +115,15 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights1AsMap, 0.375, 0.3125, 0.001);
assertWeightedScores(searchResponseWithWeights1AsMap, 0.4, 0.3, 0.001);

// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.233f, 0.666f, 0.1f }))
);

Map<String, Object> searchResponseWithWeights2AsMap = search(
@@ -131,7 +134,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights2AsMap, 0.606, 0.242, 0.001);
assertWeightedScores(searchResponseWithWeights2AsMap, 0.6666, 0.2332, 0.001);

// check case when number of weights is less than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
@@ -140,18 +143,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 1.0f }))
);

Map<String, Object> searchResponseWithWeights3AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
ResponseException exception1 = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE))
);
org.hamcrest.MatcherAssert.assertThat(
exception1.getMessage(),
allOf(
containsString("number of weights"),
containsString("must match number of sub-queries"),
containsString("in hybrid query")
)
);

assertWeightedScores(searchResponseWithWeights3AsMap, 0.357, 0.285, 0.001);

// check case when number of weights is more than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
@@ -160,18 +166,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.25f, 0.25f, 0.2f }))
);

Map<String, Object> searchResponseWithWeights4AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
ResponseException exception2 = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE))
);
org.hamcrest.MatcherAssert.assertThat(
exception2.getMessage(),
allOf(
containsString("number of weights"),
containsString("must match number of sub-queries"),
containsString("in hybrid query")
)
);

assertWeightedScores(searchResponseWithWeights4AsMap, 0.375, 0.3125, 0.001);
}

/**
@@ -199,7 +208,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

@@ -223,7 +232,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
@@ -265,7 +274,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

@@ -289,7 +298,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
Original file line number Diff line number Diff line change
@@ -91,7 +91,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

@@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
@@ -138,7 +138,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
@@ -180,7 +180,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

@@ -204,7 +204,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
@@ -227,7 +227,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
Original file line number Diff line number Diff line change
@@ -12,8 +12,6 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.carrotsearch.randomizedtesting.RandomizedTest;

public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
@@ -33,9 +31,7 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
}

public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
@@ -44,20 +40,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores()
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, -1.0f, 0.6f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Float> scores = List.of(1.0f, 0.0f, 0.6f);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.825f;
float expectedScore = 0.69f;
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Original file line number Diff line number Diff line change
@@ -12,8 +12,6 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.carrotsearch.randomizedtesting.RandomizedTest;

public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
@@ -34,19 +32,17 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {

public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, 0.5f, 0.3f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.5797f;
float expectedScore = 0.5567f;
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
@@ -55,20 +51,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores()
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, -1.0f, 0.6f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Float> scores = List.of(1.0f, 0.0f, 0.6f);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.7997f;
float expectedScore = 0.7863f;
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Original file line number Diff line number Diff line change
@@ -12,8 +12,6 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.carrotsearch.randomizedtesting.RandomizedTest;

public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
@@ -34,30 +32,28 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {

public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, 0.5f, 0.3f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expecteScore = 0.4954f;
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expecteScore);
float expectedScore = 0.48f;
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, -1.0f, 0.6f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Float> scores = List.of(1.0f, 0.0f, 0.6f);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.7741f;
float expectedScore = 0.7611f;
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;

import java.util.List;
import java.util.Map;

import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;

public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase {

public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() {
ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
List<Float> weights = scoreCombinationUtil.getWeights(Map.of());
assertNotNull(weights);
assertTrue(weights.isEmpty());
}

public void testCombinationWeights_whenWeightsArePassed_thenSuccessful() {
ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
List<Float> weights = scoreCombinationUtil.getWeights(Map.of("weights", List.of(0.4, 0.6)));
assertNotNull(weights);
assertEquals(2, weights.size());
assertTrue(weights.containsAll(List.of(0.4f, 0.6f)));
}

public void testCombinationWeights_whenInvalidWeightsArePassed_thenFail() {
ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();

IllegalArgumentException exception1 = expectThrows(
IllegalArgumentException.class,
() -> scoreCombinationUtil.getWeights(Map.of("weights", List.of(2.4)))
);
assertTrue(exception1.getMessage().contains("all weights must be in range"));

IllegalArgumentException exception2 = expectThrows(
IllegalArgumentException.class,
() -> scoreCombinationUtil.getWeights(Map.of("weights", List.of(0.4, 0.5, 0.6)))
);
assertTrue(exception2.getMessage().contains("sum of weights for combination must be equal to 1.0"));
}

public void testWeightsValidation_whenNumberOfScoresDifferentFromNumberOfWeights_thenFail() {
ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
IllegalArgumentException exception1 = expectThrows(
IllegalArgumentException.class,
() -> scoreCombinationUtil.validateIfWeightsMatchScores(new float[] { 0.6f, 0.5f }, List.of(0.4f, 0.2f, 0.4f))
);
org.hamcrest.MatcherAssert.assertThat(
exception1.getMessage(),
allOf(
containsString("number of weights"),
containsString("must match number of sub-queries"),
containsString("in hybrid query")
)
);
}
}
Original file line number Diff line number Diff line change
@@ -132,17 +132,14 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful(
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;
double weight1 = RandomizedTest.randomDouble();
double weight2 = 1.0f - weight1;
Map<String, Object> config = new HashMap<>();
config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max")));
config.put(
COMBINATION_CLAUSE,
new HashMap<>(
Map.of(
TECHNIQUE,
"arithmetic_mean",
PARAMETERS,
new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble())))
)
Map.of(TECHNIQUE, "arithmetic_mean", PARAMETERS, new HashMap<>(Map.of("weights", Arrays.asList(weight1, weight2))))
)
);
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
@@ -160,6 +157,43 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful(
assertEquals("normalization-processor", normalizationProcessor.getType());
}

@SneakyThrows
public void testWeightsParams_whenInvalidValues_thenFail() {
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),
new ScoreNormalizationFactory(),
new ScoreCombinationFactory()
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;
Map<String, Object> config = new HashMap<>();
config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max")));
config.put(
COMBINATION_CLAUSE,
new HashMap<>(
Map.of(
TECHNIQUE,
"arithmetic_mean",
PARAMETERS,
new HashMap<>(
Map.of(
"weights",
Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble(), RandomizedTest.randomDouble())
)
)
)
)
);
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext)
);
assertTrue(exception.getMessage().contains("sum of weights for combination must be equal to 1.0"));
}

public void testInputValidation_whenInvalidNormalizationClause_thenFail() {
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),

0 comments on commit 40479e9

Please sign in to comment.