From 8388091e5be50d0b744e9d1dbb498f80dde60bf2 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 22 Feb 2024 16:08:47 -0800 Subject: [PATCH 1/7] Added optional step param to get the workflow steps API Signed-off-by: Owais Kazi --- .../flowframework/common/CommonValue.java | 2 ++ .../rest/RestGetWorkflowStepAction.java | 26 +++++++++++-------- .../GetWorkflowStepTransportAction.java | 13 +++++++--- .../transport/WorkflowRequest.java | 2 +- .../workflow/WorkflowStepFactory.java | 12 +++++++++ 5 files changed, 39 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index c1752a018..3ef2bb3dd 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -64,6 +64,8 @@ private CommonValue() {} public static final String VALIDATION = "validation"; /** The field name for provision workflow within a use case template*/ public static final String PROVISION_WORKFLOW = "provision"; + /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ + public static final String STEPS = "steps"; /* * Constants associated with plugin configuration diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java index 28a9ffac4..f62267dd2 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java @@ -11,25 +11,27 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.GetWorkflowStepAction; +import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.CommonValue.*; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; /** @@ -60,7 +62,9 @@ public List routes() { } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + String[] steps = request.paramAsStringArray(STEPS, Strings.EMPTY_ARRAY); try { if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -69,13 +73,13 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient ); } - ActionRequest request = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - }; - return channel -> client.execute(GetWorkflowStepAction.INSTANCE, request, ActionListener.wrap(response -> { + Map params = new HashMap<>(); + for (String step : steps) { + params.put(STEPS, step); + } + + WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params); + return channel -> client.execute(GetWorkflowStepAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); }, exception -> { diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java index 8b4d8a001..6a643f4ec 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; @@ -22,10 +21,13 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import java.util.ArrayList; +import java.util.List; + /** * Transport action to retrieve a workflow step json */ -public class GetWorkflowStepTransportAction extends HandledTransportAction { +public class GetWorkflowStepTransportAction extends HandledTransportAction { private final Logger logger = LogManager.getLogger(GetWorkflowStepTransportAction.class); private final WorkflowStepFactory workflowStepFactory; @@ -47,9 +49,12 @@ public GetWorkflowStepTransportAction( } @Override - protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { try { - WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidator(); + + List steps = new ArrayList<>(request.getParams().values()); + WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps); + listener.onResponse(new GetWorkflowStepResponse(workflowValidator)); } catch (Exception e) { logger.error("Failed to retrieve workflow step json.", e); diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 6268dab2c..5b3c3c0d8 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -64,7 +64,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param template the use case template which describes the workflow * @param params The parameters from the REST path */ - public WorkflowRequest(String workflowId, @Nullable Template template, Map params) { + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map params) { this(workflowId, template, new String[] { "all" }, true, params); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 3fb91d6c6..c15a8223c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -364,6 +364,18 @@ public WorkflowValidator getWorkflowValidator() { return new WorkflowValidator(workflowStepValidators); } + public WorkflowValidator getWorkflowValidatorByStep(List steps) { + Map workflowStepValidators = new HashMap<>(); + + for (WorkflowSteps mapping : WorkflowSteps.values()) { + if (steps.contains(mapping.getWorkflowStepName())) { + workflowStepValidators.put(mapping.getWorkflowStepName(), mapping.getWorkflowStepValidator()); + } + } + + return new WorkflowValidator(workflowStepValidators); + } + /** * Create a new instance of a {@link WorkflowStep}. * @param type The type of instance to create From c8cf57fb7d275686e48fa45f831dd726ec386946 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 22 Feb 2024 17:37:11 -0800 Subject: [PATCH 2/7] Fixed api response Signed-off-by: Owais Kazi --- .../flowframework/rest/RestGetWorkflowStepAction.java | 5 +++-- .../transport/GetWorkflowStepTransportAction.java | 11 +++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java index f62267dd2..5d676fe44 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java @@ -31,7 +31,8 @@ import java.util.Locale; import java.util.Map; -import static org.opensearch.flowframework.common.CommonValue.*; +import static org.opensearch.flowframework.common.CommonValue.STEPS; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; /** @@ -75,7 +76,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli Map params = new HashMap<>(); for (String step : steps) { - params.put(STEPS, step); + params.put(step, STEPS); } WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java index 6a643f4ec..246c7425d 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java @@ -51,10 +51,13 @@ public GetWorkflowStepTransportAction( @Override protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { try { - - List steps = new ArrayList<>(request.getParams().values()); - WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps); - + List steps = new ArrayList<>(request.getParams().keySet()); + WorkflowValidator workflowValidator; + if (steps.isEmpty()) { + workflowValidator = this.workflowStepFactory.getWorkflowValidator(); + } else { + workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps); + } listener.onResponse(new GetWorkflowStepResponse(workflowValidator)); } catch (Exception e) { logger.error("Failed to retrieve workflow step json.", e); From be988254e39452f18b71fb91cea61e553bea0a28 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 22 Feb 2024 23:27:56 -0800 Subject: [PATCH 3/7] Added tests Signed-off-by: Owais Kazi --- .../flowframework/common/CommonValue.java | 2 +- .../rest/RestGetWorkflowStepAction.java | 6 +- .../GetWorkflowStepTransportAction.java | 4 ++ .../workflow/WorkflowStepFactory.java | 9 +++ .../rest/RestGetWorkflowStepActionTests.java | 62 +++++++++++++++++++ .../GetWorkflowStepTransportActionTests.java | 1 - 6 files changed, 79 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 3ef2bb3dd..265f0fdea 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -65,7 +65,7 @@ private CommonValue() {} /** The field name for provision workflow within a use case template*/ public static final String PROVISION_WORKFLOW = "provision"; /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ - public static final String STEPS = "steps"; + public static final String STEP = "step"; /* * Constants associated with plugin configuration diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java index 5d676fe44..472446961 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java @@ -31,7 +31,7 @@ import java.util.Locale; import java.util.Map; -import static org.opensearch.flowframework.common.CommonValue.STEPS; +import static org.opensearch.flowframework.common.CommonValue.STEP; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -65,7 +65,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - String[] steps = request.paramAsStringArray(STEPS, Strings.EMPTY_ARRAY); + String[] steps = request.paramAsStringArray(STEP, Strings.EMPTY_ARRAY); try { if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -76,7 +76,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli Map params = new HashMap<>(); for (String step : steps) { - params.put(step, STEPS); + params.put(step, STEP); } WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java index 246c7425d..0173452b3 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java @@ -60,6 +60,10 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener steps) { Map workflowStepValidators = new HashMap<>(); @@ -373,6 +378,10 @@ public WorkflowValidator getWorkflowValidatorByStep(List steps) { } } + if (workflowStepValidators.isEmpty()) { + throw new FlowFrameworkException("Please only use only valid step name", RestStatus.BAD_REQUEST); + } + return new WorkflowValidator(workflowStepValidators); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java index 6854e3c49..d5c85d4a0 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java @@ -9,20 +9,32 @@ package org.opensearch.flowframework.rest; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.WorkflowValidator; +import org.opensearch.flowframework.transport.GetWorkflowStepResponse; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.rest.FakeRestChannel; import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.ThreadPool; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -31,12 +43,22 @@ public class RestGetWorkflowStepActionTests extends OpenSearchTestCase { private String getPath; private FlowFrameworkSettings flowFrameworkFeatureEnabledSetting; private NodeClient nodeClient; + private WorkflowStepFactory workflowStepFactory; + private FlowFrameworkSettings flowFrameworkSettings; @Override public void setUp() throws Exception { super.setUp(); + flowFrameworkSettings = mock(FlowFrameworkSettings.class); + when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true); + this.getPath = String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, "_steps"); + ThreadPool threadPool = mock(ThreadPool.class); + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + + this.workflowStepFactory = new WorkflowStepFactory(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkSettings.class); when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); this.restGetWorkflowStepAction = new RestGetWorkflowStepAction(flowFrameworkFeatureEnabledSetting); @@ -68,6 +90,46 @@ public void testInvalidRequestWithContent() { assertEquals("request [GET /_plugins/_flow_framework/workflow/_steps] does not support having a body", ex.getMessage()); } + public void testWorkflowSteps() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) + .withPath(this.getPath + "?step=create_connector") + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + List steps = new ArrayList<>(); + steps.add("create_connector"); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps); + actionListener.onResponse(new GetWorkflowStepResponse(workflowValidator)); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + restGetWorkflowStepAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.OK, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("create_connector")); + } + + public void testFailedWorkflowSteps() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) + .withPath(this.getPath + "?step=xyz") + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + List steps = new ArrayList<>(); + steps.add("xyz"); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps); + actionListener.onResponse(new GetWorkflowStepResponse(workflowValidator)); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + FlowFrameworkException exception = expectThrows( + FlowFrameworkException.class, + () -> restGetWorkflowStepAction.handleRequest(request, channel, nodeClient) + ); + assertEquals("Please only use only valid step name", exception.getMessage()); + } + public void testFeatureFlagNotEnabled() throws Exception { when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java index 685198e3d..6a63f2b16 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java @@ -45,6 +45,5 @@ public void testGetWorkflowStepAction() throws IOException { ArgumentCaptor stepCaptor = ArgumentCaptor.forClass(GetWorkflowStepResponse.class); verify(listener, times(1)).onResponse(stepCaptor.capture()); - } } From bbd1d253bfc0e64b1c9c51f3cac48ff73ef2c079 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 22 Feb 2024 23:37:55 -0800 Subject: [PATCH 4/7] Added CHANGELOG Signed-off-by: Owais Kazi --- CHANGELOG.md | 1 + .../org/opensearch/flowframework/common/CommonValue.java | 2 +- .../flowframework/rest/RestGetWorkflowStepAction.java | 6 +++--- .../flowframework/rest/RestGetWorkflowStepActionTests.java | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5377c0296..bfa0f8f18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Enhancements - Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525)) +- Added an optional workflow_step param to the get workflow steps API ([#538](https://github.com/opensearch-project/flow-framework/pull/538)) ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 265f0fdea..f4bec21a7 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -65,7 +65,7 @@ private CommonValue() {} /** The field name for provision workflow within a use case template*/ public static final String PROVISION_WORKFLOW = "provision"; /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ - public static final String STEP = "step"; + public static final String WORKFLOW_STEP = "workflow_step"; /* * Constants associated with plugin configuration diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java index 472446961..bd665dd81 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java @@ -31,7 +31,7 @@ import java.util.Locale; import java.util.Map; -import static org.opensearch.flowframework.common.CommonValue.STEP; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -65,7 +65,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - String[] steps = request.paramAsStringArray(STEP, Strings.EMPTY_ARRAY); + String[] steps = request.paramAsStringArray(WORKFLOW_STEP, Strings.EMPTY_ARRAY); try { if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -76,7 +76,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli Map params = new HashMap<>(); for (String step : steps) { - params.put(step, STEP); + params.put(step, WORKFLOW_STEP); } WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java index d5c85d4a0..49424196f 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java @@ -92,7 +92,7 @@ public void testInvalidRequestWithContent() { public void testWorkflowSteps() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) - .withPath(this.getPath + "?step=create_connector") + .withPath(this.getPath + "?workflow_step=create_connector") .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 1); From 8172bbaaf17511784b9f65b386da08413b752f33 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 23 Feb 2024 15:29:50 -0800 Subject: [PATCH 5/7] Addressed PR comments Signed-off-by: Owais Kazi --- .../rest/RestGetWorkflowStepAction.java | 12 ++++-------- .../transport/GetWorkflowStepTransportAction.java | 12 +++++++++--- .../flowframework/workflow/WorkflowStepFactory.java | 11 ++++++++--- .../rest/RestGetWorkflowStepActionTests.java | 4 ++-- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java index bd665dd81..e25daa0aa 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java @@ -13,7 +13,6 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -26,7 +25,7 @@ import org.opensearch.rest.RestRequest; import java.io.IOException; -import java.util.HashMap; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -64,8 +63,6 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - - String[] steps = request.paramAsStringArray(WORKFLOW_STEP, Strings.EMPTY_ARRAY); try { if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -74,10 +71,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } - Map params = new HashMap<>(); - for (String step : steps) { - params.put(step, WORKFLOW_STEP); - } + Map params = request.hasParam(WORKFLOW_STEP) + ? Map.of(WORKFLOW_STEP, request.param(WORKFLOW_STEP)) + : Collections.emptyMap(); WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params); return channel -> client.execute(GetWorkflowStepAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java index 0173452b3..260d0871c 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java @@ -15,15 +15,19 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP; + /** * Transport action to retrieve a workflow step json */ @@ -51,7 +55,9 @@ public GetWorkflowStepTransportAction( @Override protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { try { - List steps = new ArrayList<>(request.getParams().keySet()); + List steps = request.getParams().size() > 0 + ? Arrays.asList(Strings.splitStringByCommaToArray(request.getParams().get(WORKFLOW_STEP))) + : Collections.emptyList(); WorkflowValidator workflowValidator; if (steps.isEmpty()) { workflowValidator = this.workflowStepFactory.getWorkflowValidator(); @@ -61,7 +67,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener steps) { Map workflowStepValidators = new HashMap<>(); + Set invalidSteps = new HashSet<>(steps); for (WorkflowSteps mapping : WorkflowSteps.values()) { - if (steps.contains(mapping.getWorkflowStepName())) { + String step = mapping.getWorkflowStepName(); + if (steps.contains(step)) { workflowStepValidators.put(mapping.getWorkflowStepName(), mapping.getWorkflowStepValidator()); + invalidSteps.remove(step); } } - if (workflowStepValidators.isEmpty()) { - throw new FlowFrameworkException("Please only use only valid step name", RestStatus.BAD_REQUEST); + if (!invalidSteps.isEmpty()) { + throw new FlowFrameworkException("Invalid step name: " + invalidSteps, RestStatus.BAD_REQUEST); } return new WorkflowValidator(workflowStepValidators); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java index 49424196f..59df28a42 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java @@ -111,7 +111,7 @@ public void testWorkflowSteps() throws Exception { public void testFailedWorkflowSteps() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) - .withPath(this.getPath + "?step=xyz") + .withPath(this.getPath + "?workflow_step=xyz") .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 1); @@ -127,7 +127,7 @@ public void testFailedWorkflowSteps() throws Exception { FlowFrameworkException.class, () -> restGetWorkflowStepAction.handleRequest(request, channel, nodeClient) ); - assertEquals("Please only use only valid step name", exception.getMessage()); + assertEquals("Invalid step name: [xyz]", exception.getMessage()); } public void testFeatureFlagNotEnabled() throws Exception { From dd6e74b741929d50b3b53c91ce2ab26bf30f9e65 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Sat, 24 Feb 2024 10:55:22 -0800 Subject: [PATCH 6/7] Added another test Signed-off-by: Owais Kazi --- .../GetWorkflowStepTransportActionTests.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java index 6a63f2b16..d4b362746 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java @@ -16,13 +16,17 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import org.mockito.ArgumentCaptor; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +@SuppressWarnings("unchecked") public class GetWorkflowStepTransportActionTests extends OpenSearchTestCase { private GetWorkflowStepTransportAction getWorkflowStepTransportAction; @@ -46,4 +50,17 @@ public void testGetWorkflowStepAction() throws IOException { ArgumentCaptor stepCaptor = ArgumentCaptor.forClass(GetWorkflowStepResponse.class); verify(listener, times(1)).onResponse(stepCaptor.capture()); } + + public void testGetWorkflowStepValidator() throws IOException { + Map params = new HashMap<>(); + params.put(WORKFLOW_STEP, "create_connector, delete_model"); + + WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params); + ActionListener listener = mock(ActionListener.class); + getWorkflowStepTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor stepCaptor = ArgumentCaptor.forClass(GetWorkflowStepResponse.class); + verify(listener, times(1)).onResponse(stepCaptor.capture()); + assertEquals(GetWorkflowStepResponse.class, stepCaptor.getValue().getClass()); + + } } From 3429ac3635e9ff829a76842d36e8fecdb89ba8c8 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 26 Feb 2024 10:20:41 -0800 Subject: [PATCH 7/7] Logged exception message Signed-off-by: Owais Kazi --- .../flowframework/transport/GetWorkflowStepTransportAction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java index 260d0871c..6ec66e3f9 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java @@ -67,7 +67,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener