Skip to content

Commit

Permalink
[Feature/agent_framework] Fetches modelID for RegisterAgent and Tools…
Browse files Browse the repository at this point in the history
… workflow steps (#235)

* Flattened llm field of register agent

Signed-off-by: Owais Kazi <[email protected]>

* Handled optional modelId

Signed-off-by: Owais Kazi <[email protected]>

* Handled modelId for llm

Signed-off-by: Owais Kazi <[email protected]>

* Parsing for parameters field of tools

Signed-off-by: Owais Kazi <[email protected]>

* Handled test case failures

Signed-off-by: Owais Kazi <[email protected]>

* Fixed spotless failure

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 authored and dbwiddis committed Dec 18, 2023
1 parent ffc0aa4 commit b7f638b
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 144 deletions.
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ dependencies {
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'org.bouncycastle:bcprov-jdk18on:1.77'
implementation "com.google.code.gson:gson:2.10.1"

// ZipArchive dependencies used for integration tests
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ private CommonValue() {}
public static final String WORKFLOW_STEP_NAME = "workflow_step_name";
/** The field name for the step ID where a resource is created */
public static final String WORKFLOW_STEP_ID = "workflow_step_id";
/** LLM Name for registering an agent */
public static final String LLM_FIELD = "llm";
/** The tools' field for an agent */
public static final String TOOLS_FIELD = "tools";
/** The memory field for an agent */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -25,10 +24,7 @@
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseLLM;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap;

/**
Expand Down Expand Up @@ -101,12 +97,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
}
xContentBuilder.endArray();
} else if (e.getValue() instanceof LLMSpec) {
if (LLM_FIELD.equals(e.getKey())) {
xContentBuilder.startObject();
buildLLMMap(xContentBuilder, (LLMSpec) e.getValue());
xContentBuilder.endObject();
}
}
}
xContentBuilder.endObject();
Expand Down Expand Up @@ -150,11 +140,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
userInputs.put(inputFieldName, parser.text());
break;
case START_OBJECT:
if (LLM_FIELD.equals(inputFieldName)) {
userInputs.put(inputFieldName, parseLLM(parser));
} else {
userInputs.put(inputFieldName, parseStringToStringMap(parser));
}
userInputs.put(inputFieldName, parseStringToStringMap(parser));
break;
case START_ARRAY:
if (PROCESSORS_FIELD.equals(inputFieldName)) {
Expand Down
67 changes: 0 additions & 67 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
*/
package org.opensearch.flowframework.util;

import com.google.gson.Gson;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
Expand All @@ -25,9 +24,6 @@
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -36,20 +32,13 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

/**
* Utility methods for Template parsing
*/
public class ParseUtils {
private static final Logger logger = LogManager.getLogger(ParseUtils.class);

public static final Gson gson;

static {
gson = new Gson();
}

private ParseUtils() {}

/**
Expand Down Expand Up @@ -117,37 +106,6 @@ public static Map<String, String> parseStringToStringMap(XContentParser parser)
return map;
}

// TODO Figure out a way to use the parse method of LLMSpec of ml-commons
/**
* Parses an XContent object representing the object of LLMSpec
* @param parser An XContent parser whose position is at the start of the map object to parse
* @return instance of {@link org.opensearch.ml.common.agent.LLMSpec}
* @throws IOException parsing error
*/
public static LLMSpec parseLLM(XContentParser parser) throws IOException {
String modelId = null;
Map<String, String> parameters = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MODEL_ID:
modelId = parser.text();
break;
case PARAMETERS_FIELD:
parameters = getParameterMap(parser.map());
break;
default:
parser.skipChildren();
break;
}
}
return LLMSpec.builder().modelId(modelId).parameters(parameters).build();
}

/**
* Parse content parser to {@link java.time.Instant}.
*
Expand Down Expand Up @@ -176,31 +134,6 @@ public static User getUserContext(Client client) {
return User.parse(userStr);
}

/**
* Generates a parameter map required when the parameter is nested within an object
* @param parameterObjs parameters
* @return a parameters map of type String,String
*/
public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs) {
Map<String, String> parameters = new HashMap<>();
for (String key : parameterObjs.keySet()) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String) value);
} else {
parameters.put(key, gson.toJson(value));
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;
}

