Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cleanup and fixes for tools API #303

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion gptme/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,14 @@ def confirm_func(msg) -> bool:
while True:
try:
set_interruptible()
response_msgs = list(step(manager.log, stream, confirm_func, tool_format=tool_format))
response_msgs = list(
step(
manager.log,
stream,
confirm_func,
tool_format=tool_format,
)
)
except KeyboardInterrupt:
console.log("Interrupted. Stopping current execution.")
manager.append(Message("system", "Interrupted"))
Expand Down
14 changes: 9 additions & 5 deletions gptme/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from ..config import get_config
from ..constants import PROMPT_ASSISTANT
from ..message import Message, format_msgs, len_tokens
from ..tools import ToolSpec, ToolUse
from ..util import console
from .llm_anthropic import chat as chat_anthropic
from .llm_anthropic import get_client as get_anthropic_client
from .llm_anthropic import init as init_anthropic
Expand All @@ -17,15 +20,12 @@
from .llm_openai import get_client as get_openai_client
from .llm_openai import init as init_openai
from .llm_openai import stream as stream_openai
from ..message import Message, format_msgs, len_tokens
from .models import (
MODELS,
PROVIDERS_OPENAI,
Provider,
get_summary_model,
)
from ..tools import ToolSpec, ToolUse
from ..util import console

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,8 +114,12 @@ def print_clear():
if char == "\n":
# TODO: make this more robust/general, maybe with a callback that runs on each char/chunk
# pause inference on finished code-block, letting user run the command before continuing
tooluses = list(ToolUse.iter_from_content(output))
if tooluses and any(tooluse.is_runnable for tooluse in tooluses):
tooluses = [
tooluse
for tooluse in ToolUse.iter_from_content(output)
if tooluse.is_runnable
]
if tooluses:
logger.debug("Found tool use, breaking")
break
except KeyboardInterrupt:
Expand Down
68 changes: 38 additions & 30 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
from ..message import Message, len_tokens, msgs2dicts
from ..tools.base import Parameter, ToolSpec

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
# noreorder
import anthropic # fmt: skip
import anthropic.types.beta.prompt_caching # fmt: skip
from anthropic import Anthropic # fmt: skip

logger = logging.getLogger(__name__)

_anthropic: "Anthropic | None" = None

ALLOWED_FILE_EXTS = ["jpg", "jpeg", "png", "gif"]
Expand Down Expand Up @@ -107,34 +106,43 @@ def stream(
) as stream:
for chunk in stream:
if hasattr(chunk, "usage"):
print(chunk.usage)
if chunk.type == "content_block_start":
block = chunk.content_block
if isinstance(block, anthropic.types.ToolUseBlock):
tool_use = block
yield f"@{tool_use.name}: "
elif chunk.type == "content_block_delta":
chunk = cast(anthropic.types.RawContentBlockDeltaEvent, chunk)
delta = chunk.delta
if isinstance(delta, anthropic.types.TextDelta):
yield delta.text
elif isinstance(delta, anthropic.types.InputJSONDelta):
yield delta.partial_json
else:
logger.warning("Unknown delta type: %s", delta)
elif chunk.type == "content_block_stop":
pass
elif chunk.type == "text":
# full text message
pass
elif chunk.type == "message_delta":
pass
elif chunk.type == "message_stop":
pass
else:
# print(f"Unknown chunk type: {chunk.type}")
# print(chunk)
# print(chunk.usage)
pass
match chunk.type:
case "content_block_start":
chunk = cast(anthropic.types.RawContentBlockStartEvent, chunk)
block = chunk.content_block
if isinstance(block, anthropic.types.ToolUseBlock):
tool_use = block
yield f"\n@{tool_use.name}: "
elif isinstance(block, anthropic.types.TextBlock):
if block.text:
logger.warning("unexpected text block: %s", block.text)
else:
print(f"Unknown block type: {block}")
case "content_block_delta":
chunk = cast(anthropic.types.RawContentBlockDeltaEvent, chunk)
delta = chunk.delta
if isinstance(delta, anthropic.types.TextDelta):
yield delta.text
elif isinstance(delta, anthropic.types.InputJSONDelta):
yield delta.partial_json
else:
logger.warning("Unknown delta type: %s", delta)
case "content_block_stop":
pass
case "text":
# full text message
pass
case "message_start":
pass
case "message_delta":
pass
case "message_stop":
pass
case _:
# print(f"Unknown chunk type: {chunk.type}")
pass


