From 0effd0741818683ba8ad36f158729ac246b0a639 Mon Sep 17 00:00:00 2001 From: Samuel Herman Date: Sat, 21 Oct 2023 11:35:39 -0700 Subject: [PATCH] review feedback Signed-off-by: Samuel Herman --- .../NormalizationProcessorWorkflow.java | 1 - .../ZScoreNormalizationTechnique.java | 62 +++---- .../common/BaseNeuralSearchIT.java | 15 ++ .../HybridQueryZScoreIT.java | 70 ++++---- .../processor/NormalizationProcessorIT.java | 16 -- .../ZScoreNormalizationTechniqueTests.java | 164 +++++++++--------- .../neuralsearch/query/HybridQueryIT.java | 16 -- 7 files changed, 162 insertions(+), 182 deletions(-) rename src/test/java/org/opensearch/neuralsearch/{query => processor}/HybridQueryZScoreIT.java (83%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index d5e898185..9e0069b21 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -52,7 +52,6 @@ public void execute( final ScoreNormalizationTechnique normalizationTechnique, final ScoreCombinationTechnique combinationTechnique ) { - log.info("Entering normalization processor workflow"); // save original state List unprocessedDocIds = unprocessedDocIds(querySearchResults); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java index bc2ae9a7b..fc97e8a4b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java @@ -5,15 +5,18 @@ package org.opensearch.neuralsearch.processor.normalization; -import com.google.common.primitives.Floats; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + import lombok.ToString; + import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; +import com.google.common.primitives.Floats; /** * Implementation of z-score normalization technique for hybrid query @@ -24,24 +27,26 @@ TODO: Some todo items that apply here but also on the original normalization techniques on which it is modeled {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique} 1. Random access to abstract list object is a bad practice both stylistically and from performance perspective and should be removed 2. Identical sub queries and their distribution between shards is currently completely implicit based on ordering and should be explicit based on identifier -3. Weird calculation of numOfSubQueries instead of having a more explicit indicator +3. Implicit calculation of numOfSubQueries instead of having a more explicit upstream indicator/metadata regarding it */ @ToString(onlyExplicitlyIncluded = true) public class ZScoreNormalizationTechnique implements ScoreNormalizationTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "z_score"; private static final float SINGLE_RESULT_SCORE = 1.0f; + @Override - public void normalize(List queryTopDocs) { - // why are we doing that? is List the list of subqueries for a single shard? or a global list of all subqueries across shards? - // If a subquery comes from each shard then when is it combined? that seems weird that combination will do combination of normalized results that each is normalized just based on shard level result - int numOfSubQueries = queryTopDocs.stream() - .filter(Objects::nonNull) - .filter(topDocs -> topDocs.getTopDocs().size() > 0) - .findAny() - .get() - .getTopDocs() - .size(); + public void normalize(final List queryTopDocs) { + /* + TODO: There is an implicit assumption in this calculation that probably need to be made clearer by passing some metadata with the results. + Currently assuming that finding a single non empty shard result will contain all sub query results with 0 hits. + */ + final Optional maybeCompoundTopDocs = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> topDocs.getTopDocs().size() > 0) + .findAny(); + + final int numOfSubQueries = maybeCompoundTopDocs.map(compoundTopDocs -> compoundTopDocs.getTopDocs().size()).orElse(0); // to be done for each subquery float[] sumPerSubquery = findScoreSumPerSubQuery(queryTopDocs, numOfSubQueries); @@ -67,9 +72,7 @@ public void normalize(List queryTopDocs) { static private float[] findScoreSumPerSubQuery(final List queryTopDocs, final int numOfScores) { final float[] sumOfScorePerSubQuery = new float[numOfScores]; Arrays.fill(sumOfScorePerSubQuery, 0); - //TODO: make this better, currently - // this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) - // which does a random search on an abstract list type. + // TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -86,9 +89,7 @@ static private float[] findScoreSumPerSubQuery(final List query static private long[] findNumberOfElementsPerSubQuery(final List queryTopDocs, final int numOfScores) { final long[] numberOfElementsPerSubQuery = new long[numOfScores]; Arrays.fill(numberOfElementsPerSubQuery, 0); - //TODO: make this better, currently - // this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) - // which does a random search on an abstract list type. + // TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -108,21 +109,22 @@ static private float[] findMeanPerSubquery(final float[] sumPerSubquery, final l if (elementsPerSubquery[i] == 0) { meanPerSubQuery[i] = 0; } else { - meanPerSubQuery[i] = sumPerSubquery[i]/elementsPerSubquery[i]; + meanPerSubQuery[i] = sumPerSubquery[i] / elementsPerSubquery[i]; } } return meanPerSubQuery; } - static private float[] findStdPerSubquery(final List queryTopDocs, final float[] meanPerSubQuery, final long[] elementsPerSubquery, final int numOfScores) { + static private float[] findStdPerSubquery( + final List queryTopDocs, + final float[] meanPerSubQuery, + final long[] elementsPerSubquery, + final int numOfScores + ) { final double[] deltaSumPerSubquery = new double[numOfScores]; Arrays.fill(deltaSumPerSubquery, 0); - - - //TODO: make this better, currently - // this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) - // which does a random search on an abstract list type. + // TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -147,7 +149,7 @@ static private float[] findStdPerSubquery(final List queryTopDo return stdPerSubQuery; } - static private float sumScoreDocsArray(ScoreDoc[] scoreDocs) { + static private float sumScoreDocsArray(final ScoreDoc[] scoreDocs) { float sum = 0; for (ScoreDoc scoreDoc : scoreDocs) { sum += scoreDoc.score; @@ -161,6 +163,6 @@ private static float normalizeSingleScore(final float score, final float standar if (Floats.compare(mean, score) == 0) { return SINGLE_RESULT_SCORE; } - return (score - mean) / standardDeviation; + return (score - mean) / standardDeviation; } } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 9c24e81fd..e6265724f 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -760,4 +760,19 @@ private String registerModelGroup() { assertNotNull(modelGroupId); return modelGroupId; } + + protected List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + protected Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + protected Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryZScoreIT.java b/src/test/java/org/opensearch/neuralsearch/processor/HybridQueryZScoreIT.java similarity index 83% rename from src/test/java/org/opensearch/neuralsearch/query/HybridQueryZScoreIT.java rename to src/test/java/org/opensearch/neuralsearch/processor/HybridQueryZScoreIT.java index d197fc710..23db97fe2 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryZScoreIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/HybridQueryZScoreIT.java @@ -3,10 +3,17 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.query; +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.*; +import java.util.stream.IntStream; -import com.google.common.primitives.Floats; import lombok.SneakyThrows; + import org.junit.After; import org.junit.Before; import org.opensearch.index.query.BoolQueryBuilder; @@ -15,13 +22,10 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; import org.opensearch.neuralsearch.processor.normalization.ZScoreNormalizationTechnique; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; -import java.io.IOException; -import java.util.*; -import java.util.stream.IntStream; - -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import com.google.common.primitives.Floats; public class HybridQueryZScoreIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; @@ -46,7 +50,12 @@ public void setUp() throws Exception { super.setUp(); updateClusterSettings(); prepareModel(); - createSearchPipeline(SEARCH_PIPELINE, ZScoreNormalizationTechnique.TECHNIQUE_NAME, DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, "[0.5,0.5]")); + createSearchPipeline( + SEARCH_PIPELINE, + ZScoreNormalizationTechnique.TECHNIQUE_NAME, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, "[0.5,0.5]") + ); } @After @@ -114,25 +123,24 @@ public void testComplexQuery_withZScoreNormalization() { String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( - TEST_KNN_VECTOR_FIELD_NAME_1, - TEST_QUERY_TEXT, - modelId, - 5, - null, - null + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + modelId, + 5, + null, + null ); HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); hybridQueryBuilderNeuralThenTerm.add(neuralQueryBuilder); hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); - final Map searchResponseAsMap = search( - TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, - hybridQueryBuilderNeuralThenTerm, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) ); assertEquals(2, getHitCount(searchResponseAsMap)); @@ -146,10 +154,11 @@ public void testComplexQuery_withZScoreNormalization() { } assertEquals(2, scores.size()); - // by design when there are only two results with z score since it's z-score normalized we would expect 1 , -1 to be the corresponding score, + // by design when there are only two results with z score since it's z-score normalized we would expect 1 , -1 to be the + // corresponding score, // furthermore the combination logic with weights should make it doc1Score: (1 * w1 + 0.98 * w2)/(w1 + w2), doc2Score: -1 ~ 0 assertEquals(0.9999, scores.get(0).floatValue(), DELTA_FOR_SCORE_ASSERTION); - assertEquals(0 , scores.get(1).floatValue(), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0, scores.get(1).floatValue(), DELTA_FOR_SCORE_ASSERTION); // verify that scores are in desc order assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); @@ -193,19 +202,4 @@ private void initializeIndexIfNotExist() throws IOException { assertEquals(2, getDocCount(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)); } } - - private List> getNestedHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (List>) hitsMap.get("hits"); - } - - private Map getTotalHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (Map) hitsMap.get("total"); - } - - private Optional getMaxScore(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 3cd71e5a1..79db226e1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -12,7 +12,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; @@ -341,21 +340,6 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { } } - private List> getNestedHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (List>) hitsMap.get("hits"); - } - - private Map getTotalHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (Map) hitsMap.get("total"); - } - - private Optional getMaxScore(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); - } - private void assertQueryResults(Map searchResponseAsMap, int totalExpectedDocQty, boolean assertMinScore) { assertQueryResults(searchResponseAsMap, totalExpectedDocQty, assertMinScore, Range.between(0.5f, 1.0f)); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java index 1d0c61373..45e350dbb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java @@ -5,83 +5,88 @@ package org.opensearch.neuralsearch.processor.normalization; +import java.util.List; + import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import java.util.List; - public class ZScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; + /** + * Z score will check the relative distance from the center of distribution and hence can also be negative. + * When only two values are available their z-score numbers will be 1 and -1 correspondingly. + * For more information regarding z-score you can check this link + * https://www.z-table.com/ + * + */ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); List compoundTopDocs = List.of( - new CompoundTopDocs( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } - ) - ) + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ) ) + ) ); normalizationTechnique.normalize(compoundTopDocs); + // since we only have two scores of 0.5 and 0.2 their z-score numbers will be 1 and -1 CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } - ) - ) + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) }) + ) ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); assertNotNull(compoundTopDocs.get(0).getTopDocs()); assertCompoundTopDocs( - new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), - compoundTopDocs.get(0).getTopDocs().get(0) + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) ); } public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); List compoundTopDocs = List.of( - new CompoundTopDocs( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } - ), - new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), - new TopDocs( - new TotalHits(3, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } - ) - ) + new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } + ) ) + ) ); normalizationTechnique.normalize(compoundTopDocs); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( - new TotalHits(3, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } - ), - new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), - new TopDocs( - new TotalHits(3, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } - ) + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } ) + ) ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -94,57 +99,54 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); List compoundTopDocs = List.of( - new CompoundTopDocs( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } - ), - new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), - new TopDocs( - new TotalHits(3, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } - ) - ) - ), - new CompoundTopDocs( + new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), - new TopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } - ) - ) + new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } + ) ) + ) ); normalizationTechnique.normalize(compoundTopDocs); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( - new TotalHits(3, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } - ), - new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), - new TopDocs( - new TotalHits(3, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } - ) + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } ) + ) ); CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), - new TopDocs( - new TotalHits(2, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, -1.0f) } - ) - ) + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, -1.0f) }) + ) ); assertNotNull(compoundTopDocs); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index eec6955ff..229374730 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -13,7 +13,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; @@ -267,19 +266,4 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME)); } } - - private List> getNestedHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (List>) hitsMap.get("hits"); - } - - private Map getTotalHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (Map) hitsMap.get("total"); - } - - private Optional getMaxScore(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); - } }