Skip to content

Commit

Permalink
Merge branch 'feature/agent_framework' into issue259
Browse files Browse the repository at this point in the history
Signed-off-by: Jackie Han <[email protected]>
  • Loading branch information
jackiehanyang authored Dec 12, 2023
2 parents dbf5fc8 + 7bec6e8 commit 090b16b
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

/**
Expand Down Expand Up @@ -106,7 +107,7 @@ public Collection<Object> createComponents(
mlClient,
flowFrameworkIndicesHandler
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler);
}
Expand Down Expand Up @@ -144,6 +145,7 @@ public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList.of(
FLOW_FRAMEWORK_ENABLED,
MAX_WORKFLOWS,
MAX_WORKFLOW_STEPS,
WORKFLOW_REQUEST_TIMEOUT,
MAX_GET_TASK_REQUEST_RETRY
);
Expand All @@ -158,7 +160,7 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
settings,
PROVISION_THREAD_POOL,
OpenSearchExecutors.allocatedProcessors(settings),
10,
100,
FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ private FlowFrameworkSettings() {}

/** The upper limit of max workflows that can be created */
public static final int MAX_WORKFLOWS_LIMIT = 10000;
/** The upper limit of max workflow steps that can be in a single workflow */
public static final int MAX_WORKFLOW_STEPS_LIMIT = 500;

/** This setting sets max workflows that can be created */
public static final Setting<Integer> MAX_WORKFLOWS = Setting.intSetting(
Expand All @@ -29,6 +31,16 @@ private FlowFrameworkSettings() {}
Setting.Property.Dynamic
);

/** This setting sets max workflows that can be created */
public static final Setting<Integer> MAX_WORKFLOW_STEPS = Setting.intSetting(
"plugins.flow_framework.max_workflow_steps",
50,
1,
MAX_WORKFLOW_STEPS_LIMIT,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting sets the timeout for the request */
public static final Setting<TimeValue> WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting(
"plugins.flow_framework.request_timeout",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap;

/**
Expand Down Expand Up @@ -93,7 +95,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
} else {
for (Map<?, ?> map : (Map<?, ?>[]) e.getValue()) {
buildStringToStringMap(xContentBuilder, map);
buildStringToObjectMap(xContentBuilder, map);
}
}
xContentBuilder.endArray();
Expand Down Expand Up @@ -150,9 +152,9 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
}
userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0]));
} else {
List<Map<String, String>> mapList = new ArrayList<>();
List<Map<String, Object>> mapList = new ArrayList<>();
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
mapList.add(parseStringToStringMap(parser));
mapList.add(parseStringToObjectMap(parser));
}
userInputs.put(inputFieldName, mapList.toArray(new Map[0]));
}
Expand Down
43 changes: 43 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,25 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* Builds an XContent object representing a map of String keys to Object values.
*
* @param xContentBuilder An XContent builder whose position is at the start of the map object to build
* @param map A map as key-value String to Object.
* @throws IOException on a build failure
*/
public static void buildStringToObjectMap(XContentBuilder xContentBuilder, Map<?, ?> map) throws IOException {
xContentBuilder.startObject();
for (Entry<?, ?> e : map.entrySet()) {
if (e.getValue() instanceof String) {
xContentBuilder.field((String) e.getKey(), (String) e.getValue());
} else {
xContentBuilder.field((String) e.getKey(), e.getValue());
}
}
xContentBuilder.endObject();
}

/**
* Builds an XContent object representing a LLMSpec.
*
Expand Down Expand Up @@ -117,6 +136,30 @@ public static Map<String, String> parseStringToStringMap(XContentParser parser)
return map;
}

/**
* Parses an XContent object representing a map of String keys to Object values.
* The Object value here can either be a string or a map
* @param parser An XContent parser whose position is at the start of the map object to parse
* @return A map as identified by the key-value pairs in the XContent
* @throws IOException on a parse failure
*/
public static Map<String, Object> parseStringToObjectMap(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
Map<String, Object> map = new HashMap<>();
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
// If the current token is a START_OBJECT, parse it as Map<String, String>
map.put(fieldName, parseStringToStringMap(parser));
} else {
// Otherwise, parse it as a string
map.put(fieldName, parser.text());
}
}
return map;
}

