Skip to content

Commit

Permalink
keep dryrun option in create workflow
Browse files Browse the repository at this point in the history
Signed-off-by: Jackie Han <[email protected]>
  • Loading branch information
jackiehanyang committed Dec 13, 2023
1 parent 090b16b commit fb9e428
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@
import org.opensearch.flowframework.model.ProvisioningProgress;
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import java.util.List;

import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.getUserContext;
Expand Down Expand Up @@ -91,6 +95,20 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
user
);

if (request.isProvision()) {
try {
validateWorkflows(templateWithUser);
} catch (Exception e) {
if (e instanceof FlowFrameworkException) {
logger.error("Workflow validation failed for template : " + templateWithUser.name());
listener.onFailure(e);
} else {
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
return;
}
}

if (request.getWorkflowId() == null) {
// Throttle incoming requests
checkMaxWorkflows(request.getRequestTimeout(), request.getMaxWorkflows(), ActionListener.wrap(max -> {
Expand Down Expand Up @@ -242,4 +260,11 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow,
}));
}
}

private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null);
workflowProcessSorter.validateGraph(sortedNodes);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.flowframework.TestHelpers;
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
Expand All @@ -30,8 +31,10 @@
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand All @@ -50,10 +53,8 @@
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyInt;
Expand All @@ -72,8 +73,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase {
private Template template;
private Client client;
private ThreadPool threadPool;
private ClusterSettings clusterSettings;
private ClusterService clusterService;
private static TestThreadPool testThreadPool;

private ParseUtils parseUtils;
private ThreadContext threadContext;
private Settings settings;
Expand All @@ -83,28 +84,39 @@ public void setUp() throws Exception {
super.setUp();
threadPool = mock(ThreadPool.class);
this.client = mock(Client.class);

ClusterService clusterService = mock(ClusterService.class);
settings = Settings.builder()
.put("plugins.flow_framework.max_workflows.", 2)
.put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10))
.build();
.put("plugins.flow_framework.max_workflows.", 2)
.put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10))
.build();
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings);

MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
WorkflowStepFactory factory = new WorkflowStepFactory(
Settings.EMPTY,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);
this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);

this.createWorkflowTransportAction = spy(
new CreateWorkflowTransportAction(
mock(TransportService.class),
mock(ActionFilters.class),
workflowProcessSorter,
flowFrameworkIndicesHandler,
settings,
client
)
new CreateWorkflowTransportAction(
mock(TransportService.class),
mock(ActionFilters.class),
workflowProcessSorter,
flowFrameworkIndicesHandler,
settings,
client
)
);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
// threadContext = mock(ThreadContext.class);
Expand All @@ -123,26 +135,97 @@ public void setUp() throws Exception {
Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges);

this.template = new Template(
"test",
"description",
"use case",
templateVersion,
compatibilityVersions,
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
"test",
"description",
"use case",
templateVersion,
compatibilityVersions,
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
);
}

public void testDryRunValidation_Success() {
Template validTemplate = generateValidTemplate();

@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, true, null, null);
createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener);
}

public void testDryRunValidation_Failed() {

WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
"create_connector",
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", "")
)
);

WorkflowNode registerModel = new WorkflowNode(
"workflow_step_2",
"register_model",
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
);

WorkflowNode deployModel = new WorkflowNode(
"workflow_step_3",
"deploy_model",
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
);

WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id());
WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id());
WorkflowEdge cyclicalEdge = new WorkflowEdge(deployModel.id(), createConnector.id());

Workflow workflow = new Workflow(
Map.of(),
List.of(createConnector, registerModel, deployModel),
List.of(edge1, edge2, cyclicalEdge)
);

Template cyclicalTemplate = new Template(
"test",
"description",
"use case",
Version.fromString("1.0.0"),
List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")),
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
);

@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, null, null);

createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage());
}

public void testMaxWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

doAnswer(invocation -> {
Expand Down Expand Up @@ -175,11 +258,11 @@ public void testFailedToCreateNewWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

// Bypass checkMaxWorkflows and force onResponse
Expand Down Expand Up @@ -212,11 +295,11 @@ public void testCreateNewWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

// Bypass checkMaxWorkflows and force onResponse
Expand Down Expand Up @@ -298,14 +381,17 @@ public void testUpdateWorkflow() {
}

public void testCreateWorkflow_withProvisionParam() {

Template validTemplate = generateValidTemplate();

@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
template,
true,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
validTemplate,
true,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

// Bypass checkMaxWorkflows and force onResponse
Expand Down Expand Up @@ -353,14 +439,16 @@ public void testCreateWorkflow_withProvisionParam() {
}

public void testCreateWorkflow_withFailedProvision() {
Template validTemplate = generateValidTemplate();

@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
template,
true,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
validTemplate,
true,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

// Bypass checkMaxWorkflows and force onResponse
Expand Down Expand Up @@ -404,4 +492,51 @@ public void testCreateWorkflow_withFailedProvision() {
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("failed", exceptionCaptor.getValue().getMessage());
}

private Template generateValidTemplate() {
WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(),
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", "")
)
);
WorkflowNode registerModel = new WorkflowNode(
"workflow_step_2",
WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(),
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
);
WorkflowNode deployModel = new WorkflowNode(
"workflow_step_3",
WorkflowResources.DEPLOY_MODEL.getWorkflowStep(),
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
);

WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id());
WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id());

Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2));

Template validTemplate = new Template(
"test",
"description",
"use case",
Version.fromString("1.0.0"),
List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")),
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
);

return validTemplate;
}
}

0 comments on commit fb9e428

Please sign in to comment.