Skip to content

Commit

Permalink
refactor: improve tools API and fix issues
Browse files Browse the repository at this point in the history
- Improve toolcall regex and JSON parsing
- Update parameter order in tools (path before content)
- Add better error handling and logging
- Improve streaming output handling
- Add test cases for tool use parsing
  • Loading branch information
ErikBjare committed Dec 3, 2024
1 parent b426f13 commit 5c18ec9
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 74 deletions.
14 changes: 9 additions & 5 deletions gptme/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from ..config import get_config
from ..constants import PROMPT_ASSISTANT
from ..message import Message, format_msgs, len_tokens
from ..tools import ToolSpec, ToolUse
from ..util import console
from .llm_anthropic import chat as chat_anthropic
from .llm_anthropic import get_client as get_anthropic_client
from .llm_anthropic import init as init_anthropic
Expand All @@ -17,15 +20,12 @@
from .llm_openai import get_client as get_openai_client
from .llm_openai import init as init_openai
from .llm_openai import stream as stream_openai
from ..message import Message, format_msgs, len_tokens
from .models import (
MODELS,
PROVIDERS_OPENAI,
Provider,
get_summary_model,
)
from ..tools import ToolSpec, ToolUse
from ..util import console

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,8 +114,12 @@ def print_clear():
if char == "\n":
# TODO: make this more robust/general, maybe with a callback that runs on each char/chunk
# pause inference on finished code-block, letting user run the command before continuing
tooluses = list(ToolUse.iter_from_content(output))
if tooluses and any(tooluse.is_runnable for tooluse in tooluses):
tooluses = [
tooluse
for tooluse in ToolUse.iter_from_content(output)
if tooluse.is_runnable
]
if tooluses:
logger.debug("Found tool use, breaking")
break
except KeyboardInterrupt:
Expand Down
67 changes: 37 additions & 30 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
from ..message import Message, len_tokens, msgs2dicts
from ..tools.base import Parameter, ToolSpec

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
# noreorder
import anthropic # fmt: skip
import anthropic.types.beta.prompt_caching # fmt: skip
from anthropic import Anthropic # fmt: skip

logger = logging.getLogger(__name__)

_anthropic: "Anthropic | None" = None

