From 988020e1fb2fce4ca27615340da8dc7e919e919f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Tue, 3 Dec 2024 10:29:07 +0100 Subject: [PATCH] refactor(tools-api): improve anthropic streaming and minor fixes - Enhance anthropic streaming to handle different chunk types - Rename global anthropic to _anthropic - Fix imports and formatting in various files Part of the tools API refactoring work. --- gptme/llm/llm_anthropic.py | 54 +++++++++++++++++++++++++++++--------- gptme/server/api.py | 8 +++--- gptme/tools/python.py | 16 ++++++----- gptme/tools/save.py | 2 +- gptme/tools/shell.py | 2 +- tests/test_cli.py | 3 ++- 6 files changed, 60 insertions(+), 25 deletions(-) diff --git a/gptme/llm/llm_anthropic.py b/gptme/llm/llm_anthropic.py index 86fdbce5..48abc213 100644 --- a/gptme/llm/llm_anthropic.py +++ b/gptme/llm/llm_anthropic.py @@ -7,41 +7,42 @@ Any, Literal, TypedDict, + cast, ) +import anthropic.types from anthropic import NOT_GIVEN +from anthropic.types.beta.prompt_caching import PromptCachingBetaToolParam from typing_extensions import Required -from ..tools.base import ToolSpec, Parameter - from ..constants import TEMPERATURE, TOP_P from ..message import Message, len_tokens, msgs2dicts +from ..tools.base import Parameter, ToolSpec logger = logging.getLogger(__name__) if TYPE_CHECKING: from anthropic import Anthropic # fmt: skip - from anthropic.types.beta.prompt_caching import PromptCachingBetaToolParam -anthropic: "Anthropic | None" = None +_anthropic: "Anthropic | None" = None ALLOWED_FILE_EXTS = ["jpg", "jpeg", "png", "gif"] def init(config): - global anthropic + global _anthropic api_key = config.get_env_required("ANTHROPIC_API_KEY") from anthropic import Anthropic # fmt: skip - anthropic = Anthropic( + _anthropic = Anthropic( api_key=api_key, max_retries=5, ) def get_client() -> "Anthropic | None": - return anthropic + return _anthropic class ToolAnthropic(TypedDict): @@ -58,14 +59,14 @@ class MessagePart(TypedDict, total=False): def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> str: - assert anthropic, "LLM not initialized" + assert _anthropic, "LLM not initialized" messages, system_messages = _transform_system_messages(messages) messages_dicts = _handle_files(msgs2dicts(messages)) tools_dict = [_spec2tool(tool) for tool in tools] if tools else None - response = anthropic.beta.prompt_caching.messages.create( + response = _anthropic.beta.prompt_caching.messages.create( model=model, messages=messages_dicts, # type: ignore system=system_messages, # type: ignore @@ -83,14 +84,14 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s def stream( messages: list[Message], model: str, tools: list[ToolSpec] | None ) -> Generator[str, None, None]: - assert anthropic, "LLM not initialized" + assert _anthropic, "LLM not initialized" messages, system_messages = _transform_system_messages(messages) messages_dicts = _handle_files(msgs2dicts(messages)) tools_dict = [_spec2tool(tool) for tool in tools] if tools else None - with anthropic.beta.prompt_caching.messages.stream( + with _anthropic.beta.prompt_caching.messages.stream( model=model, messages=messages_dicts, # type: ignore system=system_messages, # type: ignore @@ -99,7 +100,36 @@ def stream( max_tokens=4096, tools=tools_dict if tools_dict else NOT_GIVEN, ) as stream: - yield from stream.text_stream + for chunk in stream: + if hasattr(chunk, "usage"): + print(chunk.usage) + if chunk.type == "content_block_start": + block = chunk.content_block + if isinstance(block, anthropic.types.ToolUseBlock): + tool_use = block + yield f"@{tool_use.name}: " + elif chunk.type == "content_block_delta": + chunk = cast(anthropic.types.RawContentBlockDeltaEvent, chunk) + delta = chunk.delta + if isinstance(delta, anthropic.types.TextDelta): + yield delta.text + elif isinstance(delta, anthropic.types.InputJSONDelta): + yield delta.partial_json + else: + logger.warning("Unknown delta type: %s", delta) + elif chunk.type == "content_block_stop": + pass + elif chunk.type == "text": + # full text message + pass + elif chunk.type == "message_delta": + pass + elif chunk.type == "message_stop": + pass + else: + # print(f"Unknown chunk type: {chunk.type}") + # print(chunk) + pass def _handle_files(message_dicts: list[dict]) -> list[dict]: diff --git a/gptme/server/api.py b/gptme/server/api.py index 59ab5cde..0da6b1a9 100644 --- a/gptme/server/api.py +++ b/gptme/server/api.py @@ -21,9 +21,9 @@ from ..commands import execute_cmd from ..dirs import get_logs_dir from ..llm import _stream +from ..llm.models import get_model from ..logmanager import LogManager, get_user_conversations, prepare_messages from ..message import Message -from ..llm.models import get_model from ..tools import execute_msg from ..tools.base import ToolUse @@ -116,7 +116,7 @@ def api_conversation_generate(logfile: str): # Non-streaming response try: # Get complete response - output = "".join(_stream(msgs, model)) + output = "".join(_stream(msgs, model, tools=None)) # Store the message msg = Message("assistant", output) @@ -170,7 +170,9 @@ def generate() -> Generator[str, None, None]: # Stream tokens from the model logger.debug(f"Starting token stream with model {model}") - for char in (char for chunk in _stream(msgs, model) for char in chunk): + for char in ( + char for chunk in _stream(msgs, model, tools=None) for char in chunk + ): output += char # Send each token as a JSON event yield f"data: {flask.json.dumps({'role': 'assistant', 'content': char, 'stored': False})}\n\n" diff --git a/gptme/tools/python.py b/gptme/tools/python.py index 46b21656..6afdb23f 100644 --- a/gptme/tools/python.py +++ b/gptme/tools/python.py @@ -10,14 +10,16 @@ import re from collections.abc import Callable, Generator from logging import getLogger -from typing import ( - TYPE_CHECKING, - TypeVar, -) +from typing import TYPE_CHECKING, TypeVar from ..message import Message from ..util import print_preview -from .base import ConfirmFunc, Parameter, ToolSpec, ToolUse +from .base import ( + ConfirmFunc, + Parameter, + ToolSpec, + ToolUse, +) if TYPE_CHECKING: from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip @@ -166,13 +168,13 @@ def examples(tool_format): #### It can write an example and then execute it: > User: compute fib 10 > Assistant: To compute the 10th Fibonacci number, we can execute this code: -{ToolUse("ipython", [], """ +{ToolUse("ipython", [], ''' def fib(n): if n <= 1: return n return fib(n - 1) + fib(n - 2) fib(10) -""".strip()).to_output(tool_format)} +'''.strip()).to_output(tool_format)} > System: Executed code block. {ToolUse("result", [], "55").to_output()} """.strip() diff --git a/gptme/tools/save.py b/gptme/tools/save.py index 2e9e6633..b68990cb 100644 --- a/gptme/tools/save.py +++ b/gptme/tools/save.py @@ -29,7 +29,7 @@ instructions_format_append = { "markdown": """ -Use a code block with the language tag: `append ` +Use a code block with the language tag: `append ` to append the code block content to the file at the given path.""".strip(), } diff --git a/gptme/tools/shell.py b/gptme/tools/shell.py index defd5260..30c14170 100644 --- a/gptme/tools/shell.py +++ b/gptme/tools/shell.py @@ -43,7 +43,7 @@ instructions = f""" -The given command will be executed in a stateful bash shell. +The given command will be executed in a stateful bash shell. The shell tool will respond with the output of the execution. Do not use EOF/HereDoc syntax to send multiline commands, as the assistant will not be able to handle it. diff --git a/tests/test_cli.py b/tests/test_cli.py index 7605996b..4c6ce408 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -378,8 +378,9 @@ def test_tool_format_option( args.append(tool_format) args.append("test") - with patch("gptme.chat.reply", return_value=[]) as mock_reply: + with patch("gptme.llm.reply", return_value=[]) as mock_reply: result = runner.invoke(gptme.cli.main, args) + assert result.exit_code == 0 mock_reply.assert_called_once()