/**
* Parse content parser to {@link java.time.Instant}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;

import java.util.Map;
Expand Down Expand Up @@ -59,24 +58,6 @@ public AbstractRetryableWorkflowStep(
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
}

/**
* Completes the future for either deploy or register local model step
* @param resourceName resource name for the given step
* @param nodeId node ID of the given step
* @param workflowId workflow ID of the given workflow
* @param response Response from ml commons get Task API
* @param future CompletableFuture of the given step
*/
public void completeFuture(String resourceName, String nodeId, String workflowId, MLTask response, CompletableFuture future) {
future.complete(
new WorkflowData(
Map.ofEntries(Map.entry(resourceName, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
workflowId,
nodeId
)
);
}

/**
* Retryable get ml task
* @param workflowId the workflow id
Expand Down Expand Up @@ -110,25 +91,36 @@ void retryableGetMlTask(
try {
logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId());
String resourceName = WorkflowResources.getResourceByWorkflowStep(getName());
String id;
if (getName().equals(WorkflowResources.DEPLOY_MODEL.getWorkflowStep())) {
completeFuture(resourceName, nodeId, workflowId, response, future);
id = response.getModelId();
} else {
flowFrameworkIndicesHandler.updateResourceInStateIndex(
workflowId,
nodeId,
getName(),
response.getTaskId(),
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
completeFuture(resourceName, nodeId, workflowId, response, future);
}, exception -> {
logger.error("Failed to update new created resource", exception);
future.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);
id = response.getTaskId();
}
flowFrameworkIndicesHandler.updateResourceInStateIndex(
workflowId,
nodeId,
getName(),
id,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
future.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId,
nodeId
)
);
}, exception -> {
logger.error("Failed to update new created resource", exception);
future.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);
} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ public CompletableFuture<WorkflowData> execute(
Map<String, String> previousNodeInputs
) throws IOException {

String workflowId = currentNodeInputs.getWorkflowId();

CompletableFuture<WorkflowData> registerAgentModelFuture = new CompletableFuture<>();

ActionListener<MLRegisterAgentResponse> actionListener = new ActionListener<>() {
Expand All @@ -92,7 +94,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) {
String resourceName = WorkflowResources.getResourceByWorkflowStep(getName());
logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId());
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
workflowId,
currentNodeId,
getName(),
mlRegisterAgentResponse.getAgentId(),
Expand All @@ -101,7 +103,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) {
registerAgentModelFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(resourceName, mlRegisterAgentResponse.getAgentId())),
currentNodeInputs.getWorkflowId(),
workflowId,
currentNodeId
)
);
Expand Down Expand Up @@ -168,12 +170,15 @@ public void onFailure(Exception e) {
// Case when modelId is not present at all
if (llmModelId == null) {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST)
new FlowFrameworkException(
"llm model id is not provided for workflow: " + workflowId + " on node: " + currentNodeId,
RestStatus.BAD_REQUEST
)
);
return registerAgentModelFuture;
}

LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters);
LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters, workflowId, currentNodeId);

MLAgentBuilder builder = MLAgent.builder().name(name);

Expand Down Expand Up @@ -246,9 +251,12 @@ private String getLlmModelId(Map<String, String> previousNodeInputs, Map<String,
return llmModelId;
}

private LLMSpec getLLMSpec(String llmModelId, Map<String, String> llmParameters) {
private LLMSpec getLLMSpec(String llmModelId, Map<String, String> llmParameters, String workflowId, String currentNodeId) {
if (llmModelId == null) {
throw new FlowFrameworkException("model id for llm is null", RestStatus.BAD_REQUEST);
throw new FlowFrameworkException(
"model id for llm is null for workflow: " + workflowId + " on node: " + currentNodeId,
RestStatus.BAD_REQUEST
);
}
LLMSpec.LLMSpecBuilder builder = LLMSpec.builder();
builder.modelId(llmModelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
Expand All @@ -32,6 +34,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE;
import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD;
import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD;
Expand All @@ -45,16 +48,26 @@ public class WorkflowProcessSorter {

private WorkflowStepFactory workflowStepFactory;
private ThreadPool threadPool;
private Integer maxWorkflowSteps;

/**
* Instantiate this class.
*
* @param workflowStepFactory The factory which matches template step types to instances.
* @param threadPool The OpenSearch Thread pool to pass to process nodes.
* @param clusterService The OpenSearch cluster service.
* @param settings OpenSerch settings
*/
public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) {
public WorkflowProcessSorter(
WorkflowStepFactory workflowStepFactory,
ThreadPool threadPool,
ClusterService clusterService,
Settings settings
) {
this.workflowStepFactory = workflowStepFactory;
this.threadPool = threadPool;
this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it);
}

/**
Expand All @@ -64,6 +77,20 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool
* @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list.
*/
public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId) {
if (workflow.nodes().size() > this.maxWorkflowSteps) {
throw new FlowFrameworkException(
"Workflow "
+ workflowId
+ " has "
+ workflow.nodes().size()
+ " nodes, which exceeds the maximum of "
+ this.maxWorkflowSteps
+ ". Change the setting ["
+ MAX_WORKFLOW_STEPS.getKey()
+ "] to increase this.",
RestStatus.BAD_REQUEST
);
}
List<WorkflowNode> sortedNodes = topologicalSort(workflow.nodes(), workflow.edges());

List<ProcessNode> nodes = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -62,7 +63,7 @@ public void setUp() throws Exception {

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
Expand All @@ -84,7 +85,7 @@ public void testPlugin() throws IOException {
assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size());
assertEquals(4, ffp.getActions().size());
assertEquals(1, ffp.getExecutorBuilders(settings).size());
assertEquals(4, ffp.getSettings().size());
assertEquals(5, ffp.getSettings().size());
}
}
}
Loading

0 comments on commit 090b16b

Please sign in to comment.