Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.13] Add support for local cache in hybrid query #675

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,13 @@ private void updateOriginalFetchResults(
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);
// checking case when results are cached
boolean requestCache = Objects.nonNull(querySearchResults)
&& !querySearchResults.isEmpty()
&& Objects.nonNull(querySearchResults.get(0).getShardSearchRequest().requestCache())
&& querySearchResults.get(0).getShardSearchRequest().requestCache();

SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult, requestCache);

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
Expand Down Expand Up @@ -168,7 +174,7 @@ private void updateOriginalFetchResults(
fetchSearchResult.hits(updatedSearchHits);
}

private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult) {
private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult, final boolean requestCache) {
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
Expand All @@ -177,7 +183,9 @@ private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchR
"score normalization processor cannot produce final query result, fetch query phase returns empty results"
);
}
if (searchHitArray.length != docIds.size()) {
// in case of cached request results of fetch and query may be different, only restriction is
// that number of query results size is greater or equal size of fetch results
if ((!requestCache && searchHitArray.length != docIds.size()) || requestCache && docIds.size() < searchHitArray.length) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.fetch.QueryFetchSearchResult;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -401,6 +402,9 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
querySearchResult.setShardSearchRequest(shardSearchRequest);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

Expand Down Expand Up @@ -485,6 +489,9 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail()

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE);
querySearchResult.setShardSearchRequest(shardSearchRequest);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
*/
package org.opensearch.neuralsearch.processor;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

Expand All @@ -29,6 +31,7 @@
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -156,6 +159,9 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
querySearchResult.setShardSearchRequest(shardSearchRequest);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(0, "10", Map.of(), Map.of()),
Expand Down Expand Up @@ -213,6 +219,9 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
querySearchResult.setShardSearchRequest(shardSearchRequest);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
Expand All @@ -236,7 +245,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
Expand Down Expand Up @@ -270,15 +279,11 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE);
querySearchResult.setShardSearchRequest(shardSearchRequest);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
SearchHits searchHits = getSearchHits();
fetchSearchResult.hits(searchHits);

expectThrows(
Expand All @@ -291,4 +296,68 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
)
);
}

public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

List<QuerySearchResult> querySearchResults = new ArrayList<>();
FetchSearchResult fetchSearchResult = new FetchSearchResult();
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
querySearchResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] {
createStartStopElementForHybridSearchResults(0),
createDelimiterElementForHybridSearchResults(0),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(0) }
),
0.5f
),
new DocValueFormat[0]
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
querySearchResult.setShardSearchRequest(shardSearchRequest);
querySearchResults.add(querySearchResult);
SearchHits searchHits = getSearchHits();
fetchSearchResult.hits(searchHits);

normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);

TestUtils.assertQueryResultScores(querySearchResults);
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

private static SearchHits getSearchHits() {
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
return searchHits;
}
}
Loading
Loading