diff --git a/CHANGELOG.md b/CHANGELOG.md index 5377c0296..bfa0f8f18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Enhancements - Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525)) +- Added an optional workflow_step param to the get workflow steps API ([#538](https://github.com/opensearch-project/flow-framework/pull/538)) ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index c1752a018..f4bec21a7 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -64,6 +64,8 @@ private CommonValue() {} public static final String VALIDATION = "validation"; /** The field name for provision workflow within a use case template*/ public static final String PROVISION_WORKFLOW = "provision"; + /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ + public static final String WORKFLOW_STEP = "workflow_step"; /* * Constants associated with plugin configuration diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java index 28a9ffac4..e25daa0aa 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java @@ -11,8 +11,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -21,14 +19,18 @@ import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.GetWorkflowStepAction; +import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Locale; +import java.util.Map; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -60,7 +62,7 @@ public List routes() { } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { try { if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -69,13 +71,12 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient ); } - ActionRequest request = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - }; - return channel -> client.execute(GetWorkflowStepAction.INSTANCE, request, ActionListener.wrap(response -> { + Map params = request.hasParam(WORKFLOW_STEP) + ? Map.of(WORKFLOW_STEP, request.param(WORKFLOW_STEP)) + : Collections.emptyMap(); + + WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params); + return channel -> client.execute(GetWorkflowStepAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); }, exception -> { diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java index 8b4d8a001..6ec66e3f9 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java @@ -11,21 +11,27 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP; + /** * Transport action to retrieve a workflow step json */ -public class GetWorkflowStepTransportAction extends HandledTransportAction { +public class GetWorkflowStepTransportAction extends HandledTransportAction { private final Logger logger = LogManager.getLogger(GetWorkflowStepTransportAction.class); private final WorkflowStepFactory workflowStepFactory; @@ -47,11 +53,23 @@ public GetWorkflowStepTransportAction( } @Override - protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { try { - WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidator(); + List steps = request.getParams().size() > 0 + ? Arrays.asList(Strings.splitStringByCommaToArray(request.getParams().get(WORKFLOW_STEP))) + : Collections.emptyList(); + WorkflowValidator workflowValidator; + if (steps.isEmpty()) { + workflowValidator = this.workflowStepFactory.getWorkflowValidator(); + } else { + workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps); + } listener.onResponse(new GetWorkflowStepResponse(workflowValidator)); } catch (Exception e) { + if (e instanceof FlowFrameworkException) { + logger.error(e.getMessage()); + listener.onFailure(e); + } logger.error("Failed to retrieve workflow step json.", e); listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 6268dab2c..5b3c3c0d8 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -64,7 +64,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param template the use case template which describes the workflow * @param params The parameters from the REST path */ - public WorkflowRequest(String workflowId, @Nullable Template template, Map params) { + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map params) { this(workflowId, template, new String[] { "all" }, true, params); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 3fb91d6c6..6242fe926 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -23,8 +23,10 @@ import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; @@ -364,6 +366,30 @@ public WorkflowValidator getWorkflowValidator() { return new WorkflowValidator(workflowStepValidators); } + /** + * Get the object of WorkflowValidator consisting of passed workflow steps + * @param steps workflow steps + * @return WorkflowValidator + */ + public WorkflowValidator getWorkflowValidatorByStep(List steps) { + Map workflowStepValidators = new HashMap<>(); + Set invalidSteps = new HashSet<>(steps); + + for (WorkflowSteps mapping : WorkflowSteps.values()) { + String step = mapping.getWorkflowStepName(); + if (steps.contains(step)) { + workflowStepValidators.put(mapping.getWorkflowStepName(), mapping.getWorkflowStepValidator()); + invalidSteps.remove(step); + } + } + + if (!invalidSteps.isEmpty()) { + throw new FlowFrameworkException("Invalid step name: " + invalidSteps, RestStatus.BAD_REQUEST); + } + + return new WorkflowValidator(workflowStepValidators); + } + /** * Create a new instance of a {@link WorkflowStep}. * @param type The type of instance to create diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java index 6854e3c49..59df28a42 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java @@ -9,20 +9,32 @@ package org.opensearch.flowframework.rest; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.WorkflowValidator; +import org.opensearch.flowframework.transport.GetWorkflowStepResponse; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.rest.FakeRestChannel; import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.ThreadPool; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -31,12 +43,22 @@ public class RestGetWorkflowStepActionTests extends OpenSearchTestCase { private String getPath; private FlowFrameworkSettings flowFrameworkFeatureEnabledSetting; private NodeClient nodeClient; + private WorkflowStepFactory workflowStepFactory; + private FlowFrameworkSettings flowFrameworkSettings; @Override public void setUp() throws Exception { super.setUp(); + flowFrameworkSettings = mock(FlowFrameworkSettings.class); + when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true); + this.getPath = String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, "_steps"); + ThreadPool threadPool = mock(ThreadPool.class); + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + + this.workflowStepFactory = new WorkflowStepFactory(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkSettings.class); when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); this.restGetWorkflowStepAction = new RestGetWorkflowStepAction(flowFrameworkFeatureEnabledSetting); @@ -68,6 +90,46 @@ public void testInvalidRequestWithContent() { assertEquals("request [GET /_plugins/_flow_framework/workflow/_steps] does not support having a body", ex.getMessage()); } + public void testWorkflowSteps() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) + .withPath(this.getPath + "?workflow_step=create_connector") + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + List steps = new ArrayList<>(); + steps.add("create_connector"); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps); + actionListener.onResponse(new GetWorkflowStepResponse(workflowValidator)); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + restGetWorkflowStepAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.OK, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("create_connector")); + } + + public void testFailedWorkflowSteps() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) + .withPath(this.getPath + "?workflow_step=xyz") + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + List steps = new ArrayList<>(); + steps.add("xyz"); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps); + actionListener.onResponse(new GetWorkflowStepResponse(workflowValidator)); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + FlowFrameworkException exception = expectThrows( + FlowFrameworkException.class, + () -> restGetWorkflowStepAction.handleRequest(request, channel, nodeClient) + ); + assertEquals("Invalid step name: [xyz]", exception.getMessage()); + } + public void testFeatureFlagNotEnabled() throws Exception { when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java index 685198e3d..d4b362746 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java @@ -16,13 +16,17 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import org.mockito.ArgumentCaptor; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +@SuppressWarnings("unchecked") public class GetWorkflowStepTransportActionTests extends OpenSearchTestCase { private GetWorkflowStepTransportAction getWorkflowStepTransportAction; @@ -45,6 +49,18 @@ public void testGetWorkflowStepAction() throws IOException { ArgumentCaptor stepCaptor = ArgumentCaptor.forClass(GetWorkflowStepResponse.class); verify(listener, times(1)).onResponse(stepCaptor.capture()); + } + + public void testGetWorkflowStepValidator() throws IOException { + Map params = new HashMap<>(); + params.put(WORKFLOW_STEP, "create_connector, delete_model"); + + WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params); + ActionListener listener = mock(ActionListener.class); + getWorkflowStepTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor stepCaptor = ArgumentCaptor.forClass(GetWorkflowStepResponse.class); + verify(listener, times(1)).onResponse(stepCaptor.capture()); + assertEquals(GetWorkflowStepResponse.class, stepCaptor.getValue().getClass()); } }