Skip to content

Commit

Permalink
refactor(tools-api): improve anthropic streaming and minor fixes
Browse files Browse the repository at this point in the history
- Enhance anthropic streaming to handle different chunk types
- Rename global anthropic to _anthropic
- Fix imports and formatting in various files

Part of the tools API refactoring work.
  • Loading branch information
ErikBjare committed Dec 3, 2024
1 parent 67f8f1c commit 988020e
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 25 deletions.
54 changes: 42 additions & 12 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,42 @@
Any,
Literal,
TypedDict,
cast,
)

import anthropic.types
from anthropic import NOT_GIVEN
from anthropic.types.beta.prompt_caching import PromptCachingBetaToolParam
from typing_extensions import Required

from ..tools.base import ToolSpec, Parameter

from ..constants import TEMPERATURE, TOP_P
from ..message import Message, len_tokens, msgs2dicts
from ..tools.base import Parameter, ToolSpec

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from anthropic import Anthropic # fmt: skip
from anthropic.types.beta.prompt_caching import PromptCachingBetaToolParam

anthropic: "Anthropic | None" = None
_anthropic: "Anthropic | None" = None

ALLOWED_FILE_EXTS = ["jpg", "jpeg", "png", "gif"]


def init(config):
global anthropic
global _anthropic
api_key = config.get_env_required("ANTHROPIC_API_KEY")
from anthropic import Anthropic # fmt: skip

anthropic = Anthropic(
_anthropic = Anthropic(
api_key=api_key,
max_retries=5,
)


def get_client() -> "Anthropic | None":
return anthropic
return _anthropic


class ToolAnthropic(TypedDict):
Expand All @@ -58,14 +59,14 @@ class MessagePart(TypedDict, total=False):


def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> str:
assert anthropic, "LLM not initialized"
assert _anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)

messages_dicts = _handle_files(msgs2dicts(messages))

tools_dict = [_spec2tool(tool) for tool in tools] if tools else None

response = anthropic.beta.prompt_caching.messages.create(
response = _anthropic.beta.prompt_caching.messages.create(
model=model,
messages=messages_dicts, # type: ignore
system=system_messages, # type: ignore
Expand All @@ -83,14 +84,14 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
def stream(
messages: list[Message], model: str, tools: list[ToolSpec] | None
) -> Generator[str, None, None]:
assert anthropic, "LLM not initialized"
assert _anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)

messages_dicts = _handle_files(msgs2dicts(messages))

tools_dict = [_spec2tool(tool) for tool in tools] if tools else None

with anthropic.beta.prompt_caching.messages.stream(
with _anthropic.beta.prompt_caching.messages.stream(
model=model,
messages=messages_dicts, # type: ignore
system=system_messages, # type: ignore
Expand All @@ -99,7 +100,36 @@ def stream(
max_tokens=4096,
tools=tools_dict if tools_dict else NOT_GIVEN,
) as stream:
yield from stream.text_stream
for chunk in stream:
if hasattr(chunk, "usage"):
print(chunk.usage)
if chunk.type == "content_block_start":
block = chunk.content_block
if isinstance(block, anthropic.types.ToolUseBlock):
tool_use = block
yield f"@{tool_use.name}: "
elif chunk.type == "content_block_delta":
chunk = cast(anthropic.types.RawContentBlockDeltaEvent, chunk)
delta = chunk.delta
if isinstance(delta, anthropic.types.TextDelta):
yield delta.text
elif isinstance(delta, anthropic.types.InputJSONDelta):
yield delta.partial_json
else:
logger.warning("Unknown delta type: %s", delta)
elif chunk.type == "content_block_stop":
pass
elif chunk.type == "text":
# full text message
pass
elif chunk.type == "message_delta":
pass
elif chunk.type == "message_stop":
pass
else:
# print(f"Unknown chunk type: {chunk.type}")
# print(chunk)
pass


def _handle_files(message_dicts: list[dict]) -> list[dict]:
Expand Down
8 changes: 5 additions & 3 deletions gptme/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from ..commands import execute_cmd
from ..dirs import get_logs_dir
from ..llm import _stream
from ..llm.models import get_model
from ..logmanager import LogManager, get_user_conversations, prepare_messages
from ..message import Message
from ..llm.models import get_model
from ..tools import execute_msg
from ..tools.base import ToolUse

Expand Down Expand Up @@ -116,7 +116,7 @@ def api_conversation_generate(logfile: str):
# Non-streaming response
try:
# Get complete response
output = "".join(_stream(msgs, model))
output = "".join(_stream(msgs, model, tools=None))

# Store the message
msg = Message("assistant", output)
Expand Down Expand Up @@ -170,7 +170,9 @@ def generate() -> Generator[str, None, None]:

# Stream tokens from the model
logger.debug(f"Starting token stream with model {model}")
for char in (char for chunk in _stream(msgs, model) for char in chunk):
for char in (
char for chunk in _stream(msgs, model, tools=None) for char in chunk
):
output += char
# Send each token as a JSON event
yield f"data: {flask.json.dumps({'role': 'assistant', 'content': char, 'stored': False})}\n\n"
Expand Down
16 changes: 9 additions & 7 deletions gptme/tools/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import re
from collections.abc import Callable, Generator
from logging import getLogger
from typing import (
TYPE_CHECKING,
TypeVar,
)
from typing import TYPE_CHECKING, TypeVar

from ..message import Message
from ..util import print_preview
from .base import ConfirmFunc, Parameter, ToolSpec, ToolUse
from .base import (
ConfirmFunc,
Parameter,
ToolSpec,
ToolUse,
)

if TYPE_CHECKING:
from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip
Expand Down Expand Up @@ -166,13 +168,13 @@ def examples(tool_format):
#### It can write an example and then execute it:
> User: compute fib 10
> Assistant: To compute the 10th Fibonacci number, we can execute this code:
{ToolUse("ipython", [], """
{ToolUse("ipython", [], '''
def fib(n):
if n <= 1:
return n
return fib(n - 1) + fib(n - 2)
fib(10)
""".strip()).to_output(tool_format)}
'''.strip()).to_output(tool_format)}
> System: Executed code block.
{ToolUse("result", [], "55").to_output()}
""".strip()
Expand Down
2 changes: 1 addition & 1 deletion gptme/tools/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

instructions_format_append = {
"markdown": """
Use a code block with the language tag: `append <path>`
Use a code block with the language tag: `append <path>`
to append the code block content to the file at the given path.""".strip(),
}

Expand Down
2 changes: 1 addition & 1 deletion gptme/tools/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


instructions = f"""
The given command will be executed in a stateful bash shell.
The given command will be executed in a stateful bash shell.
The shell tool will respond with the output of the execution.
Do not use EOF/HereDoc syntax to send multiline commands, as the assistant will not be able to handle it.
Expand Down
3 changes: 2 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,9 @@ def test_tool_format_option(
args.append(tool_format)
args.append("test")

with patch("gptme.chat.reply", return_value=[]) as mock_reply:
with patch("gptme.llm.reply", return_value=[]) as mock_reply:
result = runner.invoke(gptme.cli.main, args)
assert result.exit_code == 0

mock_reply.assert_called_once()

Expand Down

0 comments on commit 988020e

Please sign in to comment.