From 2afe82e3f9af91ad7c5dec189266d2a3b997646e Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 3 Oct 2024 10:22:27 -0700 Subject: [PATCH] Added specific rescore exception, refactor code Signed-off-by: Martin Gaievski --- .../search/query/HybridCollectorManager.java | 29 ++++---- .../HybridSearchRescoreQueryException.java | 17 +++++ .../query/HybridCollectorManagerTests.java | 74 +++++++++++++++++++ 3 files changed, 105 insertions(+), 15 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/exception/HybridSearchRescoreQueryException.java diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index a0d444fe6..f9457f6ca 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -19,7 +19,6 @@ import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.FieldDoc; -import org.opensearch.OpenSearchException; import org.opensearch.common.Nullable; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; @@ -37,6 +36,7 @@ import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; +import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException; import java.io.IOException; import java.util.ArrayList; @@ -189,23 +189,22 @@ private TopDocsAndMaxScore getSortedTopDocsAndMaxScore(List topDoc } private TopDocsAndMaxScore getTopDocsAndMaxScore(List topDocs, HybridSearchCollector hybridSearchCollector) { - List rescoredTopDocs = rescore(topDocs); - float maxScore = calculateMaxScore(rescoredTopDocs, hybridSearchCollector.getMaxScore()); - TopDocs finalTopDocs = getNewTopDocs( - getTotalHits(this.trackTotalHitsUpTo, rescoredTopDocs, hybridSearchCollector.getTotalHits()), - rescoredTopDocs - ); + if (shouldRescore()) { + topDocs = rescore(topDocs); + } + float maxScore = calculateMaxScore(topDocs, hybridSearchCollector.getMaxScore()); + TopDocs finalTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), topDocs); return new TopDocsAndMaxScore(finalTopDocs, maxScore); } - private List rescore(List topDocs) { + private boolean shouldRescore() { List rescoreContexts = searchContext.rescore(); - boolean shouldRescore = Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty(); - if (!shouldRescore) { - return topDocs; - } + return Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty(); + } + + private List rescore(List topDocs) { List rescoredTopDocs = topDocs; - for (RescoreContext ctx : rescoreContexts) { + for (RescoreContext ctx : searchContext.rescore()) { rescoredTopDocs = rescoredTopDocs(ctx, rescoredTopDocs); } return rescoredTopDocs; @@ -220,8 +219,8 @@ private List rescoredTopDocs(final RescoreContext ctx, final List, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.flush(); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + RescoreContext rescoreContext = mock(RescoreContext.class); + Rescorer rescorer = mock(Rescorer.class); + when(rescoreContext.rescorer()).thenReturn(rescorer); + when(rescorer.rescore(any(), any(), any())).thenThrow(new IOException("something happened with rescorer")); + List rescoreContexts = List.of(rescoreContext); + when(searchContext.rescore()).thenReturn(rescoreContexts); + + CollectorManager hybridCollectorManager1 = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager1.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector.setWeight(weight); + + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + + expectThrows(HybridSearchRescoreQueryException.class, () -> hybridCollectorManager1.reduce(List.of())); + + // release resources + w.close(); + reader.close(); + directory.close(); + } }