ALLOWED_FILE_EXTS = ["jpg", "jpeg", "png", "gif"]
Expand Down Expand Up @@ -109,33 +108,41 @@ def stream(
if hasattr(chunk, "usage"):
# print(chunk.usage)
pass
if chunk.type == "content_block_start":
block = chunk.content_block
if isinstance(block, anthropic.types.ToolUseBlock):
tool_use = block
yield f"@{tool_use.name}: "
elif chunk.type == "content_block_delta":
chunk = cast(anthropic.types.RawContentBlockDeltaEvent, chunk)
delta = chunk.delta
if isinstance(delta, anthropic.types.TextDelta):
yield delta.text
elif isinstance(delta, anthropic.types.InputJSONDelta):
yield delta.partial_json
else:
logger.warning("Unknown delta type: %s", delta)
elif chunk.type == "content_block_stop":
pass
elif chunk.type == "text":
# full text message
pass
elif chunk.type == "message_delta":
pass
elif chunk.type == "message_stop":
pass
else:
# print(f"Unknown chunk type: {chunk.type}")
# print(chunk)
pass
match chunk.type:
case "content_block_start":
chunk = cast(anthropic.types.RawContentBlockStartEvent, chunk)
block = chunk.content_block
if isinstance(block, anthropic.types.ToolUseBlock):
tool_use = block
yield f"\n@{tool_use.name}: "
elif isinstance(block, anthropic.types.TextBlock):
if block.text:
logger.warning("unexpected text block: %s", block.text)
else:
print(f"Unknown block type: {block}")
case "content_block_delta":
chunk = cast(anthropic.types.RawContentBlockDeltaEvent, chunk)
delta = chunk.delta
if isinstance(delta, anthropic.types.TextDelta):
yield delta.text
elif isinstance(delta, anthropic.types.InputJSONDelta):
yield delta.partial_json
else:
logger.warning("Unknown delta type: %s", delta)
case "content_block_stop":
pass
case "text":
# full text message
pass
case "message_start":
pass
case "message_delta":
pass
case "message_stop":
pass
case _:
# print(f"Unknown chunk type: {chunk.type}")
pass


def _handle_files(message_dicts: list[dict]) -> list[dict]:
Expand Down
3 changes: 2 additions & 1 deletion gptme/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def stream(
func = tool_call.function
if isinstance(func, ChoiceDeltaToolCallFunction):
if func.name:
yield "\n@" + func.name + ": "
print("tool call in openai")
yield f"\n@{func.name}: "
if func.arguments:
yield func.arguments

Expand Down
7 changes: 4 additions & 3 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from ..message import Message
from .base import (
ConfirmFunc,
ToolFormat,
ToolSpec,
ToolUse,
ToolFormat,
get_tool_format,
set_tool_format,
)
Expand All @@ -18,6 +18,7 @@
from .patch import tool as patch_tool
from .python import register_function
from .python import tool as python_tool
from .rag import tool as rag_tool
from .read import tool as tool_read
from .save import tool_append, tool_save
from .screenshot import tool as screenshot_tool
Expand All @@ -26,7 +27,6 @@
from .tmux import tool as tmux_tool
from .vision import tool as vision_tool
from .youtube import tool as youtube_tool
from .rag import tool as rag_tool

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,7 +110,8 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None,
assert msg.role == "assistant", "Only assistant messages can be executed"

for tooluse in ToolUse.iter_from_content(msg.content):
yield from tooluse.execute(confirm)
if tooluse.is_runnable:
yield from tooluse.execute(confirm)


# Called often when checking streaming output for executable blocks,
Expand Down
78 changes: 64 additions & 14 deletions gptme/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,45 @@
tool_format: ToolFormat = "markdown"
exclusive_mode = False

toolcall_re = re.compile(r"@(\w+):\s*(\{.*?\})")
# Match tool name and start of JSON
toolcall_re = re.compile(r"^@(\w+):\s*({.*)", re.M | re.S)


def find_json_end(s: str, start: int) -> int | None:
"""Find the end of a JSON object by counting braces"""
stack = []
in_string = False
escape = False

for i, c in enumerate(s[start:], start):
if escape:
escape = False
continue

if c == "\\":
escape = True
elif c == '"' and not escape:
in_string = not in_string
elif not in_string:
if c == "{":
stack.append(c)
elif c == "}":
if not stack:
return None
stack.pop()
if not stack:
return i + 1
return None


def extract_json(content: str, match: re.Match) -> str | None:
"""Extract complete JSON object starting from a regex match"""
json_start = match.start(2) # start of the JSON content
json_end = find_json_end(content, json_start)
if json_end is None:
return None
return content[json_start:json_end]


ConfirmFunc = Callable[[str], bool]

Expand Down Expand Up @@ -295,11 +333,20 @@ def iter_from_content(cls, content: str) -> Generator["ToolUse", None, None]:
for tool_use in tool_uses:
yield tool_use

# check if its a toolcall
# check if its a toolcall and extract valid JSON
if match := toolcall_re.search(content):
tool_name = match.group(1)
kwargs = cast(dict[str, str], json_repair.loads(match.group(2)))
yield ToolUse(tool_name, None, None, kwargs=kwargs)
if (json_str := extract_json(content, match)) is not None:
try:
kwargs = json_repair.loads(json_str)
if not isinstance(kwargs, dict):
logger.debug(f"JSON repair result is not a dict: {kwargs}")
return
yield ToolUse(
tool_name, None, None, kwargs=cast(dict[str, str], kwargs)
)
except json.JSONDecodeError:
logger.debug(f"Failed to parse JSON: {json_str}")

@classmethod
def _iter_from_markdown(cls, content: str) -> Generator["ToolUse", None, None]:
Expand Down Expand Up @@ -360,7 +407,7 @@ def to_output(self, tool_format: ToolFormat = "markdown") -> str:
elif tool_format == "xml":
return self._to_xml()
elif tool_format == "tool":
return self._to_json()
return self._to_toolcall()

def _to_markdown(self) -> str:
assert self.args is not None
Expand All @@ -373,27 +420,30 @@ def _to_xml(self) -> str:
args_str = "" if not args else f" args='{args}'"
return f"<tool-use>\n<{self.tool}{args_str}>\n{self.content}\n</{self.tool}>\n</tool-use>"

def _to_json(self) -> str:
def _to_params(self) -> dict:
# noreorder
from . import get_tool # fmt: skip

base = {"name": self.tool, "parameters": {}}
if self.kwargs is not None:
base["parameters"] = self.kwargs
return self.kwargs
elif self.args is not None and self.content is not None:
# match positional args with kwargs
tool = get_tool(self.tool)

if tool:
if tool := get_tool(self.tool):
if self.args:
args = [self.content, *self.args]
args = [*self.args, self.content]
else:
args = [self.content]

json_parameters: dict[str, str] = {}
for index, param in enumerate(tool.parameters):
json_parameters[param.name] = args[index]

base["parameters"] = json_parameters
return json_parameters
return {}

def _to_json(self) -> str:
return json.dumps({"name": self.tool, "parameters": self._to_params()})

return json.dumps(base)
def _to_toolcall(self) -> str:
self._to_json()
return f"@{self.tool}: {json.dumps(self._to_params(), indent=2)}"
8 changes: 4 additions & 4 deletions gptme/tools/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,15 @@ def execute_patch(
block_types=["patch"],
parameters=[
Parameter(
name="patch",
name="path",
type="string",
description="The patch to apply.",
description="The path of the file to patch.",
required=True,
),
Parameter(
name="path",
name="patch",
type="string",
description="The path of the file to patch.",
description="The patch to apply.",
required=True,
),
],
Expand Down
23 changes: 14 additions & 9 deletions gptme/tools/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def execute_save(
current = path.read_text()
p = Patch(current, content)
# TODO: if inefficient save, replace request with patch (and vice versa), or even append
print_preview(p.diff_minimal(), "diff")
diff_str = p.diff_minimal()
if diff_str.strip():
print_preview(p.diff_minimal(), "diff")
else:
yield Message("system", "File already exists with identical content.")
return

if not confirm(f"Save to {fn}?"):
# early return
Expand Down Expand Up @@ -170,15 +175,15 @@ def execute_append(
block_types=["save"],
parameters=[
Parameter(
name="content",
name="path",
type="string",
description="The content to save",
description="The path of the file",
required=True,
),
Parameter(
name="path",
name="content",
type="string",
description="The path of the file",
description="The content to save",
required=True,
),
],
Expand All @@ -195,15 +200,15 @@ def execute_append(
block_types=["append"],
parameters=[
Parameter(
name="content",
name="path",
type="string",
description="The content to append",
description="The path of the file",
required=True,
),
Parameter(
name="path",
name="content",
type="string",
description="The path of the file",
description="The content to append",
required=True,
),
],
Expand Down
5 changes: 4 additions & 1 deletion gptme/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
def get_tokenizer(model: str):
import tiktoken # fmt: skip

global _warned_models
if "gpt-4o" in model:
return tiktoken.get_encoding("o200k_base")

try:
return tiktoken.encoding_for_model(model)
except KeyError:
global _warned_models
if model not in _warned_models:
logger.warning(
f"No tokenizer for '{model}'. Using tiktoken cl100k_base. Use results only as estimates."
Expand Down
Loading

0 comments on commit 5c18ec9

Please sign in to comment.