Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(anthropic): improve prompt caching and type safety #317

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 113 additions & 39 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
cast,
)

from typing_extensions import Required

from ..constants import TEMPERATURE, TOP_P
from ..message import Message, len_tokens, msgs2dicts
from ..message import Message, msgs2dicts
from ..tools.base import Parameter, ToolSpec

if TYPE_CHECKING:
# noreorder
import anthropic # fmt: skip
import anthropic.types # fmt: skip
import anthropic.types.beta.prompt_caching # fmt: skip
from anthropic import Anthropic # fmt: skip
from anthropic.types.beta.prompt_caching import (
PromptCachingBetaTextBlockParam,
)

logger = logging.getLogger(__name__)

Expand All @@ -44,33 +45,22 @@ def get_client() -> "Anthropic | None":
return _anthropic


class ToolAnthropic(TypedDict):
name: str
description: str
input_schema: dict


class MessagePart(TypedDict, total=False):
type: Required[Literal["text", "image_url"]]
text: str
image_url: str
cache_control: dict[str, str]
class CacheControl(TypedDict):
type: Literal["ephemeral"]


def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> str:
from anthropic import NOT_GIVEN # fmt: skip

assert _anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)

messages_dicts = _handle_files(msgs2dicts(messages))

tools_dict = [_spec2tool(tool) for tool in tools] if tools else None
messages_dicts, system_messages, tools_dict = _prepare_messages_for_api(
messages, tools
)

response = _anthropic.beta.prompt_caching.messages.create(
model=model,
messages=messages_dicts, # type: ignore
system=system_messages, # type: ignore
messages=messages_dicts,
system=system_messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=4096,
Expand All @@ -89,16 +79,14 @@ def stream(
from anthropic import NOT_GIVEN # fmt: skip

assert _anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)

messages_dicts = _handle_files(msgs2dicts(messages))

tools_dict = [_spec2tool(tool) for tool in tools] if tools else None
messages_dicts, system_messages, tools_dict = _prepare_messages_for_api(
messages, tools
)

with _anthropic.beta.prompt_caching.messages.stream(
model=model,
messages=messages_dicts, # type: ignore
system=system_messages, # type: ignore
messages=messages_dicts,
system=system_messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=4096,
Expand Down Expand Up @@ -210,9 +198,23 @@ def _process_file(message_dict: dict) -> dict:

def _transform_system_messages(
messages: list[Message],
) -> tuple[list[Message], list[MessagePart]]:
# transform system messages into system kwarg for anthropic
# for first system message, transform it into a system kwarg
) -> tuple[list[Message], list["PromptCachingBetaTextBlockParam"]]:
"""Transform system messages into Anthropic's expected format.

This function:
1. Extracts the first system message as the main system prompt
2. Transforms subsequent system messages into <system> tags in user messages
3. Merges consecutive user messages
4. Applies cache control to optimize performance

Note: Anthropic allows up to 4 cache breakpoints in a conversation.
We use this to cache:
1. The system prompt (if long enough)
2. Earlier messages in multi-turn conversations

Returns:
tuple[list[Message], list[PromptCachingBetaTextBlockParam]]: Transformed messages and system messages
"""
assert messages[0].role == "system"
system_prompt = messages[0].content
messages = messages.copy()
Expand Down Expand Up @@ -240,19 +242,13 @@ def _transform_system_messages(
else:
messages_new.append(message)
messages = messages_new
system_messages: list[MessagePart] = [
system_messages: list[PromptCachingBetaTextBlockParam] = [
{
"type": "text",
"text": system_prompt,
}
]

# prompt caching for the system prompt, saving cost and reducing latency
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
# if system_messages is long (>2048 tokens), we add cache_control
if len_tokens(system_prompt) > 2048 + 500: # margin for tokenizer diff
system_messages[-1]["cache_control"] = {"type": "ephemeral"}

return messages, system_messages


Expand Down Expand Up @@ -286,3 +282,81 @@ def _spec2tool(
"description": spec.get_instructions("tool"),
"input_schema": parameters2dict(spec.parameters),
}


def _prepare_messages_for_api(
messages: list[Message], tools: list[ToolSpec] | None
) -> tuple[
list["anthropic.types.beta.prompt_caching.PromptCachingBetaMessageParam"],
list["anthropic.types.beta.prompt_caching.PromptCachingBetaTextBlockParam"],
list["anthropic.types.beta.prompt_caching.PromptCachingBetaToolParam"] | None,
]:
"""Prepare messages for the Anthropic API.

This function:
1. Transforms system messages
2. Handles file attachments
3. Applies cache control
4. Prepares tools

Args:
messages: List of messages to prepare
tools: List of tool specifications

Returns:
tuple containing:
- Prepared message dictionaries
- System messages
- Tool dictionaries (if tools provided)
"""
# noreorder
from anthropic.types.beta.prompt_caching import ( # fmt: skip
PromptCachingBetaImageBlockParam,
PromptCachingBetaMessageParam,
PromptCachingBetaTextBlockParam,
PromptCachingBetaToolResultBlockParam,
PromptCachingBetaToolUseBlockParam,
)

# Transform system messages
messages, system_messages = _transform_system_messages(messages)

# Handle files and convert to dicts
messages_dicts = _handle_files(msgs2dicts(messages))

# Apply cache control to optimize performance
messages_dicts_new: list[PromptCachingBetaMessageParam] = []
for i, msg in enumerate(messages_dicts):
content_parts: list[
PromptCachingBetaTextBlockParam
| PromptCachingBetaImageBlockParam
| PromptCachingBetaToolUseBlockParam
| PromptCachingBetaToolResultBlockParam
] = []
raw_content = (
msg["content"]
if isinstance(msg["content"], list)
else [{"type": "text", "text": msg["content"]}]
)

for part in raw_content:
if isinstance(part, dict):
if part.get("type") == "text" and i == len(messages_dicts) - 1:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to review this later to cache at the optimal locations, especially considering RAG and "live" messages. Should probably also add a cache step at the system prompt.

content_parts.append(
{
"type": "text",
"text": part["text"],
"cache_control": {"type": "ephemeral"},
}
)
else:
content_parts.append(part) # type: ignore
else:
content_parts.append({"type": "text", "text": str(part)})

messages_dicts_new.append({"role": msg["role"], "content": content_parts})

# Prepare tools
tools_dict = [_spec2tool(tool) for tool in tools] if tools else None

return messages_dicts_new, system_messages, tools_dict
Loading