Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/agent_framework] Fetches modelID for RegisterAgent and Tools workflow steps #235

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ dependencies {
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'org.bouncycastle:bcprov-jdk18on:1.77'
implementation "com.google.code.gson:gson:2.10.1"

configurations.all {
resolutionStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ private CommonValue() {}
public static final String WORKFLOW_STEP_NAME = "workflow_step_name";
/** The field name for the step ID where a resource is created */
public static final String WORKFLOW_STEP_ID = "workflow_step_id";
/** LLM Name for registering an agent */
public static final String LLM_FIELD = "llm";
/** The tools' field for an agent */
public static final String TOOLS_FIELD = "tools";
/** The memory field for an agent */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -25,10 +24,7 @@
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseLLM;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap;

/**
Expand Down Expand Up @@ -101,12 +97,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
}
xContentBuilder.endArray();
} else if (e.getValue() instanceof LLMSpec) {
if (LLM_FIELD.equals(e.getKey())) {
xContentBuilder.startObject();
buildLLMMap(xContentBuilder, (LLMSpec) e.getValue());
xContentBuilder.endObject();
}
}
}
xContentBuilder.endObject();
Expand Down Expand Up @@ -150,11 +140,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
userInputs.put(inputFieldName, parser.text());
break;
case START_OBJECT:
if (LLM_FIELD.equals(inputFieldName)) {
userInputs.put(inputFieldName, parseLLM(parser));
} else {
userInputs.put(inputFieldName, parseStringToStringMap(parser));
}
userInputs.put(inputFieldName, parseStringToStringMap(parser));
break;
case START_ARRAY:
if (PROCESSORS_FIELD.equals(inputFieldName)) {
Expand Down
67 changes: 0 additions & 67 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
*/
package org.opensearch.flowframework.util;

import com.google.gson.Gson;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
Expand All @@ -25,9 +24,6 @@
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -36,20 +32,13 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

/**
* Utility methods for Template parsing
*/
public class ParseUtils {
private static final Logger logger = LogManager.getLogger(ParseUtils.class);

public static final Gson gson;

static {
gson = new Gson();
}

private ParseUtils() {}

/**
Expand Down Expand Up @@ -117,37 +106,6 @@ public static Map<String, String> parseStringToStringMap(XContentParser parser)
return map;
}

// TODO Figure out a way to use the parse method of LLMSpec of ml-commons
/**
* Parses an XContent object representing the object of LLMSpec
* @param parser An XContent parser whose position is at the start of the map object to parse
* @return instance of {@link org.opensearch.ml.common.agent.LLMSpec}
* @throws IOException parsing error
*/
public static LLMSpec parseLLM(XContentParser parser) throws IOException {
String modelId = null;
Map<String, String> parameters = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MODEL_ID:
modelId = parser.text();
break;
case PARAMETERS_FIELD:
parameters = getParameterMap(parser.map());
break;
default:
parser.skipChildren();
break;
}
}
return LLMSpec.builder().modelId(modelId).parameters(parameters).build();
}

/**
* Parse content parser to {@link java.time.Instant}.
*
Expand Down Expand Up @@ -176,31 +134,6 @@ public static User getUserContext(Client client) {
return User.parse(userStr);
}

/**
* Generates a parameter map required when the parameter is nested within an object
* @param parameterObjs parameters
* @return a parameters map of type String,String
*/
public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs) {
Map<String, String> parameters = new HashMap<>();
for (String key : parameterObjs.keySet()) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String) value);
} else {
parameters.put(key, gson.toJson(value));
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;
}

