From 5c18ec941a5e6f4a174fca8c08fdaee149118e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Tue, 3 Dec 2024 14:42:14 +0100 Subject: [PATCH] refactor: improve tools API and fix issues - Improve toolcall regex and JSON parsing - Update parameter order in tools (path before content) - Add better error handling and logging - Improve streaming output handling - Add test cases for tool use parsing --- gptme/llm/__init__.py | 14 +++-- gptme/llm/llm_anthropic.py | 67 +++++++++++++----------- gptme/llm/llm_openai.py | 3 +- gptme/tools/__init__.py | 7 +-- gptme/tools/base.py | 78 +++++++++++++++++++++++----- gptme/tools/patch.py | 8 +-- gptme/tools/save.py | 23 +++++---- gptme/util/__init__.py | 5 +- tests/test_prompt_tools.py | 4 +- tests/test_tool_use.py | 103 +++++++++++++++++++++++++++++++++++-- 10 files changed, 238 insertions(+), 74 deletions(-) diff --git a/gptme/llm/__init__.py b/gptme/llm/__init__.py index 5fff2c45..9fc24bd8 100644 --- a/gptme/llm/__init__.py +++ b/gptme/llm/__init__.py @@ -9,6 +9,9 @@ from ..config import get_config from ..constants import PROMPT_ASSISTANT +from ..message import Message, format_msgs, len_tokens +from ..tools import ToolSpec, ToolUse +from ..util import console from .llm_anthropic import chat as chat_anthropic from .llm_anthropic import get_client as get_anthropic_client from .llm_anthropic import init as init_anthropic @@ -17,15 +20,12 @@ from .llm_openai import get_client as get_openai_client from .llm_openai import init as init_openai from .llm_openai import stream as stream_openai -from ..message import Message, format_msgs, len_tokens from .models import ( MODELS, PROVIDERS_OPENAI, Provider, get_summary_model, ) -from ..tools import ToolSpec, ToolUse -from ..util import console logger = logging.getLogger(__name__) @@ -114,8 +114,12 @@ def print_clear(): if char == "\n": # TODO: make this more robust/general, maybe with a callback that runs on each char/chunk # pause inference on finished code-block, letting user run the command before continuing - tooluses = list(ToolUse.iter_from_content(output)) - if tooluses and any(tooluse.is_runnable for tooluse in tooluses): + tooluses = [ + tooluse + for tooluse in ToolUse.iter_from_content(output) + if tooluse.is_runnable + ] + if tooluses: logger.debug("Found tool use, breaking") break except KeyboardInterrupt: diff --git a/gptme/llm/llm_anthropic.py b/gptme/llm/llm_anthropic.py index 3f13483c..883195d0 100644 --- a/gptme/llm/llm_anthropic.py +++ b/gptme/llm/llm_anthropic.py @@ -16,15 +16,14 @@ from ..message import Message, len_tokens, msgs2dicts from ..tools.base import Parameter, ToolSpec -logger = logging.getLogger(__name__) - - if TYPE_CHECKING: # noreorder import anthropic # fmt: skip import anthropic.types.beta.prompt_caching # fmt: skip from anthropic import Anthropic # fmt: skip +logger = logging.getLogger(__name__) + _anthropic: "Anthropic | None" = None ALLOWED_FILE_EXTS = ["jpg", "jpeg", "png", "gif"] @@ -109,33 +108,41 @@ def stream( if hasattr(chunk, "usage"): # print(chunk.usage) pass - if chunk.type == "content_block_start": - block = chunk.content_block - if isinstance(block, anthropic.types.ToolUseBlock): - tool_use = block - yield f"@{tool_use.name}: " - elif chunk.type == "content_block_delta": - chunk = cast(anthropic.types.RawContentBlockDeltaEvent, chunk) - delta = chunk.delta - if isinstance(delta, anthropic.types.TextDelta): - yield delta.text - elif isinstance(delta, anthropic.types.InputJSONDelta): - yield delta.partial_json - else: - logger.warning("Unknown delta type: %s", delta) - elif chunk.type == "content_block_stop": - pass - elif chunk.type == "text": - # full text message - pass - elif chunk.type == "message_delta": - pass - elif chunk.type == "message_stop": - pass - else: - # print(f"Unknown chunk type: {chunk.type}") - # print(chunk) - pass + match chunk.type: + case "content_block_start": + chunk = cast(anthropic.types.RawContentBlockStartEvent, chunk) + block = chunk.content_block + if isinstance(block, anthropic.types.ToolUseBlock): + tool_use = block + yield f"\n@{tool_use.name}: " + elif isinstance(block, anthropic.types.TextBlock): + if block.text: + logger.warning("unexpected text block: %s", block.text) + else: + print(f"Unknown block type: {block}") + case "content_block_delta": + chunk = cast(anthropic.types.RawContentBlockDeltaEvent, chunk) + delta = chunk.delta + if isinstance(delta, anthropic.types.TextDelta): + yield delta.text + elif isinstance(delta, anthropic.types.InputJSONDelta): + yield delta.partial_json + else: + logger.warning("Unknown delta type: %s", delta) + case "content_block_stop": + pass + case "text": + # full text message + pass + case "message_start": + pass + case "message_delta": + pass + case "message_stop": + pass + case _: + # print(f"Unknown chunk type: {chunk.type}") + pass def _handle_files(message_dicts: list[dict]) -> list[dict]: diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py index 470e2a97..cd65de6f 100644 --- a/gptme/llm/llm_openai.py +++ b/gptme/llm/llm_openai.py @@ -196,7 +196,8 @@ def stream( func = tool_call.function if isinstance(func, ChoiceDeltaToolCallFunction): if func.name: - yield "\n@" + func.name + ": " + print("tool call in openai") + yield f"\n@{func.name}: " if func.arguments: yield func.arguments diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py index 62caf3a5..66083dac 100644 --- a/gptme/tools/__init__.py +++ b/gptme/tools/__init__.py @@ -5,9 +5,9 @@ from ..message import Message from .base import ( ConfirmFunc, + ToolFormat, ToolSpec, ToolUse, - ToolFormat, get_tool_format, set_tool_format, ) @@ -18,6 +18,7 @@ from .patch import tool as patch_tool from .python import register_function from .python import tool as python_tool +from .rag import tool as rag_tool from .read import tool as tool_read from .save import tool_append, tool_save from .screenshot import tool as screenshot_tool @@ -26,7 +27,6 @@ from .tmux import tool as tmux_tool from .vision import tool as vision_tool from .youtube import tool as youtube_tool -from .rag import tool as rag_tool logger = logging.getLogger(__name__) @@ -110,7 +110,8 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None, assert msg.role == "assistant", "Only assistant messages can be executed" for tooluse in ToolUse.iter_from_content(msg.content): - yield from tooluse.execute(confirm) + if tooluse.is_runnable: + yield from tooluse.execute(confirm) # Called often when checking streaming output for executable blocks, diff --git a/gptme/tools/base.py b/gptme/tools/base.py index ec616230..1fa71a13 100644 --- a/gptme/tools/base.py +++ b/gptme/tools/base.py @@ -31,7 +31,45 @@ tool_format: ToolFormat = "markdown" exclusive_mode = False -toolcall_re = re.compile(r"@(\w+):\s*(\{.*?\})") +# Match tool name and start of JSON +toolcall_re = re.compile(r"^@(\w+):\s*({.*)", re.M | re.S) + + +def find_json_end(s: str, start: int) -> int | None: + """Find the end of a JSON object by counting braces""" + stack = [] + in_string = False + escape = False + + for i, c in enumerate(s[start:], start): + if escape: + escape = False + continue + + if c == "\\": + escape = True + elif c == '"' and not escape: + in_string = not in_string + elif not in_string: + if c == "{": + stack.append(c) + elif c == "}": + if not stack: + return None + stack.pop() + if not stack: + return i + 1 + return None + + +def extract_json(content: str, match: re.Match) -> str | None: + """Extract complete JSON object starting from a regex match""" + json_start = match.start(2) # start of the JSON content + json_end = find_json_end(content, json_start) + if json_end is None: + return None + return content[json_start:json_end] + ConfirmFunc = Callable[[str], bool] @@ -295,11 +333,20 @@ def iter_from_content(cls, content: str) -> Generator["ToolUse", None, None]: for tool_use in tool_uses: yield tool_use - # check if its a toolcall + # check if its a toolcall and extract valid JSON if match := toolcall_re.search(content): tool_name = match.group(1) - kwargs = cast(dict[str, str], json_repair.loads(match.group(2))) - yield ToolUse(tool_name, None, None, kwargs=kwargs) + if (json_str := extract_json(content, match)) is not None: + try: + kwargs = json_repair.loads(json_str) + if not isinstance(kwargs, dict): + logger.debug(f"JSON repair result is not a dict: {kwargs}") + return + yield ToolUse( + tool_name, None, None, kwargs=cast(dict[str, str], kwargs) + ) + except json.JSONDecodeError: + logger.debug(f"Failed to parse JSON: {json_str}") @classmethod def _iter_from_markdown(cls, content: str) -> Generator["ToolUse", None, None]: @@ -360,7 +407,7 @@ def to_output(self, tool_format: ToolFormat = "markdown") -> str: elif tool_format == "xml": return self._to_xml() elif tool_format == "tool": - return self._to_json() + return self._to_toolcall() def _to_markdown(self) -> str: assert self.args is not None @@ -373,20 +420,17 @@ def _to_xml(self) -> str: args_str = "" if not args else f" args='{args}'" return f"\n<{self.tool}{args_str}>\n{self.content}\n\n" - def _to_json(self) -> str: + def _to_params(self) -> dict: # noreorder from . import get_tool # fmt: skip - base = {"name": self.tool, "parameters": {}} if self.kwargs is not None: - base["parameters"] = self.kwargs + return self.kwargs elif self.args is not None and self.content is not None: # match positional args with kwargs - tool = get_tool(self.tool) - - if tool: + if tool := get_tool(self.tool): if self.args: - args = [self.content, *self.args] + args = [*self.args, self.content] else: args = [self.content] @@ -394,6 +438,12 @@ def _to_json(self) -> str: for index, param in enumerate(tool.parameters): json_parameters[param.name] = args[index] - base["parameters"] = json_parameters + return json_parameters + return {} + + def _to_json(self) -> str: + return json.dumps({"name": self.tool, "parameters": self._to_params()}) - return json.dumps(base) + def _to_toolcall(self) -> str: + self._to_json() + return f"@{self.tool}: {json.dumps(self._to_params(), indent=2)}" diff --git a/gptme/tools/patch.py b/gptme/tools/patch.py index 1c315fd3..62f031fd 100644 --- a/gptme/tools/patch.py +++ b/gptme/tools/patch.py @@ -256,15 +256,15 @@ def execute_patch( block_types=["patch"], parameters=[ Parameter( - name="patch", + name="path", type="string", - description="The patch to apply.", + description="The path of the file to patch.", required=True, ), Parameter( - name="path", + name="patch", type="string", - description="The path of the file to patch.", + description="The patch to apply.", required=True, ), ], diff --git a/gptme/tools/save.py b/gptme/tools/save.py index 957fbd79..b2b97c26 100644 --- a/gptme/tools/save.py +++ b/gptme/tools/save.py @@ -92,7 +92,12 @@ def execute_save( current = path.read_text() p = Patch(current, content) # TODO: if inefficient save, replace request with patch (and vice versa), or even append - print_preview(p.diff_minimal(), "diff") + diff_str = p.diff_minimal() + if diff_str.strip(): + print_preview(p.diff_minimal(), "diff") + else: + yield Message("system", "File already exists with identical content.") + return if not confirm(f"Save to {fn}?"): # early return @@ -170,15 +175,15 @@ def execute_append( block_types=["save"], parameters=[ Parameter( - name="content", + name="path", type="string", - description="The content to save", + description="The path of the file", required=True, ), Parameter( - name="path", + name="content", type="string", - description="The path of the file", + description="The content to save", required=True, ), ], @@ -195,15 +200,15 @@ def execute_append( block_types=["append"], parameters=[ Parameter( - name="content", + name="path", type="string", - description="The content to append", + description="The path of the file", required=True, ), Parameter( - name="path", + name="content", type="string", - description="The path of the file", + description="The content to append", required=True, ), ], diff --git a/gptme/util/__init__.py b/gptme/util/__init__.py index 6739abac..41979edb 100644 --- a/gptme/util/__init__.py +++ b/gptme/util/__init__.py @@ -35,10 +35,13 @@ def get_tokenizer(model: str): import tiktoken # fmt: skip - global _warned_models + if "gpt-4o" in model: + return tiktoken.get_encoding("o200k_base") + try: return tiktoken.encoding_for_model(model) except KeyError: + global _warned_models if model not in _warned_models: logger.warning( f"No tokenizer for '{model}'. Using tiktoken cl100k_base. Use results only as estimates." diff --git a/tests/test_prompt_tools.py b/tests/test_prompt_tools.py index cd4fbdb3..ba7b7f32 100644 --- a/tests/test_prompt_tools.py +++ b/tests/test_prompt_tools.py @@ -44,7 +44,7 @@ [ "the `shell` tool", "aliases", - '{"name": "shell", "parameters": {"command": "cat file.txt"}}', + '@shell: {"command": "cat file.txt"}', "### Examples", ], [], @@ -57,7 +57,7 @@ "aliases", ], [ - '{"name": "shell", "parameters": {"command": "cat file.txt"}}', + '@shell: {"command": "cat file.txt"}', "### Examples", ], ), diff --git a/tests/test_tool_use.py b/tests/test_tool_use.py index fb721fe3..e43aaf77 100644 --- a/tests/test_tool_use.py +++ b/tests/test_tool_use.py @@ -1,6 +1,7 @@ +import json_repair import pytest -from gptme.tools.base import ToolUse from gptme.tools import init_tools +from gptme.tools.base import ToolUse, extract_json, toolcall_re @pytest.mark.parametrize( @@ -38,16 +39,22 @@ ( "tool", ["test.txt"], - "patch", + "...", None, - """{"name": "patch", "parameters": {"patch": "patch", "path": "test.txt"}}""", + """@patch: { + "path": "test.txt", + "patch": "..." +}""", ), ( "tool", ["test.txt"], "patch", - {"patch": "patch_kwargs", "path": "test_kwargs.txt"}, - """{"name": "patch", "parameters": {"patch": "patch_kwargs", "path": "test_kwargs.txt"}}""", + {"path": "test_kwargs.txt", "patch": "..."}, + """@patch: { + "path": "test_kwargs.txt", + "patch": "..." +}""", ), ], ) @@ -57,3 +64,89 @@ def test_tool_use_output_patch(tool_format, args, content, kwargs, expected): result = ToolUse("patch", args, content, kwargs).to_output(tool_format) assert result == expected + + +@pytest.mark.parametrize( + "content, expected_tool, expected_json", + [ + ( + '@tool: {"param": "value"}', + "tool", + '{"param": "value"}', + ), + ( + '@tool: {"missing": "comma" "key": "value"}', # json_repair can fix this + "tool", + '{"missing": "comma", "key": "value"}', + ), + ( + "@tool: {invalid json}", # json_repair can handle this + "tool", + "{}", + ), + ( + '@tool: {\n "param": "value"\n}', + "tool", + '{\n "param": "value"\n}', + ), + ( + '@tool: {\n "param": "value with\nnewline",\n "another": "value"\n}', + "tool", + '{\n "param": "value with\nnewline",\n "another": "value"\n}', + ), + ( + '@tool: {"param": {"nested": "value"}}', + "tool", + '{"param": {"nested": "value"}}', + ), + ( + '@tool: {"param": {"deeply": {"nested": "value"}}}', + "tool", + '{"param": {"deeply": {"nested": "value"}}}', + ), + ( + '@tool: {"text": "a string with } brace"}', + "tool", + '{"text": "a string with } brace"}', + ), + ( + '@tool: {"text": "a string with \\"quote\\" and } brace"}', + "tool", + '{"text": "a string with \\"quote\\" and } brace"}', + ), + ( + '@save: {"path": "hello.py", "content": "def main():\n print(\\"Hello, World!\\")\n \nif __name__ == \\"__main__\\":\n main()"}', + "save", + '{"path": "hello.py", "content": "def main():\n print(\\"Hello, World!\\")\n \nif __name__ == \\"__main__\\":\n main()"}', + ), + ], +) +def test_toolcall_regex(content, expected_tool, expected_json): + match = toolcall_re.search(content) + assert match is not None + assert match.group(1) == expected_tool + json_str = extract_json(content, match) + assert json_str is not None + # Parse both strings with json_repair to compare structure + expected_dict = json_repair.loads(expected_json) + actual_dict = json_repair.loads(json_str) + assert actual_dict == expected_dict + + +@pytest.mark.parametrize( + "content", + [ + "some text @tool: {'param': 'value'}", # leading characters + "@tool: {", # incomplete JSON + " @tool: {'param': 'value'}", # leading whitespace + '@tool: {"unclosed": "string}', # unclosed string + '@tool: {"unclosed": {', # unclosed nested object + '@tool: {"mismatched": "quote\'}', # mismatched quotes + # TODO: fix these + # "```\n@tool: {'param': 'value'}\n```", # inside codeblock + ], +) +def test_toolcall_regex_invalid(content): + # No ToolUse should be created for invalid content + tool_uses = list(ToolUse.iter_from_content(content)) + assert len(tool_uses) == 0