From a2f06df8c067dbd3440b62e22c9a73d963ca659c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Tue, 10 Dec 2024 21:33:11 +0100 Subject: [PATCH] feat(anthropic): improve prompt caching and type safety (#317) - Add better type hints for Anthropic API types - Optimize prompt caching with ephemeral cache control - Extract message preparation into dedicated function - Improve documentation --- gptme/llm/llm_anthropic.py | 152 +++++++++++++++++++++++++++---------- 1 file changed, 113 insertions(+), 39 deletions(-) diff --git a/gptme/llm/llm_anthropic.py b/gptme/llm/llm_anthropic.py index 3fb22d61..9fb5a63e 100644 --- a/gptme/llm/llm_anthropic.py +++ b/gptme/llm/llm_anthropic.py @@ -10,17 +10,18 @@ cast, ) -from typing_extensions import Required - from ..constants import TEMPERATURE, TOP_P -from ..message import Message, len_tokens, msgs2dicts +from ..message import Message, msgs2dicts from ..tools.base import Parameter, ToolSpec if TYPE_CHECKING: # noreorder - import anthropic # fmt: skip + import anthropic.types # fmt: skip import anthropic.types.beta.prompt_caching # fmt: skip from anthropic import Anthropic # fmt: skip + from anthropic.types.beta.prompt_caching import ( + PromptCachingBetaTextBlockParam, + ) logger = logging.getLogger(__name__) @@ -44,33 +45,22 @@ def get_client() -> "Anthropic | None": return _anthropic -class ToolAnthropic(TypedDict): - name: str - description: str - input_schema: dict - - -class MessagePart(TypedDict, total=False): - type: Required[Literal["text", "image_url"]] - text: str - image_url: str - cache_control: dict[str, str] +class CacheControl(TypedDict): + type: Literal["ephemeral"] def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> str: from anthropic import NOT_GIVEN # fmt: skip assert _anthropic, "LLM not initialized" - messages, system_messages = _transform_system_messages(messages) - - messages_dicts = _handle_files(msgs2dicts(messages)) - - tools_dict = [_spec2tool(tool) for tool in tools] if tools else None + messages_dicts, system_messages, tools_dict = _prepare_messages_for_api( + messages, tools + ) response = _anthropic.beta.prompt_caching.messages.create( model=model, - messages=messages_dicts, # type: ignore - system=system_messages, # type: ignore + messages=messages_dicts, + system=system_messages, temperature=TEMPERATURE, top_p=TOP_P, max_tokens=4096, @@ -89,16 +79,14 @@ def stream( from anthropic import NOT_GIVEN # fmt: skip assert _anthropic, "LLM not initialized" - messages, system_messages = _transform_system_messages(messages) - - messages_dicts = _handle_files(msgs2dicts(messages)) - - tools_dict = [_spec2tool(tool) for tool in tools] if tools else None + messages_dicts, system_messages, tools_dict = _prepare_messages_for_api( + messages, tools + ) with _anthropic.beta.prompt_caching.messages.stream( model=model, - messages=messages_dicts, # type: ignore - system=system_messages, # type: ignore + messages=messages_dicts, + system=system_messages, temperature=TEMPERATURE, top_p=TOP_P, max_tokens=4096, @@ -210,9 +198,23 @@ def _process_file(message_dict: dict) -> dict: def _transform_system_messages( messages: list[Message], -) -> tuple[list[Message], list[MessagePart]]: - # transform system messages into system kwarg for anthropic - # for first system message, transform it into a system kwarg +) -> tuple[list[Message], list["PromptCachingBetaTextBlockParam"]]: + """Transform system messages into Anthropic's expected format. + + This function: + 1. Extracts the first system message as the main system prompt + 2. Transforms subsequent system messages into tags in user messages + 3. Merges consecutive user messages + 4. Applies cache control to optimize performance + + Note: Anthropic allows up to 4 cache breakpoints in a conversation. + We use this to cache: + 1. The system prompt (if long enough) + 2. Earlier messages in multi-turn conversations + + Returns: + tuple[list[Message], list[PromptCachingBetaTextBlockParam]]: Transformed messages and system messages + """ assert messages[0].role == "system" system_prompt = messages[0].content messages = messages.copy() @@ -240,19 +242,13 @@ def _transform_system_messages( else: messages_new.append(message) messages = messages_new - system_messages: list[MessagePart] = [ + system_messages: list[PromptCachingBetaTextBlockParam] = [ { "type": "text", "text": system_prompt, } ] - # prompt caching for the system prompt, saving cost and reducing latency - # https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching - # if system_messages is long (>2048 tokens), we add cache_control - if len_tokens(system_prompt) > 2048 + 500: # margin for tokenizer diff - system_messages[-1]["cache_control"] = {"type": "ephemeral"} - return messages, system_messages @@ -286,3 +282,81 @@ def _spec2tool( "description": spec.get_instructions("tool"), "input_schema": parameters2dict(spec.parameters), } + + +def _prepare_messages_for_api( + messages: list[Message], tools: list[ToolSpec] | None +) -> tuple[ + list["anthropic.types.beta.prompt_caching.PromptCachingBetaMessageParam"], + list["anthropic.types.beta.prompt_caching.PromptCachingBetaTextBlockParam"], + list["anthropic.types.beta.prompt_caching.PromptCachingBetaToolParam"] | None, +]: + """Prepare messages for the Anthropic API. + + This function: + 1. Transforms system messages + 2. Handles file attachments + 3. Applies cache control + 4. Prepares tools + + Args: + messages: List of messages to prepare + tools: List of tool specifications + + Returns: + tuple containing: + - Prepared message dictionaries + - System messages + - Tool dictionaries (if tools provided) + """ + # noreorder + from anthropic.types.beta.prompt_caching import ( # fmt: skip + PromptCachingBetaImageBlockParam, + PromptCachingBetaMessageParam, + PromptCachingBetaTextBlockParam, + PromptCachingBetaToolResultBlockParam, + PromptCachingBetaToolUseBlockParam, + ) + + # Transform system messages + messages, system_messages = _transform_system_messages(messages) + + # Handle files and convert to dicts + messages_dicts = _handle_files(msgs2dicts(messages)) + + # Apply cache control to optimize performance + messages_dicts_new: list[PromptCachingBetaMessageParam] = [] + for i, msg in enumerate(messages_dicts): + content_parts: list[ + PromptCachingBetaTextBlockParam + | PromptCachingBetaImageBlockParam + | PromptCachingBetaToolUseBlockParam + | PromptCachingBetaToolResultBlockParam + ] = [] + raw_content = ( + msg["content"] + if isinstance(msg["content"], list) + else [{"type": "text", "text": msg["content"]}] + ) + + for part in raw_content: + if isinstance(part, dict): + if part.get("type") == "text" and i == len(messages_dicts) - 1: + content_parts.append( + { + "type": "text", + "text": part["text"], + "cache_control": {"type": "ephemeral"}, + } + ) + else: + content_parts.append(part) # type: ignore + else: + content_parts.append({"type": "text", "text": str(part)}) + + messages_dicts_new.append({"role": msg["role"], "content": content_parts}) + + # Prepare tools + tools_dict = [_spec2tool(tool) for tool in tools] if tools else None + + return messages_dicts_new, system_messages, tools_dict