Skip to content

Commit

Permalink
Handled test case failures
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Dec 1, 2023
1 parent e7465cc commit 0fc58b1
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ private CommonValue() {}
public static final String RESOURCE_ID_FIELD = "resource_id";
/** The field name for the ResourceCreated's resource name */
public static final String WORKFLOW_STEP_NAME = "workflow_step_name";
/** LLM Name for registering an agent */
public static final String LLM_FIELD = "llm";
/** The tools' field for an agent */
public static final String TOOLS_FIELD = "tools";
/** The memory field for an agent */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -25,8 +24,6 @@
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap;

Expand Down Expand Up @@ -100,12 +97,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
}
xContentBuilder.endArray();
} else if (e.getValue() instanceof LLMSpec) {
if (LLM_FIELD.equals(e.getKey())) {
xContentBuilder.startObject();
buildLLMMap(xContentBuilder, (LLMSpec) e.getValue());
xContentBuilder.endObject();
}
}
}
xContentBuilder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,26 @@

import java.io.IOException;
import java.time.Instant;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.*;
import static org.opensearch.flowframework.common.CommonValue.AGENT_ID;
import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.LAST_UPDATED_TIME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MEMORY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap;

/**
Expand Down Expand Up @@ -70,7 +84,7 @@ public CompletableFuture<WorkflowData> execute(
ActionListener<MLRegisterAgentResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) {
logger.info("Agent registration successful");
logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId());
registerAgentModelFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())),
Expand Down Expand Up @@ -149,13 +163,17 @@ public void onFailure(Exception e) {
}
}

// Case when modelId can is present in previous node inputs
// Case when modelId is present in previous node inputs
if (llmModelId == null) {
llmModelId = getLlmModelId(previousNodeInputs, outputs);
} else {
}

// Case when modelId is not present at all
if (llmModelId == null) {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST)
);
return registerAgentModelFuture;
}

LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters);
Expand Down Expand Up @@ -263,11 +281,4 @@ private MLMemorySpec getMLMemorySpec(Object mlMemory) {

}

private Instant getInstant(Object instant, String fieldName) {
if (instant instanceof Instant) {
return (Instant) instant;
}
throw new IllegalArgumentException("[" + fieldName + "] must be of type Instant.");
}

}
14 changes: 12 additions & 2 deletions src/main/java/org/opensearch/flowframework/workflow/ToolStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,21 @@
import org.opensearch.ml.common.agent.MLToolSpec;

import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.*;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;

/**
* Step to register a tool for an agent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
*/
package org.opensearch.flowframework.model;

import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class WorkflowNodeTests extends OpenSearchTestCase {
Expand All @@ -23,11 +21,6 @@ public void setUp() throws Exception {
}

public void testNode() throws IOException {
Map<String, String> parameters = new HashMap<>();
parameters.put("stop", "true");
parameters.put("max", "5");

LLMSpec llmSpec = new LLMSpec("modelId", parameters);

WorkflowNode nodeA = new WorkflowNode(
"A",
Expand All @@ -38,7 +31,6 @@ public void testNode() throws IOException {
Map.entry("bar", Map.of("key", "value")),
Map.entry("baz", new Map<?, ?>[] { Map.of("A", "a"), Map.of("B", "b") }),
Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }),
Map.entry("llm", llmSpec),
Map.entry("created_time", 1689793598499L)
)
);
Expand Down Expand Up @@ -71,9 +63,6 @@ public void testNode() throws IOException {
assertTrue(json.contains("\"bar\":{\"key\":\"value\"}"));
assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]"));
assertTrue(json.contains("\"created_time\":1689793598499"));
assertTrue(json.contains("llm\":{"));
assertTrue(json.contains("\"parameters\":{\"stop\":\"true\",\"max\":\"5\""));
assertTrue(json.contains("\"model_id\":\"modelId\""));

WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json));
assertEquals("A", nodeX.id());
Expand All @@ -88,9 +77,6 @@ public void testNode() throws IOException {
assertEquals(1, ppX.length);
assertEquals("test-type", ppX[0].type());
assertEquals(Map.of("key2", "value2"), ppX[0].params());
LLMSpec llm = (LLMSpec) mapX.get("llm");
assertEquals("modelId", llm.getModelId());
assertEquals(parameters, llm.getParameters());
}

public void testExceptions() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public void setUp() throws Exception {
Map.entry("name", "test"),
Map.entry("description", "description"),
Map.entry("type", "type"),
Map.entry("llm", llmSpec),
Map.entry("llm.model_id", "xyz"),
Map.entry("llm.parameters", Collections.emptyMap()),
Map.entry("tools", tools),
Map.entry("parameters", Collections.emptyMap()),
Map.entry("memory", mlMemorySpec),
Expand Down

0 comments on commit 0fc58b1

Please sign in to comment.