Skip to content

Commit

Permalink
Added FlowFrameworkMaxRequestRetrySetting and applied this to GetMlTa…
Browse files Browse the repository at this point in the history
…skStep

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis committed Nov 20, 2023
1 parent 2dcfc9c commit 52f0c04
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting;
import org.opensearch.flowframework.common.FlowFrameworkMaxRequestRetrySetting;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.rest.RestCreateWorkflowAction;
import org.opensearch.flowframework.rest.RestGetWorkflowAction;
Expand Down Expand Up @@ -60,6 +61,7 @@
import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

Expand All @@ -69,6 +71,7 @@
public class FlowFrameworkPlugin extends Plugin implements ActionPlugin {

private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting;

private ClusterService clusterService;

/**
Expand All @@ -93,10 +96,19 @@ public Collection<Object> createComponents(
Settings settings = environment.settings();
this.clusterService = clusterService;
flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings);

FlowFrameworkMaxRequestRetrySetting flowFrameworkMaxRequestRetrySetting = new FlowFrameworkMaxRequestRetrySetting(
clusterService,
settings
);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler,
flowFrameworkMaxRequestRetrySetting
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler);
Expand Down Expand Up @@ -132,7 +144,7 @@ public List<RestHandler> getRestHandlers(

@Override
public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT);
List<Setting<?>> settings = ImmutableList.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_REQUEST_RETRY);
return settings;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.common;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;

/**
* Controls max request retry setting for workflow step transport action APIs
*/
public class FlowFrameworkMaxRequestRetrySetting {

protected volatile Integer maxRetry;

/**
* Instantiate this class.
*
* @param clusterService OpenSearch cluster service
* @param settings OpenSearch settings
*/
public FlowFrameworkMaxRequestRetrySetting(ClusterService clusterService, Settings settings) {
maxRetry = MAX_REQUEST_RETRY.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_REQUEST_RETRY, it -> maxRetry = it);
}

/**
* Gets the maximum number of retries
* @return the maximum number of retries
*/
public int getMaxRetries() {
return maxRetry;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,14 @@ private FlowFrameworkSettings() {}
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting sets the maximum number of transport request retries */
public static final Setting<Integer> MAX_REQUEST_RETRY = Setting.intSetting(
"plugins.flow_framework.max_request_retry",
5,
0,
20,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,25 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
* @param internalListener listener for search request
*/
protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, ActionListener<Boolean> internalListener) {
QueryBuilder query = QueryBuilders.matchAllQuery();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeOut);
if (!flowFrameworkIndicesHandler.doesIndexExist(CommonValue.GLOBAL_CONTEXT_INDEX)) {
internalListener.onResponse(true);

Check warning on line 199 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L199

Added line #L199 was not covered by tests
} else {
QueryBuilder query = QueryBuilders.matchAllQuery();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeOut);

Check warning on line 202 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L201-L202

Added lines #L201 - L202 were not covered by tests

SearchRequest searchRequest = new SearchRequest(CommonValue.GLOBAL_CONTEXT_INDEX).source(searchSourceBuilder);
SearchRequest searchRequest = new SearchRequest(CommonValue.GLOBAL_CONTEXT_INDEX).source(searchSourceBuilder);

Check warning on line 204 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L204

Added line #L204 was not covered by tests

client.search(searchRequest, ActionListener.wrap(searchResponse -> {
if (searchResponse.getHits().getTotalHits().value >= maxWorkflow) {
internalListener.onResponse(false);
} else {
internalListener.onResponse(true);
}
}, exception -> {
logger.error("Unable to fetch the workflows {}", exception);
internalListener.onFailure(new FlowFrameworkException("Unable to fetch the workflows", RestStatus.BAD_REQUEST));
}));
client.search(searchRequest, ActionListener.wrap(searchResponse -> {

Check warning on line 206 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L206

Added line #L206 was not covered by tests
if (searchResponse.getHits().getTotalHits().value >= maxWorkflow) {
internalListener.onResponse(false);

Check warning on line 208 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L208

Added line #L208 was not covered by tests
} else {
internalListener.onResponse(true);

Check warning on line 210 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L210

Added line #L210 was not covered by tests
}
}, exception -> {
logger.error("Unable to fetch the workflows {}", exception);
internalListener.onFailure(new FlowFrameworkException("Unable to fetch the workflows", RestStatus.BAD_REQUEST));
}));

Check warning on line 215 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L212-L215

Added lines #L212 - L215 were not covered by tests
}
}

