Skip to content

Commit

Permalink
add more unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 16, 2024
1 parent c361ed8 commit 9b683cd
Showing 1 changed file with 93 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.FINAL_ANSWER;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE;

import java.util.Arrays;
import java.util.HashMap;
Expand Down Expand Up @@ -445,6 +446,98 @@ public void testParseLLMOutput() {
}
}

@Test
public void testParseLLMOutput_MultipleFields() {
Set<String> tools = Set.of("VectorDBTool", "CatIndexTool");
String thought = "Let me run VectorDBTool to get more information";
String toolName = "vectordbtool";
ModelTensorOutput modelTensoOutput = ModelTensorOutput
.builder()
.mlModelOutputs(
List
.of(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name("response").dataAsMap(Map.of(THOUGHT, thought, ACTION, toolName)).build()
)
)
.build()
)
)
.build();
Map<String, String> output = AgentUtils.parseLLMOutput(modelTensoOutput, null, tools);
Assert.assertEquals(3, output.size());
Assert.assertEquals(thought, output.get(THOUGHT));
Assert.assertEquals("VectorDBTool", output.get(ACTION));
Set<String> expected = Set
.of(
"{\"action\":\"vectordbtool\",\"thought\":\"Let me run VectorDBTool to get more information\"}",
"{\"thought\":\"Let me run VectorDBTool to get more information\",\"action\":\"vectordbtool\"}"
);
Assert.assertTrue(expected.contains(output.get(THOUGHT_RESPONSE)));
}

@Test
public void testParseLLMOutput_MultipleFields_NoActionAndFinalAnswer() {
Set<String> tools = Set.of("VectorDBTool", "CatIndexTool");
String key1 = "dummy key1";
String value1 = "dummy value1";
String key2 = "dummy key2";
String value2 = "dummy value2";
ModelTensorOutput modelTensoOutput = ModelTensorOutput
.builder()
.mlModelOutputs(
List
.of(
ModelTensors
.builder()
.mlModelTensors(
List.of(ModelTensor.builder().name("response").dataAsMap(Map.of(key1, value1, key2, value2)).build())
)
.build()
)
)
.build();
Map<String, String> output = AgentUtils.parseLLMOutput(modelTensoOutput, null, tools);
Assert.assertEquals(2, output.size());
Assert.assertFalse(output.containsKey(THOUGHT));
Assert.assertFalse(output.containsKey(ACTION));
Set<String> expected = Set
.of(
"{\"dummy key1\":\"dummy value1\",\"dummy key2\":\"dummy value2\"}",
"{\"dummy key2\":\"dummy value2\",\"dummy key1\":\"dummy value1\"}"
);
Assert.assertTrue(expected.contains(output.get(THOUGHT_RESPONSE)));
Assert.assertEquals(output.get(THOUGHT_RESPONSE), output.get(FINAL_ANSWER));
}

@Test
public void testParseLLMOutput_OneFields_NoActionAndFinalAnswer() {
Set<String> tools = Set.of("VectorDBTool", "CatIndexTool");
String thought = "Let me run VectorDBTool to get more information";
ModelTensorOutput modelTensoOutput = ModelTensorOutput
.builder()
.mlModelOutputs(
List
.of(
ModelTensors
.builder()
.mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(Map.of(THOUGHT, thought)).build()))
.build()
)
)
.build();
Map<String, String> output = AgentUtils.parseLLMOutput(modelTensoOutput, null, tools);
Assert.assertEquals(3, output.size());
Assert.assertEquals(thought, output.get(THOUGHT));
Assert.assertFalse(output.containsKey(ACTION));
Assert.assertEquals("{\"thought\":\"Let me run VectorDBTool to get more information\"}", output.get(THOUGHT_RESPONSE));
Assert.assertEquals("{\"thought\":\"Let me run VectorDBTool to get more information\"}", output.get(FINAL_ANSWER));
}

@Test
public void testExtractThought_InvalidResult() {
String text = responseForActionInvalidJson;
Expand Down

0 comments on commit 9b683cd

Please sign in to comment.