/**
* Creates a XContentParser from a given Registry
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

Expand All @@ -37,8 +38,8 @@
import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.LAST_UPDATED_TIME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MEMORY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
Expand All @@ -56,6 +57,9 @@ public class RegisterAgentStep implements WorkflowStep {

static final String NAME = "register_agent";

private static final String LLM_MODEL_ID = "llm.model_id";
private static final String LLM_PARAMETERS = "llm.parameters";

private List<MLToolSpec> mlToolSpecList;

/**
Expand All @@ -80,7 +84,7 @@ public CompletableFuture<WorkflowData> execute(
ActionListener<MLRegisterAgentResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) {
logger.info("Remote Agent registration successful");
logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId());
registerAgentModelFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())),
Expand All @@ -100,7 +104,8 @@ public void onFailure(Exception e) {
String name = null;
String type = null;
String description = null;
LLMSpec llm = null;
String llmModelId = null;
Map<String, String> llmParameters = Collections.emptyMap();
List<MLToolSpec> tools = new ArrayList<>();
Map<String, String> parameters = Collections.emptyMap();
MLMemorySpec memory = null;
Expand Down Expand Up @@ -128,8 +133,11 @@ public void onFailure(Exception e) {
case TYPE:
type = (String) entry.getValue();
break;
case LLM_FIELD:
llm = getLLMSpec(entry.getValue());
case LLM_MODEL_ID:
llmModelId = (String) entry.getValue();
break;
case LLM_PARAMETERS:
llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS);
break;
case TOOLS_FIELD:
tools = addTools(entry.getValue());
Expand All @@ -155,15 +163,30 @@ public void onFailure(Exception e) {
}
}

if (Stream.of(name, type, llm, tools, parameters, memory, appType).allMatch(x -> x != null)) {
// Case when modelId is present in previous node inputs
if (llmModelId == null) {
llmModelId = getLlmModelId(previousNodeInputs, outputs);
}

// Case when modelId is not present at all
if (llmModelId == null) {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST)
);
return registerAgentModelFuture;
}

LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters);

if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) {
MLAgentBuilder builder = MLAgent.builder().name(name);

if (description != null) {
builder.description(description);
}

builder.type(type)
.llm(llm)
.llm(llmSpec)
.tools(tools)
.parameters(parameters)
.memory(memory)
Expand Down Expand Up @@ -195,11 +218,38 @@ private List<MLToolSpec> addTools(Object tools) {
return mlToolSpecList;
}

private LLMSpec getLLMSpec(Object llm) {
if (llm instanceof LLMSpec) {
return (LLMSpec) llm;
private String getLlmModelId(Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) {
// Case when modelId is already pass in the template
String llmModelId = null;

// Case when modelId is passed through previousSteps
Optional<String> previousNode = previousNodeInputs.entrySet()
.stream()
.filter(e -> MODEL_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

if (previousNode.isPresent()) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
llmModelId = previousNodeOutput.getContent().get(MODEL_ID).toString();
}
}
throw new IllegalArgumentException("[" + LLM_FIELD + "] must be of type LLMSpec.");
return llmModelId;
}

private LLMSpec getLLMSpec(String llmModelId, Map<String, String> llmParameters) {
if (llmModelId == null) {
throw new IllegalArgumentException("model id for llm is null");
}
LLMSpec.LLMSpecBuilder builder = LLMSpec.builder();
builder.modelId(llmModelId);
if (llmParameters != null) {
builder.parameters(llmParameters);
}

LLMSpec llmSpec = builder.build();
return llmSpec;
}

private MLMemorySpec getMLMemorySpec(Object mlMemory) {
Expand Down Expand Up @@ -231,11 +281,4 @@ private MLMemorySpec getMLMemorySpec(Object mlMemory) {

}

private Instant getInstant(Object instant, String fieldName) {
if (instant instanceof Instant) {
return (Instant) instant;
}
throw new IllegalArgumentException("[" + fieldName + "] must be of type Instant.");
}

}
37 changes: 31 additions & 6 deletions src/main/java/org/opensearch/flowframework/workflow/ToolStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap;

/**
* Step to register a tool for an agent
Expand Down Expand Up @@ -64,19 +65,19 @@ public CompletableFuture<WorkflowData> execute(
for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case TYPE:
type = (String) content.get(TYPE);
type = (String) entry.getValue();
break;
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
description = (String) entry.getValue();
break;
case PARAMETERS_FIELD:
parameters = getStringToStringMap(content.get(PARAMETERS_FIELD), PARAMETERS_FIELD);
parameters = getToolsParametersMap(entry.getValue(), previousNodeInputs, outputs);
break;
case INCLUDE_OUTPUT_IN_AGENT_RESPONSE:
includeOutputInAgentResponse = (Boolean) content.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE);
includeOutputInAgentResponse = (Boolean) entry.getValue();
break;
default:
break;
Expand Down Expand Up @@ -124,4 +125,28 @@ public CompletableFuture<WorkflowData> execute(
public String getName() {
return NAME;
}

private Map<String, String> getToolsParametersMap(
Object parameters,
Map<String, String> previousNodeInputs,
Map<String, WorkflowData> outputs
) {
Map<String, String> parametersMap = (Map<String, String>) parameters;
Optional<String> previousNode = previousNodeInputs.entrySet()
.stream()
.filter(e -> MODEL_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();
// Case when modelId is passed through previousSteps and not present already in parameters
if (previousNode.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString());
return parametersMap;
}
}
// For other cases where modelId is already present in the parameters or not return the parametersMap
return parametersMap;
}

}
1 change: 0 additions & 1 deletion src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
"inputs":[
"name",
"type",
"llm",
"tools",
"parameters",
"memory",
Expand Down
Loading

0 comments on commit b7f638b

Please sign in to comment.