private void validateWorkflows(Template template) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.FlowFrameworkMaxRequestRetrySetting;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;

import java.util.List;
import java.util.Map;
Expand All @@ -32,39 +33,25 @@
public class GetMLTaskStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class);
private FlowFrameworkMaxRequestRetrySetting maxRequestRetrySetting;
private MachineLearningNodeClient mlClient;
static final String NAME = "get_ml_task";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
* @param maxRequestRetrySetting the max request retry setting
*/
public GetMLTaskStep(MachineLearningNodeClient mlClient) {
public GetMLTaskStep(MachineLearningNodeClient mlClient, FlowFrameworkMaxRequestRetrySetting maxRequestRetrySetting) {
this.mlClient = mlClient;
this.maxRequestRetrySetting = maxRequestRetrySetting;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> getMLTaskFuture = new CompletableFuture<>();

ActionListener<MLTask> actionListener = ActionListener.wrap(response -> {

// TODO : Add retry capability if response status is not COMPLETED :
// https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/158

logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
data.get(0).getWorkflowId()
)
);
}, exception -> {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
});

String taskId = null;

for (WorkflowData workflowData : data) {
Expand All @@ -84,7 +71,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST));
} else {
mlClient.getTask(taskId, actionListener);
retryableGetMlTask(data, getMLTaskFuture, taskId, 0);
}

return getMLTaskFuture;
Expand All @@ -95,4 +82,42 @@ public String getName() {
return NAME;
}

private void retryableGetMlTask(List<WorkflowData> data, CompletableFuture<WorkflowData> getMLTaskFuture, String taskId, int retries) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
if (response.getState() != MLTaskState.COMPLETED) {
throw new IllegalStateException("MLTask is not yet completed");

Check warning on line 88 in src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java#L88

Added line #L88 was not covered by tests
} else {
logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(MODEL_ID, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
data.get(0).getWorkflowId()
)
);
}
}, exception -> {
if (shouldRetry(getMLTaskFuture, retries)) {
final int retryAdd = retries + 1;
retryableGetMlTask(data, getMLTaskFuture, taskId, retryAdd);
} else {
logger.error("Failed to retrieve ML Task, maximum retries exceeded");
getMLTaskFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
}
}));
}

private boolean shouldRetry(CompletableFuture<WorkflowData> getMLTaskFuture, int retries) {
try {
Thread.sleep(5000);
} catch (Exception e) {
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));

Check warning on line 118 in src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java#L117-L118

Added lines #L117 - L118 were not covered by tests
}
return retries < maxRequestRetrySetting.getMaxRetries();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.FlowFrameworkMaxRequestRetrySetting;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand All @@ -33,22 +34,25 @@ public class WorkflowStepFactory {
* @param client The OpenSearch client steps can use
* @param mlClient Machine Learning client to perform ml operations
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
* @param maxRequestRetrySetting FlowFramework Setting to control maximum transport request retries
*/
public WorkflowStepFactory(
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
FlowFrameworkMaxRequestRetrySetting maxRequestRetrySetting
) {
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler);
populateMap(clusterService, client, mlClient, flowFrameworkIndicesHandler, maxRequestRetrySetting);
}

private void populateMap(
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
FlowFrameworkMaxRequestRetrySetting maxRequestRetrySetting
) {
stepMap.put(NoOpStep.NAME, new NoOpStep());
stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client));
Expand All @@ -58,7 +62,7 @@ private void populateMap(
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(mlClient, maxRequestRetrySetting));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -61,7 +62,7 @@ public void setUp() throws Exception {

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT)
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
Expand All @@ -83,7 +84,7 @@ public void testPlugin() throws IOException {
assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size());
assertEquals(4, ffp.getActions().size());
assertEquals(1, ffp.getExecutorBuilders(settings).size());
assertEquals(3, ffp.getSettings().size());
assertEquals(4, ffp.getSettings().size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.FlowFrameworkMaxRequestRetrySetting;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand All @@ -21,7 +25,14 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -70,7 +81,25 @@ public void testWorkflowStepFactoryHasValidators() throws IOException {
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient, flowFrameworkIndicesHandler);
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_REQUEST_RETRY)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

FlowFrameworkMaxRequestRetrySetting maxRequestRetrySetting = new FlowFrameworkMaxRequestRetrySetting(
clusterService,
Settings.EMPTY
);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler,
maxRequestRetrySetting
);

// Read in workflow-steps.json
WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json");
Expand Down
Loading

0 comments on commit 52f0c04

Please sign in to comment.