From 2cae54c419335dab1802b4c36f033cc171146734 Mon Sep 17 00:00:00 2001 From: Jackie Han Date: Thu, 1 Feb 2024 12:29:38 -0800 Subject: [PATCH] rank result Signed-off-by: Jackie Han --- .../FlowFrameworkResponseProcessor.java | 38 +++++++++++++++---- .../workflow/RankSearchResultStep.java | 31 +++++++++++++++ 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/processor/FlowFrameworkResponseProcessor.java b/src/main/java/org/opensearch/flowframework/processor/FlowFrameworkResponseProcessor.java index 7262f45f3..759ea6708 100644 --- a/src/main/java/org/opensearch/flowframework/processor/FlowFrameworkResponseProcessor.java +++ b/src/main/java/org/opensearch/flowframework/processor/FlowFrameworkResponseProcessor.java @@ -6,6 +6,7 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.model.Template; @@ -13,11 +14,17 @@ import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.transport.GetWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.workflow.RankSearchResultStep; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.search.profile.SearchProfileShardResults; import java.util.Map; +import java.util.concurrent.TimeUnit; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; @@ -57,18 +64,33 @@ protected FlowFrameworkResponseProcessor(String tag, */ @Override public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { - System.out.println("id: " + workflowId); - - GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); - GetResponse getResponse = client.get(getRequest).get(); - System.out.println("source: " + getResponse.getSourceAsString()); + SearchHits hits = response.getHits(); + if (hits.getHits().length == 0) { + logger.info("TotalHits = 0. Returning search response without applying flow framework processor"); + return response; + } + GetResponse getResponse = client.get(new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId)).get(); Template template = Template.parse(getResponse.getSourceAsString()); Workflow searchWorkflow = template.workflows().get(SEARCH_WORKFLOW); WorkflowNode workflowNode = searchWorkflow.nodes().get(0); - System.out.println("node: " + workflowNode.type()); - System.out.println("template: " + template.toJson()); + logger.info("Executing search workflow step: " + workflowNode.type()); + RankSearchResultStep rankSearchResultStep = new RankSearchResultStep(); + long startTime = System.nanoTime(); + SearchHits rankedSearchHits = rankSearchResultStep.rankAsec(hits); + long rankedTimeTookMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime); + + + final SearchResponseSections transformedSearchResponseSections = new InternalSearchResponse(rankedSearchHits, + (InternalAggregations) response.getAggregations(), response.getSuggest(), + new SearchProfileShardResults(response.getProfileResults()), response.isTimedOut(), + response.isTerminatedEarly(), response.getNumReducePhases()); + + final SearchResponse transformedResponse = new SearchResponse(transformedSearchResponseSections, response.getScrollId(), + response.getTotalShards(), response.getSuccessfulShards(), + response.getSkippedShards(), response.getTook().getMillis() + rankedTimeTookMs, response.getShardFailures(), + response.getClusters()); - return response; + return transformedResponse; } /** diff --git a/src/main/java/org/opensearch/flowframework/workflow/RankSearchResultStep.java b/src/main/java/org/opensearch/flowframework/workflow/RankSearchResultStep.java index 564d1a7c0..25c946e32 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RankSearchResultStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RankSearchResultStep.java @@ -1,14 +1,21 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; public class RankSearchResultStep implements WorkflowStep { public RankSearchResultStep() {} public static final String NAME = "rank_search_result"; + /** * Triggers the actual processing of the building block. * @@ -28,6 +35,30 @@ public PlainActionFuture execute(String currentNodeId, return future; } + public SearchHits rankAsec(SearchHits hits) { + List originalHits = Arrays.asList(hits.getHits()); + List sortedHits = originalHits.stream() + .sorted((hit1, hit2) -> { + // Extract the "aa" field values as Double from the source maps of hit1 and hit2 + Double aa1 = toDouble(hit1.getSourceAsMap().get("aa")); + Double aa2 = toDouble(hit2.getSourceAsMap().get("aa")); + + // Use Double.compare to compare the "aa" field values, handling nulls + return Double.compare(aa1 != null ? aa1 : Double.MIN_VALUE, aa2 != null ? aa2 : Double.MIN_VALUE); + }) + .collect(Collectors.toList()); + SearchHits sortedSearchHits = new SearchHits(sortedHits.toArray(new SearchHit[0]), hits.getTotalHits(),hits.getMaxScore()); + return sortedSearchHits; + } + + // Utility method to safely convert an object to Double + private Double toDouble(Object obj) { + if (obj instanceof Number) { + return ((Number) obj).doubleValue(); + } + return null; + } + /** * Gets the name of the workflow step. *