From 2c1401c87620d648a0f6cbabc321c932698551a1 Mon Sep 17 00:00:00 2001 From: Jackie Han Date: Thu, 7 Dec 2023 16:17:40 -0800 Subject: [PATCH 1/4] Includ workflow id and current node id in the exception message (#262) Includ workflow id and current node id in the exception message during registe agent step Signed-off-by: Jackie Han --- .../workflow/RegisterAgentStep.java | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 0e3a1c7c6..f9f17b4f0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -83,6 +83,8 @@ public CompletableFuture execute( Map previousNodeInputs ) throws IOException { + String workflowId = currentNodeInputs.getWorkflowId(); + CompletableFuture registerAgentModelFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @@ -92,7 +94,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), + workflowId, currentNodeId, getName(), mlRegisterAgentResponse.getAgentId(), @@ -101,7 +103,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { registerAgentModelFuture.complete( new WorkflowData( Map.ofEntries(Map.entry(resourceName, mlRegisterAgentResponse.getAgentId())), - currentNodeInputs.getWorkflowId(), + workflowId, currentNodeId ) ); @@ -168,12 +170,15 @@ public void onFailure(Exception e) { // Case when modelId is not present at all if (llmModelId == null) { registerAgentModelFuture.completeExceptionally( - new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) + new FlowFrameworkException( + "llm model id is not provided for workflow: " + workflowId + " on node: " + currentNodeId, + RestStatus.BAD_REQUEST + ) ); return registerAgentModelFuture; } - LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); + LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters, workflowId, currentNodeId); MLAgentBuilder builder = MLAgent.builder().name(name); @@ -246,9 +251,12 @@ private String getLlmModelId(Map previousNodeInputs, Map llmParameters) { + private LLMSpec getLLMSpec(String llmModelId, Map llmParameters, String workflowId, String currentNodeId) { if (llmModelId == null) { - throw new FlowFrameworkException("model id for llm is null", RestStatus.BAD_REQUEST); + throw new FlowFrameworkException( + "model id for llm is null for workflow: " + workflowId + " on node: " + currentNodeId, + RestStatus.BAD_REQUEST + ); } LLMSpec.LLMSpecBuilder builder = LLMSpec.builder(); builder.modelId(llmModelId); From bde10f5f200a87c16d19bdde006eac6fe7f8b57c Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Mon, 11 Dec 2023 16:04:00 -0800 Subject: [PATCH 2/4] Change thread queue to 100 and fix headers parsing bug (#265) change thread queue to 100 and fix headers bug Signed-off-by: Amit Galitzky --- .../flowframework/FlowFrameworkPlugin.java | 2 +- .../flowframework/model/WorkflowNode.java | 8 ++-- .../flowframework/util/ParseUtils.java | 43 +++++++++++++++++++ .../flowframework/util/ParseUtilsTests.java | 10 +++++ 4 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 14df7e17e..a1c75043d 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -158,7 +158,7 @@ public List> getExecutorBuilders(Settings settings) { settings, PROVISION_THREAD_POOL, OpenSearchExecutors.allocatedProcessors(settings), - 10, + 100, FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL ) ); diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 42d59e07f..706cd2c62 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -24,7 +24,9 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.parseStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** @@ -93,7 +95,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } else { for (Map map : (Map[]) e.getValue()) { - buildStringToStringMap(xContentBuilder, map); + buildStringToObjectMap(xContentBuilder, map); } } xContentBuilder.endArray(); @@ -150,9 +152,9 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { } userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); } else { - List> mapList = new ArrayList<>(); + List> mapList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - mapList.add(parseStringToStringMap(parser)); + mapList.add(parseStringToObjectMap(parser)); } userInputs.put(inputFieldName, mapList.toArray(new Map[0])); } diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 9e3b8d067..6e1a506a1 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -84,6 +84,25 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map map) throws IOException { + xContentBuilder.startObject(); + for (Entry e : map.entrySet()) { + if (e.getValue() instanceof String) { + xContentBuilder.field((String) e.getKey(), (String) e.getValue()); + } else { + xContentBuilder.field((String) e.getKey(), e.getValue()); + } + } + xContentBuilder.endObject(); + } + /** * Builds an XContent object representing a LLMSpec. * @@ -117,6 +136,30 @@ public static Map parseStringToStringMap(XContentParser parser) return map; } + /** + * Parses an XContent object representing a map of String keys to Object values. + * The Object value here can either be a string or a map + * @param parser An XContent parser whose position is at the start of the map object to parse + * @return A map as identified by the key-value pairs in the XContent + * @throws IOException on a parse failure + */ + public static Map parseStringToObjectMap(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + Map map = new HashMap<>(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if (parser.currentToken() == XContentParser.Token.START_OBJECT) { + // If the current token is a START_OBJECT, parse it as Map + map.put(fieldName, parseStringToStringMap(parser)); + } else { + // Otherwise, parse it as a string + map.put(fieldName, parser.text()); + } + } + return map; + } + /** * Parse content parser to {@link java.time.Instant}. * diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 94fe7b01e..02222b9aa 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -60,6 +60,16 @@ public void testToInstantWithNotValue() throws IOException { assertNull(instant); } + public void testBuildAndParseStringToStringMap() throws IOException { + Map stringMap = Map.ofEntries(Map.entry("one", "two")); + XContentBuilder builder = XContentFactory.jsonBuilder(); + ParseUtils.buildStringToStringMap(builder, stringMap); + XContentParser parser = this.createParser(builder); + parser.nextToken(); + Map parsedMap = ParseUtils.parseStringToStringMap(parser); + assertEquals(stringMap.get("one"), parsedMap.get("one")); + } + public void testGetInputsFromPreviousSteps() { WorkflowData currentNodeInputs = new WorkflowData( Map.ofEntries(Map.entry("content1", 1), Map.entry("param1", 2), Map.entry("content3", "${{step1.output1}}")), From fa96284c0a0fa6b129be2e40b66b3cfdb1cab604 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Mon, 11 Dec 2023 17:46:56 -0800 Subject: [PATCH 3/4] Update resources_created with deploy model: (#275) add deploy model resource Signed-off-by: Amit Galitzky --- .../AbstractRetryableWorkflowStep.java | 62 ++++++++----------- 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index f807c752a..121f477bb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -20,7 +20,6 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import java.util.Map; @@ -59,24 +58,6 @@ public AbstractRetryableWorkflowStep( this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } - /** - * Completes the future for either deploy or register local model step - * @param resourceName resource name for the given step - * @param nodeId node ID of the given step - * @param workflowId workflow ID of the given workflow - * @param response Response from ml commons get Task API - * @param future CompletableFuture of the given step - */ - public void completeFuture(String resourceName, String nodeId, String workflowId, MLTask response, CompletableFuture future) { - future.complete( - new WorkflowData( - Map.ofEntries(Map.entry(resourceName, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())), - workflowId, - nodeId - ) - ); - } - /** * Retryable get ml task * @param workflowId the workflow id @@ -110,25 +91,36 @@ void retryableGetMlTask( try { logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId()); String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + String id; if (getName().equals(WorkflowResources.DEPLOY_MODEL.getWorkflowStep())) { - completeFuture(resourceName, nodeId, workflowId, response, future); + id = response.getModelId(); } else { - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - nodeId, - getName(), - response.getTaskId(), - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - completeFuture(resourceName, nodeId, workflowId, response, future); - }, exception -> { - logger.error("Failed to update new created resource", exception); - future.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - }) - ); + id = response.getTaskId(); } + flowFrameworkIndicesHandler.updateResourceInStateIndex( + workflowId, + nodeId, + getName(), + id, + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + future.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, response.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) + ), + workflowId, + nodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + future.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); From 7bec6e8eb89cc45a1447b209e9cae2bd58dd84e5 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 11 Dec 2023 19:04:56 -0800 Subject: [PATCH 4/4] Add setting to limit max workflow steps (#266) * Rename repo name in code files and CI (#182) Signed-off-by: owaiskazi19 * Update actions/setup-java action to v4 (#219) Signed-off-by: mend-for-github.aaakk.us.kg[bot] Co-authored-by: mend-for-github.aaakk.us.kg[bot] <50673670+mend-for-github.aaakk.us.kg[bot]@users.noreply.github.com> * Use only pluginZip publication of Apache Maven artifacts (#226) * Use only pluginZip publication of Apache Maven artifacts Signed-off-by: Andriy Redko * Address code review comments Signed-off-by: Andriy Redko * Address code review comments Signed-off-by: Andriy Redko --------- Signed-off-by: Andriy Redko * Integration test infrastructure set up (#230) * Initial integ test framework modification, sets up integration test cluster and fixes ./gradlew run Signed-off-by: Joshua Palis * spotless Signed-off-by: Joshua Palis * Updating DEVELOPER_GUIDE Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis * Update dependency com.diffplug.spotless:spotless-plugin-gradle to v6.23.2 (#229) Signed-off-by: mend-for-github.aaakk.us.kg[bot] Co-authored-by: mend-for-github.aaakk.us.kg[bot] <50673670+mend-for-github.aaakk.us.kg[bot]@users.noreply.github.com> * Update to Gradle 8.5 (#227) Signed-off-by: Andriy Redko Co-authored-by: Owais Kazi * Update dependency com.diffplug.spotless:spotless-plugin-gradle to v6.23.3 (#252) Signed-off-by: mend-for-github.aaakk.us.kg[bot] Co-authored-by: mend-for-github.aaakk.us.kg[bot] <50673670+mend-for-github.aaakk.us.kg[bot]@users.noreply.github.com> * Update dependency org.eclipse.platform:org.eclipse.core.runtime to v3.30.0 (#255) Signed-off-by: mend-for-github.aaakk.us.kg[bot] Co-authored-by: mend-for-github.aaakk.us.kg[bot] <50673670+mend-for-github.aaakk.us.kg[bot]@users.noreply.github.com> * Add setting to limit max workflow steps Signed-off-by: Daniel Widdis --------- Signed-off-by: owaiskazi19 Signed-off-by: mend-for-github.aaakk.us.kg[bot] Signed-off-by: Andriy Redko Signed-off-by: Joshua Palis Signed-off-by: Daniel Widdis Co-authored-by: Owais Kazi Co-authored-by: mend-for-github.aaakk.us.kg[bot] <50673670+mend-for-github.aaakk.us.kg[bot]@users.noreply.github.com> Co-authored-by: Andriy Redko Co-authored-by: Joshua Palis --- .../flowframework/FlowFrameworkPlugin.java | 4 ++- .../common/FlowFrameworkSettings.java | 12 ++++++++ .../workflow/WorkflowProcessSorter.java | 29 ++++++++++++++++++- .../FlowFrameworkPluginTests.java | 5 ++-- .../CreateWorkflowTransportActionTests.java | 22 ++++++++++++-- .../workflow/WorkflowProcessSorterTests.java | 25 ++++++++++++++-- 6 files changed, 88 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index a1c75043d..513984c68 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -63,6 +63,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; /** @@ -106,7 +107,7 @@ public Collection createComponents( mlClient, flowFrameworkIndicesHandler ); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } @@ -144,6 +145,7 @@ public List> getSettings() { List> settings = ImmutableList.of( FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, + MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY ); diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index 1824197e8..536fa2c73 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -18,6 +18,8 @@ private FlowFrameworkSettings() {} /** The upper limit of max workflows that can be created */ public static final int MAX_WORKFLOWS_LIMIT = 10000; + /** The upper limit of max workflow steps that can be in a single workflow */ + public static final int MAX_WORKFLOW_STEPS_LIMIT = 500; /** This setting sets max workflows that can be created */ public static final Setting MAX_WORKFLOWS = Setting.intSetting( @@ -29,6 +31,16 @@ private FlowFrameworkSettings() {} Setting.Property.Dynamic ); + /** This setting sets max workflows that can be created */ + public static final Setting MAX_WORKFLOW_STEPS = Setting.intSetting( + "plugins.flow_framework.max_workflow_steps", + 50, + 1, + MAX_WORKFLOW_STEPS_LIMIT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + /** This setting sets the timeout for the request */ public static final Setting WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting( "plugins.flow_framework.request_timeout", diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 3e8b77f9d..da362383b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,6 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -32,6 +34,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD; @@ -45,16 +48,26 @@ public class WorkflowProcessSorter { private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; + private Integer maxWorkflowSteps; /** * Instantiate this class. * * @param workflowStepFactory The factory which matches template step types to instances. * @param threadPool The OpenSearch Thread pool to pass to process nodes. + * @param clusterService The OpenSearch cluster service. + * @param settings OpenSerch settings */ - public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { + public WorkflowProcessSorter( + WorkflowStepFactory workflowStepFactory, + ThreadPool threadPool, + ClusterService clusterService, + Settings settings + ) { this.workflowStepFactory = workflowStepFactory; this.threadPool = threadPool; + this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); } /** @@ -64,6 +77,20 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ public List sortProcessNodes(Workflow workflow, String workflowId) { + if (workflow.nodes().size() > this.maxWorkflowSteps) { + throw new FlowFrameworkException( + "Workflow " + + workflowId + + " has " + + workflow.nodes().size() + + " nodes, which exceeds the maximum of " + + this.maxWorkflowSteps + + ". Change the setting [" + + MAX_WORKFLOW_STEPS.getKey() + + "] to increase this.", + RestStatus.BAD_REQUEST + ); + } List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); List nodes = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index e3827e0b3..2585ffb09 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -29,6 +29,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -62,7 +63,7 @@ public void setUp() throws Exception { final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); @@ -84,7 +85,7 @@ public void testPlugin() throws IOException { assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); assertEquals(4, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); - assertEquals(4, ffp.getSettings().size()); + assertEquals(5, ffp.getSettings().size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 70c066c0e..9b664b729 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -13,6 +13,9 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; @@ -34,18 +37,24 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.mockito.ArgumentCaptor; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -61,6 +70,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private Template template; private Client client = mock(Client.class); private ThreadPool threadPool; + private ClusterSettings clusterSettings; + private ClusterService clusterService; private ParseUtils parseUtils; private ThreadContext threadContext; private Settings settings; @@ -73,8 +84,15 @@ public void setUp() throws Exception { .put("plugins.flow_framework.max_workflows.", 2) .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) .build(); + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + ).collect(Collectors.toSet()); + clusterSettings = new ClusterSettings(settings, settingsSet); + clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool); + this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings); this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( mock(TransportService.class), diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 8103f4fbf..d1590acd8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -16,6 +16,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.TemplateTestJsonUtil; @@ -32,6 +33,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -41,6 +43,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; @@ -79,11 +82,12 @@ public static void setup() { MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + Settings settings = Settings.builder().put("plugins.flow_framework.max_workflow_steps", 5).build(); final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); + ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.admin()).thenReturn(adminClient); @@ -96,7 +100,7 @@ public static void setup() { mlClient, flowFrameworkIndicesHandler ); - workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); + workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, settings); } @AfterClass @@ -245,6 +249,21 @@ public void testExceptions() throws IOException { ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("A")), Collections.emptyList()))); assertEquals("Duplicate node id A.", ex.getMessage()); assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); + + ex = assertThrows( + FlowFrameworkException.class, + () -> parse(workflow(List.of(node("A"), node("B"), node("C"), node("D"), node("E"), node("F")), Collections.emptyList())) + ); + String message = String.format( + Locale.ROOT, + "Workflow %s has %d nodes, which exceeds the maximum of %d. Change the setting [%s] to increase this.", + "123", + 6, + 5, + FlowFrameworkSettings.MAX_WORKFLOW_STEPS.getKey() + ); + assertEquals(message, ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); } public void testSuccessfulGraphValidation() throws Exception {