diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 59a1cd926..09df66169 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -34,7 +34,6 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -53,6 +52,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -71,10 +71,10 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private WorkflowProcessSorter workflowProcessSorter; private Template template; - private Client client; + private Client client = mock(Client.class); private ThreadPool threadPool; - private static TestThreadPool testThreadPool; - + private ClusterSettings clusterSettings; + private ClusterService clusterService; private ParseUtils parseUtils; private ThreadContext threadContext; private Settings settings; @@ -83,41 +83,39 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); threadPool = mock(ThreadPool.class); - this.client = mock(Client.class); - - ClusterService clusterService = mock(ClusterService.class); settings = Settings.builder() - .put("plugins.flow_framework.max_workflows.", 2) - .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) - .build(); - this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + .put("plugins.flow_framework.max_workflows.", 2) + .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) + .build(); final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); + clusterSettings = new ClusterSettings(settings, settingsSet); + clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); WorkflowStepFactory factory = new WorkflowStepFactory( - Settings.EMPTY, - clusterService, - client, - mlClient, - flowFrameworkIndicesHandler + Settings.EMPTY, + clusterService, + client, + mlClient, + flowFrameworkIndicesHandler ); - this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); - + this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool, clusterService, settings); this.createWorkflowTransportAction = spy( - new CreateWorkflowTransportAction( - mock(TransportService.class), - mock(ActionFilters.class), - workflowProcessSorter, - flowFrameworkIndicesHandler, - settings, - client - ) + new CreateWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + workflowProcessSorter, + flowFrameworkIndicesHandler, + settings, + client + ) ); + // client = mock(Client.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); // threadContext = mock(ThreadContext.class); when(client.threadPool()).thenReturn(threadPool); @@ -135,14 +133,14 @@ public void setUp() throws Exception { Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); this.template = new Template( - "test", - "description", - "use case", - templateVersion, - compatibilityVersions, - Map.of("workflow", workflow), - Map.of(), - TestHelpers.randomUser() + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("workflow", workflow), + Map.of(), + TestHelpers.randomUser() ); } @@ -158,32 +156,32 @@ public void testDryRunValidation_Success() { public void testDryRunValidation_Failed() { WorkflowNode createConnector = new WorkflowNode( - "workflow_step_1", - "create_connector", - Map.of(), - Map.ofEntries( - Map.entry("name", ""), - Map.entry("description", ""), - Map.entry("version", ""), - Map.entry("protocol", ""), - Map.entry("parameters", ""), - Map.entry("credential", ""), - Map.entry("actions", "") - ) + "workflow_step_1", + "create_connector", + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) ); WorkflowNode registerModel = new WorkflowNode( - "workflow_step_2", - "register_model", - Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), - Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + "workflow_step_2", + "register_model", + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) ); WorkflowNode deployModel = new WorkflowNode( - "workflow_step_3", - "deploy_model", - Map.ofEntries(Map.entry("workflow_step_2", "model_id")), - Map.of() + "workflow_step_3", + "deploy_model", + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() ); WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); @@ -191,20 +189,20 @@ public void testDryRunValidation_Failed() { WorkflowEdge cyclicalEdge = new WorkflowEdge(deployModel.id(), createConnector.id()); Workflow workflow = new Workflow( - Map.of(), - List.of(createConnector, registerModel, deployModel), - List.of(edge1, edge2, cyclicalEdge) + Map.of(), + List.of(createConnector, registerModel, deployModel), + List.of(edge1, edge2, cyclicalEdge) ); Template cyclicalTemplate = new Template( - "test", - "description", - "use case", - Version.fromString("1.0.0"), - List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), - Map.of("workflow", workflow), - Map.of(), - TestHelpers.randomUser() + "test", + "description", + "use case", + Version.fromString("1.0.0"), + List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), + Map.of("workflow", workflow), + Map.of(), + TestHelpers.randomUser() ); @SuppressWarnings("unchecked") @@ -221,11 +219,11 @@ public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); doAnswer(invocation -> { @@ -258,11 +256,11 @@ public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -295,11 +293,11 @@ public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - template, - false, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -387,11 +385,11 @@ public void testCreateWorkflow_withProvisionParam() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - validTemplate, - true, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + validTemplate, + true, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -444,11 +442,11 @@ public void testCreateWorkflow_withFailedProvision() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( - null, - validTemplate, - true, - WORKFLOW_REQUEST_TIMEOUT.get(settings), - MAX_WORKFLOWS.get(settings) + null, + validTemplate, + true, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) ); // Bypass checkMaxWorkflows and force onResponse @@ -495,30 +493,30 @@ public void testCreateWorkflow_withFailedProvision() { private Template generateValidTemplate() { WorkflowNode createConnector = new WorkflowNode( - "workflow_step_1", - WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(), - Map.of(), - Map.ofEntries( - Map.entry("name", ""), - Map.entry("description", ""), - Map.entry("version", ""), - Map.entry("protocol", ""), - Map.entry("parameters", ""), - Map.entry("credential", ""), - Map.entry("actions", "") - ) + "workflow_step_1", + WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(), + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) ); WorkflowNode registerModel = new WorkflowNode( - "workflow_step_2", - WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(), - Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), - Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + "workflow_step_2", + WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(), + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) ); WorkflowNode deployModel = new WorkflowNode( - "workflow_step_3", - WorkflowResources.DEPLOY_MODEL.getWorkflowStep(), - Map.ofEntries(Map.entry("workflow_step_2", "model_id")), - Map.of() + "workflow_step_3", + WorkflowResources.DEPLOY_MODEL.getWorkflowStep(), + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() ); WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); @@ -527,14 +525,14 @@ private Template generateValidTemplate() { Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); Template validTemplate = new Template( - "test", - "description", - "use case", - Version.fromString("1.0.0"), - List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), - Map.of("workflow", workflow), - Map.of(), - TestHelpers.randomUser() + "test", + "description", + "use case", + Version.fromString("1.0.0"), + List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), + Map.of("workflow", workflow), + Map.of(), + TestHelpers.randomUser() ); return validTemplate;