From 71dc4651d5447b029aff98da8bd54c05c010609e Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 27 Dec 2023 10:44:14 -0800 Subject: [PATCH 1/6] Allow multiple identical sub-queries in hybrid query, removed validation for total hits Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/query/HybridQueryScorer.java | 29 ++++++--- .../search/HitsThresholdChecker.java | 3 - .../query/HybridQueryWeightTests.java | 61 ++++++++++++++++++- .../search/HitsTresholdCheckerTests.java | 10 ++- .../HybridTopScoreDocCollectorTests.java | 51 ++++++++++++++++ 6 files changed, 141 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f02272c4..6522d54f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Multiple identical subqueries in Hybrid query ([#524](https://github.com/opensearch-project/neural-search/pull/524)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 109a50d05..ca2f06fc3 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -37,7 +38,7 @@ public final class HybridQueryScorer extends Scorer { private final float[] subScores; - private final Map queryToIndex; + private final Map> queryToIndex; public HybridQueryScorer(Weight weight, List subScorers) throws IOException { super(weight); @@ -111,24 +112,34 @@ public float[] hybridScores() throws IOException { DisiWrapper topList = subScorersPQ.topList(); for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue - if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { + Scorer scorer = disiWrapper.scorer; + if (scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { continue; } - float subScore = disiWrapper.scorer.score(); - scores[queryToIndex.get(disiWrapper.scorer.getWeight().getQuery())] = subScore; + Query query = scorer.getWeight().getQuery(); + List indexes = queryToIndex.get(query); + // we need to find the index of first sub-query that hasn't been updated yet + int index = indexes.stream() + .mapToInt(idx -> idx) + .filter(index1 -> Float.compare(scores[index1], 0.0f) == 0) + .findFirst() + .orElseThrow(() -> new IllegalStateException("cannot collect score for subquery")); + scores[index] = scorer.score(); } return scores; } - private Map mapQueryToIndex() { - Map queryToIndex = new HashMap<>(); + private Map> mapQueryToIndex() { + Map> queryToIndex = new HashMap<>(); int idx = 0; for (Scorer scorer : subScorers) { if (scorer == null) { idx++; continue; } - queryToIndex.put(scorer.getWeight().getQuery(), idx); + Query query = scorer.getWeight().getQuery(); + queryToIndex.putIfAbsent(query, new ArrayList<>()); + queryToIndex.get(query).add(idx); idx++; } return queryToIndex; @@ -137,7 +148,9 @@ private Map mapQueryToIndex() { private DisiPriorityQueue initializeSubScorersPQ() { Objects.requireNonNull(queryToIndex, "should not be null"); Objects.requireNonNull(subScorers, "should not be null"); - DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(queryToIndex.size()); + // we need to count this way in order to include all identical sub-queries + int numOfSubQueries = queryToIndex.values().stream().map(List::size).reduce(0, Integer::sum); + DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numOfSubQueries); for (Scorer scorer : subScorers) { if (scorer == null) { continue; diff --git a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java index c8c52320f..dea9c6bae 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java @@ -24,9 +24,6 @@ public HitsThresholdChecker(int totalHitsThreshold) { if (totalHitsThreshold < 0) { throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be >= 0, got %d", totalHitsThreshold)); } - if (totalHitsThreshold == Integer.MAX_VALUE) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be less than max integer value")); - } this.totalHitsThreshold = totalHitsThreshold; } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 50198af46..0b9af2bcd 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -38,7 +38,10 @@ public class HybridQueryWeightTests extends OpenSearchQueryTestCase { - static final String TERM_QUERY_TEXT = "keyword"; + private static final String TERM_QUERY_TEXT = "keyword"; + private static final String RANGE_FIELD = "date _range"; + private static final String FROM_TEXT = "123"; + private static final String TO_TEXT = "456"; @SneakyThrows public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { @@ -87,6 +90,62 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { directory.close(); } + @SneakyThrows + public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId, TERM_QUERY_TEXT, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), + QueryBuilders.rangeQuery(RANGE_FIELD) + .from(FROM_TEXT) + .to(TO_TEXT) + .rewrite(mockQueryShardContext) + .rewrite(mockQueryShardContext) + .toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) + ) + ); + IndexSearcher searcher = newSearcher(reader); + Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); + + assertNotNull(weight); + + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + Scorer scorer = weight.scorer(leafReaderContext); + + assertNotNull(scorer); + + DocIdSetIterator iterator = scorer.iterator(); + int actualDoc = iterator.nextDoc(); + int actualDocId = Integer.parseInt(reader.document(actualDoc).getField("id").stringValue()); + + assertEquals(docId, actualDocId); + + assertTrue(weight.isCacheable(leafReaderContext)); + + Matches matches = weight.matches(leafReaderContext, actualDoc); + MatchesIterator matchesIterator = matches.getMatches(TEXT_FIELD_NAME); + assertTrue(matchesIterator.next()); + + w.close(); + reader.close(); + directory.close(); + } + @SneakyThrows public void testExplain_whenCallExplain_thenFail() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java b/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java index 0a6a12c88..df1198232 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java @@ -27,7 +27,13 @@ public void testTresholdLimit_whenThresholdNegative_thenFail() { expectThrows(IllegalArgumentException.class, () -> new HitsThresholdChecker(-1)); } - public void testTresholdLimit_whenThresholdMaxValue_thenFail() { - expectThrows(IllegalArgumentException.class, () -> new HitsThresholdChecker(Integer.MAX_VALUE)); + public void testTrackThreshold_whenTrackThresholdSet_thenSuccessful() { + HitsThresholdChecker hitsThresholdChecker = new HitsThresholdChecker(Integer.MAX_VALUE); + assertEquals(ScoreMode.TOP_SCORES, hitsThresholdChecker.scoreMode()); + assertFalse(hitsThresholdChecker.isThresholdReached()); + hitsThresholdChecker.incrementHitCount(); + assertFalse(hitsThresholdChecker.isThresholdReached()); + IntStream.rangeClosed(1, 5).forEach((checker) -> hitsThresholdChecker.incrementHitCount()); + assertFalse(hitsThresholdChecker.isThresholdReached()); } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index 06bbfc416..72cf8be49 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -349,4 +349,55 @@ public void testTopDocs_whenMatchedDocsDifferentForEachSubQuery_thenSuccessful() reader.close(); directory.close(); } + + @SneakyThrows + public void testTrackTotalHits_whenTotalHitsSetIntegerMaxValue_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + int[] docIds = new int[] { DOC_ID_1, DOC_ID_2, DOC_ID_3 }; + Arrays.sort(docIds); + final List scores = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList()); + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList(scorer(docIds, scores, fakeWeight(new MatchAllDocsQuery()))) + ); + + leafCollector.setScorer(hybridQueryScorer); + List hybridScores = new ArrayList<>(); + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + int nextDoc = iterator.nextDoc(); + while (nextDoc != NO_MORE_DOCS) { + hybridScores.add(hybridQueryScorer.hybridScores()); + nextDoc = iterator.nextDoc(); + } + // assert + assertEquals(3, hybridScores.size()); + assertFalse(hybridScores.stream().anyMatch(score -> score[0] <= 0.0)); + + w.close(); + reader.close(); + directory.close(); + } } From e40477522bf3c53199054ddbcc982e262288cee1 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 27 Dec 2023 22:47:51 -0800 Subject: [PATCH 2/6] Add check for fetch and query result sizes Signed-off-by: Martin Gaievski --- .../processor/NormalizationProcessor.java | 39 +++- .../NormalizationProcessorWorkflow.java | 11 +- .../NormalizationProcessorTests.java | 175 ++++++++++++++++++ 3 files changed, 222 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 0d6742dbe..6126a3c56 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -7,6 +7,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -18,6 +19,8 @@ import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.SearchContext; @@ -98,7 +101,16 @@ private boolean shouldSkipProcessor(SearchPha } QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult; - return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); + if (queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery)) { + return true; + } + List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); + Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); + if (shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(querySearchResults, fetchSearchResult)) { + log.debug("Query and fetch results do not match, normalization processor is skipped"); + return true; + } + return false; } /** @@ -131,4 +143,29 @@ private Optional getFetchS Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); } + + private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults( + final List querySearchResults, + final Optional fetchSearchResultOptional + ) { + if (fetchSearchResultOptional.isEmpty()) { + return false; + } + final List docIds = unprocessedDocIds(querySearchResults); + SearchHits searchHits = fetchSearchResultOptional.get().hits(); + SearchHit[] searchHitArray = searchHits.getHits(); + // validate the both collections are of the same size + if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) { + return true; + } + return false; + } + + private List unprocessedDocIds(final List querySearchResults) { + return querySearchResults.isEmpty() + ? List.of() + : Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.doc) + .collect(Collectors.toList()); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index b8bc86de5..55ec63631 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -173,8 +173,15 @@ private SearchHit[] getSearchHits(final List docIds, final FetchSearchR SearchHits searchHits = fetchSearchResult.hits(); SearchHit[] searchHitArray = searchHits.getHits(); // validate the both collections are of the same size - if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) { - throw new IllegalStateException("Score normalization processor cannot produce final query result"); + if (Objects.isNull(searchHitArray)) { + throw new IllegalStateException( + "Score normalization processor cannot produce final query result, for one shard case fetch does not have any results" + ); + } + if (searchHitArray.length != docIds.size()) { + throw new IllegalStateException( + "Score normalization processor cannot produce final query result, for one shard case number of fetched documents does not match number of search hits" + ); } return searchHitArray; } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 41348ec49..46ee122d1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -15,6 +15,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -45,9 +46,13 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.pipeline.PipelineAggregator; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.fetch.QueryFetchSearchResult; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -325,4 +330,174 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); } + + public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(4), + createDelimiterElementForHybridSearchResults(4), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(4) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + fetchSearchResult.setShardIndex(shardId); + fetchSearchResult.setSearchShardTarget(searchShardTarget); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(2, "1", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(10, "3", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()) }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); + queryFetchSearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + List querySearchResults = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + + TestUtils.assertQueryResultScores(querySearchResults); + verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + } + + public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNormalization() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(4), + createDelimiterElementForHybridSearchResults(4), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(4) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + fetchSearchResult.setShardIndex(shardId); + fetchSearchResult.setSearchShardTarget(searchShardTarget); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(2, "1", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(10, "3", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(5, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); + queryFetchSearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + List querySearchResults = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + + assertNotNull(querySearchResults); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + } } From 68018441318fa3c5df989f43f8edc4c82c7c29d0 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 28 Dec 2023 12:15:44 -0800 Subject: [PATCH 3/6] Address Navneets comments Signed-off-by: Martin Gaievski --- CHANGELOG.md | 2 +- .../processor/NormalizationProcessor.java | 15 +++- .../NormalizationProcessorWorkflow.java | 18 +---- .../neuralsearch/query/HybridQueryScorer.java | 17 ++++- .../NormalizationProcessorWorkflowTests.java | 16 ++--- .../neuralsearch/query/HybridQueryIT.java | 69 +++++++++++++++++++ 6 files changed, 108 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6522d54f2..c9896ff13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes -- Multiple identical subqueries in Hybrid query ([#524](https://github.com/opensearch-project/neural-search/pull/524)) +- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 6126a3c56..657b5c30c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -9,6 +9,7 @@ import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -155,7 +156,19 @@ private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults( SearchHits searchHits = fetchSearchResultOptional.get().hits(); SearchHit[] searchHitArray = searchHits.getHits(); // validate the both collections are of the same size - if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) { + if (Objects.isNull(searchHitArray)) { + log.info("array of search hits in fetch phase results is null"); + return true; + } + if (searchHitArray.length != docIds.size()) { + log.info( + String.format( + Locale.ROOT, + "number of documents in fetch results [%d] and query results [%d] is different", + searchHitArray.length, + docIds.size() + ) + ); return true; } return false; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 55ec63631..5929370be 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -139,7 +139,7 @@ private void updateOriginalFetchResults( // 3. update original scores to normalized and combined values // 4. order scores based on normalized and combined values FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get(); - SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult); + SearchHit[] searchHitArray = getSearchHits(fetchSearchResult); // create map of docId to index of search hits. This solves (2), duplicates are from // delimiter and start/stop elements, they all have same valid doc_id. For this map @@ -169,21 +169,9 @@ private void updateOriginalFetchResults( fetchSearchResult.hits(updatedSearchHits); } - private SearchHit[] getSearchHits(final List docIds, final FetchSearchResult fetchSearchResult) { + private SearchHit[] getSearchHits(final FetchSearchResult fetchSearchResult) { SearchHits searchHits = fetchSearchResult.hits(); - SearchHit[] searchHitArray = searchHits.getHits(); - // validate the both collections are of the same size - if (Objects.isNull(searchHitArray)) { - throw new IllegalStateException( - "Score normalization processor cannot produce final query result, for one shard case fetch does not have any results" - ); - } - if (searchHitArray.length != docIds.size()) { - throw new IllegalStateException( - "Score normalization processor cannot produce final query result, for one shard case number of fetched documents does not match number of search hits" - ); - } - return searchHitArray; + return searchHits.getHits(); } private List unprocessedDocIds(final List querySearchResults) { diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index ca2f06fc3..57ad4451f 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.query; +import static java.util.Locale.ROOT; + import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -118,12 +120,21 @@ public float[] hybridScores() throws IOException { } Query query = scorer.getWeight().getQuery(); List indexes = queryToIndex.get(query); - // we need to find the index of first sub-query that hasn't been updated yet + // we need to find the index of first sub-query that hasn't been set yet. Such score will have initial value of "0.0" int index = indexes.stream() .mapToInt(idx -> idx) - .filter(index1 -> Float.compare(scores[index1], 0.0f) == 0) + .filter(idx -> Float.compare(scores[idx], 0.0f) == 0) .findFirst() - .orElseThrow(() -> new IllegalStateException("cannot collect score for subquery")); + .orElseThrow( + () -> new IllegalStateException( + String.format( + ROOT, + "cannot set score for one of hybrid search subquery [%s] and document [%d]", + query.toString(), + scorer.docID() + ) + ) + ); scores[index] = scorer.score(); } return scores; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 95c2ba0c2..a8f1d8eb7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -237,7 +237,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom TestUtils.assertFetchResultScores(fetchSearchResult, 4); } - public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { + public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccess() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -282,14 +282,12 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - expectThrows( - IllegalStateException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ) + normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD ); + TestUtils.assertQueryResultScores(querySearchResults); } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 312cb8b3a..36613ef1b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -174,6 +174,75 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + /** + * Tests complex query with multiple nested sub-queries, where soem sub-queries are same + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "term": { + * "text": "word1" + * } + * }, + * { + * "term": { + * "text": "word2" + * } + * }, + * { + * "term": { + * "text": "word3" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testComplexQuery_whenMultipleIdenticalSubQueries_thenSuccessful() { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilderThreeTerms = new HybridQueryBuilder(); + hybridQueryBuilderThreeTerms.add(termQueryBuilder1); + hybridQueryBuilderThreeTerms.add(termQueryBuilder2); + hybridQueryBuilderThreeTerms.add(termQueryBuilder3); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderThreeTerms, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(2, getHitCount(searchResponseAsMap1)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(2, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + @SneakyThrows public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); From 7d6dc4c858569f2747bab249adde9556c062d587 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 29 Dec 2023 12:10:15 -0800 Subject: [PATCH 4/6] Throw exception when results of fetch and query phases are different Signed-off-by: Martin Gaievski --- .../processor/NormalizationProcessor.java | 52 +------------------ .../NormalizationProcessorWorkflow.java | 18 +++++-- .../NormalizationProcessorTests.java | 21 ++++---- .../NormalizationProcessorWorkflowTests.java | 16 +++--- 4 files changed, 35 insertions(+), 72 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 657b5c30c..0d6742dbe 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -7,9 +7,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; -import java.util.Arrays; import java.util.List; -import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -20,8 +18,6 @@ import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.SearchContext; @@ -102,16 +98,7 @@ private boolean shouldSkipProcessor(SearchPha } QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult; - if (queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery)) { - return true; - } - List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); - Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - if (shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(querySearchResults, fetchSearchResult)) { - log.debug("Query and fetch results do not match, normalization processor is skipped"); - return true; - } - return false; + return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); } /** @@ -144,41 +131,4 @@ private Optional getFetchS Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); } - - private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults( - final List querySearchResults, - final Optional fetchSearchResultOptional - ) { - if (fetchSearchResultOptional.isEmpty()) { - return false; - } - final List docIds = unprocessedDocIds(querySearchResults); - SearchHits searchHits = fetchSearchResultOptional.get().hits(); - SearchHit[] searchHitArray = searchHits.getHits(); - // validate the both collections are of the same size - if (Objects.isNull(searchHitArray)) { - log.info("array of search hits in fetch phase results is null"); - return true; - } - if (searchHitArray.length != docIds.size()) { - log.info( - String.format( - Locale.ROOT, - "number of documents in fetch results [%d] and query results [%d] is different", - searchHitArray.length, - docIds.size() - ) - ); - return true; - } - return false; - } - - private List unprocessedDocIds(final List querySearchResults) { - return querySearchResults.isEmpty() - ? List.of() - : Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs) - .map(scoreDoc -> scoreDoc.doc) - .collect(Collectors.toList()); - } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 5929370be..71daeac35 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -139,7 +139,7 @@ private void updateOriginalFetchResults( // 3. update original scores to normalized and combined values // 4. order scores based on normalized and combined values FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get(); - SearchHit[] searchHitArray = getSearchHits(fetchSearchResult); + SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult); // create map of docId to index of search hits. This solves (2), duplicates are from // delimiter and start/stop elements, they all have same valid doc_id. For this map @@ -169,9 +169,21 @@ private void updateOriginalFetchResults( fetchSearchResult.hits(updatedSearchHits); } - private SearchHit[] getSearchHits(final FetchSearchResult fetchSearchResult) { + private SearchHit[] getSearchHits(final List docIds, final FetchSearchResult fetchSearchResult) { SearchHits searchHits = fetchSearchResult.hits(); - return searchHits.getHits(); + SearchHit[] searchHitArray = searchHits.getHits(); + // validate the both collections are of the same size + if (Objects.isNull(searchHitArray)) { + throw new IllegalStateException( + "score normalization processor cannot produce final query result, fetch query phase returns empty results" + ); + } + if (searchHitArray.length != docIds.size()) { + throw new IllegalStateException( + "score normalization processor cannot produce final query result, the number of documents returned by fetch and query phases does not match" + ); + } + return searchHitArray; } private List unprocessedDocIds(final List querySearchResults) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 46ee122d1..26d9fc808 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -417,7 +418,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); } - public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNormalization() { + public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -489,15 +490,13 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNor queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); - normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - - List querySearchResults = queryPhaseResultConsumer.getAtomicArray() - .asList() - .stream() - .map(result -> result == null ? null : result.queryResult()) - .collect(Collectors.toList()); - - assertNotNull(querySearchResults); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext) + ); + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + startsWith("score normalization processor cannot produce final query result") + ); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index a8f1d8eb7..95c2ba0c2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -237,7 +237,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom TestUtils.assertFetchResultScores(fetchSearchResult, 4); } - public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccess() { + public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -282,12 +282,14 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + expectThrows( + IllegalStateException.class, + () -> normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ) ); - TestUtils.assertQueryResultScores(querySearchResults); } } From 6e3264d185cd1308e5bef5592952e1a76c9a2de3 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 29 Dec 2023 13:21:37 -0800 Subject: [PATCH 5/6] Extend error message, fix typo in class name Signed-off-by: Martin Gaievski --- .../processor/NormalizationProcessorWorkflow.java | 7 ++++++- ...oldCheckerTests.java => HitsThresholdCheckerTests.java} | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) rename src/test/java/org/opensearch/neuralsearch/search/{HitsTresholdCheckerTests.java => HitsThresholdCheckerTests.java} (87%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 71daeac35..c322102d5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -180,7 +180,12 @@ private SearchHit[] getSearchHits(final List docIds, final FetchSearchR } if (searchHitArray.length != docIds.size()) { throw new IllegalStateException( - "score normalization processor cannot produce final query result, the number of documents returned by fetch and query phases does not match" + String.format( + Locale.ROOT, + "score normalization processor cannot produce final query result, the number of documents after fetch phase [%d] is different from number of documents from query phase [%d]", + searchHitArray.length, + docIds.size() + ) ); } return searchHitArray; diff --git a/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java b/src/test/java/org/opensearch/neuralsearch/search/HitsThresholdCheckerTests.java similarity index 87% rename from src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java rename to src/test/java/org/opensearch/neuralsearch/search/HitsThresholdCheckerTests.java index df1198232..3ce9e3dfe 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HitsThresholdCheckerTests.java @@ -10,9 +10,9 @@ import org.apache.lucene.search.ScoreMode; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -public class HitsTresholdCheckerTests extends OpenSearchQueryTestCase { +public class HitsThresholdCheckerTests extends OpenSearchQueryTestCase { - public void testTresholdReached_whenIncrementCount_thenTresholdReached() { + public void testThresholdReached_whenIncrementCount_thenThresholdReached() { HitsThresholdChecker hitsThresholdChecker = new HitsThresholdChecker(5); assertEquals(5, hitsThresholdChecker.getTotalHitsThreshold()); assertEquals(ScoreMode.TOP_SCORES, hitsThresholdChecker.scoreMode()); @@ -23,7 +23,7 @@ public void testTresholdReached_whenIncrementCount_thenTresholdReached() { assertTrue(hitsThresholdChecker.isThresholdReached()); } - public void testTresholdLimit_whenThresholdNegative_thenFail() { + public void testThresholdLimit_whenThresholdNegative_thenFail() { expectThrows(IllegalArgumentException.class, () -> new HitsThresholdChecker(-1)); } From 7ea7d904952e66965be55beefa95c5c4ac490aa7 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 29 Dec 2023 14:08:45 -0800 Subject: [PATCH 6/6] Address Heemin's comments Signed-off-by: Martin Gaievski --- .../org/opensearch/neuralsearch/query/HybridQueryScorer.java | 5 ++--- .../org/opensearch/neuralsearch/query/HybridQueryIT.java | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 57ad4451f..e3e6a0862 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -5,13 +5,12 @@ package org.opensearch.neuralsearch.query; -import static java.util.Locale.ROOT; - import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -128,7 +127,7 @@ public float[] hybridScores() throws IOException { .orElseThrow( () -> new IllegalStateException( String.format( - ROOT, + Locale.ROOT, "cannot set score for one of hybrid search subquery [%s] and document [%d]", query.toString(), scorer.docID() diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 36613ef1b..864ebdc68 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -175,7 +175,7 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { } /** - * Tests complex query with multiple nested sub-queries, where soem sub-queries are same + * Tests complex query with multiple nested sub-queries, where some sub-queries are same * { * "query": { * "hybrid": {