Skip to content

Commit

Permalink
replace dryrun parameter with provision in create workflow
Browse files Browse the repository at this point in the history
Signed-off-by: Jackie Han <[email protected]>
  • Loading branch information
jackiehanyang committed Dec 12, 2023
1 parent 18706c0 commit 314650a
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.Locale;

import static org.opensearch.flowframework.common.CommonValue.DRY_RUN;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -90,9 +91,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

String workflowId = request.param(WORKFLOW_ID);
Template template = Template.parse(request.content().utf8ToString());
boolean dryRun = request.paramAsBoolean(DRY_RUN, false);
boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);

WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun, requestTimeout, maxWorkflows);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, provision, requestTimeout, maxWorkflows);

return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand All @@ -21,7 +22,9 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -35,8 +38,11 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.List;

import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD;
Expand All @@ -55,6 +61,8 @@ public class CreateWorkflowTransportAction extends HandledTransportAction<Workfl
private final Client client;
private final Settings settings;

private final TransportService transportService;

/**
* Intantiates a new CreateWorkflowTransportAction
* @param transportService the TransportService
Expand All @@ -74,6 +82,7 @@ public CreateWorkflowTransportAction(
Client client
) {
super(CreateWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.transportService = transportService;
this.workflowProcessSorter = workflowProcessSorter;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.settings = settings;
Expand All @@ -95,20 +104,6 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
user
);

if (request.isDryRun()) {
try {
validateWorkflows(templateWithUser);
} catch (Exception e) {
if (e instanceof FlowFrameworkException) {
logger.error("Workflow validation failed for template : " + templateWithUser.name());
listener.onFailure(e);
} else {
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
return;
}
}

if (request.getWorkflowId() == null) {
// Throttle incoming requests
checkMaxWorkflows(request.getRequestTimeout(), request.getMaxWorkflows(), ActionListener.wrap(max -> {
Expand Down Expand Up @@ -147,6 +142,14 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
}
})
);
if (request.isProvision()) {
WorkflowRequest workflowRequest = new WorkflowRequest(globalContextResponse.getId(), null);
transportService.sendRequest(transportService.getLocalNode(),
ProvisionWorkflowAction.NAME,
workflowRequest,
new ActionListenerResponseHandler<>(listener, WorkflowResponse::new)
);
}
}, exception -> {
logger.error("Failed to save use case template : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public class WorkflowRequest extends ActionRequest {
*/
private boolean dryRun;

/**
* Provision flag
*/
private boolean provision;

/**
* Timeout for request
*/
Expand Down Expand Up @@ -77,20 +82,20 @@ public WorkflowRequest(
* Instantiates a new WorkflowRequest
* @param workflowId the documentId of the workflow
* @param template the use case template which describes the workflow
* @param dryRun flag to indicate if validation is necessary
* @param provision flag to indicate if provision is necessary
* @param requestTimeout timeout of the request
* @param maxWorkflows max number of workflows
*/
public WorkflowRequest(
@Nullable String workflowId,
@Nullable Template template,
boolean dryRun,
boolean provision,
@Nullable TimeValue requestTimeout,
@Nullable Integer maxWorkflows
) {
this.workflowId = workflowId;
this.template = template;
this.dryRun = dryRun;
this.provision = provision;
this.requestTimeout = requestTimeout;
this.maxWorkflows = maxWorkflows;
}
Expand All @@ -105,7 +110,7 @@ public WorkflowRequest(StreamInput in) throws IOException {
this.workflowId = in.readOptionalString();
String templateJson = in.readOptionalString();
this.template = templateJson == null ? null : Template.parse(templateJson);
this.dryRun = in.readBoolean();
this.provision = in.readBoolean();
this.requestTimeout = in.readOptionalTimeValue();
this.maxWorkflows = in.readOptionalInt();
}
Expand All @@ -129,11 +134,11 @@ public Template getTemplate() {
}

/**
* Gets the dry run validation flag
* @return the dry run boolean
* Gets the provision flag
* @return the provision boolean
*/
public boolean isDryRun() {
return this.dryRun;
public boolean isProvision() {
return this.provision;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,67 +114,67 @@ public void setUp() throws Exception {
);
}

public void testFailedDryRunValidation() {

WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
"create_connector",
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", "")
)
);

WorkflowNode registerModel = new WorkflowNode(
"workflow_step_2",
"register_model",
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
);

WorkflowNode deployModel = new WorkflowNode(
"workflow_step_3",
"deploy_model",
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
);

WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id());
WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id());
WorkflowEdge cyclicalEdge = new WorkflowEdge(deployModel.id(), createConnector.id());

Workflow workflow = new Workflow(
Map.of(),
List.of(createConnector, registerModel, deployModel),
List.of(edge1, edge2, cyclicalEdge)
);

Template cyclicalTemplate = new Template(
"test",
"description",
"use case",
Version.fromString("1.0.0"),
List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")),
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
);

@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, null, null);

createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage());
}
// public void testFailedDryRunValidation() {
//
// WorkflowNode createConnector = new WorkflowNode(
// "workflow_step_1",
// "create_connector",
// Map.of(),
// Map.ofEntries(
// Map.entry("name", ""),
// Map.entry("description", ""),
// Map.entry("version", ""),
// Map.entry("protocol", ""),
// Map.entry("parameters", ""),
// Map.entry("credential", ""),
// Map.entry("actions", "")
// )
// );
//
// WorkflowNode registerModel = new WorkflowNode(
// "workflow_step_2",
// "register_model",
// Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
// Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
// );
//
// WorkflowNode deployModel = new WorkflowNode(
// "workflow_step_3",
// "deploy_model",
// Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
// Map.of()
// );
//
// WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id());
// WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id());
// WorkflowEdge cyclicalEdge = new WorkflowEdge(deployModel.id(), createConnector.id());
//
// Workflow workflow = new Workflow(
// Map.of(),
// List.of(createConnector, registerModel, deployModel),
// List.of(edge1, edge2, cyclicalEdge)
// );
//
// Template cyclicalTemplate = new Template(
// "test",
// "description",
// "use case",
// Version.fromString("1.0.0"),
// List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")),
// Map.of("workflow", workflow),
// Map.of(),
// TestHelpers.randomUser()
// );
//
// @SuppressWarnings("unchecked")
// ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
// WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, null, null);
//
// createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener);
// ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
// verify(listener, times(1)).onFailure(exceptionCaptor.capture());
// assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage());
// }

public void testMaxWorkflow() {
@SuppressWarnings("unchecked")
Expand Down

0 comments on commit 314650a

Please sign in to comment.