From b7f638b8c6697199e0d4a23ca3c9e147f12a9baa Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 1 Dec 2023 16:53:47 -0800 Subject: [PATCH] [Feature/agent_framework] Fetches modelID for RegisterAgent and Tools workflow steps (#235) * Flattened llm field of register agent Signed-off-by: Owais Kazi * Handled optional modelId Signed-off-by: Owais Kazi * Handled modelId for llm Signed-off-by: Owais Kazi * Parsing for parameters field of tools Signed-off-by: Owais Kazi * Handled test case failures Signed-off-by: Owais Kazi * Fixed spotless failure Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi --- build.gradle | 1 - .../flowframework/common/CommonValue.java | 2 - .../flowframework/model/WorkflowNode.java | 16 +--- .../flowframework/util/ParseUtils.java | 67 ---------------- .../workflow/RegisterAgentStep.java | 79 ++++++++++++++----- .../flowframework/workflow/ToolStep.java | 37 +++++++-- .../resources/mappings/workflow-steps.json | 1 - .../model/WorkflowNodeTests.java | 14 ---- .../flowframework/util/ParseUtilsTests.java | 19 ----- .../workflow/RegisterAgentTests.java | 3 +- 10 files changed, 95 insertions(+), 144 deletions(-) diff --git a/build.gradle b/build.gradle index d6e7afbd8..20a9837ca 100644 --- a/build.gradle +++ b/build.gradle @@ -155,7 +155,6 @@ dependencies { implementation "org.opensearch:common-utils:${common_utils_version}" implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'org.bouncycastle:bcprov-jdk18on:1.77' - implementation "com.google.code.gson:gson:2.10.1" // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 774660bfd..2343cd305 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -166,8 +166,6 @@ private CommonValue() {} public static final String WORKFLOW_STEP_NAME = "workflow_step_name"; /** The field name for the step ID where a resource is created */ public static final String WORKFLOW_STEP_ID = "workflow_step_id"; - /** LLM Name for registering an agent */ - public static final String LLM_FIELD = "llm"; /** The tools' field for an agent */ public static final String TOOLS_FIELD = "tools"; /** The memory field for an agent */ diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 999ba460f..b942ccb16 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -14,7 +14,6 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; -import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; import java.util.ArrayList; @@ -25,10 +24,7 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; -import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; -import static org.opensearch.flowframework.util.ParseUtils.parseLLM; import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** @@ -101,12 +97,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } xContentBuilder.endArray(); - } else if (e.getValue() instanceof LLMSpec) { - if (LLM_FIELD.equals(e.getKey())) { - xContentBuilder.startObject(); - buildLLMMap(xContentBuilder, (LLMSpec) e.getValue()); - xContentBuilder.endObject(); - } } } xContentBuilder.endObject(); @@ -150,11 +140,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - if (LLM_FIELD.equals(inputFieldName)) { - userInputs.put(inputFieldName, parseLLM(parser)); - } else { - userInputs.put(inputFieldName, parseStringToStringMap(parser)); - } + userInputs.put(inputFieldName, parseStringToStringMap(parser)); break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 14d113e34..e22017eaf 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -8,7 +8,6 @@ */ package org.opensearch.flowframework.util; -import com.google.gson.Gson; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -25,9 +24,6 @@ import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.time.Instant; import java.util.HashMap; import java.util.Map; @@ -36,7 +32,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; /** * Utility methods for Template parsing @@ -44,12 +39,6 @@ public class ParseUtils { private static final Logger logger = LogManager.getLogger(ParseUtils.class); - public static final Gson gson; - - static { - gson = new Gson(); - } - private ParseUtils() {} /** @@ -117,37 +106,6 @@ public static Map parseStringToStringMap(XContentParser parser) return map; } - // TODO Figure out a way to use the parse method of LLMSpec of ml-commons - /** - * Parses an XContent object representing the object of LLMSpec - * @param parser An XContent parser whose position is at the start of the map object to parse - * @return instance of {@link org.opensearch.ml.common.agent.LLMSpec} - * @throws IOException parsing error - */ - public static LLMSpec parseLLM(XContentParser parser) throws IOException { - String modelId = null; - Map parameters = null; - - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - - switch (fieldName) { - case MODEL_ID: - modelId = parser.text(); - break; - case PARAMETERS_FIELD: - parameters = getParameterMap(parser.map()); - break; - default: - parser.skipChildren(); - break; - } - } - return LLMSpec.builder().modelId(modelId).parameters(parameters).build(); - } - /** * Parse content parser to {@link java.time.Instant}. * @@ -176,31 +134,6 @@ public static User getUserContext(Client client) { return User.parse(userStr); } - /** - * Generates a parameter map required when the parameter is nested within an object - * @param parameterObjs parameters - * @return a parameters map of type String,String - */ - public static Map getParameterMap(Map parameterObjs) { - Map parameters = new HashMap<>(); - for (String key : parameterObjs.keySet()) { - Object value = parameterObjs.get(key); - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - if (value instanceof String) { - parameters.put(key, (String) value); - } else { - parameters.put(key, gson.toJson(value)); - } - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } - } - return parameters; - } - /** * Creates a XContentParser from a given Registry * diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 44270d8e6..022d46c22 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; @@ -37,8 +38,8 @@ import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.LAST_UPDATED_TIME_FIELD; -import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; import static org.opensearch.flowframework.common.CommonValue.MEMORY_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; @@ -56,6 +57,9 @@ public class RegisterAgentStep implements WorkflowStep { static final String NAME = "register_agent"; + private static final String LLM_MODEL_ID = "llm.model_id"; + private static final String LLM_PARAMETERS = "llm.parameters"; + private List mlToolSpecList; /** @@ -80,7 +84,7 @@ public CompletableFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { - logger.info("Remote Agent registration successful"); + logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); registerAgentModelFuture.complete( new WorkflowData( Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())), @@ -100,7 +104,8 @@ public void onFailure(Exception e) { String name = null; String type = null; String description = null; - LLMSpec llm = null; + String llmModelId = null; + Map llmParameters = Collections.emptyMap(); List tools = new ArrayList<>(); Map parameters = Collections.emptyMap(); MLMemorySpec memory = null; @@ -128,8 +133,11 @@ public void onFailure(Exception e) { case TYPE: type = (String) entry.getValue(); break; - case LLM_FIELD: - llm = getLLMSpec(entry.getValue()); + case LLM_MODEL_ID: + llmModelId = (String) entry.getValue(); + break; + case LLM_PARAMETERS: + llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS); break; case TOOLS_FIELD: tools = addTools(entry.getValue()); @@ -155,7 +163,22 @@ public void onFailure(Exception e) { } } - if (Stream.of(name, type, llm, tools, parameters, memory, appType).allMatch(x -> x != null)) { + // Case when modelId is present in previous node inputs + if (llmModelId == null) { + llmModelId = getLlmModelId(previousNodeInputs, outputs); + } + + // Case when modelId is not present at all + if (llmModelId == null) { + registerAgentModelFuture.completeExceptionally( + new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) + ); + return registerAgentModelFuture; + } + + LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); + + if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) { MLAgentBuilder builder = MLAgent.builder().name(name); if (description != null) { @@ -163,7 +186,7 @@ public void onFailure(Exception e) { } builder.type(type) - .llm(llm) + .llm(llmSpec) .tools(tools) .parameters(parameters) .memory(memory) @@ -195,11 +218,38 @@ private List addTools(Object tools) { return mlToolSpecList; } - private LLMSpec getLLMSpec(Object llm) { - if (llm instanceof LLMSpec) { - return (LLMSpec) llm; + private String getLlmModelId(Map previousNodeInputs, Map outputs) { + // Case when modelId is already pass in the template + String llmModelId = null; + + // Case when modelId is passed through previousSteps + Optional previousNode = previousNodeInputs.entrySet() + .stream() + .filter(e -> MODEL_ID.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + + if (previousNode.isPresent()) { + WorkflowData previousNodeOutput = outputs.get(previousNode.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) { + llmModelId = previousNodeOutput.getContent().get(MODEL_ID).toString(); + } } - throw new IllegalArgumentException("[" + LLM_FIELD + "] must be of type LLMSpec."); + return llmModelId; + } + + private LLMSpec getLLMSpec(String llmModelId, Map llmParameters) { + if (llmModelId == null) { + throw new IllegalArgumentException("model id for llm is null"); + } + LLMSpec.LLMSpecBuilder builder = LLMSpec.builder(); + builder.modelId(llmModelId); + if (llmParameters != null) { + builder.parameters(llmParameters); + } + + LLMSpec llmSpec = builder.build(); + return llmSpec; } private MLMemorySpec getMLMemorySpec(Object mlMemory) { @@ -231,11 +281,4 @@ private MLMemorySpec getMLMemorySpec(Object mlMemory) { } - private Instant getInstant(Object instant, String fieldName) { - if (instant instanceof Instant) { - return (Instant) instant; - } - throw new IllegalArgumentException("[" + fieldName + "] must be of type Instant."); - } - } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 339142139..af8556289 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -20,15 +20,16 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; -import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; /** * Step to register a tool for an agent @@ -64,19 +65,19 @@ public CompletableFuture execute( for (Entry entry : content.entrySet()) { switch (entry.getKey()) { case TYPE: - type = (String) content.get(TYPE); + type = (String) entry.getValue(); break; case NAME_FIELD: - name = (String) content.get(NAME_FIELD); + name = (String) entry.getValue(); break; case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); + description = (String) entry.getValue(); break; case PARAMETERS_FIELD: - parameters = getStringToStringMap(content.get(PARAMETERS_FIELD), PARAMETERS_FIELD); + parameters = getToolsParametersMap(entry.getValue(), previousNodeInputs, outputs); break; case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: - includeOutputInAgentResponse = (Boolean) content.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE); + includeOutputInAgentResponse = (Boolean) entry.getValue(); break; default: break; @@ -124,4 +125,28 @@ public CompletableFuture execute( public String getName() { return NAME; } + + private Map getToolsParametersMap( + Object parameters, + Map previousNodeInputs, + Map outputs + ) { + Map parametersMap = (Map) parameters; + Optional previousNode = previousNodeInputs.entrySet() + .stream() + .filter(e -> MODEL_ID.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + // Case when modelId is passed through previousSteps and not present already in parameters + if (previousNode.isPresent() && !parametersMap.containsKey(MODEL_ID)) { + WorkflowData previousNodeOutput = outputs.get(previousNode.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) { + parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString()); + return parametersMap; + } + } + // For other cases where modelId is already present in the parameters or not return the parametersMap + return parametersMap; + } + } diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index eb92ccd5e..b5d09e8cb 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -96,7 +96,6 @@ "inputs":[ "name", "type", - "llm", "tools", "parameters", "memory", diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java index c0011f7ae..b9620c214 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -8,11 +8,9 @@ */ package org.opensearch.flowframework.model; -import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.HashMap; import java.util.Map; public class WorkflowNodeTests extends OpenSearchTestCase { @@ -23,11 +21,6 @@ public void setUp() throws Exception { } public void testNode() throws IOException { - Map parameters = new HashMap<>(); - parameters.put("stop", "true"); - parameters.put("max", "5"); - - LLMSpec llmSpec = new LLMSpec("modelId", parameters); WorkflowNode nodeA = new WorkflowNode( "A", @@ -38,7 +31,6 @@ public void testNode() throws IOException { Map.entry("bar", Map.of("key", "value")), Map.entry("baz", new Map[] { Map.of("A", "a"), Map.of("B", "b") }), Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }), - Map.entry("llm", llmSpec), Map.entry("created_time", 1689793598499L) ) ); @@ -71,9 +63,6 @@ public void testNode() throws IOException { assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]")); assertTrue(json.contains("\"created_time\":1689793598499")); - assertTrue(json.contains("llm\":{")); - assertTrue(json.contains("\"parameters\":{\"stop\":\"true\",\"max\":\"5\"")); - assertTrue(json.contains("\"model_id\":\"modelId\"")); WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); assertEquals("A", nodeX.id()); @@ -88,9 +77,6 @@ public void testNode() throws IOException { assertEquals(1, ppX.length); assertEquals("test-type", ppX[0].type()); assertEquals(Map.of("key2", "value2"), ppX[0].params()); - LLMSpec llm = (LLMSpec) mapX.get("llm"); - assertEquals("modelId", llm.getModelId()); - assertEquals(parameters, llm.getParameters()); } public void testExceptions() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 76334b52b..a5c4253b3 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -12,12 +12,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; -import org.junit.Assert; import java.io.IOException; import java.time.Instant; -import java.util.HashMap; -import java.util.Map; public class ParseUtilsTests extends OpenSearchTestCase { public void testToInstant() throws IOException { @@ -57,20 +54,4 @@ public void testToInstantWithNotValue() throws IOException { Instant instant = ParseUtils.parseInstant(parser); assertNull(instant); } - - public void testGetParameterMap() { - Map parameters = new HashMap<>(); - parameters.put("key1", "value1"); - parameters.put("key2", 2); - parameters.put("key3", 2.1); - parameters.put("key4", new int[] { 10, 20 }); - parameters.put("key5", new Object[] { 1.01, "abc" }); - Map parameterMap = ParseUtils.getParameterMap(parameters); - Assert.assertEquals(5, parameterMap.size()); - Assert.assertEquals("value1", parameterMap.get("key1")); - Assert.assertEquals("2", parameterMap.get("key2")); - Assert.assertEquals("2.1", parameterMap.get("key3")); - Assert.assertEquals("[10,20]", parameterMap.get("key4")); - Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5")); - } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index 0f4b33471..c393be5e4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -60,7 +60,8 @@ public void setUp() throws Exception { Map.entry("name", "test"), Map.entry("description", "description"), Map.entry("type", "type"), - Map.entry("llm", llmSpec), + Map.entry("llm.model_id", "xyz"), + Map.entry("llm.parameters", Collections.emptyMap()), Map.entry("tools", tools), Map.entry("parameters", Collections.emptyMap()), Map.entry("memory", mlMemorySpec),