Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add geometric mean normalization for scores #239

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Abstracts combination of scores based on geometrical mean method
*/
public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "geometric_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
private final ScoreCombinationUtil scoreCombinationUtil;

public GeometricMeanScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
scoreCombinationUtil = combinationUtil;
scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
weights = scoreCombinationUtil.getWeights(params);
}

/**
* Weighted geometric mean method for combining scores.
*
* We use formula below to calculate mean. It's based on fact that logarithm of geometric mean is the
* weighted arithmetic mean of the logarithms of individual scores.
*
* geometric_mean = exp(sum(weight_1*ln(score_1) + .... + weight_n*ln(score_n))/sum(weight_1 + ... + weight_n))
*/
@Override
public float combine(final float[] scores) {
float weightedLnSum = 0;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score <= 0) {
// scores 0.0 need to be skipped, ln() of 0 is not defined
continue;
}
float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery);
sumOfWeights += weight;
weightedLnSum += weight * Math.log(score);
}
return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ public class ScoreCombinationFactory {
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil),
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil)
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil),
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil)
);

/**
Expand Down
98 changes: 97 additions & 1 deletion src/test/java/org/opensearch/neuralsearch/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.opensearch.test.OpenSearchTestCase.randomFloat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Range;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.search.query.QuerySearchResult;

public class TestUtils {

private final static String RELATION_EQUAL_TO = "eq";

/**
* Convert an xContentBuilder to a map
* @param xContentBuilder to produce map from
Expand Down Expand Up @@ -58,7 +66,7 @@ public static float[] createRandomVector(int dimension) {
}

/**
* Assert results of hyrdir query after score normalization and combination
* Assert results of hyrdid query after score normalization and combination
* @param querySearchResults collection of query search results after they processed by normalization processor
*/
public static void assertQueryResultScores(List<QuerySearchResult> querySearchResults) {
Expand Down Expand Up @@ -94,4 +102,92 @@ public static void assertQueryResultScores(List<QuerySearchResult> querySearchRe
.orElse(Float.MAX_VALUE);
assertEquals(0.001f, minScoreScoreFromScoreDocs, 0.0f);
}

/**
* Assert results of hybrid query after score normalization and combination
* @param searchResponseWithWeightsAsMap collection of query search results after they processed by normalization processor
* @param expectedMaxScore expected maximum score
* @param expectedMaxMinusOneScore second highest expected score
* @param expectedMinScore expected minimal score
*/
public static void assertWeightedScores(
Map<String, Object> searchResponseWithWeightsAsMap,
double expectedMaxScore,
double expectedMaxMinusOneScore,
double expectedMinScore
) {
assertNotNull(searchResponseWithWeightsAsMap);
Map<String, Object> totalWeights = getTotalHits(searchResponseWithWeightsAsMap);
assertNotNull(totalWeights.get("value"));
assertEquals(4, totalWeights.get("value"));
assertNotNull(totalWeights.get("relation"));
assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation"));
assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent());
assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f);

List<Double> scoresWeights = new ArrayList<>();
for (Map<String, Object> oneHit : getNestedHits(searchResponseWithWeightsAsMap)) {
scoresWeights.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1)));
// verify the scores are normalized with inclusion of weights
assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001);
assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001);
assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001);
}

/**
* Assert results of hybrid query after score normalization and combination
* @param searchResponseAsMap collection of query search results after they processed by normalization processor
* @param totalExpectedDocQty expected total document quantity
* @param minMaxScoreRange range of scores from min to max inclusive
*/
public static void assertHybridSearchResults(
Map<String, Object> searchResponseAsMap,
int totalExpectedDocQty,
float[] minMaxScoreRange
) {
assertNotNull(searchResponseAsMap);
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(totalExpectedDocQty, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get()));

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<String> ids = new ArrayList<>();
List<Double> scores = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
ids.add((String) oneHit.get("_id"));
scores.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));
// verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range
assertTrue(
Range.between(minMaxScoreRange[0], minMaxScoreRange[1])
.contains(scores.stream().map(Double::floatValue).max(Double::compare).get())
);

// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());
}

private static List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
}

private static Map<String, Object> getTotalHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (Map<String, Object>) hitsMap.get("total");
}

private static Optional<Float> getMaxScore(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -623,4 +625,57 @@ protected void deleteSearchPipeline(final String pipelineId) {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}

/**
* Find all modesl that are currently deployed in the cluster
* @return set of model ids
*/
@SneakyThrows
protected Set<String> findDeployedModels() {

StringBuilder stringBuilderForContentBody = new StringBuilder();
stringBuilderForContentBody.append("{")
.append("\"query\": { \"match_all\": {} },")
.append(" \"_source\": {")
.append(" \"includes\": [\"model_id\"],")
.append(" \"excludes\": [\"content\", \"model_content\"]")
.append("}}");

Response response = makeRequest(
client(),
"POST",
"/_plugins/_ml/models/_search",
null,
toHttpEntity(stringBuilderForContentBody.toString()),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);

String responseBody = EntityUtils.toString(response.getEntity());

Map<String, Object> models = XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false);
Set<String> modelIds = new HashSet<>();
if (Objects.isNull(models) || models.isEmpty()) {
return modelIds;
}

Map<String, Object> hits = (Map<String, Object>) models.get("hits");
List<Map<String, Object>> innerHitsMap = (List<Map<String, Object>>) hits.get("hits");
return innerHitsMap.stream()
.map(hit -> (Map<String, Object>) hit.get("_source"))
.filter(hitsMap -> !Objects.isNull(hitsMap) && hitsMap.containsKey("model_id"))
.map(hitsMap -> (String) hitsMap.get("model_id"))
.collect(Collectors.toSet());
}

/**
* Get the id for model currently deployed in the cluster. If there are no models deployed or it's more than 1 model
* fail on assertion
* @return id of deployed model
*/
protected String getDeployedModelId() {
Set<String> modelIds = findDeployedModels();
assertEquals(1, modelIds.size());
return modelIds.iterator().next();
}

}
Loading