Skip to content

Commit

Permalink
Changed format for hybrid query results to a single list of scores wi…
Browse files Browse the repository at this point in the history
…th delimiter (#259) (#267)

* Changed approach for storing hybrid query results from compound top docs to signle list of scores with delimiter

Signed-off-by: Martin Gaievski <[email protected]>
(cherry picked from commit 75b59cd)

Co-authored-by: Martin Gaievski <[email protected]>
  • Loading branch information
1 parent 185050a commit b4e4535
Show file tree
Hide file tree
Showing 26 changed files with 902 additions and 385 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/))
### Enhancements
* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;

/**
* Class stores collection of TopDocs for each sub query from hybrid query. Collection of results is at shard level. We do store
* list of TopDocs and list of ScoreDoc as well as total hits for the shard.
*/
@AllArgsConstructor
@Getter
@ToString(includeFieldNames = true)
@Log4j2
public class CompoundTopDocs {

@Setter
private TotalHits totalHits;
private List<TopDocs> topDocs;
@Setter
private List<ScoreDoc> scoreDocs;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
initialize(totalHits, topDocs);
}

private void initialize(TotalHits totalHits, List<TopDocs> topDocs) {
this.totalHits = totalHits;
this.topDocs = topDocs;
scoreDocs = cloneLargestScoreDocs(topDocs);
}

/**
* Create new instance from TopDocs by parsing scores of sub-queries. Final format looks like:
* doc_id | magic_number_1
* doc_id | magic_number_2
* ...
* doc_id | magic_number_2
* ...
* doc_id | magic_number_2
* ...
* doc_id | magic_number_1
*
* where doc_id is one of valid ids from result. For example, this is list with results for there sub-queries
*
* 0, 9549511920.4881596047
* 0, 4422440593.9791198149
* 0, 0.8
* 2, 0.5
* 0, 4422440593.9791198149
* 0, 4422440593.9791198149
* 2, 0.7
* 5, 0.65
* 6, 0.15
* 0, 9549511920.4881596047
*/
public CompoundTopDocs(final TopDocs topDocs) {
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
initialize(topDocs.totalHits, new ArrayList<>());
return;
}
// skipping first two elements, it's a start-stop element and delimiter for first series
List<TopDocs> topDocsList = new ArrayList<>();
List<ScoreDoc> scoreDocList = new ArrayList<>();
for (int index = 2; index < scoreDocs.length; index++) {
// getting first element of score's series
ScoreDoc scoreDoc = scoreDocs[index];
if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) {
ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]);
TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO);
TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores);
topDocsList.add(subQueryTopDocs);
scoreDocList.clear();
} else {
scoreDocList.add(scoreDoc);
}
}
initialize(topDocs.totalHits, topDocsList);
}

private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
continue;
}
if (topDoc.scoreDocs.length > maxLength) {
maxLength = topDoc.scoreDocs.length;
maxScoreDocs = topDoc.scoreDocs;
}
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand All @@ -19,8 +21,8 @@
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;
Expand Down Expand Up @@ -56,7 +58,8 @@ public <Result extends SearchPhaseResult> void process(
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
}

@Override
Expand Down Expand Up @@ -95,19 +98,21 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
Optional<SearchPhaseResult> optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.filter(Objects::nonNull)
.findFirst();
return isNotHybridQuery(optionalSearchPhaseResult);
return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
}

private boolean isNotHybridQuery(final Optional<SearchPhaseResult> optionalSearchPhaseResult) {
return optionalSearchPhaseResult.isEmpty()
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult())
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs())
|| !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
/**
* Return true if results are from hybrid query.
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs)
&& searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
Expand All @@ -119,4 +124,11 @@ private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhase
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@

package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

/**
Expand All @@ -39,6 +48,7 @@ public class NormalizationProcessorWorkflow {
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
Expand All @@ -57,6 +67,7 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional);
}

/**
Expand All @@ -67,22 +78,87 @@ public void execute(
private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> querySearchResults) {
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> Objects.nonNull(searchResult.topDocs()))
.filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs)
.map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs)
.map(querySearchResult -> querySearchResult.topDocs().topDocs)
.map(CompoundTopDocs::new)
.collect(Collectors.toList());
if (queryTopDocs.size() != querySearchResults.size()) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match",
querySearchResults.size(),
queryTopDocs.size()
)
);
}
return queryTopDocs;
}

private void updateOriginalQueryResults(final List<QuerySearchResult> querySearchResults, final List<CompoundTopDocs> queryTopDocs) {
for (int i = 0; i < querySearchResults.size(); i++) {
QuerySearchResult querySearchResult = querySearchResults.get(i);
if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) {
continue;
}
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
if (querySearchResults.size() != queryTopDocs.size()) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match",
querySearchResults.size(),
queryTopDocs.size()
)
);
}
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs().get(0).score : 0.0f;

// create final version of top docs with all updated values
TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));

TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
}
}

/**
* A workaround for a single shard case, fetch has happened, and we need to update both fetch and query results
*/
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
}
// fetch results have list of document content, that includes start/stop and
// delimiter elements. list is in original order from query searcher. We need to:
// 1. filter out start/stop and delimiter elements
// 2. filter out duplicates from different sub-queries
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHits searchHits = fetchSearchResult.hits();

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
// we use doc_id as a key, and all those special elements are collapsed into a single
// key-value pair.
Map<Integer, SearchHit> docIdToSearchHit = Arrays.stream(searchHits.getHits())
.collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1));

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
querySearchResult.getMaxScore()
);
fetchSearchResult.hits(updatedSearchHits);
}
}
Loading

0 comments on commit b4e4535

Please sign in to comment.