forked from opensearch-project/neural-search
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (o…
…pensearch-project#874) * initial commit of RRF Signed-off-by: Isaac Johnson <[email protected]> Co-authored-by: Varun Jain <[email protected]> Signed-off-by: Martin Gaievski <[email protected]>
- Loading branch information
1 parent
4875dd5
commit 4cc2141
Showing
34 changed files
with
1,168 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.NonNull; | ||
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; | ||
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; | ||
import org.opensearch.search.fetch.FetchSearchResult; | ||
import org.opensearch.search.query.QuerySearchResult; | ||
|
||
import java.util.List; | ||
import java.util.Optional; | ||
|
||
/** | ||
* DTO object to hold data in NormalizationProcessorWorkflow class | ||
* in NormalizationProcessorWorkflow. | ||
*/ | ||
@AllArgsConstructor | ||
@Builder | ||
@Getter | ||
public class NormalizationExecuteDTO { | ||
@NonNull | ||
private List<QuerySearchResult> querySearchResults; | ||
@NonNull | ||
private Optional<FetchSearchResult> fetchSearchResultOptional; | ||
@NonNull | ||
private ScoreNormalizationTechnique normalizationTechnique; | ||
@NonNull | ||
private ScoreCombinationTechnique combinationTechnique; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.NonNull; | ||
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* DTO object to hold data required for score normalization. | ||
*/ | ||
@AllArgsConstructor | ||
@Builder | ||
@Getter | ||
public class NormalizeScoresDTO { | ||
@NonNull | ||
private List<CompoundTopDocs> queryTopDocs; | ||
@NonNull | ||
private ScoreNormalizationTechnique normalizationTechnique; | ||
} |
139 changes: 139 additions & 0 deletions
139
src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor; | ||
|
||
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; | ||
|
||
import java.util.stream.Collectors; | ||
|
||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
import lombok.Getter; | ||
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; | ||
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; | ||
import org.opensearch.search.fetch.FetchSearchResult; | ||
import org.opensearch.search.query.QuerySearchResult; | ||
|
||
import org.opensearch.action.search.QueryPhaseResultConsumer; | ||
import org.opensearch.action.search.SearchPhaseContext; | ||
import org.opensearch.action.search.SearchPhaseName; | ||
import org.opensearch.action.search.SearchPhaseResults; | ||
import org.opensearch.search.SearchPhaseResult; | ||
import org.opensearch.search.internal.SearchContext; | ||
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Processor for implementing reciprocal rank fusion technique on post | ||
* query search results. Updates query results with | ||
* normalized and combined scores for next phase (typically it's FETCH) | ||
* by using ranks from individual subqueries to calculate 'normalized' | ||
* scores before combining results from subqueries into final results | ||
*/ | ||
@Log4j2 | ||
@AllArgsConstructor | ||
public class RRFProcessor implements SearchPhaseResultsProcessor { | ||
public static final String TYPE = "score-ranker-processor"; | ||
|
||
@Getter | ||
private final String tag; | ||
@Getter | ||
private final String description; | ||
private final ScoreNormalizationTechnique normalizationTechnique; | ||
private final ScoreCombinationTechnique combinationTechnique; | ||
private final NormalizationProcessorWorkflow normalizationWorkflow; | ||
|
||
/** | ||
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage | ||
* are set as part of class constructor | ||
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution | ||
* @param searchPhaseContext {@link SearchContext} | ||
*/ | ||
@Override | ||
public <Result extends SearchPhaseResult> void process( | ||
final SearchPhaseResults<Result> searchPhaseResult, | ||
final SearchPhaseContext searchPhaseContext | ||
) { | ||
if (shouldSkipProcessor(searchPhaseResult)) { | ||
log.debug("Query results are not compatible with RRF processor"); | ||
return; | ||
} | ||
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); | ||
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult); | ||
|
||
// make data transfer object to pass in, execute will get object with 4 or 5 fields, depending | ||
// on coming from NormalizationProcessor or RRFProcessor | ||
NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() | ||
.querySearchResults(querySearchResults) | ||
.fetchSearchResultOptional(fetchSearchResult) | ||
.normalizationTechnique(normalizationTechnique) | ||
.combinationTechnique(combinationTechnique) | ||
.build(); | ||
normalizationWorkflow.execute(normalizationExecuteDTO); | ||
} | ||
|
||
@Override | ||
public SearchPhaseName getBeforePhase() { | ||
return SearchPhaseName.QUERY; | ||
} | ||
|
||
@Override | ||
public SearchPhaseName getAfterPhase() { | ||
return SearchPhaseName.FETCH; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public boolean isIgnoreFailure() { | ||
return false; | ||
} | ||
|
||
private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) { | ||
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { | ||
return true; | ||
} | ||
|
||
return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); | ||
} | ||
|
||
/** | ||
* Return true if results are from hybrid query. | ||
* @param searchPhaseResult | ||
* @return true if results are from hybrid query | ||
*/ | ||
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { | ||
// check for delimiter at the end of the score docs. | ||
return Objects.nonNull(searchPhaseResult.queryResult()) | ||
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs()) | ||
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs) | ||
&& searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0 | ||
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); | ||
} | ||
|
||
private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults( | ||
final SearchPhaseResults<Result> results | ||
) { | ||
return results.getAtomicArray() | ||
.asList() | ||
.stream() | ||
.map(result -> result == null ? null : result.queryResult()) | ||
.collect(Collectors.toList()); | ||
} | ||
|
||
private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults( | ||
final SearchPhaseResults<Result> searchPhaseResults | ||
) { | ||
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); | ||
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); | ||
} | ||
} |
32 changes: 32 additions & 0 deletions
32
.../java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.combination; | ||
|
||
import lombok.ToString; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
import java.util.Map; | ||
|
||
@Log4j2 | ||
/** | ||
* Abstracts combination of scores based on reciprocal rank fusion algorithm | ||
*/ | ||
@ToString(onlyExplicitlyIncluded = true) | ||
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique { | ||
@ToString.Include | ||
public static final String TECHNIQUE_NAME = "rrf"; | ||
|
||
// Not currently using weights for RRF, no need to modify or verify these params | ||
public RRFScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {} | ||
|
||
@Override | ||
public float combine(final float[] scores) { | ||
float sumScores = 0.0f; | ||
for (float score : scores) { | ||
sumScores += score; | ||
} | ||
return sumScores; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.