Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Changed format for hybrid query results to a single list of scores with delimiter #266

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/))
### Enhancements
* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;

/**
* Class stores collection of TopDocs for each sub query from hybrid query. Collection of results is at shard level. We do store
* list of TopDocs and list of ScoreDoc as well as total hits for the shard.
*/
@AllArgsConstructor
@Getter
@ToString(includeFieldNames = true)
@Log4j2
public class CompoundTopDocs {

@Setter
private TotalHits totalHits;
private List<TopDocs> topDocs;
@Setter
private List<ScoreDoc> scoreDocs;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
initialize(totalHits, topDocs);
}

private void initialize(TotalHits totalHits, List<TopDocs> topDocs) {
this.totalHits = totalHits;
this.topDocs = topDocs;
scoreDocs = cloneLargestScoreDocs(topDocs);
}

/**
* Create new instance from TopDocs by parsing scores of sub-queries. Final format looks like:
* doc_id | magic_number_1
* doc_id | magic_number_2
* ...
* doc_id | magic_number_2
* ...
* doc_id | magic_number_2
* ...
* doc_id | magic_number_1
*
* where doc_id is one of valid ids from result. For example, this is list with results for there sub-queries
*
* 0, 9549511920.4881596047
* 0, 4422440593.9791198149
* 0, 0.8
* 2, 0.5
* 0, 4422440593.9791198149
* 0, 4422440593.9791198149
* 2, 0.7
* 5, 0.65
* 6, 0.15
* 0, 9549511920.4881596047
*/
public CompoundTopDocs(final TopDocs topDocs) {
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
initialize(topDocs.totalHits, new ArrayList<>());
return;
}
// skipping first two elements, it's a start-stop element and delimiter for first series
List<TopDocs> topDocsList = new ArrayList<>();
List<ScoreDoc> scoreDocList = new ArrayList<>();
for (int index = 2; index < scoreDocs.length; index++) {
// getting first element of score's series
ScoreDoc scoreDoc = scoreDocs[index];
if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) {
ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]);
TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO);
TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores);
topDocsList.add(subQueryTopDocs);
scoreDocList.clear();
} else {
scoreDocList.add(scoreDoc);
}
}
initialize(topDocs.totalHits, topDocsList);
}

private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
continue;
}
if (topDoc.scoreDocs.length > maxLength) {
maxLength = topDoc.scoreDocs.length;
maxScoreDocs = topDoc.scoreDocs;
}
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand All @@ -19,8 +21,8 @@
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;
Expand Down Expand Up @@ -56,7 +58,8 @@ public <Result extends SearchPhaseResult> void process(
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
}

@Override
Expand Down Expand Up @@ -95,19 +98,21 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
Optional<SearchPhaseResult> optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.filter(Objects::nonNull)
.findFirst();
return isNotHybridQuery(optionalSearchPhaseResult);
return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
}

private boolean isNotHybridQuery(final Optional<SearchPhaseResult> optionalSearchPhaseResult) {
return optionalSearchPhaseResult.isEmpty()
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult())
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs())
|| !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
/**
* Return true if results are from hybrid query.
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs)
&& searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
Expand All @@ -119,4 +124,11 @@ private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhase
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@

package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

/**
Expand All @@ -39,6 +48,7 @@ public class NormalizationProcessorWorkflow {
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
Expand All @@ -57,6 +67,7 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional);
}

/**
Expand All @@ -67,22 +78,87 @@ public void execute(
private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> querySearchResults) {
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> Objects.nonNull(searchResult.topDocs()))
.filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs)
.map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs)
.map(querySearchResult -> querySearchResult.topDocs().topDocs)
.map(CompoundTopDocs::new)
.collect(Collectors.toList());
if (queryTopDocs.size() != querySearchResults.size()) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match",
querySearchResults.size(),
queryTopDocs.size()
)
);
}
return queryTopDocs;
}

private void updateOriginalQueryResults(final List<QuerySearchResult> querySearchResults, final List<CompoundTopDocs> queryTopDocs) {
for (int i = 0; i < querySearchResults.size(); i++) {
QuerySearchResult querySearchResult = querySearchResults.get(i);
if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) {
continue;
}
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
if (querySearchResults.size() != queryTopDocs.size()) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match",
querySearchResults.size(),
queryTopDocs.size()
)
);
}
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs().get(0).score : 0.0f;

// create final version of top docs with all updated values
TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));

TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
}
}

/**
* A workaround for a single shard case, fetch has happened, and we need to update both fetch and query results
*/
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
}
// fetch results have list of document content, that includes start/stop and
// delimiter elements. list is in original order from query searcher. We need to:
// 1. filter out start/stop and delimiter elements
// 2. filter out duplicates from different sub-queries
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHits searchHits = fetchSearchResult.hits();

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
// we use doc_id as a key, and all those special elements are collapsed into a single
// key-value pair.
Map<Integer, SearchHit> docIdToSearchHit = Arrays.stream(searchHits.getHits())
.collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1));

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
querySearchResult.getMaxScore()
);
fetchSearchResult.hits(updatedSearchHits);
}
}
Loading