Skip to content

Commit

Permalink
Add support for local cache in hybrid query (#663)
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Apr 3, 2024
1 parent 50a6dcf commit cc6a6b2
Show file tree
Hide file tree
Showing 5 changed files with 405 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,13 @@ private void updateOriginalFetchResults(
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);
// checking case when results are cached
boolean requestCache = Objects.nonNull(querySearchResults)
&& !querySearchResults.isEmpty()
&& Objects.nonNull(querySearchResults.get(0).getShardSearchRequest().requestCache())
&& querySearchResults.get(0).getShardSearchRequest().requestCache();

SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult, requestCache);

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
Expand Down Expand Up @@ -168,7 +174,7 @@ private void updateOriginalFetchResults(
fetchSearchResult.hits(updatedSearchHits);
}

private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult) {
private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult, final boolean requestCache) {
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
Expand All @@ -177,7 +183,9 @@ private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchR
"score normalization processor cannot produce final query result, fetch query phase returns empty results"
);
}
if (searchHitArray.length != docIds.size()) {
// in case of cached request results of fetch and query may be different, only restriction is
// that number of query results size is greater or equal size of fetch results
if ((!requestCache && searchHitArray.length != docIds.size()) || requestCache && docIds.size() < searchHitArray.length) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.fetch.QueryFetchSearchResult;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -401,6 +402,9 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
querySearchResult.setShardSearchRequest(shardSearchRequest);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

Expand Down Expand Up @@ -485,6 +489,9 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail()

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE);
querySearchResult.setShardSearchRequest(shardSearchRequest);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
*/
package org.opensearch.neuralsearch.processor;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

Expand All @@ -29,6 +31,7 @@
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -156,6 +159,9 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
querySearchResult.setShardSearchRequest(shardSearchRequest);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(0, "10", Map.of(), Map.of()),
Expand Down Expand Up @@ -213,6 +219,9 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
querySearchResult.setShardSearchRequest(shardSearchRequest);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
Expand All @@ -236,7 +245,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
Expand Down Expand Up @@ -270,15 +279,11 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE);
querySearchResult.setShardSearchRequest(shardSearchRequest);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
SearchHits searchHits = getSearchHits();
fetchSearchResult.hits(searchHits);

expectThrows(
Expand All @@ -291,4 +296,68 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
)
);
}

public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

List<QuerySearchResult> querySearchResults = new ArrayList<>();
FetchSearchResult fetchSearchResult = new FetchSearchResult();
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
querySearchResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] {
createStartStopElementForHybridSearchResults(0),
createDelimiterElementForHybridSearchResults(0),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(0) }
),
0.5f
),
new DocValueFormat[0]
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
querySearchResult.setShardSearchRequest(shardSearchRequest);
querySearchResults.add(querySearchResult);
SearchHits searchHits = getSearchHits();
fetchSearchResult.hits(searchHits);

normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);

TestUtils.assertQueryResultScores(querySearchResults);
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

private static SearchHits getSearchHits() {
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
return searchHits;
}
}
Loading

0 comments on commit cc6a6b2

Please sign in to comment.