diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 988fd8bf2..de3c6fc13 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -28,6 +28,8 @@ import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -35,6 +37,8 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import java.util.List; + import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; @@ -91,6 +95,20 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { @@ -242,4 +260,11 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, })); } } + + private void validateWorkflows(Template template) throws Exception { + for (Workflow workflow : template.workflows().values()) { + List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null); + workflowProcessSorter.validateGraph(sortedNodes); + } + } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 87f897f02..59a1cd926 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -30,8 +31,10 @@ import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -50,10 +53,8 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyInt; @@ -72,8 +73,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private Template template; private Client client; private ThreadPool threadPool; - private ClusterSettings clusterSettings; - private ClusterService clusterService; + private static TestThreadPool testThreadPool; + private ParseUtils parseUtils; private ThreadContext threadContext; private Settings settings; @@ -83,28 +84,39 @@ public void setUp() throws Exception { super.setUp(); threadPool = mock(ThreadPool.class); this.client = mock(Client.class); + + ClusterService clusterService = mock(ClusterService.class); settings = Settings.builder() - .put("plugins.flow_framework.max_workflows.", 2) - .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) - .build(); + .put("plugins.flow_framework.max_workflows.", 2) + .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) + .build(); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); - clusterSettings = new ClusterSettings(settings, settingsSet); - clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings); + + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + WorkflowStepFactory factory = new WorkflowStepFactory( + Settings.EMPTY, + clusterService, + client, + mlClient, + flowFrameworkIndicesHandler + ); + this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); + this.createWorkflowTransportAction = spy( - new CreateWorkflowTransportAction( - mock(TransportService.class), - mock(ActionFilters.class), - workflowProcessSorter, - flowFrameworkIndicesHandler, - settings, - client - ) + new CreateWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + workflowProcessSorter, + flowFrameworkIndicesHandler, + settings, + client + ) ); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); // threadContext = mock(ThreadContext.class); @@ -123,26 +135,97 @@ public void setUp() throws Exception { Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); this.template = new Template( - "test", - "description", - "use case", - templateVersion, - compatibilityVersions, - Map.of("workflow", workflow), - Map.of(), - TestHelpers.randomUser() + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("workflow", workflow), + Map.of(), + TestHelpers.randomUser() ); } + public void testDryRunValidation_Success() { + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, true, null, null); + createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + } + + public void testDryRunValidation_Failed() { + + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + "create_connector", + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + "register_model", + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + "deploy_model", + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + WorkflowEdge cyclicalEdge = new WorkflowEdge(deployModel.id(), createConnector.id()); + + Workflow workflow = new Workflow( + Map.of(), + List.of(createConnector, registerModel, deployModel), + List.of(edge1, edge2, cyclicalEdge) + ); + + Template cyclicalTemplate = new Template( + "test", + "description", + "use case", + Version.fromString("1.0.0"), + List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), + Map.of("workflow", workflow), + Map.of(), + TestHelpers.randomUser() + ); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, null, null); + + createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage()); + } + public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); doAnswer(invocation -> { @@ -175,11 +258,11 @@ public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -212,11 +295,11 @@ public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -298,14 +381,17 @@ public void testUpdateWorkflow() { } public void testCreateWorkflow_withProvisionParam() { + + Template validTemplate = generateValidTemplate(); + @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - true, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + validTemplate, + true, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -353,14 +439,16 @@ public void testCreateWorkflow_withProvisionParam() { } public void testCreateWorkflow_withFailedProvision() { + Template validTemplate = generateValidTemplate(); + @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - true, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + validTemplate, + true, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -404,4 +492,51 @@ public void testCreateWorkflow_withFailedProvision() { verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("failed", exceptionCaptor.getValue().getMessage()); } + + private Template generateValidTemplate() { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(), + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(), + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + WorkflowResources.DEPLOY_MODEL.getWorkflowStep(), + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + + Template validTemplate = new Template( + "test", + "description", + "use case", + Version.fromString("1.0.0"), + List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), + Map.of("workflow", workflow), + Map.of(), + TestHelpers.randomUser() + ); + + return validTemplate; + } }