From ac85af3871242bc9d649924a6ea54afb2f359cfa Mon Sep 17 00:00:00 2001 From: Jeremie Pardou <571533+jrmi@users.noreply.github.com> Date: Tue, 31 Dec 2024 00:10:27 +0100 Subject: [PATCH] fix: broken tool call after editing the file before saving --- gptme/llm/llm_anthropic.py | 22 +++++++++++++++----- gptme/llm/llm_openai.py | 40 ++++++++++++++++++++++++++++++++++++- tests/test_llm_anthropic.py | 8 +++++++- tests/test_llm_openai.py | 6 +++++- 4 files changed, 68 insertions(+), 8 deletions(-) diff --git a/gptme/llm/llm_anthropic.py b/gptme/llm/llm_anthropic.py index 393f6b028..fbc85f55f 100644 --- a/gptme/llm/llm_anthropic.py +++ b/gptme/llm/llm_anthropic.py @@ -200,9 +200,8 @@ def stream( def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: for message in message_dicts: # Format tool result as expected by the model - if message["role"] == "system" and "call_id" in message: + if message["role"] == "user" and "call_id" in message: modified_message = dict(message) - modified_message["role"] = "user" modified_message["content"] = [ { "type": "tool_result", @@ -358,22 +357,35 @@ def _transform_system_messages( # unless a `call_id` is present, indicating the tool_format is 'tool'. # Tool responses are handled separately by _handle_tool. for i, message in enumerate(messages): - if message.role == "system" and message.call_id is None: + if message.role == "system": + content = ( + f"{message.content}" + if message.call_id is None + else message.content + ) + messages[i] = Message( "user", - content=f"{message.content}", + content=content, files=message.files, # type: ignore + call_id=message.call_id, ) # find consecutive user role messages and merge them together messages_new: list[Message] = [] while messages: message = messages.pop(0) - if messages_new and messages_new[-1].role == "user" and message.role == "user": + if ( + messages_new + and messages_new[-1].role == "user" + and message.role == "user" + and message.call_id == messages_new[-1].call_id + ): messages_new[-1] = Message( "user", content=f"{messages_new[-1].content}\n\n{message.content}", files=messages_new[-1].files + message.files, # type: ignore + call_id=messages_new[-1].call_id, ) else: messages_new.append(message) diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py index 129c62d6d..b0492f630 100644 --- a/gptme/llm/llm_openai.py +++ b/gptme/llm/llm_openai.py @@ -274,6 +274,7 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: modified_message["content"] = content if tool_calls: + # Clean content property if empty otherwise the call fails if not content: del modified_message["content"] modified_message["tool_calls"] = tool_calls @@ -283,6 +284,41 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: yield message +def _merge_tool_results_with_same_call_id( + messages_dicts: Iterable[dict], +) -> list[dict]: # Generator[dict, None, None]: + """ + When we call a tool, this tool can potentially yield multiple messages. However + the API expect to have only one tool result per tool call. This function tries + to merge subsequent tool results with the same call ID as expected by + the API. + """ + + messages_dicts = iter(messages_dicts) + + messages_new: list[dict] = [] + while message := next(messages_dicts, None): + if messages_new and ( + message["role"] == "tool" + and messages_new[-1]["role"] == "tool" + and message["tool_call_id"] == messages_new[-1]["tool_call_id"] + ): + prev_msg = messages_new[-1] + content = message["content"] + if not isinstance(content, list): + content = {"type": "text", "text": content} + + messages_new[-1] = { + "role": "tool", + "content": prev_msg["content"] + content, + "tool_call_id": prev_msg["tool_call_id"], + } + else: + messages_new.append(message) + + return messages_new + + def _process_file(msg: dict, model: ModelMeta) -> dict: message_content = msg["content"] if model.provider == "deepseek": @@ -423,7 +459,9 @@ def _prepare_messages_for_api( tools_dict = [_spec2tool(tool, model) for tool in tools] if tools else None if tools_dict is not None: - messages_dicts = _handle_tools(messages_dicts) + messages_dicts = _merge_tool_results_with_same_call_id( + _handle_tools(messages_dicts) + ) messages_dicts = _transform_msgs_for_special_provider(messages_dicts, model) diff --git a/tests/test_llm_anthropic.py b/tests/test_llm_anthropic.py index 8a5081e52..869f7cf39 100644 --- a/tests/test_llm_anthropic.py +++ b/tests/test_llm_anthropic.py @@ -97,6 +97,7 @@ def test_message_conversion_with_tools(): content='\nSomething\n\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}', ), Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"), + Message(role="system", content="(Modified by user)", call_id="tool_call_id"), ] tool_save = get_tool("save") @@ -152,7 +153,12 @@ def test_message_conversion_with_tools(): "content": [ { "type": "tool_result", - "content": [{"type": "text", "text": "Saved to toto.txt"}], + "content": [ + { + "type": "text", + "text": "Saved to toto.txt\n\n(Modified by user)", + } + ], "tool_use_id": "tool_call_id", "cache_control": {"type": "ephemeral"}, } diff --git a/tests/test_llm_openai.py b/tests/test_llm_openai.py index 34d75817f..401ae99f2 100644 --- a/tests/test_llm_openai.py +++ b/tests/test_llm_openai.py @@ -116,6 +116,7 @@ def test_message_conversion_with_tools(): content='\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}', ), Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"), + Message(role="system", content="(Modified by user)", call_id="tool_call_id"), ] set_default_model("openai/gpt-4o") @@ -193,7 +194,10 @@ def test_message_conversion_with_tools(): }, { "role": "tool", - "content": [{"type": "text", "text": "Saved to toto.txt"}], + "content": [ + {"type": "text", "text": "Saved to toto.txt"}, + {"type": "text", "text": "(Modified by user)"}, + ], "tool_call_id": "tool_call_id", }, ]