From 32bcde4528bd9f844b56e5e425ade1bb5e05dadb Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Thu, 14 Nov 2024 21:26:01 +0800 Subject: [PATCH] fix history content from agent memory. (#899) * fix history content from agent memory. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../langchain/src/strategy/react/planner.py | 11 +++-- .../langchain/src/strategy/react/utils.py | 40 ++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/comps/agent/langchain/src/strategy/react/planner.py b/comps/agent/langchain/src/strategy/react/planner.py index 9771f6220..773cc199c 100644 --- a/comps/agent/langchain/src/strategy/react/planner.py +++ b/comps/agent/langchain/src/strategy/react/planner.py @@ -147,6 +147,7 @@ async def non_streaming_run(self, query, config): from ...persistence import AgentPersistence, PersistenceConfig from ...utils import setup_chat_model +from .utils import assemble_history, assemble_memory, convert_json_to_tool_call class AgentState(TypedDict): @@ -174,16 +175,20 @@ def __init__(self, tools, args): llm = setup_chat_model(args) self.tools = tools self.chain = prompt | llm | output_parser + self.with_memory = args.with_memory def __call__(self, state): - from .utils import assemble_history, convert_json_to_tool_call print("---CALL Agent node---") messages = state["messages"] # assemble a prompt from messages - query = messages[0].content - history = assemble_history(messages) + if self.with_memory: + query, history = assemble_memory(messages) + print("@@@ Query: ", history) + else: + query = messages[0].content + history = assemble_history(messages) print("@@@ History: ", history) tools_descriptions = tool_renderer(self.tools) diff --git a/comps/agent/langchain/src/strategy/react/utils.py b/comps/agent/langchain/src/strategy/react/utils.py index f303b424a..19e51032b 100644 --- a/comps/agent/langchain/src/strategy/react/utils.py +++ b/comps/agent/langchain/src/strategy/react/utils.py @@ -5,7 +5,7 @@ import uuid from huggingface_hub import ChatCompletionOutputFunctionDefinition, ChatCompletionOutputToolCall -from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages.tool import ToolCall from langchain_core.output_parsers import BaseOutputParser @@ -82,3 +82,41 @@ def assemble_history(messages): query_history += f"Assistant Output: {m.content}\n" return query_history + + +def assemble_memory(messages): + """ + messages: Human, AI, TOOL, AI, TOOL, etc. + """ + query = "" + query_id = None + query_history = "" + breaker = "-" * 10 + + # get query + for m in messages[::-1]: + if isinstance(m, HumanMessage): + query = m.content + query_id = m.id + break + + for m in messages: + if isinstance(m, AIMessage): + # if there is tool call + if hasattr(m, "tool_calls") and len(m.tool_calls) > 0: + for tool_call in m.tool_calls: + tool = tool_call["name"] + tc_args = tool_call["args"] + id = tool_call["id"] + tool_output = get_tool_output(messages, id) + query_history += f"Tool Call: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n" + else: + # did not make tool calls + query_history += f"Assistant Output: {m.content}\n" + + elif isinstance(m, HumanMessage): + if m.id == query_id: + continue + query_history += f"Human Input: {m.content}\n" + + return query, query_history