Skip to content

Commit

Permalink
feat(anthropic): improve prompt caching and type safety (#317)
Browse files Browse the repository at this point in the history
- Add better type hints for Anthropic API types
- Optimize prompt caching with ephemeral cache control
- Extract message preparation into dedicated function
- Improve documentation
  • Loading branch information
ErikBjare authored Dec 10, 2024
1 parent c0eb21f commit a2f06df
Showing 1 changed file with 113 additions and 39 deletions.
152 changes: 113 additions & 39 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
cast,
)

from typing_extensions import Required

from ..constants import TEMPERATURE, TOP_P
from ..message import Message, len_tokens, msgs2dicts
from ..message import Message, msgs2dicts
from ..tools.base import Parameter, ToolSpec

if TYPE_CHECKING:
# noreorder
import anthropic # fmt: skip
import anthropic.types # fmt: skip
import anthropic.types.beta.prompt_caching # fmt: skip
from anthropic import Anthropic # fmt: skip
from anthropic.types.beta.prompt_caching import (
PromptCachingBetaTextBlockParam,
)

logger = logging.getLogger(__name__)

Expand All @@ -44,33 +45,22 @@ def get_client() -> "Anthropic | None":
return _anthropic


class ToolAnthropic(TypedDict):
name: str
description: str
input_schema: dict


class MessagePart(TypedDict, total=False):
type: Required[Literal["text", "image_url"]]
text: str
image_url: str
cache_control: dict[str, str]
class CacheControl(TypedDict):
type: Literal["ephemeral"]


def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> str:
from anthropic import NOT_GIVEN # fmt: skip

assert _anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)

messages_dicts = _handle_files(msgs2dicts(messages))

tools_dict = [_spec2tool(tool) for tool in tools] if tools else None
messages_dicts, system_messages, tools_dict = _prepare_messages_for_api(
messages, tools
)

response = _anthropic.beta.prompt_caching.messages.create(
model=model,
messages=messages_dicts, # type: ignore
system=system_messages, # type: ignore
messages=messages_dicts,
system=system_messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=4096,
Expand All @@ -89,16 +79,14 @@ def stream(
from anthropic import NOT_GIVEN # fmt: skip

assert _anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)

messages_dicts = _handle_files(msgs2dicts(messages))

tools_dict = [_spec2tool(tool) for tool in tools] if tools else None
messages_dicts, system_messages, tools_dict = _prepare_messages_for_api(
messages, tools
)

with _anthropic.beta.prompt_caching.messages.stream(
model=model,
messages=messages_dicts, # type: ignore
system=system_messages, # type: ignore
messages=messages_dicts,
system=system_messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=4096,
Expand Down Expand Up @@ -210,9 +198,23 @@ def _process_file(message_dict: dict) -> dict:

def _transform_system_messages(
messages: list[Message],
) -> tuple[list[Message], list[MessagePart]]:
# transform system messages into system kwarg for anthropic
# for first system message, transform it into a system kwarg
) -> tuple[list[Message], list["PromptCachingBetaTextBlockParam"]]:
"""Transform system messages into Anthropic's expected format.
This function:
1. Extracts the first system message as the main system prompt
2. Transforms subsequent system messages into <system> tags in user messages
3. Merges consecutive user messages
4. Applies cache control to optimize performance
Note: Anthropic allows up to 4 cache breakpoints in a conversation.
We use this to cache:
1. The system prompt (if long enough)
2. Earlier messages in multi-turn conversations
Returns:
tuple[list[Message], list[PromptCachingBetaTextBlockParam]]: Transformed messages and system messages
"""
assert messages[0].role == "system"
system_prompt = messages[0].content
messages = messages.copy()
Expand Down Expand Up @@ -240,19 +242,13 @@ def _transform_system_messages(
else:
messages_new.append(message)
messages = messages_new
system_messages: list[MessagePart] = [
system_messages: list[PromptCachingBetaTextBlockParam] = [
{
"type": "text",
"text": system_prompt,
}
]

# prompt caching for the system prompt, saving cost and reducing latency
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
# if system_messages is long (>2048 tokens), we add cache_control
if len_tokens(system_prompt) > 2048 + 500: # margin for tokenizer diff
system_messages[-1]["cache_control"] = {"type": "ephemeral"}

return messages, system_messages


Expand Down Expand Up @@ -286,3 +282,81 @@ def _spec2tool(
"description": spec.get_instructions("tool"),
"input_schema": parameters2dict(spec.parameters),
}


def _prepare_messages_for_api(
messages: list[Message], tools: list[ToolSpec] | None
) -> tuple[
list["anthropic.types.beta.prompt_caching.PromptCachingBetaMessageParam"],
list["anthropic.types.beta.prompt_caching.PromptCachingBetaTextBlockParam"],
list["anthropic.types.beta.prompt_caching.PromptCachingBetaToolParam"] | None,
]:
"""Prepare messages for the Anthropic API.
This function:
1. Transforms system messages
2. Handles file attachments
3. Applies cache control
4. Prepares tools
Args:
messages: List of messages to prepare
tools: List of tool specifications
Returns:
tuple containing:
- Prepared message dictionaries
- System messages
- Tool dictionaries (if tools provided)
"""
# noreorder
from anthropic.types.beta.prompt_caching import ( # fmt: skip
PromptCachingBetaImageBlockParam,
PromptCachingBetaMessageParam,
PromptCachingBetaTextBlockParam,
PromptCachingBetaToolResultBlockParam,
PromptCachingBetaToolUseBlockParam,
)

# Transform system messages
messages, system_messages = _transform_system_messages(messages)

# Handle files and convert to dicts
messages_dicts = _handle_files(msgs2dicts(messages))

# Apply cache control to optimize performance
messages_dicts_new: list[PromptCachingBetaMessageParam] = []
for i, msg in enumerate(messages_dicts):
content_parts: list[
PromptCachingBetaTextBlockParam
| PromptCachingBetaImageBlockParam
| PromptCachingBetaToolUseBlockParam
| PromptCachingBetaToolResultBlockParam
] = []
raw_content = (
msg["content"]
if isinstance(msg["content"], list)
else [{"type": "text", "text": msg["content"]}]
)

for part in raw_content:
if isinstance(part, dict):
if part.get("type") == "text" and i == len(messages_dicts) - 1:
content_parts.append(
{
"type": "text",
"text": part["text"],
"cache_control": {"type": "ephemeral"},
}
)
else:
content_parts.append(part) # type: ignore
else:
content_parts.append({"type": "text", "text": str(part)})

messages_dicts_new.append({"role": msg["role"], "content": content_parts})

# Prepare tools
tools_dict = [_spec2tool(tool) for tool in tools] if tools else None

return messages_dicts_new, system_messages, tools_dict

0 comments on commit a2f06df

Please sign in to comment.