Skip to content

Commit

Permalink
Fixed static fields initialization in WorkflowStepFactory (#532)
Browse files Browse the repository at this point in the history
Fixed static fields initialization

Signed-off-by: Owais Kazi <[email protected]>
(cherry picked from commit 24bf51a)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Feb 21, 2024
1 parent 4446fb4 commit 1684f8f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ public Collection<Object> createComponents(
);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
threadPool,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler,
flowFrameworkSettings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -61,37 +59,47 @@ public class WorkflowStepFactory {

private final Map<String, Supplier<WorkflowStep>> stepMap = new HashMap<>();
private static final Logger logger = LogManager.getLogger(WorkflowStepFactory.class);
private static ThreadPool threadPool;
private static MachineLearningNodeClient mlClient;
private static FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private static FlowFrameworkSettings flowFrameworkSettings;

/**
* Instantiate this class.
*
* @param threadPool The OpenSearch thread pool
* @param clusterService The OpenSearch cluster service
* @param client The OpenSearch client steps can use
* @param mlClient Machine Learning client to perform ml operations
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
* @param flowFrameworkSettings common settings of the plugin
*/
public WorkflowStepFactory(
ThreadPool threadPool,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
FlowFrameworkSettings flowFrameworkSettings
) {
this.threadPool = threadPool;
this.mlClient = mlClient;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.flowFrameworkSettings = flowFrameworkSettings;
// Initialize the WorkflowSteps enum inside the constructor
for (WorkflowSteps workflowStep : WorkflowSteps.values()) {
stepMap.put(workflowStep.getWorkflowStepName(), workflowStep.step());
}
stepMap.put(NoOpStep.NAME, NoOpStep::new);
stepMap.put(
RegisterLocalCustomModelStep.NAME,
() -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
);
stepMap.put(
RegisterLocalSparseEncodingModelStep.NAME,
() -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)

Check warning on line 84 in src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java#L84

Added line #L84 was not covered by tests
);
stepMap.put(
RegisterLocalPretrainedModelStep.NAME,
() -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)

Check warning on line 88 in src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java#L88

Added line #L88 was not covered by tests
);
stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient));
stepMap.put(
DeployModelStep.NAME,
() -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
);
stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient));
stepMap.put(RegisterModelGroupStep.NAME, () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ToolStep.NAME, ToolStep::new);
stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient));
}

/**
Expand All @@ -101,16 +109,15 @@ public WorkflowStepFactory(
public enum WorkflowSteps {

/** Noop Step */
NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null, NoOpStep::new),
NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null),

/** Create Connector Step */
CREATE_CONNECTOR(
CreateConnectorStep.NAME,
List.of(NAME_FIELD, DESCRIPTION_FIELD, VERSION_FIELD, PROTOCOL_FIELD, PARAMETERS_FIELD, CREDENTIAL_FIELD, ACTIONS_FIELD),
List.of(CONNECTOR_ID),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(60),
() -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)
TimeValue.timeValueSeconds(60)
),

/** Register Local Custom Model Step */
Expand All @@ -129,8 +136,7 @@ public enum WorkflowSteps {
),
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(60),
() -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
TimeValue.timeValueSeconds(60)
),

/** Register Local Sparse Encoding Model Step */
Expand All @@ -139,8 +145,7 @@ public enum WorkflowSteps {
List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT, FUNCTION_NAME, MODEL_CONTENT_HASH_VALUE, URL),
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(60),
() -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
TimeValue.timeValueSeconds(60)
),

/** Register Local Pretrained Model Step */
Expand All @@ -149,8 +154,7 @@ public enum WorkflowSteps {
List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT),
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(60),
() -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
TimeValue.timeValueSeconds(60)
),

/** Register Remote Model Step */
Expand All @@ -159,8 +163,7 @@ public enum WorkflowSteps {
List.of(NAME_FIELD, CONNECTOR_ID),
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
List.of(OPENSEARCH_ML),
null,
() -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)
null
),

/** Register Model Group Step */
Expand All @@ -169,94 +172,42 @@ public enum WorkflowSteps {
List.of(NAME_FIELD),
List.of(MODEL_GROUP_ID, MODEL_GROUP_STATUS),
List.of(OPENSEARCH_ML),
null,
() -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler)
null
),

