Skip to content

Commit

Permalink
Fixed exception in Hybrid Query for one shard and multiple node (#396)
Browse files Browse the repository at this point in the history
* Use list of original doc ids for fetch results

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Oct 6, 2023
1 parent 2c79d4d commit 8e98167
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
### Bug Fixes
Fixed exception in Hybrid Query for one shard and multiple node ([#396](https://github.com/opensearch-project/neural-search/pull/396))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -52,6 +52,9 @@ public void execute(
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);
Expand All @@ -67,7 +70,7 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
}

/**
Expand Down Expand Up @@ -123,7 +126,8 @@ private void updateOriginalQueryResults(final List<QuerySearchResult> querySearc
*/
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand All @@ -135,14 +139,17 @@ private void updateOriginalFetchResults(
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
// we use doc_id as a key, and all those special elements are collapsed into a single
// key-value pair.
Map<Integer, SearchHit> docIdToSearchHit = Arrays.stream(searchHits.getHits())
.collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1));
Map<Integer, SearchHit> docIdToSearchHit = new HashMap<>();
for (int i = 0; i < searchHitArray.length; i++) {
int originalDocId = docIds.get(i);
docIdToSearchHit.put(originalDocId, searchHitArray[i]);
}

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
Expand All @@ -161,4 +168,23 @@ private void updateOriginalFetchResults(
);
fetchSearchResult.hits(updatedSearchHits);
}

private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult) {
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
throw new IllegalStateException("Score normalization processor cannot produce final query result");
}
return searchHitArray;
}

private List<Integer> unprocessedDocIds(final List<QuerySearchResult> querySearchResults) {
List<Integer> docIds = querySearchResults.isEmpty()
? List.of()
: Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs)
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toList());
return docIds;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,117 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo
TestUtils.assertQueryResultScores(querySearchResults);
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCombination() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

List<QuerySearchResult> querySearchResults = new ArrayList<>();
FetchSearchResult fetchSearchResult = new FetchSearchResult();
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
querySearchResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] {
createStartStopElementForHybridSearchResults(0),
createDelimiterElementForHybridSearchResults(0),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(0) }
),
0.5f
),
new DocValueFormat[0]
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()), };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);

TestUtils.assertQueryResultScores(querySearchResults);
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

List<QuerySearchResult> querySearchResults = new ArrayList<>();
FetchSearchResult fetchSearchResult = new FetchSearchResult();
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
querySearchResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] {
createStartStopElementForHybridSearchResults(0),
createDelimiterElementForHybridSearchResults(0),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(0) }
),
0.5f
),
new DocValueFormat[0]
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

expectThrows(
IllegalStateException.class,
() -> normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
)
);
}
}

0 comments on commit 8e98167

Please sign in to comment.