Skip to content

Commit

Permalink
fix history content from agent memory. (#899)
Browse files Browse the repository at this point in the history
* fix history content from agent memory.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lkk12014402 and pre-commit-ci[bot] authored Nov 14, 2024
1 parent 0dbf577 commit 32bcde4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
11 changes: 8 additions & 3 deletions comps/agent/langchain/src/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def non_streaming_run(self, query, config):

from ...persistence import AgentPersistence, PersistenceConfig
from ...utils import setup_chat_model
from .utils import assemble_history, assemble_memory, convert_json_to_tool_call


class AgentState(TypedDict):
Expand Down Expand Up @@ -174,16 +175,20 @@ def __init__(self, tools, args):
llm = setup_chat_model(args)
self.tools = tools
self.chain = prompt | llm | output_parser
self.with_memory = args.with_memory

def __call__(self, state):
from .utils import assemble_history, convert_json_to_tool_call

print("---CALL Agent node---")
messages = state["messages"]

# assemble a prompt from messages
query = messages[0].content
history = assemble_history(messages)
if self.with_memory:
query, history = assemble_memory(messages)
print("@@@ Query: ", history)
else:
query = messages[0].content
history = assemble_history(messages)
print("@@@ History: ", history)

tools_descriptions = tool_renderer(self.tools)
Expand Down
40 changes: 39 additions & 1 deletion comps/agent/langchain/src/strategy/react/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid

from huggingface_hub import ChatCompletionOutputFunctionDefinition, ChatCompletionOutputToolCall
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.output_parsers import BaseOutputParser

Expand Down Expand Up @@ -82,3 +82,41 @@ def assemble_history(messages):
query_history += f"Assistant Output: {m.content}\n"

return query_history


def assemble_memory(messages):
"""
messages: Human, AI, TOOL, AI, TOOL, etc.
"""
query = ""
query_id = None
query_history = ""
breaker = "-" * 10

# get query
for m in messages[::-1]:
if isinstance(m, HumanMessage):
query = m.content
query_id = m.id
break

for m in messages:
if isinstance(m, AIMessage):
# if there is tool call
if hasattr(m, "tool_calls") and len(m.tool_calls) > 0:
for tool_call in m.tool_calls:
tool = tool_call["name"]
tc_args = tool_call["args"]
id = tool_call["id"]
tool_output = get_tool_output(messages, id)
query_history += f"Tool Call: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n"
else:
# did not make tool calls
query_history += f"Assistant Output: {m.content}\n"

elif isinstance(m, HumanMessage):
if m.id == query_id:
continue
query_history += f"Human Input: {m.content}\n"

return query, query_history

0 comments on commit 32bcde4

Please sign in to comment.