/**
* Creates a XContentParser from a given Registry
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

Expand All @@ -37,8 +38,8 @@
import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.LAST_UPDATED_TIME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MEMORY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
Expand All @@ -56,6 +57,9 @@ public class RegisterAgentStep implements WorkflowStep {

static final String NAME = "register_agent";

private static final String LLM_MODEL_ID = "llm.model_id";
private static final String LLM_PARAMETERS = "llm.parameters";

private List<MLToolSpec> mlToolSpecList;

/**
Expand All @@ -80,7 +84,7 @@ public CompletableFuture<WorkflowData> execute(
ActionListener<MLRegisterAgentResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) {
logger.info("Remote Agent registration successful");
logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId());
registerAgentModelFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())),
Expand All @@ -100,7 +104,8 @@ public void onFailure(Exception e) {
String name = null;
String type = null;
String description = null;
LLMSpec llm = null;
String llmModelId = null;
Map<String, String> llmParameters = Collections.emptyMap();
List<MLToolSpec> tools = new ArrayList<>();
Map<String, String> parameters = Collections.emptyMap();
MLMemorySpec memory = null;
Expand Down Expand Up @@ -128,8 +133,11 @@ public void onFailure(Exception e) {
case TYPE:
type = (String) entry.getValue();
break;
case LLM_FIELD:
llm = getLLMSpec(entry.getValue());
case LLM_MODEL_ID:
llmModelId = (String) entry.getValue();
break;
case LLM_PARAMETERS:
llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS);
break;
case TOOLS_FIELD:
tools = addTools(entry.getValue());
Expand All @@ -155,15 +163,30 @@ public void onFailure(Exception e) {
}
}

if (Stream.of(name, type, llm, tools, parameters, memory, appType).allMatch(x -> x != null)) {
// Case when modelId is present in previous node inputs
if (llmModelId == null) {
llmModelId = getLlmModelId(previousNodeInputs, outputs);
}

// Case when modelId is not present at all
if (llmModelId == null) {
registerAgentModelFuture.completeExceptionally(
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST)
);
return registerAgentModelFuture;
}

LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters);

if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) {
MLAgentBuilder builder = MLAgent.builder().name(name);

if (description != null) {
builder.description(description);
}

builder.type(type)
.llm(llm)
.llm(llmSpec)
.tools(tools)
.parameters(parameters)
.memory(memory)
Expand Down Expand Up @@ -195,11 +218,38 @@ private List<MLToolSpec> addTools(Object tools) {
return mlToolSpecList;
}

private LLMSpec getLLMSpec(Object llm) {
if (llm instanceof LLMSpec) {
return (LLMSpec) llm;
private String getLlmModelId(Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) {
// Case when modelId is already pass in the template
String llmModelId = null;

// Case when modelId is passed through previousSteps
Optional<String> previousNode = previousNodeInputs.entrySet()
.stream()
.filter(e -> MODEL_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

if (previousNode.isPresent()) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
llmModelId = previousNodeOutput.getContent().get(MODEL_ID).toString();
}
}
throw new IllegalArgumentException("[" + LLM_FIELD + "] must be of type LLMSpec.");
return llmModelId;
}

private LLMSpec getLLMSpec(String llmModelId, Map<String, String> llmParameters) {
if (llmModelId == null) {
throw new IllegalArgumentException("model id for llm is null");
}
LLMSpec.LLMSpecBuilder builder = LLMSpec.builder();
builder.modelId(llmModelId);
if (llmParameters != null) {
builder.parameters(llmParameters);
}

LLMSpec llmSpec = builder.build();
return llmSpec;
}

private MLMemorySpec getMLMemorySpec(Object mlMemory) {
Expand Down Expand Up @@ -231,11 +281,4 @@ private MLMemorySpec getMLMemorySpec(Object mlMemory) {

}

private Instant getInstant(Object instant, String fieldName) {
if (instant instanceof Instant) {
return (Instant) instant;
}
throw new IllegalArgumentException("[" + fieldName + "] must be of type Instant.");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap;

/**
* Step to register a tool for an agent
Expand Down Expand Up @@ -64,19 +65,19 @@ public CompletableFuture<WorkflowData> execute(
for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case TYPE:
type = (String) content.get(TYPE);
type = (String) entry.getValue();
break;
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
description = (String) entry.getValue();
break;
case PARAMETERS_FIELD:
parameters = getStringToStringMap(content.get(PARAMETERS_FIELD), PARAMETERS_FIELD);
parameters = getToolsParametersMap(entry.getValue(), previousNodeInputs, outputs);
break;
case INCLUDE_OUTPUT_IN_AGENT_RESPONSE:
includeOutputInAgentResponse = (Boolean) content.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE);
includeOutputInAgentResponse = (Boolean) entry.getValue();
break;
default:
break;
Expand Down Expand Up @@ -124,4 +125,28 @@ public CompletableFuture<WorkflowData> execute(
public String getName() {
return NAME;
}

private Map<String, String> getToolsParametersMap(
Object parameters,
Map<String, String> previousNodeInputs,
Map<String, WorkflowData> outputs
) {
Map<String, String> parametersMap = (Map<String, String>) parameters;
Optional<String> previousNode = previousNodeInputs.entrySet()
.stream()
.filter(e -> MODEL_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();
// Case when modelId is passed through previousSteps and not present already in parameters
if (previousNode.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString());
return parametersMap;
}
}
// For other cases where modelId is already present in the parameters or not return the parametersMap
return parametersMap;
}

}
1 change: 0 additions & 1 deletion src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
"inputs":[
"name",
"type",
"llm",
"tools",
"parameters",
"memory",
Expand Down
Loading
Loading