/** Deploy Model Step */
DEPLOY_MODEL(
DeployModelStep.NAME,
List.of(MODEL_ID),
List.of(MODEL_ID),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(15),
() -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
),
DEPLOY_MODEL(DeployModelStep.NAME, List.of(MODEL_ID), List.of(MODEL_ID), List.of(OPENSEARCH_ML), TimeValue.timeValueSeconds(15)),

/** Undeploy Model Step */
UNDEPLOY_MODEL(
UndeployModelStep.NAME,
List.of(MODEL_ID),
List.of(SUCCESS),
List.of(OPENSEARCH_ML),
null,
() -> new UndeployModelStep(mlClient)
),
UNDEPLOY_MODEL(UndeployModelStep.NAME, List.of(MODEL_ID), List.of(SUCCESS), List.of(OPENSEARCH_ML), null),

/** Delete Model Step */
DELETE_MODEL(
DeleteModelStep.NAME,
List.of(MODEL_ID),
List.of(MODEL_ID),
List.of(OPENSEARCH_ML),
null,
() -> new DeleteModelStep(mlClient)
),
DELETE_MODEL(DeleteModelStep.NAME, List.of(MODEL_ID), List.of(MODEL_ID), List.of(OPENSEARCH_ML), null),

/** Delete Connector Step */
DELETE_CONNECTOR(
DeleteConnectorStep.NAME,
List.of(CONNECTOR_ID),
List.of(CONNECTOR_ID),
List.of(OPENSEARCH_ML),
null,
() -> new DeleteConnectorStep(mlClient)
),
DELETE_CONNECTOR(DeleteConnectorStep.NAME, List.of(CONNECTOR_ID), List.of(CONNECTOR_ID), List.of(OPENSEARCH_ML), null),

/** Register Agent Step */
REGISTER_AGENT(
RegisterAgentStep.NAME,
List.of(NAME_FIELD, TYPE),
List.of(AGENT_ID),
List.of(OPENSEARCH_ML),
null,
() -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)
),
REGISTER_AGENT(RegisterAgentStep.NAME, List.of(NAME_FIELD, TYPE), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null),

/** Delete Agent Step */
DELETE_AGENT(
DeleteAgentStep.NAME,
List.of(AGENT_ID),
List.of(AGENT_ID),
List.of(OPENSEARCH_ML),
null,
() -> new DeleteAgentStep(mlClient)
),
DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null),

/** Create Tool Step */
CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null, ToolStep::new);
CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null);

private final String workflowStepName;
private final List<String> inputs;
private final List<String> outputs;
private final List<String> requiredPlugins;
private final TimeValue timeout;
private final Supplier<WorkflowStep> workflowStep;

WorkflowSteps(
String workflowStepName,
List<String> inputs,
List<String> outputs,
List<String> requiredPlugins,
TimeValue timeout,
Supplier<WorkflowStep> workflowStep
) {

WorkflowSteps(String workflowStepName, List<String> inputs, List<String> outputs, List<String> requiredPlugins, TimeValue timeout) {
this.workflowStepName = workflowStepName;
this.inputs = List.copyOf(inputs);
this.outputs = List.copyOf(outputs);
this.requiredPlugins = requiredPlugins;
this.timeout = timeout;
this.workflowStep = workflowStep;
}

/**
Expand Down Expand Up @@ -299,14 +250,6 @@ public TimeValue timeout() {
return timeout;
}

/**
* Get the step
* @return the step
*/
public Supplier<WorkflowStep> step() {
return workflowStep;
}

/**
* Get the workflow step validator object
* @return the WorkflowStepValidator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@
*/
package org.opensearch.flowframework.model;

import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
Expand All @@ -27,14 +20,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -174,26 +160,11 @@ public void testParseWorkflowValidator() throws IOException {
public void testWorkflowStepFactoryHasValidators() throws IOException {

ThreadPool threadPool = mock(ThreadPool.class);
ClusterService clusterService = mock(ClusterService.class);
ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class);
AdminClient adminClient = mock(AdminClient.class);
Client client = mock(Client.class);
when(client.admin()).thenReturn(adminClient);
when(adminClient.cluster()).thenReturn(clusterAdminClient);
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
threadPool,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler,
flowFrameworkSettings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,7 @@ public static void setup() throws IOException {
FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL
)
);
WorkflowStepFactory factory = new WorkflowStepFactory(
testThreadPool,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler,
flowFrameworkSettings
);
WorkflowStepFactory factory = new WorkflowStepFactory(testThreadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings);
workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, flowFrameworkSettings);
}

Expand Down

0 comments on commit 1684f8f

Please sign in to comment.