Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added validations for score combination weights in Hybrid Search #265

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/))
### Enhancements
* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259))
* Added validations for score combination weights in Hybrid Search ([#265](https://github.com/opensearch-project/neural-search/pull/265))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params,
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float combinedScore = 0.0f;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public GeometricMeanScoreCombinationTechnique(final Map<String, Object> params,
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float weightedLnSum = 0;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params, f
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float sumOfWeights = 0;
float sumOfHarmonics = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor.combination;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand All @@ -13,11 +14,19 @@
import java.util.Set;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.Range;

import com.google.common.math.DoubleMath;

/**
* Collection of utility methods for score combination technique classes
*/
@Log4j2
class ScoreCombinationUtil {
private static final String PARAM_NAME_WEIGHTS = "weights";
private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f;

/**
* Get collection of weights based on user provided config
Expand All @@ -29,9 +38,11 @@ public List<Float> getWeights(final Map<String, Object> params) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
List<Float> weightsList = ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
validateWeights(weightsList);
return weightsList;
}

/**
Expand Down Expand Up @@ -77,4 +88,55 @@ public void validateParams(final Map<String, Object> actualParams, final Set<Str
public float getWeightForSubQuery(final List<Float> weights, final int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}

/**
* Check if number of weights matches number of queries. This does not apply for case when
* weights were not provided, as this is valid default value
* @param scores collection of scores from all sub-queries of a single hybrid search query
* @param weights score combination weights that are defined as part of search result processor
*/
protected void validateIfWeightsMatchScores(final float[] scores, final List<Float> weights) {
if (weights.isEmpty()) {
return;
}
if (scores.length != weights.size()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"number of weights [%d] must match number of sub-queries [%d] in hybrid query",
weights.size(),
scores.length
)
);
}
}

/**
* Check if provided weights are valid for combination. Following conditions are checked:
* - every weight is between 0.0 and 1.0
* - sum of all weights must be equal 1.0
* @param weightsList
*/
private void validateWeights(final List<Float> weightsList) {
boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight));
if (isOutOfRange) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"all weights must be in range [0.0 ... 1.0], submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
}
float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum);
if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"sum of weights for combination must be equal to 1.0, submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.processor;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults;
import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores;
import static org.opensearch.neuralsearch.TestUtils.createRandomVector;
Expand All @@ -18,6 +20,7 @@

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.ResponseException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;
Expand Down Expand Up @@ -96,7 +99,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.4f, 0.3f, 0.3f }))
);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
Expand All @@ -112,15 +115,15 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights1AsMap, 0.375, 0.3125, 0.001);
assertWeightedScores(searchResponseWithWeights1AsMap, 0.4, 0.3, 0.001);

// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.233f, 0.666f, 0.1f }))
);

Map<String, Object> searchResponseWithWeights2AsMap = search(
Expand All @@ -131,7 +134,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights2AsMap, 0.606, 0.242, 0.001);
assertWeightedScores(searchResponseWithWeights2AsMap, 0.6666, 0.2332, 0.001);

// check case when number of weights is less than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
Expand All @@ -140,18 +143,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 1.0f }))
);

Map<String, Object> searchResponseWithWeights3AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
ResponseException exception1 = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE))
);
org.hamcrest.MatcherAssert.assertThat(
exception1.getMessage(),
allOf(
containsString("number of weights"),
containsString("must match number of sub-queries"),
containsString("in hybrid query")
)
);

assertWeightedScores(searchResponseWithWeights3AsMap, 0.357, 0.285, 0.001);

// check case when number of weights is more than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
Expand All @@ -160,18 +166,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.25f, 0.25f, 0.2f }))
);

Map<String, Object> searchResponseWithWeights4AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
ResponseException exception2 = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE))
);
org.hamcrest.MatcherAssert.assertThat(
exception2.getMessage(),
allOf(
containsString("number of weights"),
containsString("must match number of sub-queries"),
containsString("in hybrid query")
)
);

assertWeightedScores(searchResponseWithWeights4AsMap, 0.375, 0.3125, 0.001);
}

/**
Expand Down Expand Up @@ -199,7 +208,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -223,7 +232,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
Expand Down Expand Up @@ -265,7 +274,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -289,7 +298,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
Expand All @@ -138,7 +138,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
Expand Down Expand Up @@ -180,7 +180,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -204,7 +204,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
Expand All @@ -227,7 +227,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.carrotsearch.randomizedtesting.RandomizedTest;

public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
Expand All @@ -33,9 +31,7 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
}

public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Expand All @@ -44,20 +40,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores()
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, -1.0f, 0.6f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Float> scores = List.of(1.0f, 0.0f, 0.6f);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.825f;
float expectedScore = 0.69f;
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Expand Down
Loading