From 3fc5b302b994923186ed19dd41e80a73d0c9dc55 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 13 Feb 2024 10:31:54 -0800 Subject: [PATCH] fine tune prompt;refactor conversational agent code Signed-off-by: Yaliang Wu --- .../engine/algorithms/agent/AgentUtils.java | 131 ++- .../algorithms/agent/MLChatAgentRunner.java | 948 +++++++++--------- .../MLConversationalFlowAgentRunner.java | 109 +- .../algorithms/agent/MLFlowAgentRunner.java | 3 +- .../algorithms/agent/PromptTemplate.java | 9 +- .../memory/ConversationIndexMessage.java | 2 +- .../ml/engine/memory/MLMemoryManager.java | 2 +- .../ml/engine/tools/CatIndexTool.java | 33 +- .../algorithms/agent/AgentUtilsTest.java | 49 +- .../agent/MLChatAgentRunnerTest.java | 18 +- .../memory/ConversationIndexMessageTest.java | 2 +- .../ml/engine/tools/CatIndexToolTests.java | 7 +- 12 files changed, 755 insertions(+), 558 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 5268f4a559..9ae74b8e31 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -11,8 +11,6 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES; -import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX; -import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS; @@ -20,6 +18,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,6 +27,7 @@ import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -35,6 +35,15 @@ public class AgentUtils { + public static final String SELECTED_TOOLS = "selected_tools"; + public static final String PROMPT_PREFIX = "prompt.prefix"; + public static final String PROMPT_SUFFIX = "prompt.suffix"; + public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction"; + public static final String TOOL_RESPONSE = "prompt.tool_response"; + public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix"; + public static final String DISABLE_TRACE = "disable_trace"; + public static final String VERBOSE = "verbose"; + public static String addExamplesToPrompt(Map parameters, String prompt) { Map examplesMap = new HashMap<>(); if (parameters.containsKey(EXAMPLES)) { @@ -150,17 +159,46 @@ public static String addContextToPrompt(Map parameters, String p return prompt; } + public static List MODEL_RESPONSE_PATTERNS = List + .of( + "\\{\\s*\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?\\}", + "\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"action\"\\s*:\\s*\".*?\"\\s*,\\s*\"action_input\"\\s*:\\s*\".*?\"\\s*}", + "\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"final_answer\"\\s*:\\s*\".*?\"\\s*}" + ); + public static String extractModelResponseJson(String text) { - Pattern pattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```"); - Matcher matcher = pattern.matcher(text); + return extractModelResponseJson(text, null); + } - if (matcher.find()) { - return matcher.group(1); + public static String extractModelResponseJson(String text, List llmResponsePatterns) { + Pattern pattern1 = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```"); + Matcher matcher1 = pattern1.matcher(text); + + if (matcher1.find()) { + return matcher1.group(1); } else { + String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS); + if (matchedPart == null && llmResponsePatterns != null) { + matchedPart = findMatchedPart(text, llmResponsePatterns); + } + if (matchedPart != null) { + return matchedPart; + } throw new IllegalArgumentException("Model output is invalid"); } } + public static String findMatchedPart(String text, List patternList) { + for (String p : patternList) { + Pattern pattern = Pattern.compile(p); + Matcher matcher = pattern.matcher(text); + if (matcher.find()) { + return matcher.group(); + } + } + return null; + } + public static String outputToOutputString(Object output) throws PrivilegedActionException { String outputString; if (output instanceof ModelTensorOutput) { @@ -179,16 +217,6 @@ public static String outputToOutputString(Object output) throws PrivilegedAction return outputString; } - public static String parseInputFromLLMReturn(Map retMap) { - Object actionInput = retMap.get("action_input"); - if (actionInput instanceof Map) { - return gson.toJson(actionInput); - } else { - return String.valueOf(actionInput); - } - - } - public static int getMessageHistoryLimit(Map params) { String messageHistoryLimitStr = params.get(MESSAGE_HISTORY_LIMIT); return messageHistoryLimitStr != null ? Integer.parseInt(messageHistoryLimitStr) : LAST_N_INTERACTIONS; @@ -197,4 +225,75 @@ public static int getMessageHistoryLimit(Map params) { public static String getToolName(MLToolSpec toolSpec) { return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType(); } + + public static List getMlToolSpecs(MLAgent mlAgent, Map params) { + String selectedToolsStr = params.get(SELECTED_TOOLS); + List toolSpecs = mlAgent.getTools(); + if (selectedToolsStr != null) { + List selectedTools = gson.fromJson(selectedToolsStr, List.class); + Map toolNameSpecMap = new HashMap<>(); + for (MLToolSpec toolSpec : toolSpecs) { + toolNameSpecMap.put(getToolName(toolSpec), toolSpec); + } + List selectedToolSpecs = new ArrayList<>(); + for (String tool : selectedTools) { + if (toolNameSpecMap.containsKey(tool)) { + selectedToolSpecs.add(toolNameSpecMap.get(tool)); + } + } + toolSpecs = selectedToolSpecs; + } + return toolSpecs; + } + + public static void createTools( + Map toolFactories, + Map params, + List toolSpecs, + Map tools, + Map toolSpecMap + ) { + for (MLToolSpec toolSpec : toolSpecs) { + Tool tool = createTool(toolFactories, params, toolSpec); + tools.put(tool.getName(), tool); + toolSpecMap.put(tool.getName(), toolSpec); + } + } + + public static Tool createTool(Map toolFactories, Map params, MLToolSpec toolSpec) { + if (!toolFactories.containsKey(toolSpec.getType())) { + throw new IllegalArgumentException("Tool not found: " + toolSpec.getType()); + } + Map executeParams = new HashMap<>(); + if (toolSpec.getParameters() != null) { + executeParams.putAll(toolSpec.getParameters()); + } + for (String key : params.keySet()) { + String toolNamePrefix = getToolName(toolSpec) + "."; + if (key.startsWith(toolNamePrefix)) { + executeParams.put(key.replace(toolNamePrefix, ""), params.get(key)); + } + } + Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams); + String toolName = getToolName(toolSpec); + tool.setName(toolName); + + if (toolSpec.getDescription() != null) { + tool.setDescription(toolSpec.getDescription()); + } + if (params.containsKey(toolName + ".description")) { + tool.setDescription(params.get(toolName + ".description")); + } + + return tool; + } + + public static List getToolNames(Map tools) { + final List inputTools = new ArrayList<>(); + for (Map.Entry entry : tools.entrySet()) { + String toolName = entry.getValue().getName(); + inputTools.add(toolName); + } + return inputTools; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index f944aa90e9..9479d8d958 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -7,37 +7,43 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.ml.common.utils.StringUtils.toJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolNames; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; -import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; +import java.security.PrivilegedActionException; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; import org.opensearch.action.StepListener; -import org.opensearch.action.support.GroupedActionListener; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; @@ -56,10 +62,10 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; -import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.ml.repackage.com.google.common.collect.Lists; @@ -73,10 +79,8 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String SESSION_ID = "session_id"; - public static final String PROMPT_PREFIX = "prompt_prefix"; public static final String LLM_TOOL_PROMPT_PREFIX = "LanguageModelTool.prompt_prefix"; public static final String LLM_TOOL_PROMPT_SUFFIX = "LanguageModelTool.prompt_suffix"; - public static final String PROMPT_SUFFIX = "prompt_suffix"; public static final String TOOLS = "tools"; public static final String TOOL_DESCRIPTIONS = "tool_descriptions"; public static final String TOOL_NAMES = "tool_names"; @@ -87,6 +91,12 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String CONTEXT = "context"; public static final String PROMPT = "prompt"; public static final String LLM_RESPONSE = "llm_response"; + public static final String MAX_ITERATION = "max_iteration"; + public static final String THOUGHT = "thought"; + public static final String ACTION = "action"; + public static final String ACTION_INPUT = "action_input"; + public static final String FINAL_ANSWER = "final_answer"; + public static final String THOUGHT_RESPONSE = "thought_response"; private Client client; private Settings settings; @@ -121,6 +131,7 @@ public void run(MLAgent mlAgent, Map params, ActionListenerwrap(memory -> { + // TODO: call runAgent directly if messageHistoryLimit == 0 memory.getMessages(ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); for (Interaction next : r) { @@ -144,7 +155,8 @@ public void run(MLAgent mlAgent, Map params, ActionListener 0) { - chatHistoryBuilder.append("Below is Chat History between Human and AI which sorted by time with asc order:\n"); + String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX); + chatHistoryBuilder.append(chatHistoryPrefix); for (Message message : messageList) { chatHistoryBuilder.append(message.toString()).append("\n"); } @@ -160,34 +172,10 @@ public void run(MLAgent mlAgent, Map params, ActionListener params, ActionListener listener, Memory memory, String sessionId) { - List toolSpecs = mlAgent.getTools(); + List toolSpecs = getMlToolSpecs(mlAgent, params); Map tools = new HashMap<>(); Map toolSpecMap = new HashMap<>(); - for (MLToolSpec toolSpec : toolSpecs) { - Map toolParams = new HashMap<>(); - Map executeParams = new HashMap<>(); - if (toolSpec.getParameters() != null) { - toolParams.putAll(toolSpec.getParameters()); - executeParams.putAll(toolSpec.getParameters()); - } - for (String key : params.keySet()) { - if (key.startsWith(toolSpec.getType() + ".")) { - executeParams.put(key.replace(toolSpec.getType() + ".", ""), params.get(key)); - } - } - log.info("Fetching tool for type: " + toolSpec.getType()); - Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams); - if (toolSpec.getName() != null) { - tool.setName(toolSpec.getName()); - } - - if (toolSpec.getDescription() != null) { - tool.setDescription(toolSpec.getDescription()); - } - String toolName = Optional.ofNullable(tool.getName()).orElse(toolSpec.getType()); - tools.put(toolName, tool); - toolSpecMap.put(toolName, toolSpec); - } + createTools(toolFactories, params, toolSpecs, tools, toolSpecMap); runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, listener); } @@ -201,97 +189,23 @@ private void runReAct( String sessionId, ActionListener listener ) { + final List inputTools = getToolNames(tools); String question = parameters.get(MLAgentExecutor.QUESTION); String parentInteractionId = parameters.get(MLAgentExecutor.PARENT_INTERACTION_ID); - boolean verbose = parameters.containsKey("verbose") && Boolean.parseBoolean(parameters.get("verbose")); - Map tmpParameters = new HashMap<>(); - if (llm.getParameters() != null) { - tmpParameters.putAll(llm.getParameters()); - } - tmpParameters.putAll(parameters); - if (!tmpParameters.containsKey("stop")) { - tmpParameters.put("stop", gson.toJson(new String[] { "\nObservation:", "\n\tObservation:" })); - } - if (!tmpParameters.containsKey("stop_sequences")) { - tmpParameters - .put( - "stop_sequences", - gson - .toJson( - new String[] { - "\n\nHuman:", - "\nObservation:", - "\n\tObservation:", - "\nObservation", - "\n\tObservation", - "\n\nQuestion" } - ) - ); - } - - String prompt = parameters.get(PROMPT); - if (prompt == null) { - prompt = PromptTemplate.PROMPT_TEMPLATE; - } - String promptPrefix = parameters.getOrDefault("prompt.prefix", PromptTemplate.PROMPT_TEMPLATE_PREFIX); - tmpParameters.put("prompt.prefix", promptPrefix); - - String promptSuffix = parameters.getOrDefault("prompt.suffix", PromptTemplate.PROMPT_TEMPLATE_SUFFIX); - tmpParameters.put("prompt.suffix", promptSuffix); - - String promptFormatInstruction = parameters.getOrDefault("prompt.format_instruction", PromptTemplate.PROMPT_FORMAT_INSTRUCTION); - tmpParameters.put("prompt.format_instruction", promptFormatInstruction); - if (!tmpParameters.containsKey("prompt.tool_response")) { - tmpParameters.put("prompt.tool_response", PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE); - } - String promptToolResponse = parameters.getOrDefault("prompt.tool_response", PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE); - tmpParameters.put("prompt.tool_response", promptToolResponse); - - StringSubstitutor promptSubstitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}"); - prompt = promptSubstitutor.replace(prompt); - - final List inputTools = new ArrayList<>(); - for (Map.Entry entry : tools.entrySet()) { - String toolName = Optional.ofNullable(entry.getValue().getName()).orElse(entry.getValue().getType()); - inputTools.add(toolName); - } - - prompt = AgentUtils.addPrefixSuffixToPrompt(parameters, prompt); - prompt = AgentUtils.addToolsToPrompt(tools, parameters, inputTools, prompt); - prompt = AgentUtils.addIndicesToPrompt(parameters, prompt); - prompt = AgentUtils.addExamplesToPrompt(parameters, prompt); - prompt = AgentUtils.addChatHistoryToPrompt(parameters, prompt); - prompt = AgentUtils.addContextToPrompt(parameters, prompt); + boolean verbose = parameters.containsKey(VERBOSE) && Boolean.parseBoolean(parameters.get(VERBOSE)); + boolean traceDisabled = parameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(parameters.get(DISABLE_TRACE)); + Map tmpParameters = constructLLMParams(llm, parameters); + String prompt = constructLLMPrompt(tools, parameters, inputTools, tmpParameters); tmpParameters.put(PROMPT, prompt); - List modelTensors = new ArrayList<>(); - - List cotModelTensors = new ArrayList<>(); - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() - ) - ) - .build() - ); + List traceTensors = createModelTensors(sessionId, parentInteractionId); StringBuilder scratchpadBuilder = new StringBuilder(); - StringSubstitutor tmpSubstitutor = new StringSubstitutor( - ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()), - "${parameters.", - "}" - ); + StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); - - String maxIteration = Optional.ofNullable(tmpParameters.get("max_iteration")).orElse("3"); + String finalPrompt = prompt; // Create root interaction. ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory; @@ -299,22 +213,18 @@ private void runReAct( // Trace number AtomicInteger traceNumber = new AtomicInteger(0); - StepListener firstListener; AtomicReference> lastLlmListener = new AtomicReference<>(); - AtomicBoolean getFinalAnswer = new AtomicBoolean(false); AtomicReference lastThought = new AtomicReference<>(); AtomicReference lastAction = new AtomicReference<>(); AtomicReference lastActionInput = new AtomicReference<>(); + AtomicReference lastToolSelectionResponse = new AtomicReference<>(); Map additionalInfo = new ConcurrentHashMap<>(); - StepListener lastStepListener = null; - int maxIterations = Integer.parseInt(maxIteration) * 2; - - String finalPrompt = prompt; - - firstListener = new StepListener(); + StepListener firstListener = new StepListener(); lastLlmListener.set(firstListener); - lastStepListener = firstListener; + StepListener lastStepListener = firstListener; + + int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2; for (int i = 0; i < maxIterations; i++) { int finalI = i; StepListener nextStepListener = new StepListener<>(); @@ -324,225 +234,84 @@ private void runReAct( if (finalI % 2 == 0) { MLTaskResponse llmResponse = (MLTaskResponse) output; ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); - Map dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); - if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) { - String response = (String) dataAsMap.get("response"); - String thoughtResponse = extractModelResponseJson(response); - dataAsMap = gson.fromJson(thoughtResponse, Map.class); - } - String thought = String.valueOf(dataAsMap.get("thought")); - String action = String.valueOf(dataAsMap.get("action")); - String actionInput = parseInputFromLLMReturn(dataAsMap); - String finalAnswer = (String) dataAsMap.get("final_answer"); - if (!dataAsMap.containsKey("thought")) { - String response = (String) dataAsMap.get("response"); - Pattern pattern = Pattern.compile("```json(.*?)```", Pattern.DOTALL); - Matcher matcher = pattern.matcher(response); - if (matcher.find()) { - String jsonBlock = matcher.group(1); - Map map = gson.fromJson(jsonBlock, Map.class); - thought = String.valueOf(map.get("thought")); - action = String.valueOf(map.get("action")); - actionInput = parseInputFromLLMReturn(map); - finalAnswer = (String) map.get("final_answer"); - } else { - finalAnswer = response; - } - } + List llmResponsePatterns = gson.fromJson(parameters.get("llm_response_pattern"), List.class); + Map modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns); - if (finalI == 0 && !thought.contains("Thought:")) { - sessionMsgAnswerBuilder.append("Thought: "); + String thought = String.valueOf(modelOutput.get(THOUGHT)); + String action = String.valueOf(modelOutput.get(ACTION)); + String actionInput = String.valueOf(modelOutput.get(ACTION_INPUT)); + String thoughtResponse = modelOutput.get(THOUGHT_RESPONSE); + String finalAnswer = modelOutput.get(FINAL_ANSWER); + + if (finalAnswer != null) { + finalAnswer = finalAnswer.trim(); + sendFinalAnswer( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + finalAnswer + ); + return; } + sessionMsgAnswerBuilder.append(thought); lastThought.set(thought); - cotModelTensors + lastAction.set(action); + lastActionInput.set(actionInput); + lastToolSelectionResponse.set(thoughtResponse); + + traceTensors .add( ModelTensors .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build() - ) - ) + .mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build())) .build() ); - // TODO: check if verbose - modelTensors.addAll(tmpModelTensorOutput.getMlModelOutputs()); - - if (conversationIndexMemory != null) { - ConversationIndexMessage msgTemp = ConversationIndexMessage - .conversationIndexMessageBuilder() - .type(memory.getType()) - .question(question) - .response(thought) - .finalAnswer(false) - .sessionId(sessionId) - .build(); - conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null); - } - if (finalAnswer != null) { - finalAnswer = finalAnswer.trim(); - String finalAnswer2 = finalAnswer; - // Composite execution response and reply. - final ActionListener executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> { - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build()) - ) - .build() - ); - - List finalModelTensors = new ArrayList<>(); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor - .builder() - .name(MLAgentExecutor.PARENT_INTERACTION_ID) - .result(parentInteractionId) - .build() - ) - ) - .build() - ); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor - .builder() - .name("response") - .dataAsMap( - ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo) - ) - .build() - ) - ) - .build() - ); - getFinalAnswer.set(true); - if (verbose) { - listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build()); - } else { - listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); - } - }, listener::onFailure)); - // Sending execution response by internalListener is after the trace and answer saving. - final GroupedActionListener groupedListener = createGroupedListener(2, executionListener); - if (conversationIndexMemory != null) { - String finalAnswer1 = finalAnswer; - // Create final trace message. - ConversationIndexMessage msgTemp = ConversationIndexMessage - .conversationIndexMessageBuilder() - .type(memory.getType()) - .question(question) - .response(finalAnswer1) - .finalAnswer(true) - .sessionId(sessionId) - .build(); - // Save last trace and update final answer in parallel. - conversationIndexMemory - .save( - msgTemp, - parentInteractionId, - traceNumber.addAndGet(1), - null, - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - conversationIndexMemory - .getMemoryManager() - .updateInteraction( - parentInteractionId, - ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo), - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - } - return; - } - lastAction.set(action); - lastActionInput.set(actionInput); + saveTraceData( + conversationIndexMemory, + memory.getType(), + question, + thoughtResponse, + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + "LLM" + ); - String toolName = action; - for (String key : tools.keySet()) { - if (action.toLowerCase().contains(key.toLowerCase())) { - toolName = key; - } - } - action = toolName; + action = getMatchingTool(tools, action); if (tools.containsKey(action) && inputTools.contains(action)) { - Map toolParams = new HashMap<>(); - Map toolSpecParams = toolSpecMap.get(action).getParameters(); - if (toolSpecParams != null) { - toolParams.putAll(toolSpecParams); - } - if (tools.get(action).useOriginalInput()) { - toolParams.put("input", question); - lastActionInput.set(question); - } else { - toolParams.put("input", actionInput); - } - if (tools.get(action).validate(toolParams)) { - try { - String finalAction = action; - ActionListener toolListener = ActionListener - .wrap(r -> { ((ActionListener) nextStepListener).onResponse(r); }, e -> { - ((ActionListener) nextStepListener) - .onResponse( - String - .format( - Locale.ROOT, - "Failed to run the tool %s with the error message %s.", - finalAction, - e.getMessage() - ) - ); - }); - if (tools.get(action) instanceof MLModelTool) { - Map llmToolTmpParameters = new HashMap<>(); - llmToolTmpParameters.putAll(tmpParameters); - llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters()); - // TODO: support tool parameter override : langauge_model_tool.prompt - llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput); - tools.get(action).run(llmToolTmpParameters, toolListener); // run tool - } else { - tools.get(action).run(toolParams, toolListener); // run tool - } - } catch (Exception e) { - ((ActionListener) nextStepListener) - .onResponse( - String - .format( - Locale.ROOT, - "Failed to run the tool %s with the error message %s.", - action, - e.getMessage() - ) - ); - } - } else { - String res = String - .format(Locale.ROOT, "Failed to run the tool %s due to wrong input %s.", action, actionInput); - ((ActionListener) nextStepListener).onResponse(res); - } + Map toolParams = constructToolParams( + tools, + toolSpecMap, + question, + lastActionInput, + action, + actionInput + ); + runTool( + tools, + toolSpecMap, + tmpParameters, + (ActionListener) nextStepListener, + action, + actionInput, + toolParams + ); } else { String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action); ((ActionListener) nextStepListener).onResponse(res); StringSubstitutor substitutor = new StringSubstitutor( - ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()), + Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}" ); @@ -550,72 +319,35 @@ private void runReAct( tmpParameters.put(PROMPT, newPrompt.get()); } } else { - MLToolSpec toolSpec = toolSpecMap.get(lastAction.get()); - if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) { - String outputString = outputToOutputString(output); - - String toolOutputKey = String.format("%s.output", toolSpec.getType()); - if (additionalInfo.get(toolOutputKey) != null) { - List list = (List) additionalInfo.get(toolOutputKey); - list.add(outputString); - } else { - additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); - } - - } - modelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor - .builder() - .dataAsMap( - ImmutableMap - .of( - "response", - lastThought.get() + "\nObservation: " + outputToOutputString(output) - ) - ) - .build() - ) - ) - .build() - ); - - String toolResponse = tmpParameters.get("prompt.tool_response"); - StringSubstitutor toolResponseSubstitutor = new StringSubstitutor( - ImmutableMap.of("observation", outputToOutputString(output)), - "${parameters.", - "}" + addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, output); + + String toolResponse = constructToolResponse( + tmpParameters, + lastAction, + lastActionInput, + lastToolSelectionResponse, + output ); - toolResponse = toolResponseSubstitutor.replace(toolResponse); scratchpadBuilder.append(toolResponse).append("\n\n"); - if (conversationIndexMemory != null) { - // String res = "Action: " + lastAction.get() + "\nAction Input: " + lastActionInput + "\nObservation: " + result; - ConversationIndexMessage msgTemp = ConversationIndexMessage - .conversationIndexMessageBuilder() - .type("ReAct") - .question(lastActionInput.get()) - .response(outputToOutputString(output)) - .finalAnswer(false) - .sessionId(sessionId) - .build(); - conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), lastAction.get()); - } - StringSubstitutor substitutor = new StringSubstitutor( - ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()), - "${parameters.", - "}" + saveTraceData( + conversationIndexMemory, + "ReAct", + lastActionInput.get(), + outputToOutputString(output), + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + lastAction.get() ); + + StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}"); newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); - sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output)); - cotModelTensors + sessionMsgAnswerBuilder.append(outputToOutputString(output)); + traceTensors .add( ModelTensors .builder() @@ -628,56 +360,25 @@ private void runReAct( .build() ); - ActionRequest request = new MLPredictionTaskRequest( - llm.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()) - .build() - ); - client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener) nextStepListener); if (finalI == maxIterations - 1) { if (verbose) { - listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build()); + listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(traceTensors).build()); } else { - List finalModelTensors = new ArrayList<>(); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor - .builder() - .name(MLAgentExecutor.PARENT_INTERACTION_ID) - .result(parentInteractionId) - .build() - ) - ) - .build() - ); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor - .builder() - .name("response") - .dataAsMap(ImmutableMap.of("response", lastThought.get())) - .build() - ) - ) - .build() - ); + List finalModelTensors = createFinalAnswerTensors( + createModelTensors(sessionId, parentInteractionId), + List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", lastThought.get())).build()) + ); listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); } } else { + ActionRequest request = new MLPredictionTaskRequest( + llm.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()) + .build() + ); client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener) nextStepListener); } } @@ -701,27 +402,372 @@ private void runReAct( client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } - private GroupedActionListener createGroupedListener(final int size, final ActionListener listener) { - return new GroupedActionListener<>(new ActionListener>() { - @Override - public void onResponse(final Collection responses) { - CreateInteractionResponse createInteractionResponse = extractResponse(responses, CreateInteractionResponse.class); - log.info("saved message with interaction id: {}", createInteractionResponse.getId()); - UpdateResponse updateResponse = extractResponse(responses, UpdateResponse.class); - log.info("Updated final answer into interaction id: {}", updateResponse.getId()); + private static Map parseLLMOutput(ModelTensorOutput tmpModelTensorOutput, List llmResponsePatterns) { + Map modelOutput = new HashMap<>(); + Map dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); + if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) { + String llmReasoningResponse = (String) dataAsMap.get("response"); + String thoughtResponse = null; + try { + thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns); + modelOutput.put(THOUGHT_RESPONSE, thoughtResponse); + } catch (IllegalArgumentException e) { + modelOutput.put(THOUGHT_RESPONSE, llmReasoningResponse); + modelOutput.put(FINAL_ANSWER, llmReasoningResponse); + } + if (isJson(thoughtResponse)) { + modelOutput.putAll(getParameterMap(gson.fromJson(thoughtResponse, Map.class))); + } + } else { + extractParams(modelOutput, dataAsMap, THOUGHT); + extractParams(modelOutput, dataAsMap, ACTION); + extractParams(modelOutput, dataAsMap, ACTION_INPUT); + extractParams(modelOutput, dataAsMap, FINAL_ANSWER); + try { + modelOutput.put(THOUGHT_RESPONSE, StringUtils.toJson(dataAsMap)); + } catch (Exception e) { + log.warn("Failed to parse model response", e); + } + } + return modelOutput; + } + + private static void extractParams(Map modelOutput, Map dataAsMap, String paramName) { + if (dataAsMap.containsKey(paramName)) { + modelOutput.put(paramName, toJson(dataAsMap.get(paramName))); + } + } + + private static List createFinalAnswerTensors(List sessionId, List lastThought) { + List finalModelTensors = sessionId; + finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build()); + return finalModelTensors; + } - listener.onResponse(true); + private static String constructToolResponse( + Map tmpParameters, + AtomicReference lastAction, + AtomicReference lastActionInput, + AtomicReference lastToolSelectionResponse, + Object output + ) throws PrivilegedActionException { + String toolResponse = tmpParameters.get(TOOL_RESPONSE); + StringSubstitutor toolResponseSubstitutor = new StringSubstitutor( + Map + .of( + "llm_tool_selection_response", + lastToolSelectionResponse.get(), + "tool_name", + lastAction.get(), + "tool_input", + lastActionInput.get(), + "observation", + outputToOutputString(output) + ), + "${parameters.", + "}" + ); + toolResponse = toolResponseSubstitutor.replace(toolResponse); + return toolResponse; + } + + private static void addToolOutputToAddtionalInfo( + Map toolSpecMap, + AtomicReference lastAction, + Map additionalInfo, + Object output + ) throws PrivilegedActionException { + MLToolSpec toolSpec = toolSpecMap.get(lastAction.get()); + if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) { + String outputString = outputToOutputString(output); + String toolOutputKey = String.format("%s.output", toolSpec.getType()); + if (additionalInfo.get(toolOutputKey) != null) { + List list = (List) additionalInfo.get(toolOutputKey); + list.add(outputString); + } else { + additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); } + } + } - @Override - public void onFailure(final Exception e) { - listener.onFailure(e); + private static void runTool( + Map tools, + Map toolSpecMap, + Map tmpParameters, + ActionListener nextStepListener, + String action, + String actionInput, + Map toolParams + ) { + if (tools.get(action).validate(toolParams)) { + try { + String finalAction = action; + ActionListener toolListener = ActionListener.wrap(r -> { nextStepListener.onResponse(r); }, e -> { + nextStepListener + .onResponse( + String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", finalAction, e.getMessage()) + ); + }); + if (tools.get(action) instanceof MLModelTool) { + Map llmToolTmpParameters = new HashMap<>(); + llmToolTmpParameters.putAll(tmpParameters); + llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters()); + // TODO: support tool parameter override : langauge_model_tool.prompt + llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput); + tools.get(action).run(llmToolTmpParameters, toolListener); // run tool + } else { + tools.get(action).run(toolParams, toolListener); // run tool + } + } catch (Exception e) { + nextStepListener + .onResponse(String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", action, e.getMessage())); } - }, size); + } else { + String res = String.format(Locale.ROOT, "Failed to run the tool %s due to wrong input %s.", action, actionInput); + nextStepListener.onResponse(res); + } + } + + private static Map constructToolParams( + Map tools, + Map toolSpecMap, + String question, + AtomicReference lastActionInput, + String action, + String actionInput + ) { + Map toolParams = new HashMap<>(); + Map toolSpecParams = toolSpecMap.get(action).getParameters(); + if (toolSpecParams != null) { + toolParams.putAll(toolSpecParams); + } + if (tools.get(action).useOriginalInput()) { + toolParams.put("input", question); + lastActionInput.set(question); + } else { + toolParams.put("input", actionInput); + } + return toolParams; + } + + private static String getMatchingTool(Map tools, String name) { + String toolName = name; + for (String key : tools.keySet()) { + if (name.toLowerCase().contains(key.toLowerCase())) { + toolName = key; + } + } + return toolName; + } + + private static void saveTraceData( + ConversationIndexMemory conversationIndexMemory, + String memory, + String question, + String thoughtResponse, + String sessionId, + boolean traceDisabled, + String parentInteractionId, + AtomicInteger traceNumber, + String origin + ) { + if (conversationIndexMemory != null) { + ConversationIndexMessage msgTemp = ConversationIndexMessage + .conversationIndexMessageBuilder() + .type(memory) + .question(question) + .response(thoughtResponse) + .finalAnswer(false) + .sessionId(sessionId) + .build(); + if (!traceDisabled) { + conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), origin); + } + } + } + + private void sendFinalAnswer( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List cotModelTensors, + ConversationIndexMemory conversationIndexMemory, + AtomicInteger traceNumber, + Map additionalInfo, + String finalAnswer + ) { + if (conversationIndexMemory != null) { + String copyOfFinalAnswer = finalAnswer; + ActionListener saveTraceListener = ActionListener.wrap(r -> { + conversationIndexMemory + .getMemoryManager() + .updateInteraction( + parentInteractionId, + Map.of(AI_RESPONSE_FIELD, copyOfFinalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo), + ActionListener.wrap(res -> { + returnFinalResponse( + sessionId, + listener, + parentInteractionId, + verbose, + cotModelTensors, + additionalInfo, + copyOfFinalAnswer + ); + }, e -> { listener.onFailure(e); }) + ); + }, e -> { listener.onFailure(e); }); + saveMessage( + conversationIndexMemory, + question, + finalAnswer, + sessionId, + parentInteractionId, + traceNumber, + true, + traceDisabled, + saveTraceListener + ); + } else { + returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); + } + } + + private static List createModelTensors(String sessionId, String parentInteractionId) { + List cotModelTensors = new ArrayList<>(); + + cotModelTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + List + .of( + ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), + ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() + ) + ) + .build() + ); + return cotModelTensors; } - @SuppressWarnings("unchecked") - private static A extractResponse(final Collection responses, Class c) { - return (A) responses.stream().filter(c::isInstance).findFirst().get(); + private static String constructLLMPrompt( + Map tools, + Map parameters, + List inputTools, + Map tmpParameters + ) { + String prompt = parameters.getOrDefault(PROMPT, PromptTemplate.PROMPT_TEMPLATE); + StringSubstitutor promptSubstitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}"); + prompt = promptSubstitutor.replace(prompt); + prompt = AgentUtils.addPrefixSuffixToPrompt(parameters, prompt); + prompt = AgentUtils.addToolsToPrompt(tools, parameters, inputTools, prompt); + prompt = AgentUtils.addIndicesToPrompt(parameters, prompt); + prompt = AgentUtils.addExamplesToPrompt(parameters, prompt); + prompt = AgentUtils.addChatHistoryToPrompt(parameters, prompt); + prompt = AgentUtils.addContextToPrompt(parameters, prompt); + return prompt; + } + + private static Map constructLLMParams(LLMSpec llm, Map parameters) { + Map tmpParameters = new HashMap<>(); + if (llm.getParameters() != null) { + tmpParameters.putAll(llm.getParameters()); + } + tmpParameters.putAll(parameters); + if (!tmpParameters.containsKey("stop")) { + tmpParameters.put("stop", gson.toJson(new String[] { "\nObservation:", "\n\tObservation:" })); + } + if (!tmpParameters.containsKey("stop_sequences")) { + tmpParameters + .put( + "stop_sequences", + gson + .toJson( + new String[] { + "\n\nHuman:", + "\nObservation:", + "\n\tObservation:", + "\nObservation", + "\n\tObservation", + "\n\nQuestion" } + ) + ); + } + + String promptPrefix = parameters.getOrDefault(PROMPT_PREFIX, PromptTemplate.PROMPT_TEMPLATE_PREFIX); + tmpParameters.put(PROMPT_PREFIX, promptPrefix); + + String promptSuffix = parameters.getOrDefault(PROMPT_SUFFIX, PromptTemplate.PROMPT_TEMPLATE_SUFFIX); + tmpParameters.put(PROMPT_SUFFIX, promptSuffix); + + String promptFormatInstruction = parameters.getOrDefault(RESPONSE_FORMAT_INSTRUCTION, PromptTemplate.PROMPT_FORMAT_INSTRUCTION); + tmpParameters.put(RESPONSE_FORMAT_INSTRUCTION, promptFormatInstruction); + + String promptToolResponse = parameters.getOrDefault(TOOL_RESPONSE, PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE); + tmpParameters.put(TOOL_RESPONSE, promptToolResponse); + return tmpParameters; + } + + private static void returnFinalResponse( + String sessionId, + ActionListener listener, + String parentInteractionId, + boolean verbose, + List cotModelTensors, // AtomicBoolean getFinalAnswer, + Map additionalInfo, + String finalAnswer2 + ) { + cotModelTensors + .add( + ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name("response").result(finalAnswer2).build())).build() + ); + + List finalModelTensors = createFinalAnswerTensors( + createModelTensors(sessionId, parentInteractionId), + List + .of( + ModelTensor + .builder() + .name("response") + .dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) + .build() + ) + ); + if (verbose) { + listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build()); + } else { + listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); + } + } + + private void saveMessage( + ConversationIndexMemory memory, + String question, + String finalAnswer, + String sessionId, + String parentInteractionId, + AtomicInteger traceNumber, + boolean isFinalAnswer, + boolean traceDisabled, + ActionListener listener + ) { + ConversationIndexMessage msgTemp = ConversationIndexMessage + .conversationIndexMessageBuilder() + .type(memory.getType()) + .question(question) + .response(finalAnswer) + .finalAnswer(isFinalAnswer) + .sessionId(sessionId) + .build(); + if (traceDisabled) { + listener.onResponse(true); + } else { + memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); + } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index c2471e4f35..6f374e0bb8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -10,8 +10,10 @@ import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD; -import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName; import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; @@ -62,7 +64,7 @@ public class MLConversationalFlowAgentRunner implements MLAgentRunner { public static final String CHAT_HISTORY = "chat_history"; - public static final String SELECTED_TOOLS = "selected_tools"; + private Client client; private Settings settings; private ClusterService clusterService; @@ -156,8 +158,7 @@ private void runAgent( Map firstToolExecuteParams = null; StepListener previousStepListener = null; Map additionalInfo = new ConcurrentHashMap<>(); - String selectedToolsStr = params.get(SELECTED_TOOLS); - List toolSpecs = getMlToolSpecs(mlAgent, selectedToolsStr); + List toolSpecs = getMlToolSpecs(mlAgent, params); if (toolSpecs == null || toolSpecs.size() == 0) { listener.onFailure(new IllegalArgumentException("no tool configured")); @@ -173,7 +174,7 @@ private void runAgent( for (int i = 0; i <= toolSpecs.size(); i++) { if (i == 0) { MLToolSpec toolSpec = toolSpecs.get(i); - Tool tool = createTool(toolSpec); + Tool tool = createTool(toolFactories, params, toolSpec); firstStepListener = new StepListener<>(); previousStepListener = firstStepListener; firstTool = tool; @@ -231,25 +232,6 @@ private void runAgent( } } - private static List getMlToolSpecs(MLAgent mlAgent, String selectedToolsStr) { - List toolSpecs = mlAgent.getTools(); - if (selectedToolsStr != null) { - List selectedTools = gson.fromJson(selectedToolsStr, List.class); - Map toolNameSpecMap = new HashMap<>(); - for (MLToolSpec toolSpec : toolSpecs) { - toolNameSpecMap.put(getToolName(toolSpec), toolSpec); - } - List selectedToolSpecs = new ArrayList<>(); - for (String tool : selectedTools) { - if (toolNameSpecMap.containsKey(tool)) { - selectedToolSpecs.add(toolNameSpecMap.get(tool)); - } - } - toolSpecs = selectedToolSpecs; - } - return toolSpecs; - } - private void processOutput( Map params, ActionListener listener, @@ -271,6 +253,7 @@ private void processOutput( String outputKey = toolName + ".output"; String outputResponse = parseResponse(output); params.put(outputKey, escapeJson(outputResponse)); + boolean traceDisabled = params.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(params.get(DISABLE_TRACE)); if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) { if (output instanceof ModelTensorOutput) { @@ -303,32 +286,53 @@ private void processOutput( updateMemoryWithListener(additionalInfo, memorySpec, memoryId, parentInteractionId, updateListener); } } else { - saveMessage(params, memory, outputResponse, memoryId, parentInteractionId, toolName, traceNumber, ActionListener.wrap(r -> { - log.info("saved last trace for interaction " + parentInteractionId + " of flow agent"); - Map updateContent = Map.of(AI_RESPONSE_FIELD, outputResponse, ADDITIONAL_INFO_FIELD, additionalInfo); - memory.update(parentInteractionId, updateContent, updateListener); - }, e -> { - log.error("Failed to update root interaction ", e); - listener.onFailure(e); - })); + saveMessage( + params, + memory, + outputResponse, + memoryId, + parentInteractionId, + toolName, + traceNumber, + traceDisabled, + ActionListener.wrap(r -> { + log.info("saved last trace for interaction " + parentInteractionId + " of flow agent"); + Map updateContent = Map + .of(AI_RESPONSE_FIELD, outputResponse, ADDITIONAL_INFO_FIELD, additionalInfo); + memory.update(parentInteractionId, updateContent, updateListener); + }, e -> { + log.error("Failed to update root interaction ", e); + listener.onFailure(e); + }) + ); } } else { if (memory == null) { runNextStep(params, toolSpecs, finalI, nextStepListener); } else { - saveMessage(params, memory, outputResponse, memoryId, parentInteractionId, toolName, traceNumber, ActionListener.wrap(r -> { - runNextStep(params, toolSpecs, finalI, nextStepListener); - }, e -> { - log.error("Failed to update root interaction ", e); - listener.onFailure(e); - })); + saveMessage( + params, + memory, + outputResponse, + memoryId, + parentInteractionId, + toolName, + traceNumber, + traceDisabled, + ActionListener.wrap(r -> { + runNextStep(params, toolSpecs, finalI, nextStepListener); + }, e -> { + log.error("Failed to update root interaction ", e); + listener.onFailure(e); + }) + ); } } } private void runNextStep(Map params, List toolSpecs, int finalI, StepListener nextStepListener) { MLToolSpec toolSpec = toolSpecs.get(finalI); - Tool tool = createTool(toolSpec); + Tool tool = createTool(toolFactories, params, toolSpec); if (finalI < toolSpecs.size()) { tool.run(getToolExecuteParams(toolSpec, params), nextStepListener); } @@ -342,6 +346,7 @@ private void saveMessage( String parentInteractionId, String toolName, AtomicInteger traceNumber, + boolean traceDisabled, ActionListener listener ) { ConversationIndexMessage finalMessage = ConversationIndexMessage @@ -352,7 +357,11 @@ private void saveMessage( .finalAnswer(true) .sessionId(memoryId) .build(); - memory.save(finalMessage, parentInteractionId, traceNumber.addAndGet(1), toolName, listener); + if (traceDisabled) { + listener.onResponse(true); + } else { + memory.save(finalMessage, parentInteractionId, traceNumber.addAndGet(1), toolName, listener); + } } @VisibleForTesting @@ -397,26 +406,6 @@ String parseResponse(Object output) throws IOException { } } - @VisibleForTesting - Tool createTool(MLToolSpec toolSpec) { - Map toolParams = new HashMap<>(); - if (toolSpec.getParameters() != null) { - toolParams.putAll(toolSpec.getParameters()); - } - if (!toolFactories.containsKey(toolSpec.getType())) { - throw new IllegalArgumentException("Tool not found: " + toolSpec.getType()); - } - Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams); - if (toolSpec.getName() != null) { - tool.setName(toolSpec.getName()); - } - - if (toolSpec.getDescription() != null) { - tool.setDescription(toolSpec.getDescription()); - } - return tool; - } - @VisibleForTesting Map getToolExecuteParams(MLToolSpec toolSpec, Map params) { Map executeParams = new HashMap<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 674a1237c6..0f6bba2931 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.agent; import static org.apache.commons.text.StringEscapeUtils.escapeJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import java.io.IOException; import java.security.AccessController; @@ -73,7 +74,7 @@ public MLFlowAgentRunner( @Override public void run(MLAgent mlAgent, Map params, ActionListener listener) { - List toolSpecs = mlAgent.getTools(); + List toolSpecs = getMlToolSpecs(mlAgent, params); StepListener firstStepListener = null; Tool firstTool = null; List flowAgentOutput = new ArrayList<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index bbeee117be..58b6f2f26b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -7,8 +7,11 @@ public class PromptTemplate { public static final String PROMPT_FORMAT_INSTRUCTION = "Human:RESPONSE FORMAT INSTRUCTIONS\n----------------------------\nOutput a JSON markdown code snippet containing a valid JSON object in one of two formats:\n\n**Option 1:**\nUse this if you want the human to use a tool.\nMarkdown code snippet formatted in the following schema:\n\n```json\n{\n \"thought\": string, // think about what to do next: if you know the final answer just return \"Now I know the final answer\", otherwise suggest which tool to use.\n \"action\": string, // The action to take. Must be one of these tool names: [${parameters.tool_names}], do NOT use any other name for action except the tool names.\n \"action_input\": string // The input to the action. May be a stringified object.\n}\n```\n\n**Option #2:**\nUse this if you want to respond directly and conversationally to the human. Markdown code snippet formatted in the following schema:\n\n```json\n{\n \"thought\": \"Now I know the final answer\",\n \"final_answer\": string, // summarize and return the final answer in a sentence with details, don't just return a number or a word.\n}\n```"; public static final String PROMPT_TEMPLATE_SUFFIX = - "Human:TOOLS\n------\nAssistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:\n\n${parameters.tool_descriptions}\n\n${parameters.prompt.format_instruction}\n\n${parameters.chat_history}\n\n\nHuman:USER'S INPUT\n--------------------\nHere is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):\n${parameters.question}\n\n${parameters.scratchpad}"; - public static final String PROMPT_TEMPLATE = "\n\nHuman:${parameters.prompt.prefix}\n\n${parameters.prompt.suffix}\n\nAssistant:"; + "Human:TOOLS\n------\nAssistant can ask Human to use tools to look up information that may be helpful in answering the users original question. The tool response will be listed in \"TOOL RESPONSE of {tool name}:\". If TOOL RESPONSE is enough to answer human's question, Assistant should avoid rerun the same tool. \nAssistant should NEVER suggest run a tool with same input if it's already in TOOL RESPONSE. \nThe tools the human can use are:\n\n${parameters.tool_descriptions}\n\n${parameters.chat_history}\n\n${parameters.prompt.format_instruction}\n\n\nHuman:USER'S INPUT\n--------------------\nHere is the user's input :\n${parameters.question}\n\n${parameters.scratchpad}"; + public static final String PROMPT_TEMPLATE = + "\n\nHuman:${parameters.prompt.prefix}\n\n${parameters.prompt.suffix}\n\nHuman: follow RESPONSE FORMAT INSTRUCTIONS\n\nAssistant:"; public static final String PROMPT_TEMPLATE_TOOL_RESPONSE = - "TOOL RESPONSE: \n---------------------\n${parameters.observation}\n\nUSER'S INPUT\n--------------------\n\nOkay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else."; + "Assistant:\n---------------------\n${parameters.llm_tool_selection_response}\n\nHuman: TOOL RESPONSE of ${parameters.tool_name}: \n---------------------\nTool input:\n${parameters.tool_input}\n\nTool output:\n${parameters.observation}\n\n"; + public static final String CHAT_HISTORY_PREFIX = + "Human:CONVERSATION HISTORY WITH AI ASSISTANT\n----------------------------\nBelow is Chat History between Human and AI which sorted by time with asc order:\n"; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java index 2a084ee9b9..8436d98059 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java @@ -34,7 +34,7 @@ public ConversationIndexMessage(String type, String sessionId, String question, @Override public String toString() { - return "Human:" + question + "\nAI:" + response; + return "Human:" + question + "\nAssistant:" + response; } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index 6219a4b3a6..1dc76711ec 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -139,7 +139,7 @@ public void createInteraction( * @param actionListener get all the final interactions that are not traces */ public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener> actionListener) { - Preconditions.checkArgument(lastNInteraction > 0, "lastN must be at least 1."); + Preconditions.checkArgument(lastNInteraction > 0, "History message size must be at least 1."); log.debug("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index e2c25b93a8..0a45981b9c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -14,6 +14,7 @@ import java.util.Locale; import java.util.Map; import java.util.Spliterators; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -120,10 +121,10 @@ public void run(Map parameters, ActionListener listener) StringBuilder sb = new StringBuilder( // Currently using c.value which is short header matching _cat/indices // May prefer to use c.attr.get("desc") for full description - table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining("\t", "", "\n")) + table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining(",", "", "\n")) ); for (List row : table.getRows()) { - sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining("\t", "", "\n"))); + sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining(",", "", "\n"))); } @SuppressWarnings("unchecked") T response = (T) sb.toString(); @@ -366,16 +367,29 @@ private Table getTableWithHeader() { table.startHeaders(); // First param is cell.value which is currently returned // Second param is cell.attr we may want to use attr.desc in the future + table.addCell("row", "alias:r;desc:row number"); table.addCell("health", "alias:h;desc:current health status"); table.addCell("status", "alias:s;desc:open/close status"); table.addCell("index", "alias:i,idx;desc:index name"); table.addCell("uuid", "alias:id,uuid;desc:index uuid"); - table.addCell("pri", "alias:p,shards.primary,shardsPrimary;text-align:right;desc:number of primary shards"); - table.addCell("rep", "alias:r,shards.replica,shardsReplica;text-align:right;desc:number of replica shards"); - table.addCell("docs.count", "alias:dc,docsCount;text-align:right;desc:available docs"); - table.addCell("docs.deleted", "alias:dd,docsDeleted;text-align:right;desc:deleted docs"); - table.addCell("store.size", "sibling:pri;alias:ss,storeSize;text-align:right;desc:store size of primaries & replicas"); - table.addCell("pri.store.size", "text-align:right;desc:store size of primaries"); + table + .addCell( + "pri(number of primary shards)", + "alias:p,shards.primary,shardsPrimary;text-align:right;desc:number of primary shards" + ); + table + .addCell( + "rep(number of replica shards)", + "alias:r,shards.replica,shardsReplica;text-align:right;desc:number of replica shards" + ); + table.addCell("docs.count(number of available documents)", "alias:dc,docsCount;text-align:right;desc:available docs"); + table.addCell("docs.deleted(number of deleted documents)", "alias:dd,docsDeleted;text-align:right;desc:deleted docs"); + table + .addCell( + "store.size(store size of primary and replica shards)", + "sibling:pri;alias:ss,storeSize;text-align:right;desc:store size of primaries & replicas" + ); + table.addCell("pri.store.size(store size of primary shards)", "text-align:right;desc:store size of primaries"); // Above includes all the default fields for cat indices. See RestIndicesAction for a lot more that could be included. table.endHeaders(); return table; @@ -388,7 +402,7 @@ private Table buildTable( final Map indicesMetadatas ) { final Table table = getTableWithHeader(); - + AtomicInteger rowNum = new AtomicInteger(0); indicesSettings.forEach((indexName, settings) -> { if (!indicesMetadatas.containsKey(indexName)) { // the index exists in the Get Indices response but is not present in the cluster state: @@ -421,6 +435,7 @@ private Table buildTable( totalStats = indexStats.getTotal(); } table.startRow(); + table.addCell(rowNum.addAndGet(1)); table.addCell(health); table.addCell(indexState.toString().toLowerCase(Locale.ROOT)); table.addCell(indexName); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index b348bcc228..f4943db25d 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -8,18 +8,19 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES; -import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX; -import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -119,7 +120,7 @@ public void testAddExamplesToPrompt_WithoutExamples() { @Test public void testAddPrefixSuffixToPrompt_WithPrefixSuffix() { // Setup - String initialPrompt = "initial prompt ${parameters.prompt_prefix} main content ${parameters.prompt_suffix}"; + String initialPrompt = "initial prompt ${parameters.prompt.prefix} main content ${parameters.prompt.suffix}"; Map parameters = new HashMap<>(); parameters.put(PROMPT_PREFIX, "Prefix: "); parameters.put(PROMPT_SUFFIX, " :Suffix"); @@ -137,7 +138,7 @@ public void testAddPrefixSuffixToPrompt_WithPrefixSuffix() { @Test public void testAddPrefixSuffixToPrompt_WithoutPrefixSuffix() { // Setup - String initialPrompt = "initial prompt ${parameters.prompt_prefix} main content ${parameters.prompt_suffix}"; + String initialPrompt = "initial prompt ${parameters.prompt.prefix} main content ${parameters.prompt.suffix}"; Map parameters = new HashMap<>(); // Expected output (should remain unchanged) @@ -265,4 +266,44 @@ public void testExtractModelResponseJsonWithValidModelOutput() { String responseJson = AgentUtils.extractModelResponseJson(text); assertEquals("{\"thought\":\"use CatIndexTool to get index first\",\"action\":\"CatIndexTool\"}", responseJson); } + + @Test + public void test() { + String text = + "---------------------\n{\n \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}"; + String result = AgentUtils.extractModelResponseJson(text); + String expectedResult = "{\n" + + " \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n" + + " \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n" + + "}"; + System.out.println(result); + Assert.assertEquals(expectedResult, result); + } + + @Test + public void test2() { + String text = + "---------------------```json\n{\n \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}\n```"; + String result = AgentUtils.extractModelResponseJson(text); + String expectedResult = "{\n" + + " \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n" + + " \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n" + + "}"; + System.out.println(result); + Assert.assertEquals(expectedResult, result); + } + + @Test + public void test3() { + String text = + "---------------------\n{\n \"thought\": \"Let me search our index to find population projections\", \n \"action\": \"VectorDBTool\",\n \"action_input\": \"Seattle population projection 2023\"\n}"; + String result = AgentUtils.extractModelResponseJson(text); + String expectedResult = "{\n" + + " \"thought\": \"Let me search our index to find population projections\", \n" + + " \"action\": \"VectorDBTool\",\n" + + " \"action_input\": \"Seattle population projection 2023\"\n" + + "}"; + System.out.println(result); + Assert.assertEquals(expectedResult, result); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index ff3be400dd..3d3538412e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -208,12 +208,12 @@ public void testParsingJsonBlockFromResponse() { ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); - ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); - ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); // Verify that the parsed values from JSON block are correctly set assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); - assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("conversation_id", modelTensor1.getResult()); assertEquals("parsed final answer", modelTensor2.getResult()); } @@ -248,12 +248,12 @@ public void testParsingJsonBlockFromResponse2() { ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); - ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); - ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); // Verify that the parsed values from JSON block are correctly set assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); - assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("conversation_id", modelTensor1.getResult()); assertEquals("parsed final answer", modelTensor2.getResult()); } @@ -288,12 +288,12 @@ public void testParsingJsonBlockFromResponse3() { ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); - ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); - ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); // Verify that the parsed values from JSON block are correctly set assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); - assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("conversation_id", modelTensor1.getResult()); assertEquals("parsed final answer", modelTensor2.getResult()); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java index 9e91695a5b..ee3fdc2088 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java @@ -30,7 +30,7 @@ public void setUp() { @Test public void testToString() { - Assert.assertEquals("Human:question\nAI:response", message.toString()); + Assert.assertEquals("Human:question\nAssistant:response", message.toString()); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java index cffb0ff338..11b29070f3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java @@ -226,8 +226,11 @@ public void testRunAsyncIndexStats() throws Exception { String header = responseRows[0]; String fooRow = responseRows[1]; assertEquals(header.split("\\t").length, fooRow.split("\\t").length); - assertEquals("health\tstatus\tindex\tuuid\tpri\trep\tdocs.count\tdocs.deleted\tstore.size\tpri.store.size", header); - assertEquals("red\topen\tfoo\tnull\t5\t1\t0\t0\t0b\t0b", fooRow); + assertEquals( + "row,health,status,index,uuid,pri(number of primary shards),rep(number of replica shards),docs.count(number of available documents),docs.deleted(number of deleted documents),store.size(store size of primary and replica shards),pri.store.size(store size of primary shards)", + header + ); + assertEquals("1,red,open,foo,null,5,1,0,0,0b,0b", fooRow); } @Test