Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainability in hybrid query #970

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
Expand Down Expand Up @@ -80,6 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
public static final String EXPLANATION_RESPONSE_KEY = "explanation_response";

@Override
public Collection<Object> createComponents(
Expand Down Expand Up @@ -181,7 +184,9 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRespon
) {
return Map.of(
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService())
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService()),
ExplanationResponseProcessor.TYPE,
new ExplanationResponseProcessorFactory()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import lombok.Setter;
import lombok.ToString;
import lombok.extern.log4j.Log4j2;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.query.QuerySearchResult;

/**
* Class stores collection of TopDocs for each sub query from hybrid query. Collection of results is at shard level. We do store
Expand All @@ -37,15 +39,23 @@ public class CompoundTopDocs {
private List<TopDocs> topDocs;
@Setter
private List<ScoreDoc> scoreDocs;
@Getter
private SearchShard searchShard;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs, final boolean isSortEnabled) {
initialize(totalHits, topDocs, isSortEnabled);
public CompoundTopDocs(
final TotalHits totalHits,
final List<TopDocs> topDocs,
final boolean isSortEnabled,
final SearchShard searchShard
) {
initialize(totalHits, topDocs, isSortEnabled, searchShard);
}

private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSortEnabled) {
private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSortEnabled, SearchShard searchShard) {
this.totalHits = totalHits;
this.topDocs = topDocs;
scoreDocs = cloneLargestScoreDocs(topDocs, isSortEnabled);
this.searchShard = searchShard;
}

/**
Expand All @@ -72,14 +82,17 @@ private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSo
* 6, 0.15
* 0, 9549511920.4881596047
*/
public CompoundTopDocs(final TopDocs topDocs) {
public CompoundTopDocs(final QuerySearchResult querySearchResult) {
final TopDocs topDocs = querySearchResult.topDocs().topDocs;
final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget();
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
boolean isSortEnabled = false;
if (topDocs instanceof TopFieldDocs) {
isSortEnabled = true;
}
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled);
initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled, searchShard);
return;
}
// skipping first two elements, it's a start-stop element and delimiter for first series
Expand All @@ -103,7 +116,7 @@ public CompoundTopDocs(final TopDocs topDocs) {
scoreDocList.add(scoreDoc);
}
}
initialize(topDocs.totalHits, topDocsList, isSortEnabled);
initialize(topDocs.totalHits, topDocsList, isSortEnabled, searchShard);
}

private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs, boolean isSortEnabled) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.lucene.search.Explanation;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR;

/**
* Processor to add explanation details to search response
*/
@Getter
@AllArgsConstructor
public class ExplanationResponseProcessor implements SearchResponseProcessor {

public static final String TYPE = "explanation_response_processor";

private final String description;
private final String tag;
private final boolean ignoreFailure;

/**
* Add explanation details to search response if it is present in request context
*/
@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) {
return processResponse(request, response, null);
}

/**
* Combines explanation from processor with search hits level explanations and adds it to search response
*/
@Override
public SearchResponse processResponse(
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
final SearchRequest request,
final SearchResponse response,
final PipelineProcessingContext requestContext
) {
if (Objects.isNull(requestContext)
|| (Objects.isNull(requestContext.getAttribute(EXPLANATION_RESPONSE_KEY)))
|| requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) {
return response;
}
// Extract explanation payload from context
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY);
Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
// for score normalization, processor level explanations will be sorted in scope of each shard,
// and we are merging both into a single sorted list
SearchHits searchHits = response.getHits();
SearchHit[] searchHitsArray = searchHits.getHits();
// create a map of searchShard and list of indexes of search hit objects in search hits array
// the list will keep original order of sorting as per final search results
Map<SearchShard, List<Integer>> searchHitsByShard = new HashMap<>();
// we keep index for each shard, where index is a position in searchHitsByShard list
Map<SearchShard, Integer> explainsByShardCount = new HashMap<>();
// Build initial shard mappings
for (int i = 0; i < searchHitsArray.length; i++) {
SearchHit searchHit = searchHitsArray[i];
SearchShardTarget searchShardTarget = searchHit.getShard();
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i);
explainsByShardCount.putIfAbsent(searchShard, -1);
}
// Process normalization details if available in correct format
if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map<?, ?>) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
@SuppressWarnings("unchecked")
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = (Map<
SearchShard,
List<CombinedExplanationDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);
// Process each search hit to add processor level explanations
for (SearchHit searchHit : searchHitsArray) {
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
// Extract various explanation components
Explanation queryLevelExplanation = searchHit.getExplanation();
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
// Create normalized explanations for each detail
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
normalizedExplanation[i] = Explanation.match(
// normalized score
normalizationExplanation.getScoreDetails().get(i).getKey(),
// description of normalized score
normalizationExplanation.getScoreDetails().get(i).getValue(),
// shard level details
queryLevelExplanation.getDetails()[i]
);
}
// Create and set final explanation combining all components
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
// combination level explanation is always a single detail
combinationExplanation.getScoreDetails().get(0).getValue(),
normalizedExplanation
);
searchHit.explanation(finalExplanation);
explainsByShardCount.put(searchShard, explanationIndexByShard);
}
}
}
return response;
}

@Override
public String getType() {
return TYPE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

Expand All @@ -43,22 +44,57 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
) {
if (shouldSkipProcessor(searchPhaseResult)) {
log.debug("Query results are not compatible with normalization processor");
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain())
&& searchPhaseContext.getRequest().source().explain();
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(explain)
.pipelineProcessingContext(requestContextOptional.orElse(null))
.build();
normalizationWorkflow.execute(request);
}

@Override
Expand Down
Loading
Loading