Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix history content from agent memory. #899

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions comps/agent/langchain/src/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def non_streaming_run(self, query, config):

from ...persistence import AgentPersistence, PersistenceConfig
from ...utils import setup_chat_model
from .utils import assemble_history, assemble_memory, convert_json_to_tool_call


class AgentState(TypedDict):
Expand Down Expand Up @@ -174,16 +175,20 @@ def __init__(self, tools, args):
llm = setup_chat_model(args)
self.tools = tools
self.chain = prompt | llm | output_parser
self.with_memory = args.with_memory

def __call__(self, state):
from .utils import assemble_history, convert_json_to_tool_call

print("---CALL Agent node---")
messages = state["messages"]

# assemble a prompt from messages
query = messages[0].content
history = assemble_history(messages)
if self.with_memory:
query, history = assemble_memory(messages)
print("@@@ Query: ", history)
else:
query = messages[0].content
history = assemble_history(messages)
print("@@@ History: ", history)

tools_descriptions = tool_renderer(self.tools)
Expand Down
40 changes: 39 additions & 1 deletion comps/agent/langchain/src/strategy/react/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid

from huggingface_hub import ChatCompletionOutputFunctionDefinition, ChatCompletionOutputToolCall
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.output_parsers import BaseOutputParser

Expand Down Expand Up @@ -82,3 +82,41 @@ def assemble_history(messages):
query_history += f"Assistant Output: {m.content}\n"

return query_history


def assemble_memory(messages):
"""
messages: Human, AI, TOOL, AI, TOOL, etc.
"""
query = ""
query_id = None
query_history = ""
breaker = "-" * 10

# get query
for m in messages[::-1]:
if isinstance(m, HumanMessage):
query = m.content
query_id = m.id
break

for m in messages:
if isinstance(m, AIMessage):
# if there is tool call
if hasattr(m, "tool_calls") and len(m.tool_calls) > 0:
for tool_call in m.tool_calls:
tool = tool_call["name"]
tc_args = tool_call["args"]
id = tool_call["id"]
tool_output = get_tool_output(messages, id)
query_history += f"Tool Call: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n"
else:
# did not make tool calls
query_history += f"Assistant Output: {m.content}\n"

elif isinstance(m, HumanMessage):
if m.id == query_id:
continue
query_history += f"Human Input: {m.content}\n"

return query, query_history