Skip to content

Commit

Permalink
fine tune prompt;refactor conversational agent code
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 13, 2024
1 parent 38566f5 commit 3fc5b30
Show file tree
Hide file tree
Showing 12 changed files with 755 additions and 558 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -28,13 +27,23 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {

public static final String SELECTED_TOOLS = "selected_tools";
public static final String PROMPT_PREFIX = "prompt.prefix";
public static final String PROMPT_SUFFIX = "prompt.suffix";
public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction";
public static final String TOOL_RESPONSE = "prompt.tool_response";
public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix";
public static final String DISABLE_TRACE = "disable_trace";
public static final String VERBOSE = "verbose";

public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> examplesMap = new HashMap<>();
if (parameters.containsKey(EXAMPLES)) {
Expand Down Expand Up @@ -150,17 +159,46 @@ public static String addContextToPrompt(Map<String, String> parameters, String p
return prompt;
}

public static List<String> MODEL_RESPONSE_PATTERNS = List
.of(
"\\{\\s*\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?\\}",
"\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"action\"\\s*:\\s*\".*?\"\\s*,\\s*\"action_input\"\\s*:\\s*\".*?\"\\s*}",
"\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"final_answer\"\\s*:\\s*\".*?\"\\s*}"
);

public static String extractModelResponseJson(String text) {
Pattern pattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher matcher = pattern.matcher(text);
return extractModelResponseJson(text, null);
}

if (matcher.find()) {
return matcher.group(1);
public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
Pattern pattern1 = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher matcher1 = pattern1.matcher(text);

if (matcher1.find()) {
return matcher1.group(1);
} else {
String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
if (matchedPart == null && llmResponsePatterns != null) {
matchedPart = findMatchedPart(text, llmResponsePatterns);
}
if (matchedPart != null) {
return matchedPart;
}
throw new IllegalArgumentException("Model output is invalid");
}
}

public static String findMatchedPart(String text, List<String> patternList) {
for (String p : patternList) {
Pattern pattern = Pattern.compile(p);
Matcher matcher = pattern.matcher(text);
if (matcher.find()) {
return matcher.group();
}
}
return null;
}

public static String outputToOutputString(Object output) throws PrivilegedActionException {
String outputString;
if (output instanceof ModelTensorOutput) {
Expand All @@ -179,16 +217,6 @@ public static String outputToOutputString(Object output) throws PrivilegedAction
return outputString;
}

public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
Object actionInput = retMap.get("action_input");
if (actionInput instanceof Map) {
return gson.toJson(actionInput);
} else {
return String.valueOf(actionInput);
}

}

public static int getMessageHistoryLimit(Map<String, String> params) {
String messageHistoryLimitStr = params.get(MESSAGE_HISTORY_LIMIT);
return messageHistoryLimitStr != null ? Integer.parseInt(messageHistoryLimitStr) : LAST_N_INTERACTIONS;
Expand All @@ -197,4 +225,75 @@ public static int getMessageHistoryLimit(Map<String, String> params) {
public static String getToolName(MLToolSpec toolSpec) {
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
}

public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, String> params) {
String selectedToolsStr = params.get(SELECTED_TOOLS);
List<MLToolSpec> toolSpecs = mlAgent.getTools();
if (selectedToolsStr != null) {
List<String> selectedTools = gson.fromJson(selectedToolsStr, List.class);
Map<String, MLToolSpec> toolNameSpecMap = new HashMap<>();
for (MLToolSpec toolSpec : toolSpecs) {
toolNameSpecMap.put(getToolName(toolSpec), toolSpec);
}
List<MLToolSpec> selectedToolSpecs = new ArrayList<>();
for (String tool : selectedTools) {
if (toolNameSpecMap.containsKey(tool)) {
selectedToolSpecs.add(toolNameSpecMap.get(tool));
}
}
toolSpecs = selectedToolSpecs;
}
return toolSpecs;
}

public static void createTools(
Map<String, Tool.Factory> toolFactories,
Map<String, String> params,
List<MLToolSpec> toolSpecs,
Map<String, Tool> tools,
Map<String, MLToolSpec> toolSpecMap
) {
for (MLToolSpec toolSpec : toolSpecs) {
Tool tool = createTool(toolFactories, params, toolSpec);
tools.put(tool.getName(), tool);
toolSpecMap.put(tool.getName(), toolSpec);
}
}

public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<String, String> params, MLToolSpec toolSpec) {
if (!toolFactories.containsKey(toolSpec.getType())) {
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
}
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
String toolNamePrefix = getToolName(toolSpec) + ".";
if (key.startsWith(toolNamePrefix)) {
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
}
}
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
String toolName = getToolName(toolSpec);
tool.setName(toolName);

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}
if (params.containsKey(toolName + ".description")) {
tool.setDescription(params.get(toolName + ".description"));
}

return tool;
}

public static List<String> getToolNames(Map<String, Tool> tools) {
final List<String> inputTools = new ArrayList<>();
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
String toolName = entry.getValue().getName();
inputTools.add(toolName);
}
return inputTools;
}
}
Loading

0 comments on commit 3fc5b30

Please sign in to comment.