From 534b2ebe343a84450f5789b0408a79450252a4ac Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 27 Dec 2023 22:47:51 -0800 Subject: [PATCH] Add check for fetch and query result sizes Signed-off-by: Martin Gaievski --- .../processor/NormalizationProcessor.java | 32 ++++ .../NormalizationProcessorWorkflow.java | 11 +- .../NormalizationProcessorTests.java | 175 ++++++++++++++++++ 3 files changed, 216 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 0d6742dbe..190bccc46 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -7,6 +7,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -18,6 +19,8 @@ import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.SearchContext; @@ -59,6 +62,10 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); + if (shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(querySearchResults, fetchSearchResult)) { + log.debug("Query and fetch results do not match, normalization processor is skipped"); + return; + } normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique); } @@ -131,4 +138,29 @@ private Optional getFetchS Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); } + + private List unprocessedDocIds(final List querySearchResults) { + return querySearchResults.isEmpty() + ? List.of() + : Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.doc) + .collect(Collectors.toList()); + } + + private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults( + final List querySearchResults, + final Optional fetchSearchResultOptional + ) { + if (fetchSearchResultOptional.isEmpty()) { + return false; + } + final List docIds = unprocessedDocIds(querySearchResults); + SearchHits searchHits = fetchSearchResultOptional.get().hits(); + SearchHit[] searchHitArray = searchHits.getHits(); + // validate the both collections are of the same size + if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) { + return true; + } + return false; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index b8bc86de5..55ec63631 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -173,8 +173,15 @@ private SearchHit[] getSearchHits(final List docIds, final FetchSearchR SearchHits searchHits = fetchSearchResult.hits(); SearchHit[] searchHitArray = searchHits.getHits(); // validate the both collections are of the same size - if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) { - throw new IllegalStateException("Score normalization processor cannot produce final query result"); + if (Objects.isNull(searchHitArray)) { + throw new IllegalStateException( + "Score normalization processor cannot produce final query result, for one shard case fetch does not have any results" + ); + } + if (searchHitArray.length != docIds.size()) { + throw new IllegalStateException( + "Score normalization processor cannot produce final query result, for one shard case number of fetched documents does not match number of search hits" + ); } return searchHitArray; } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 41348ec49..46ee122d1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -15,6 +15,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -45,9 +46,13 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.pipeline.PipelineAggregator; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.fetch.QueryFetchSearchResult; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -325,4 +330,174 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); } + + public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(4), + createDelimiterElementForHybridSearchResults(4), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(4) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + fetchSearchResult.setShardIndex(shardId); + fetchSearchResult.setSearchShardTarget(searchShardTarget); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(2, "1", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(10, "3", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()) }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); + queryFetchSearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + List querySearchResults = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + + TestUtils.assertQueryResultScores(querySearchResults); + verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + } + + public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNormalization() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(4), + createDelimiterElementForHybridSearchResults(4), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(4) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + fetchSearchResult.setShardIndex(shardId); + fetchSearchResult.setSearchShardTarget(searchShardTarget); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(2, "1", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(10, "3", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(5, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); + queryFetchSearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + List querySearchResults = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + + assertNotNull(querySearchResults); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + } }