Skip to content

Commit

Permalink
fine tune code; fix some bug
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 17, 2024
1 parent 9b683cd commit 26e1b86
Showing 1 changed file with 16 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,15 @@ private void runReAct(
String sessionId,
ActionListener<Object> listener
) {
final List<String> inputTools = getToolNames(tools);
String question = parameters.get(MLAgentExecutor.QUESTION);
String parentInteractionId = parameters.get(MLAgentExecutor.PARENT_INTERACTION_ID);
boolean verbose = parameters.containsKey(VERBOSE) && Boolean.parseBoolean(parameters.get(VERBOSE));
boolean traceDisabled = parameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(parameters.get(DISABLE_TRACE));

Map<String, String> tmpParameters = constructLLMParams(llm, parameters);
String prompt = constructLLMPrompt(tools, inputTools, tmpParameters);
String prompt = constructLLMPrompt(tools, tmpParameters);
tmpParameters.put(PROMPT, prompt);

List<ModelTensors> traceTensors = createModelTensors(sessionId, parentInteractionId);

StringBuilder scratchpadBuilder = new StringBuilder();
StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}");
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
tmpParameters.put(PROMPT, newPrompt.get());
String finalPrompt = prompt;
final String finalPrompt = prompt;

// Create root interaction.
ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory;
Expand All @@ -220,6 +212,12 @@ private void runReAct(
lastLlmListener.set(firstListener);
StepListener<?> lastStepListener = firstListener;

StringBuilder scratchpadBuilder = new StringBuilder();
StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}");
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
tmpParameters.put(PROMPT, newPrompt.get());

List<ModelTensors> traceTensors = createModelTensors(sessionId, parentInteractionId);
int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2;
for (int i = 0; i < maxIterations; i++) {
int finalI = i;
Expand All @@ -230,7 +228,7 @@ private void runReAct(
if (finalI % 2 == 0) {
MLTaskResponse llmResponse = (MLTaskResponse) output;
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
List<String> llmResponsePatterns = gson.fromJson(parameters.get("llm_response_pattern"), List.class);
List<String> llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class);
Map<String, String> modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns, tools.keySet());

String thought = String.valueOf(modelOutput.get(THOUGHT));
Expand Down Expand Up @@ -283,7 +281,7 @@ private void runReAct(
"LLM"
);

if (tools.containsKey(action) && inputTools.contains(action)) {
if (tools.containsKey(action)) {
Map<String, String> toolParams = constructToolParams(
tools,
toolSpecMap,
Expand Down Expand Up @@ -603,12 +601,12 @@ private static List<ModelTensors> createModelTensors(String sessionId, String pa
return cotModelTensors;
}

private static String constructLLMPrompt(Map<String, Tool> tools, List<String> inputTools, Map<String, String> tmpParameters) {
private static String constructLLMPrompt(Map<String, Tool> tools, Map<String, String> tmpParameters) {
String prompt = tmpParameters.getOrDefault(PROMPT, PromptTemplate.PROMPT_TEMPLATE);
StringSubstitutor promptSubstitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}");
prompt = promptSubstitutor.replace(prompt);
prompt = AgentUtils.addPrefixSuffixToPrompt(tmpParameters, prompt);
prompt = AgentUtils.addToolsToPrompt(tools, tmpParameters, inputTools, prompt);
prompt = AgentUtils.addToolsToPrompt(tools, tmpParameters, getToolNames(tools), prompt);
prompt = AgentUtils.addIndicesToPrompt(tmpParameters, prompt);
prompt = AgentUtils.addExamplesToPrompt(tmpParameters, prompt);
prompt = AgentUtils.addChatHistoryToPrompt(tmpParameters, prompt);
Expand Down Expand Up @@ -642,17 +640,10 @@ private static Map<String, String> constructLLMParams(LLMSpec llm, Map<String, S
);
}

String promptPrefix = parameters.getOrDefault(PROMPT_PREFIX, PromptTemplate.PROMPT_TEMPLATE_PREFIX);
tmpParameters.put(PROMPT_PREFIX, promptPrefix);

String promptSuffix = parameters.getOrDefault(PROMPT_SUFFIX, PromptTemplate.PROMPT_TEMPLATE_SUFFIX);
tmpParameters.put(PROMPT_SUFFIX, promptSuffix);

String promptFormatInstruction = parameters.getOrDefault(RESPONSE_FORMAT_INSTRUCTION, PromptTemplate.PROMPT_FORMAT_INSTRUCTION);
tmpParameters.put(RESPONSE_FORMAT_INSTRUCTION, promptFormatInstruction);

String promptToolResponse = parameters.getOrDefault(TOOL_RESPONSE, PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE);
tmpParameters.put(TOOL_RESPONSE, promptToolResponse);
tmpParameters.putIfAbsent(PROMPT_PREFIX, PromptTemplate.PROMPT_TEMPLATE_PREFIX);
tmpParameters.putIfAbsent(PROMPT_SUFFIX, PromptTemplate.PROMPT_TEMPLATE_SUFFIX);
tmpParameters.putIfAbsent(RESPONSE_FORMAT_INSTRUCTION, PromptTemplate.PROMPT_FORMAT_INSTRUCTION);
tmpParameters.putIfAbsent(TOOL_RESPONSE, PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE);
return tmpParameters;
}

Expand Down

0 comments on commit 26e1b86

Please sign in to comment.