diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java index 4f6c6f0cf..fe0d962ca 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -76,17 +76,23 @@ public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } + /** + * Verify score correctness by using alternative formula for geometric mean as n-th root of product of weighted scores, + * more details in here https://en.wikipedia.org/wiki/Weighted_geometric_mean + */ private float geometricMean(List scores, List weights) { - assertEquals(scores.size(), weights.size()); - float sumOfWeights = 0; - float weightedSumOfLn = 0; - for (int i = 0; i < scores.size(); i++) { - float score = scores.get(i), weight = weights.get(i).floatValue(); - if (score > 0) { - sumOfWeights += weight; - weightedSumOfLn += weight * Math.log(score); + float product = 1.0f; + float sumOfWeights = 0.0f; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.size(); indexOfSubQuery++) { + float score = scores.get(indexOfSubQuery); + if (score <= 0) { + // scores 0.0 need to be skipped, ln() of 0 is not defined + continue; } + float weight = weights.get(indexOfSubQuery).floatValue(); + product *= Math.pow(score, weight); + sumOfWeights += weight; } - return sumOfWeights == 0 ? 0f : (float) Math.exp(weightedSumOfLn / sumOfWeights); + return sumOfWeights == 0 ? 0f : (float) Math.pow(product, (float) 1 / sumOfWeights); } }