Skip to content

Commit

Permalink
Implemented throttling for max workflows to be created
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Nov 7, 2023
1 parent c9950a8 commit 4c57be1
Show file tree
Hide file tree
Showing 15 changed files with 263 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,17 @@

import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

/**
* An OpenSearch plugin that enables builders to innovate AI apps on OpenSearch.
*/
public class FlowFrameworkPlugin extends Plugin implements ActionPlugin {

private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting;
private ClusterService clusterService;

/**
* Instantiate this plugin.
Expand All @@ -81,6 +85,7 @@ public Collection<Object> createComponents(
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
Settings settings = environment.settings();
this.clusterService = clusterService;
flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings);

MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
Expand All @@ -103,7 +108,7 @@ public List<RestHandler> getRestHandlers(
Supplier<DiscoveryNodes> nodesInCluster
) {
return ImmutableList.of(
new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting),
new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService),
new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting)
);
}
Expand All @@ -118,7 +123,7 @@ public List<RestHandler> getRestHandlers(

@Override
public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList.of(FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED);
List<Setting<?>> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT);
return settings;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,15 @@
package org.opensearch.flowframework.common;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;

/**
* Controls enabling or disabling features of this plugin
*/
public class FlowFrameworkFeatureEnabledSetting {

/** This setting enables/disables the Flow Framework REST API */
public static final Setting<Boolean> FLOW_FRAMEWORK_ENABLED = Setting.boolSetting(
"plugins.flow_framework.enabled",
false,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

private volatile Boolean isFlowFrameworkEnabled;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.common;

import org.opensearch.common.settings.Setting;
import org.opensearch.common.unit.TimeValue;

/** The common settings of flow framework */
public class FlowFrameworkSettings {

private FlowFrameworkSettings() {}

/** The upper limit of max workflows that can be created */
public static final int MAX_WORKFLOWS_LIMIT = 34;

/** This setting sets max workflows that can be created */
public static final Setting<Integer> MAX_WORKFLOWS = Setting.intSetting(
"plugins.flow_framework.max_workflows",
0,
0,
MAX_WORKFLOWS_LIMIT,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting sets the timeout for the request */
public static final Setting<TimeValue> WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting(
"plugins.flow_framework.request_timeout",
TimeValue.timeValueSeconds(10),
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting enables/disables the Flow Framework REST API */
public static final Setting<Boolean> FLOW_FRAMEWORK_ENABLED = Setting.boolSetting(
"plugins.flow_framework.enabled",
false,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.rest;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.rest.BaseRestHandler;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

/**
* Abstract action for the rest actions
*/
public abstract class AbstractWorkflowAction extends BaseRestHandler {
protected volatile TimeValue requestTimeout;
protected volatile Integer maxWorkflows;

/**
* Instantiates a new AbstractWorkflowAction
*
* @param settings Environment settings
* @param clusterService clusterService
*/
public AbstractWorkflowAction(Settings settings, ClusterService clusterService) {
this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings);
this.maxWorkflows = MAX_WORKFLOWS.get(settings);

clusterService.getClusterSettings().addSettingsUpdateConsumer(WORKFLOW_REQUEST_TIMEOUT, it -> requestTimeout = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOWS, it -> maxWorkflows = it);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
Expand All @@ -21,7 +23,6 @@
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.transport.CreateWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;

Expand All @@ -32,12 +33,12 @@
import static org.opensearch.flowframework.common.CommonValue.DRY_RUN;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;

/**
* Rest Action to facilitate requests to create and update a use case template
*/
public class RestCreateWorkflowAction extends BaseRestHandler {
public class RestCreateWorkflowAction extends AbstractWorkflowAction {

private static final Logger logger = LogManager.getLogger(RestCreateWorkflowAction.class);
private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action";
Expand All @@ -48,8 +49,15 @@ public class RestCreateWorkflowAction extends BaseRestHandler {
* Intantiates a new RestCreateWorkflowAction
*
* @param flowFrameworkFeatureEnabledSetting Whether this API is enabled
* @param settings Environment settings
* @param clusterService clusterService
*/
public RestCreateWorkflowAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) {
public RestCreateWorkflowAction(
FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting,
Settings settings,
ClusterService clusterService
) {
super(settings, clusterService);
this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting;
}

Expand Down Expand Up @@ -85,7 +93,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
Template template = Template.parse(request.content().utf8ToString());
boolean dryRun = request.paramAsBoolean(DRY_RUN, false);

WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun, requestTimeout, maxWorkflows);

return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;

/**
* Rest action to facilitate requests to provision a workflow from an inline defined or stored use case template
Expand Down Expand Up @@ -84,7 +84,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
}
// Create request and provision
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, null, null);
return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.ProvisioningProgress;
Expand All @@ -27,6 +32,9 @@
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -47,13 +55,15 @@ public class CreateWorkflowTransportAction extends HandledTransportAction<Workfl
private final WorkflowProcessSorter workflowProcessSorter;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private final Client client;
private final Settings settings;

/**
* Intantiates a new CreateWorkflowTransportAction
* @param transportService the TransportService
* @param actionFilters action filters
* @param workflowProcessSorter the workflow process sorter
* @param flowFrameworkIndicesHandler The handler for the global context index
* @param settings Environment settings
* @param client The client used to make the request to OS
*/
@Inject
Expand All @@ -62,11 +72,13 @@ public CreateWorkflowTransportAction(
ActionFilters actionFilters,
WorkflowProcessSorter workflowProcessSorter,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
Settings settings,
Client client
) {
super(CreateWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.workflowProcessSorter = workflowProcessSorter;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.settings = settings;
this.client = client;
}

Expand Down Expand Up @@ -99,6 +111,22 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
}

if (request.getWorkflowId() == null) {

// Throttle incoming requests
QueryBuilder query = QueryBuilders.matchAllQuery();
TimeValue requestTimeOut = request.getRequestTimeout();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeOut);

SearchRequest searchRequest = new SearchRequest(CommonValue.GLOBAL_CONTEXT_INDEX).source(searchSourceBuilder);

client.search(
searchRequest,
ActionListener.wrap(
response -> onSearchGlobalContext(response, listener, request.getMaxWorkflows()),
exception -> listener.onFailure(exception)
)
);

// Create new global context and state index entries
flowFrameworkIndicesHandler.putTemplateToGlobalContext(templateWithUser, ActionListener.wrap(globalContextResponse -> {
flowFrameworkIndicesHandler.putInitialStateToWorkflowState(
Expand Down Expand Up @@ -160,6 +188,21 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
}
}

/**
* Checks if the max workflows limit has been reachesd
*
* @param response response of the GC index SearchRequest
* @param listener ActionListener of the SearchRequest
* @param maxWorkflow max workflows
*/
protected void onSearchGlobalContext(SearchResponse response, ActionListener listener, Integer maxWorkflow) {
if (response.getHits().getTotalHits().value >= maxWorkflow) {
String errorMessage = "Maximum workflows limit reached" + maxWorkflow;
logger.error(errorMessage);
listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
}
}

private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow);
Expand Down
Loading

0 comments on commit 4c57be1

Please sign in to comment.