Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Jackie Han <[email protected]>
  • Loading branch information
jackiehanyang committed Dec 13, 2023
1 parent fb9e428 commit ac9da88
Showing 1 changed file with 119 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand All @@ -53,6 +52,7 @@
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
Expand All @@ -71,10 +71,10 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase {
private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private WorkflowProcessSorter workflowProcessSorter;
private Template template;
private Client client;
private Client client = mock(Client.class);
private ThreadPool threadPool;
private static TestThreadPool testThreadPool;

private ClusterSettings clusterSettings;
private ClusterService clusterService;
private ParseUtils parseUtils;
private ThreadContext threadContext;
private Settings settings;
Expand All @@ -83,41 +83,39 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase {
public void setUp() throws Exception {
super.setUp();
threadPool = mock(ThreadPool.class);
this.client = mock(Client.class);

ClusterService clusterService = mock(ClusterService.class);
settings = Settings.builder()
.put("plugins.flow_framework.max_workflows.", 2)
.put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10))
.build();
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
.put("plugins.flow_framework.max_workflows.", 2)
.put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10))
.build();
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
WorkflowStepFactory factory = new WorkflowStepFactory(
Settings.EMPTY,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
Settings.EMPTY,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler
);
this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);

this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool, clusterService, settings);
this.createWorkflowTransportAction = spy(
new CreateWorkflowTransportAction(
mock(TransportService.class),
mock(ActionFilters.class),
workflowProcessSorter,
flowFrameworkIndicesHandler,
settings,
client
)
new CreateWorkflowTransportAction(
mock(TransportService.class),
mock(ActionFilters.class),
workflowProcessSorter,
flowFrameworkIndicesHandler,
settings,
client
)
);
// client = mock(Client.class);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
// threadContext = mock(ThreadContext.class);
when(client.threadPool()).thenReturn(threadPool);
Expand All @@ -135,14 +133,14 @@ public void setUp() throws Exception {
Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges);

this.template = new Template(
"test",
"description",
"use case",
templateVersion,
compatibilityVersions,
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
"test",
"description",
"use case",
templateVersion,
compatibilityVersions,
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
);
}

Expand All @@ -158,53 +156,53 @@ public void testDryRunValidation_Success() {
public void testDryRunValidation_Failed() {

WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
"create_connector",
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", "")
)
"workflow_step_1",
"create_connector",
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", "")
)
);

WorkflowNode registerModel = new WorkflowNode(
"workflow_step_2",
"register_model",
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
"workflow_step_2",
"register_model",
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
);

WorkflowNode deployModel = new WorkflowNode(
"workflow_step_3",
"deploy_model",
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
"workflow_step_3",
"deploy_model",
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
);

WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id());
WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id());
WorkflowEdge cyclicalEdge = new WorkflowEdge(deployModel.id(), createConnector.id());

Workflow workflow = new Workflow(
Map.of(),
List.of(createConnector, registerModel, deployModel),
List.of(edge1, edge2, cyclicalEdge)
Map.of(),
List.of(createConnector, registerModel, deployModel),
List.of(edge1, edge2, cyclicalEdge)
);

Template cyclicalTemplate = new Template(
"test",
"description",
"use case",
Version.fromString("1.0.0"),
List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")),
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
"test",
"description",
"use case",
Version.fromString("1.0.0"),
List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")),
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
);

@SuppressWarnings("unchecked")
Expand All @@ -221,11 +219,11 @@ public void testMaxWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

doAnswer(invocation -> {
Expand Down Expand Up @@ -258,11 +256,11 @@ public void testFailedToCreateNewWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

// Bypass checkMaxWorkflows and force onResponse
Expand Down Expand Up @@ -295,11 +293,11 @@ public void testCreateNewWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
template,
false,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

// Bypass checkMaxWorkflows and force onResponse
Expand Down Expand Up @@ -387,11 +385,11 @@ public void testCreateWorkflow_withProvisionParam() {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
validTemplate,
true,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
validTemplate,
true,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

// Bypass checkMaxWorkflows and force onResponse
Expand Down Expand Up @@ -444,11 +442,11 @@ public void testCreateWorkflow_withFailedProvision() {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
null,
validTemplate,
true,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
null,
validTemplate,
true,
WORKFLOW_REQUEST_TIMEOUT.get(settings),
MAX_WORKFLOWS.get(settings)
);

// Bypass checkMaxWorkflows and force onResponse
Expand Down Expand Up @@ -495,30 +493,30 @@ public void testCreateWorkflow_withFailedProvision() {

private Template generateValidTemplate() {
WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(),
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", "")
)
"workflow_step_1",
WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(),
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", "")
)
);
WorkflowNode registerModel = new WorkflowNode(
"workflow_step_2",
WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(),
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
"workflow_step_2",
WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(),
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
);
WorkflowNode deployModel = new WorkflowNode(
"workflow_step_3",
WorkflowResources.DEPLOY_MODEL.getWorkflowStep(),
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
"workflow_step_3",
WorkflowResources.DEPLOY_MODEL.getWorkflowStep(),
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
);

WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id());
Expand All @@ -527,14 +525,14 @@ private Template generateValidTemplate() {
Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2));

Template validTemplate = new Template(
"test",
"description",
"use case",
Version.fromString("1.0.0"),
List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")),
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
"test",
"description",
"use case",
Version.fromString("1.0.0"),
List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")),
Map.of("workflow", workflow),
Map.of(),
TestHelpers.randomUser()
);

return validTemplate;
Expand Down

0 comments on commit ac9da88

Please sign in to comment.