def _handle_files(message_dicts: list[dict]) -> list[dict]:
Expand Down
2 changes: 1 addition & 1 deletion gptme/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def stream(
func = tool_call.function
if isinstance(func, ChoiceDeltaToolCallFunction):
if func.name:
yield "\n@" + func.name + ": "
yield f"\n@{func.name}: "
if func.arguments:
yield func.arguments

Expand Down
2 changes: 1 addition & 1 deletion gptme/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from .__version__ import __version__
from .config import get_config, get_project_config
from .message import Message
from .util import document_prompt_function, get_project_dir
from .tools import ToolFormat
from .util import document_prompt_function, get_project_dir

PromptType = Literal["full", "short"]

Expand Down
7 changes: 4 additions & 3 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from ..message import Message
from .base import (
ConfirmFunc,
ToolFormat,
ToolSpec,
ToolUse,
ToolFormat,
get_tool_format,
set_tool_format,
)
Expand All @@ -18,6 +18,7 @@
from .patch import tool as patch_tool
from .python import register_function
from .python import tool as python_tool
from .rag import tool as rag_tool
from .read import tool as tool_read
from .save import tool_append, tool_save
from .screenshot import tool as screenshot_tool
Expand All @@ -26,7 +27,6 @@
from .tmux import tool as tmux_tool
from .vision import tool as vision_tool
from .youtube import tool as youtube_tool
from .rag import tool as rag_tool

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,7 +110,8 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None,
assert msg.role == "assistant", "Only assistant messages can be executed"

for tooluse in ToolUse.iter_from_content(msg.content):
yield from tooluse.execute(confirm)
if tooluse.is_runnable:
yield from tooluse.execute(confirm)


# Called often when checking streaming output for executable blocks,
Expand Down
89 changes: 73 additions & 16 deletions gptme/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import json_repair
import json
import logging
import re
import types
from collections.abc import Callable, Generator
from dataclasses import dataclass, field
from textwrap import indent
from typing import Any, Literal, Protocol, TypeAlias, cast, get_origin
from typing import (
Any,
Literal,
Protocol,
TypeAlias,
cast,
get_origin,
)

import json_repair
from lxml import etree

from ..codeblock import Codeblock
Expand All @@ -24,7 +31,45 @@
tool_format: ToolFormat = "markdown"
exclusive_mode = False

toolcall_re = re.compile(r"@(\w+): (\{.*?\})")
# Match tool name and start of JSON
toolcall_re = re.compile(r"^@(\w+):\s*({.*)", re.M | re.S)


def find_json_end(s: str, start: int) -> int | None:
"""Find the end of a JSON object by counting braces"""
stack = []
in_string = False
escape = False

for i, c in enumerate(s[start:], start):
if escape:
escape = False
continue

if c == "\\":
escape = True
elif c == '"' and not escape:
in_string = not in_string
elif not in_string:
if c == "{":
stack.append(c)
elif c == "}":
if not stack:
return None
stack.pop()
if not stack:
return i + 1
return None


def extract_json(content: str, match: re.Match) -> str | None:
"""Extract complete JSON object starting from a regex match"""
json_start = match.start(2) # start of the JSON content
json_end = find_json_end(content, json_start)
if json_end is None:
return None
return content[json_start:json_end]


ConfirmFunc = Callable[[str], bool]

Expand Down Expand Up @@ -288,11 +333,20 @@ def iter_from_content(cls, content: str) -> Generator["ToolUse", None, None]:
for tool_use in tool_uses:
yield tool_use

# check if its a toolcall
# check if its a toolcall and extract valid JSON
if match := toolcall_re.search(content):
tool_name = match.group(1)
kwargs = cast(dict[str, str], json_repair.loads(match.group(2)))
yield ToolUse(tool_name, None, None, kwargs=kwargs)
if (json_str := extract_json(content, match)) is not None:
try:
kwargs = json_repair.loads(json_str)
if not isinstance(kwargs, dict):
logger.debug(f"JSON repair result is not a dict: {kwargs}")
return
yield ToolUse(
tool_name, None, None, kwargs=cast(dict[str, str], kwargs)
)
except json.JSONDecodeError:
logger.debug(f"Failed to parse JSON: {json_str}")

@classmethod
def _iter_from_markdown(cls, content: str) -> Generator["ToolUse", None, None]:
Expand Down Expand Up @@ -353,7 +407,7 @@ def to_output(self, tool_format: ToolFormat = "markdown") -> str:
elif tool_format == "xml":
return self._to_xml()
elif tool_format == "tool":
return self._to_json()
return self._to_toolcall()

def _to_markdown(self) -> str:
assert self.args is not None
Expand All @@ -366,27 +420,30 @@ def _to_xml(self) -> str:
args_str = "" if not args else f" args='{args}'"
return f"<tool-use>\n<{self.tool}{args_str}>\n{self.content}\n</{self.tool}>\n</tool-use>"

def _to_json(self) -> str:
def _to_params(self) -> dict:
# noreorder
from . import get_tool # fmt: skip

base = {"name": self.tool, "parameters": {}}
if self.kwargs is not None:
base["parameters"] = self.kwargs
return self.kwargs
elif self.args is not None and self.content is not None:
# match positional args with kwargs
tool = get_tool(self.tool)

if tool:
if tool := get_tool(self.tool):
if self.args:
args = [self.content, *self.args]
args = [*self.args, self.content]
else:
args = [self.content]

json_parameters: dict[str, str] = {}
for index, param in enumerate(tool.parameters):
json_parameters[param.name] = args[index]

base["parameters"] = json_parameters
return json_parameters
return {}

def _to_json(self) -> str:
return json.dumps({"name": self.tool, "parameters": self._to_params()})

return json.dumps(base)
def _to_toolcall(self) -> str:
self._to_json()
return f"@{self.tool}: {json.dumps(self._to_params(), indent=2)}"
8 changes: 4 additions & 4 deletions gptme/tools/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,15 @@ def execute_patch(
block_types=["patch"],
parameters=[
Parameter(
name="patch",
name="path",
type="string",
description="The patch to apply.",
description="The path of the file to patch.",
required=True,
),
Parameter(
name="path",
name="patch",
type="string",
description="The path of the file to patch.",
description="The patch to apply.",
required=True,
),
],
Expand Down
Loading
Loading