From ac85af3871242bc9d649924a6ea54afb2f359cfa Mon Sep 17 00:00:00 2001
From: Jeremie Pardou <571533+jrmi@users.noreply.github.com>
Date: Tue, 31 Dec 2024 00:10:27 +0100
Subject: [PATCH] fix: broken tool call after editing the file before saving
---
gptme/llm/llm_anthropic.py | 22 +++++++++++++++-----
gptme/llm/llm_openai.py | 40 ++++++++++++++++++++++++++++++++++++-
tests/test_llm_anthropic.py | 8 +++++++-
tests/test_llm_openai.py | 6 +++++-
4 files changed, 68 insertions(+), 8 deletions(-)
diff --git a/gptme/llm/llm_anthropic.py b/gptme/llm/llm_anthropic.py
index 393f6b028..fbc85f55f 100644
--- a/gptme/llm/llm_anthropic.py
+++ b/gptme/llm/llm_anthropic.py
@@ -200,9 +200,8 @@ def stream(
def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
for message in message_dicts:
# Format tool result as expected by the model
- if message["role"] == "system" and "call_id" in message:
+ if message["role"] == "user" and "call_id" in message:
modified_message = dict(message)
- modified_message["role"] = "user"
modified_message["content"] = [
{
"type": "tool_result",
@@ -358,22 +357,35 @@ def _transform_system_messages(
# unless a `call_id` is present, indicating the tool_format is 'tool'.
# Tool responses are handled separately by _handle_tool.
for i, message in enumerate(messages):
- if message.role == "system" and message.call_id is None:
+ if message.role == "system":
+ content = (
+ f"{message.content}"
+ if message.call_id is None
+ else message.content
+ )
+
messages[i] = Message(
"user",
- content=f"{message.content}",
+ content=content,
files=message.files, # type: ignore
+ call_id=message.call_id,
)
# find consecutive user role messages and merge them together
messages_new: list[Message] = []
while messages:
message = messages.pop(0)
- if messages_new and messages_new[-1].role == "user" and message.role == "user":
+ if (
+ messages_new
+ and messages_new[-1].role == "user"
+ and message.role == "user"
+ and message.call_id == messages_new[-1].call_id
+ ):
messages_new[-1] = Message(
"user",
content=f"{messages_new[-1].content}\n\n{message.content}",
files=messages_new[-1].files + message.files, # type: ignore
+ call_id=messages_new[-1].call_id,
)
else:
messages_new.append(message)
diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py
index 129c62d6d..b0492f630 100644
--- a/gptme/llm/llm_openai.py
+++ b/gptme/llm/llm_openai.py
@@ -274,6 +274,7 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
modified_message["content"] = content
if tool_calls:
+ # Clean content property if empty otherwise the call fails
if not content:
del modified_message["content"]
modified_message["tool_calls"] = tool_calls
@@ -283,6 +284,41 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
yield message
+def _merge_tool_results_with_same_call_id(
+ messages_dicts: Iterable[dict],
+) -> list[dict]: # Generator[dict, None, None]:
+ """
+ When we call a tool, this tool can potentially yield multiple messages. However
+ the API expect to have only one tool result per tool call. This function tries
+ to merge subsequent tool results with the same call ID as expected by
+ the API.
+ """
+
+ messages_dicts = iter(messages_dicts)
+
+ messages_new: list[dict] = []
+ while message := next(messages_dicts, None):
+ if messages_new and (
+ message["role"] == "tool"
+ and messages_new[-1]["role"] == "tool"
+ and message["tool_call_id"] == messages_new[-1]["tool_call_id"]
+ ):
+ prev_msg = messages_new[-1]
+ content = message["content"]
+ if not isinstance(content, list):
+ content = {"type": "text", "text": content}
+
+ messages_new[-1] = {
+ "role": "tool",
+ "content": prev_msg["content"] + content,
+ "tool_call_id": prev_msg["tool_call_id"],
+ }
+ else:
+ messages_new.append(message)
+
+ return messages_new
+
+
def _process_file(msg: dict, model: ModelMeta) -> dict:
message_content = msg["content"]
if model.provider == "deepseek":
@@ -423,7 +459,9 @@ def _prepare_messages_for_api(
tools_dict = [_spec2tool(tool, model) for tool in tools] if tools else None
if tools_dict is not None:
- messages_dicts = _handle_tools(messages_dicts)
+ messages_dicts = _merge_tool_results_with_same_call_id(
+ _handle_tools(messages_dicts)
+ )
messages_dicts = _transform_msgs_for_special_provider(messages_dicts, model)
diff --git a/tests/test_llm_anthropic.py b/tests/test_llm_anthropic.py
index 8a5081e52..869f7cf39 100644
--- a/tests/test_llm_anthropic.py
+++ b/tests/test_llm_anthropic.py
@@ -97,6 +97,7 @@ def test_message_conversion_with_tools():
content='\nSomething\n\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}',
),
Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"),
+ Message(role="system", content="(Modified by user)", call_id="tool_call_id"),
]
tool_save = get_tool("save")
@@ -152,7 +153,12 @@ def test_message_conversion_with_tools():
"content": [
{
"type": "tool_result",
- "content": [{"type": "text", "text": "Saved to toto.txt"}],
+ "content": [
+ {
+ "type": "text",
+ "text": "Saved to toto.txt\n\n(Modified by user)",
+ }
+ ],
"tool_use_id": "tool_call_id",
"cache_control": {"type": "ephemeral"},
}
diff --git a/tests/test_llm_openai.py b/tests/test_llm_openai.py
index 34d75817f..401ae99f2 100644
--- a/tests/test_llm_openai.py
+++ b/tests/test_llm_openai.py
@@ -116,6 +116,7 @@ def test_message_conversion_with_tools():
content='\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}',
),
Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"),
+ Message(role="system", content="(Modified by user)", call_id="tool_call_id"),
]
set_default_model("openai/gpt-4o")
@@ -193,7 +194,10 @@ def test_message_conversion_with_tools():
},
{
"role": "tool",
- "content": [{"type": "text", "text": "Saved to toto.txt"}],
+ "content": [
+ {"type": "text", "text": "Saved to toto.txt"},
+ {"type": "text", "text": "(Modified by user)"},
+ ],
"tool_call_id": "tool_call_id",
},
]