From 56ccb1d047f3f11f088139e8f97f3197f6f602a5 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 13 Nov 2023 09:36:08 -0800 Subject: [PATCH] Implemented throttling for max workflows to be created (#151) --- .../flowframework/FlowFrameworkPlugin.java | 9 +- .../FlowFrameworkFeatureEnabledSetting.java | 11 +- .../common/FlowFrameworkSettings.java | 47 ++++++ .../indices/FlowFrameworkIndicesHandler.java | 10 +- .../rest/AbstractSearchWorkflowAction.java | 7 +- .../rest/AbstractWorkflowAction.java | 42 ++++++ .../rest/RestCreateWorkflowAction.java | 18 ++- .../rest/RestProvisionWorkflowAction.java | 2 +- .../CreateWorkflowTransportAction.java | 91 +++++++++--- .../ProvisionWorkflowTransportAction.java | 2 - .../transport/WorkflowRequest.java | 63 +++++++- .../flowframework/util/RestHandlerUtils.java | 1 + .../FlowFrameworkPluginTests.java | 10 +- .../opensearch/flowframework/TestHelpers.java | 17 +++ ...owFrameworkFeatureEnabledSettingTests.java | 2 +- .../rest/RestCreateWorkflowActionTests.java | 20 ++- .../CreateWorkflowTransportActionTests.java | 140 +++++++++++++----- 17 files changed, 400 insertions(+), 92 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java create mode 100644 src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index c52bd2fef..f44b72495 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -56,6 +56,9 @@ import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; /** * An OpenSearch plugin that enables builders to innovate AI apps on OpenSearch. @@ -63,6 +66,7 @@ public class FlowFrameworkPlugin extends Plugin implements ActionPlugin { private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + private ClusterService clusterService; /** * Instantiate this plugin. @@ -84,6 +88,7 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { Settings settings = environment.settings(); + this.clusterService = clusterService; flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings); MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); @@ -106,7 +111,7 @@ public List getRestHandlers( Supplier nodesInCluster ) { return ImmutableList.of( - new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting), + new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService), new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting) ); @@ -123,7 +128,7 @@ public List getRestHandlers( @Override public List> getSettings() { - List> settings = ImmutableList.of(FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED); + List> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT); return settings; } diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkFeatureEnabledSetting.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkFeatureEnabledSetting.java index f10068f5b..87f5412a8 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkFeatureEnabledSetting.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkFeatureEnabledSetting.java @@ -9,22 +9,15 @@ package org.opensearch.flowframework.common; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; + /** * Controls enabling or disabling features of this plugin */ public class FlowFrameworkFeatureEnabledSetting { - /** This setting enables/disables the Flow Framework REST API */ - public static final Setting FLOW_FRAMEWORK_ENABLED = Setting.boolSetting( - "plugins.flow_framework.enabled", - false, - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); - private volatile Boolean isFlowFrameworkEnabled; /** diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java new file mode 100644 index 000000000..82bea1abf --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.common; + +import org.opensearch.common.settings.Setting; +import org.opensearch.common.unit.TimeValue; + +/** The common settings of flow framework */ +public class FlowFrameworkSettings { + + private FlowFrameworkSettings() {} + + /** The upper limit of max workflows that can be created */ + public static final int MAX_WORKFLOWS_LIMIT = 10000; + + /** This setting sets max workflows that can be created */ + public static final Setting MAX_WORKFLOWS = Setting.intSetting( + "plugins.flow_framework.max_workflows", + 1000, + 0, + MAX_WORKFLOWS_LIMIT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** This setting sets the timeout for the request */ + public static final Setting WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting( + "plugins.flow_framework.request_timeout", + TimeValue.timeValueSeconds(10), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** This setting enables/disables the Flow Framework REST API */ + public static final Setting FLOW_FRAMEWORK_ENABLED = Setting.boolSetting( + "plugins.flow_framework.enabled", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); +} diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 1b0f7c9d7..2d2008b3d 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -373,24 +373,22 @@ public void updateTemplateInGlobalContext(String documentId, Template template, /** * Updates a document in the workflow state index - * @param indexName the index that we will be updating a document of. * @param documentId the document ID * @param updatedFields the fields to update the global state index with * @param listener action listener */ public void updateFlowFrameworkSystemIndexDoc( - String indexName, String documentId, Map updatedFields, ActionListener listener ) { - if (!doesIndexExist(indexName)) { - String exceptionMessage = "Failed to update document for given workflow due to missing " + indexName + " index"; + if (!doesIndexExist(WORKFLOW_STATE_INDEX)) { + String exceptionMessage = "Failed to update document for given workflow due to missing " + WORKFLOW_STATE_INDEX + " index"; logger.error(exceptionMessage); listener.onFailure(new FlowFrameworkException(exceptionMessage, RestStatus.BAD_REQUEST)); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - UpdateRequest updateRequest = new UpdateRequest(indexName, documentId); + UpdateRequest updateRequest = new UpdateRequest(WORKFLOW_STATE_INDEX, documentId); Map updatedContent = new HashMap<>(); updatedContent.putAll(updatedFields); updateRequest.doc(updatedContent); @@ -398,7 +396,7 @@ public void updateFlowFrameworkSystemIndexDoc( // TODO: decide what condition can be considered as an update conflict and add retry strategy client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { - String errorMessage = "Failed to update " + indexName + " entry : " + documentId; + String errorMessage = "Failed to update " + WORKFLOW_STATE_INDEX + " entry : " + documentId; logger.error(errorMessage, e); listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java index 0aed4348c..43919bb1c 100644 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java @@ -30,7 +30,7 @@ import java.util.List; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; -import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; /** @@ -38,10 +38,15 @@ */ public abstract class AbstractSearchWorkflowAction extends BaseRestHandler { + /** Url Paths of the routes*/ protected final List urlPaths; + /** Index on search operation needs to be performed*/ protected final String index; + /** Search class name*/ protected final Class clazz; + /** Search action type*/ protected final ActionType actionType; + /** Settings to enable FlowFramework API*/ protected final FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; /** diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java new file mode 100644 index 000000000..251ce750e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractWorkflowAction.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.rest.BaseRestHandler; + +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; + +/** + * Abstract action for the rest actions + */ +public abstract class AbstractWorkflowAction extends BaseRestHandler { + /** Timeout for the request*/ + protected volatile TimeValue requestTimeout; + /** Max workflows that can be created*/ + protected volatile Integer maxWorkflows; + + /** + * Instantiates a new AbstractWorkflowAction + * + * @param settings Environment settings + * @param clusterService clusterService + */ + public AbstractWorkflowAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings); + this.maxWorkflows = MAX_WORKFLOWS.get(settings); + + clusterService.getClusterSettings().addSettingsUpdateConsumer(WORKFLOW_REQUEST_TIMEOUT, it -> requestTimeout = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOWS, it -> maxWorkflows = it); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index bd872b4ff..7bd5fbfbe 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -12,6 +12,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -21,7 +23,6 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; -import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -32,12 +33,12 @@ import static org.opensearch.flowframework.common.CommonValue.DRY_RUN; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; -import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; /** * Rest Action to facilitate requests to create and update a use case template */ -public class RestCreateWorkflowAction extends BaseRestHandler { +public class RestCreateWorkflowAction extends AbstractWorkflowAction { private static final Logger logger = LogManager.getLogger(RestCreateWorkflowAction.class); private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action"; @@ -48,8 +49,15 @@ public class RestCreateWorkflowAction extends BaseRestHandler { * Instantiates a new RestCreateWorkflowAction * * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + * @param settings Environment settings + * @param clusterService clusterService */ - public RestCreateWorkflowAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + public RestCreateWorkflowAction( + FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting, + Settings settings, + ClusterService clusterService + ) { + super(settings, clusterService); this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; } @@ -85,7 +93,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli Template template = Template.parse(request.content().utf8ToString()); boolean dryRun = request.paramAsBoolean(DRY_RUN, false); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun, requestTimeout, maxWorkflows); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 81e4fb606..9e6eb4d01 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -30,7 +30,7 @@ import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; -import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; /** * Rest action to facilitate requests to provision a workflow from an inline defined or stored use case template diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index a6b809fc8..5781bb412 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -12,13 +12,17 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ProvisioningProgress; @@ -27,6 +31,9 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -34,7 +41,6 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; /** @@ -47,6 +53,7 @@ public class CreateWorkflowTransportAction extends HandledTransportAction { - flowFrameworkIndicesHandler.putInitialStateToWorkflowState( - globalContextResponse.getId(), - user, - ActionListener.wrap(stateResponse -> { - logger.info("create state workflow doc"); - listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + // Throttle incoming requests + checkMaxWorkflows(request.getRequestTimeout(), request.getMaxWorkflows(), ActionListener.wrap(max -> { + if (!max) { + String errorMessage = "Maximum workflows limit reached " + request.getMaxWorkflows(); + logger.error(errorMessage); + FlowFrameworkException ffe = new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + listener.onFailure(ffe); + return; + } else { + // Create new global context and state index entries + flowFrameworkIndicesHandler.putTemplateToGlobalContext(templateWithUser, ActionListener.wrap(globalContextResponse -> { + flowFrameworkIndicesHandler.putInitialStateToWorkflowState( + globalContextResponse.getId(), + user, + ActionListener.wrap(stateResponse -> { + logger.info("create state workflow doc"); + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + }, exception -> { + logger.error("Failed to save workflow state : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + } + }) + ); }, exception -> { - logger.error("Failed to save workflow state : {}", exception.getMessage()); + logger.error("Failed to save use case template : {}", exception.getMessage()); if (exception instanceof FlowFrameworkException) { listener.onFailure(exception); } else { - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); } - }) - ); - }, exception -> { - logger.error("Failed to save use case template : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); + + })); + } + }, e -> { + logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), e.getMessage()); + if (e instanceof FlowFrameworkException) { + listener.onFailure(e); } else { - listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } - })); } else { // Update existing entry, full document replacement @@ -132,7 +160,6 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( - WORKFLOW_STATE_INDEX, request.getWorkflowId(), ImmutableMap.of(STATE_FIELD, State.NOT_STARTED, PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED), ActionListener.wrap(updateResponse -> { @@ -160,6 +187,30 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener internalListener) { + QueryBuilder query = QueryBuilders.matchAllQuery(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeOut); + + SearchRequest searchRequest = new SearchRequest(CommonValue.GLOBAL_CONTEXT_INDEX).source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + if (searchResponse.getHits().getTotalHits().value >= maxWorkflow) { + internalListener.onResponse(false); + } else { + internalListener.onResponse(true); + } + }, exception -> { + logger.error("Unable to fetch the workflows {}", exception); + internalListener.onFailure(new FlowFrameworkException("Unable to fetch the workflows", RestStatus.BAD_REQUEST)); + })); + } + private void validateWorkflows(Template template) throws Exception { for (Workflow workflow : template.workflows().values()) { List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 443bbf8a6..4210effb0 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -45,7 +45,6 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; /** * Transport Action to provision a workflow from a stored use case template @@ -114,7 +113,6 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); @@ -78,10 +80,10 @@ public void testPlugin() throws IOException { 3, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(3, ffp.getRestHandlers(null, null, null, null, null, null, null).size()); + assertEquals(3, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); assertEquals(3, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); - assertEquals(1, ffp.getSettings().size()); + assertEquals(3, ffp.getSettings().size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/TestHelpers.java b/src/test/java/org/opensearch/flowframework/TestHelpers.java index 002b59458..9c3f8a07e 100644 --- a/src/test/java/org/opensearch/flowframework/TestHelpers.java +++ b/src/test/java/org/opensearch/flowframework/TestHelpers.java @@ -9,8 +9,16 @@ package org.opensearch.flowframework; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; import org.opensearch.commons.authuser.User; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; public class TestHelpers { @@ -23,4 +31,13 @@ public static User randomUser() { ImmutableList.of("attribute=test") ); } + + public static ClusterSettings clusterSetting(Settings settings, Setting... setting) { + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Sets.newHashSet(setting).stream() + ).collect(Collectors.toSet()); + ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); + return clusterSettings; + } } diff --git a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkFeatureEnabledSettingTests.java b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkFeatureEnabledSettingTests.java index 9ac16c6f3..232dd71f2 100644 --- a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkFeatureEnabledSettingTests.java +++ b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkFeatureEnabledSettingTests.java @@ -37,7 +37,7 @@ public void setUp() throws Exception { settings = Settings.builder().build(); final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED) + Stream.of(FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 8a3564abe..c81498cea 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -10,6 +10,10 @@ import org.opensearch.Version; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -30,7 +34,10 @@ import java.util.Map; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; public class RestCreateWorkflowActionTests extends OpenSearchTestCase { @@ -41,11 +48,21 @@ public class RestCreateWorkflowActionTests extends OpenSearchTestCase { private String updateWorkflowPath; private NodeClient nodeClient; private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + private Settings settings; + private ClusterService clusterService; @Override public void setUp() throws Exception { super.setUp(); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + settings = Settings.builder() + .put(WORKFLOW_REQUEST_TIMEOUT.getKey(), TimeValue.timeValueMillis(10)) + .put(MAX_WORKFLOWS.getKey(), 2) + .build(); + + ClusterSettings clusterSettings = TestHelpers.clusterSetting(settings, WORKFLOW_REQUEST_TIMEOUT, MAX_WORKFLOWS); + clusterService = spy(new ClusterService(settings, clusterSettings, null)); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); Version templateVersion = Version.fromString("1.0.0"); @@ -69,7 +86,8 @@ public void setUp() throws Exception { // Invalid template configuration, wrong field name this.invalidTemplate = template.toJson().replace("use_case", "invalid"); - this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService); this.createWorkflowPath = String.format(Locale.ROOT, "%s", WORKFLOW_URI); this.updateWorkflowPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); this.nodeClient = mock(NodeClient.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index fbec8a034..18f330c98 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -11,10 +11,13 @@ import org.opensearch.Version; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; @@ -34,9 +37,17 @@ import org.mockito.ArgumentCaptor; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -51,19 +62,27 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private ThreadPool threadPool; private ParseUtils parseUtils; private ThreadContext threadContext; + private Settings settings; @Override public void setUp() throws Exception { super.setUp(); threadPool = mock(ThreadPool.class); + settings = Settings.builder() + .put("plugins.flow_framework.max_workflows.", 2) + .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) + .build(); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool); - this.createWorkflowTransportAction = new CreateWorkflowTransportAction( - mock(TransportService.class), - mock(ActionFilters.class), - workflowProcessSorter, - flowFrameworkIndicesHandler, - client + this.createWorkflowTransportAction = spy( + new CreateWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + workflowProcessSorter, + flowFrameworkIndicesHandler, + settings, + client + ) ); // client = mock(Client.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); @@ -146,7 +165,7 @@ public void testFailedDryRunValidation() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, null, null); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -154,76 +173,118 @@ public void testFailedDryRunValidation() { assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage()); } - public void testFailedToCreateNewWorkflow() { + public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, template); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) + ); doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(1); - responseListener.onFailure(new Exception("Failed to create global_context index")); + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(false); return null; - }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(Template.class), any()); + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); - createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to create global_context index", exceptionCaptor.getValue().getMessage()); + assertEquals(("Maximum workflows limit reached 1000"), exceptionCaptor.getValue().getMessage()); } - public void testFailedToUpdateWorkflow() { + public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) + ); + // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(2); - responseListener.onFailure(new Exception("Failed to update use case template")); + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); return null; - }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); - createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to create global_context index")); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(Template.class), any()); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to update use case template", exceptionCaptor.getValue().getMessage()); + assertEquals("Failed to create global_context index", exceptionCaptor.getValue().getMessage()); } - // TODO: Fix these unit tests, manually tested these work but mocks here are wrong - /* public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ActionListener indexListener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + false, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) + ); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, template); + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + // Bypass putTemplateToGlobalContext and force onResponse doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); return null; - }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(Template.class), any()); - - ArgumentCaptor responseCaptorStateIndex = ArgumentCaptor.forClass(IndexResponse.class); - verify(indexListener, times(1)).onResponse(responseCaptorStateIndex.capture()); + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + // Bypass putInitialStateToWorkflowState and force on response doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(1); + ActionListener responseListener = invocation.getArgument(2); responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); return null; - }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(responseCaptorStateIndex.getValue().getId(), null, any()); + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); - createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); - verify(listener, times(1)).onResponse(responseCaptor.capture()); + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); - assertEquals("1", responseCaptor.getValue().getWorkflowId()); + assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + } + public void testFailedToUpdateWorkflow() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onFailure(new Exception("Failed to update use case template")); + return null; + }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + + createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to update use case template", exceptionCaptor.getValue().getMessage()); } public void testUpdateWorkflow() { - @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); @@ -234,11 +295,16 @@ public void testUpdateWorkflow() { return null; }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(2); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(anyString(), any(), any()); + createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); assertEquals("1", responseCaptor.getValue().getWorkflowId()); } - */ }