Skip to content

Commit

Permalink
rank result
Browse files Browse the repository at this point in the history
Signed-off-by: Jackie Han <[email protected]>
  • Loading branch information
jackiehanyang committed Feb 1, 2024
1 parent 84c6a68 commit 2cae54c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.transport.GetWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.workflow.RankSearchResultStep;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.profile.SearchProfileShardResults;

import java.util.Map;
import java.util.concurrent.TimeUnit;

import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
Expand Down Expand Up @@ -57,18 +64,33 @@ protected FlowFrameworkResponseProcessor(String tag,
*/
@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
System.out.println("id: " + workflowId);

GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId);
GetResponse getResponse = client.get(getRequest).get();
System.out.println("source: " + getResponse.getSourceAsString());
SearchHits hits = response.getHits();
if (hits.getHits().length == 0) {
logger.info("TotalHits = 0. Returning search response without applying flow framework processor");
return response;
}
GetResponse getResponse = client.get(new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId)).get();
Template template = Template.parse(getResponse.getSourceAsString());
Workflow searchWorkflow = template.workflows().get(SEARCH_WORKFLOW);
WorkflowNode workflowNode = searchWorkflow.nodes().get(0);
System.out.println("node: " + workflowNode.type());
System.out.println("template: " + template.toJson());
logger.info("Executing search workflow step: " + workflowNode.type());
RankSearchResultStep rankSearchResultStep = new RankSearchResultStep();
long startTime = System.nanoTime();
SearchHits rankedSearchHits = rankSearchResultStep.rankAsec(hits);
long rankedTimeTookMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);


final SearchResponseSections transformedSearchResponseSections = new InternalSearchResponse(rankedSearchHits,
(InternalAggregations) response.getAggregations(), response.getSuggest(),
new SearchProfileShardResults(response.getProfileResults()), response.isTimedOut(),
response.isTerminatedEarly(), response.getNumReducePhases());

final SearchResponse transformedResponse = new SearchResponse(transformedSearchResponseSections, response.getScrollId(),
response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getTook().getMillis() + rankedTimeTookMs, response.getShardFailures(),
response.getClusters());

return response;
return transformedResponse;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
package org.opensearch.flowframework.workflow;

import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class RankSearchResultStep implements WorkflowStep {

public RankSearchResultStep() {}

public static final String NAME = "rank_search_result";

/**
* Triggers the actual processing of the building block.
*
Expand All @@ -28,6 +35,30 @@ public PlainActionFuture<WorkflowData> execute(String currentNodeId,
return future;
}

public SearchHits rankAsec(SearchHits hits) {
List<SearchHit> originalHits = Arrays.asList(hits.getHits());
List<SearchHit> sortedHits = originalHits.stream()
.sorted((hit1, hit2) -> {
// Extract the "aa" field values as Double from the source maps of hit1 and hit2
Double aa1 = toDouble(hit1.getSourceAsMap().get("aa"));
Double aa2 = toDouble(hit2.getSourceAsMap().get("aa"));

// Use Double.compare to compare the "aa" field values, handling nulls
return Double.compare(aa1 != null ? aa1 : Double.MIN_VALUE, aa2 != null ? aa2 : Double.MIN_VALUE);
})
.collect(Collectors.toList());
SearchHits sortedSearchHits = new SearchHits(sortedHits.toArray(new SearchHit[0]), hits.getTotalHits(),hits.getMaxScore());
return sortedSearchHits;
}

// Utility method to safely convert an object to Double
private Double toDouble(Object obj) {
if (obj instanceof Number) {
return ((Number) obj).doubleValue();
}
return null;
}

/**
* Gets the name of the workflow step.
*
Expand Down

0 comments on commit 2cae54c

Please sign in to comment.