Skip to content

Commit

Permalink
Added strong check on number of weights equals number of sub-queries
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 28, 2023
1 parent d12f480 commit 6624246
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params,
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float combinedScore = 0.0f;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public GeometricMeanScoreCombinationTechnique(final Map<String, Object> params,
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float weightedLnSum = 0;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params, f
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float sumOfWeights = 0;
float sumOfHarmonics = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor.combination;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand All @@ -13,11 +14,19 @@
import java.util.Set;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.Range;

import com.google.common.math.DoubleMath;

/**
* Collection of utility methods for score combination technique classes
*/
@Log4j2
class ScoreCombinationUtil {
private static final String PARAM_NAME_WEIGHTS = "weights";
private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f;

/**
* Get collection of weights based on user provided config
Expand All @@ -29,9 +38,11 @@ public List<Float> getWeights(final Map<String, Object> params) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
List<Float> weightsList = ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
validateWeights(weightsList);
return weightsList;
}

/**
Expand Down Expand Up @@ -77,4 +88,58 @@ public void validateParams(final Map<String, Object> actualParams, final Set<Str
public float getWeightForSubQuery(final List<Float> weights, final int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}

/**
* Check if number of weights matches number of queries. This deos no apply for case when
* weights were not provided. This is default setting and in such case every sub-query has equal weight
* @param scores
* @param weights
*/
protected void validateIfWeightsMatchScores(final float[] scores, final List<Float> weights) {
if (weights.isEmpty()) {
return;
}
if (scores.length != weights.size()) {
log.error(
String.format(
Locale.ROOT,
"number of weights [%d] must match number of sub-queries [%d] in hybrid query",
weights.size(),
scores.length
)
);
throw new IllegalArgumentException("number of weights must match number of sub-queries in hybrid query");
}
}

/**
* Check if provided weights are valid for combination. Following is checked:
* - every weight is between 0.0 and 1.0
* - sum of all weights must be equal 1.0
* @param weightsList
*/
private void validateWeights(final List<Float> weightsList) {
boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight));
if (isOutOfRange) {
log.error(
String.format(
Locale.ROOT,
"all weights must be in range [0.0 ... 1.0], submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
throw new IllegalArgumentException("all weights must be in range [0.0 ... 1.0]");
}
float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum);
if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) {
log.error(
String.format(
Locale.ROOT,
"sum of weights for combination must be equal to 1.0, submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
throw new IllegalArgumentException("sum of weights for combination must be equal to 1.0");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.ResponseException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;
Expand Down Expand Up @@ -96,7 +97,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.4f, 0.3f, 0.3f }))
);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
Expand All @@ -120,7 +121,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.233f, 0.666f, 0.1f }))
);

Map<String, Object> searchResponseWithWeights2AsMap = search(
Expand All @@ -140,18 +141,14 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 1.0f }))
);

Map<String, Object> searchResponseWithWeights3AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
ResponseException exception1 = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE))
);

assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001);
assertTrue(exception1.getMessage().contains("number of weights must match number of sub-queries in hybrid query"));

// check case when number of weights is more than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
Expand All @@ -160,18 +157,14 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.25f, 0.25f, 0.2f }))
);

Map<String, Object> searchResponseWithWeights4AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
ResponseException exception2 = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE))
);

assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001);
assertTrue(exception2.getMessage().contains("number of weights must match number of sub-queries in hybrid query"));
}

/**
Expand Down Expand Up @@ -199,7 +192,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -223,7 +216,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
Expand Down Expand Up @@ -265,7 +258,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -289,7 +282,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
Expand All @@ -138,7 +138,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
Expand Down Expand Up @@ -180,7 +180,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -204,7 +204,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
Expand All @@ -227,7 +227,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.carrotsearch.randomizedtesting.RandomizedTest;

public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
Expand All @@ -33,9 +31,7 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
}

public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Expand All @@ -44,20 +40,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores()
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, -1.0f, 0.6f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Float> scores = List.of(1.0f, 0.0f, 0.6f);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.825f;
float expectedScore = 0.69f;
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.carrotsearch.randomizedtesting.RandomizedTest;

public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
Expand All @@ -34,19 +32,17 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {

public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, 0.5f, 0.3f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.5797f;
float expectedScore = 0.5567f;
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Expand All @@ -55,20 +51,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores()
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, -1.0f, 0.6f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Float> scores = List.of(1.0f, 0.0f, 0.6f);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.7997f;
float expectedScore = 0.7863f;
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Expand Down
Loading

0 comments on commit 6624246

Please sign in to comment.