Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: migrated to latest anthropic version v0.42 (prompt caching now stable) #352

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 17 additions & 26 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
if TYPE_CHECKING:
# noreorder
import anthropic.types # fmt: skip
import anthropic.types.beta.prompt_caching # fmt: skip
from anthropic import Anthropic # fmt: skip
from anthropic.types.beta.prompt_caching import PromptCachingBetaTextBlockParam

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +53,7 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
messages, tools
)

response = _anthropic.beta.prompt_caching.messages.create(
response = _anthropic.messages.create(
model=model,
messages=messages_dicts,
system=system_messages,
Expand All @@ -75,15 +73,14 @@ def stream(
messages: list[Message], model: str, tools: list[ToolSpec] | None
) -> Generator[str, None, None]:
import anthropic.types # fmt: skip
import anthropic.types.beta.prompt_caching # fmt: skip
from anthropic import NOT_GIVEN # fmt: skip

assert _anthropic, "LLM not initialized"
messages_dicts, system_messages, tools_dict = _prepare_messages_for_api(
messages, tools
)

with _anthropic.beta.prompt_caching.messages.stream(
with _anthropic.messages.stream(
model=model,
messages=messages_dicts,
system=system_messages,
Expand Down Expand Up @@ -121,7 +118,7 @@ def stream(
pass
case "message_start":
chunk = cast(
anthropic.types.beta.prompt_caching.RawPromptCachingBetaMessageStartEvent,
anthropic.types.MessageStartEvent,
chunk,
)
logger.debug(chunk.message.usage)
Expand Down Expand Up @@ -200,7 +197,7 @@ def _process_file(message_dict: dict) -> dict:

def _transform_system_messages(
messages: list[Message],
) -> tuple[list[Message], list["PromptCachingBetaTextBlockParam"]]:
) -> tuple[list[Message], list["anthropic.types.TextBlockParam"]]:
"""Transform system messages into Anthropic's expected format.

This function:
Expand All @@ -215,7 +212,7 @@ def _transform_system_messages(
2. Earlier messages in multi-turn conversations

Returns:
tuple[list[Message], list[PromptCachingBetaTextBlockParam]]: Transformed messages and system messages
tuple[list[Message], list[TextBlockParam]]: Transformed messages and system messages
"""
assert messages[0].role == "system"
system_prompt = messages[0].content
Expand Down Expand Up @@ -244,7 +241,7 @@ def _transform_system_messages(
else:
messages_new.append(message)
messages = messages_new
system_messages: list[PromptCachingBetaTextBlockParam] = [
system_messages: list[anthropic.types.TextBlockParam] = [
{
"type": "text",
"text": system_prompt,
Expand Down Expand Up @@ -273,7 +270,7 @@ def parameters2dict(parameters: list[Parameter]) -> dict[str, object]:

def _spec2tool(
spec: ToolSpec,
) -> "anthropic.types.beta.prompt_caching.PromptCachingBetaToolParam":
) -> "anthropic.types.ToolParam":
name = spec.name
if spec.block_types:
name = spec.block_types[0]
Expand All @@ -289,9 +286,9 @@ def _spec2tool(
def _prepare_messages_for_api(
messages: list[Message], tools: list[ToolSpec] | None
) -> tuple[
list["anthropic.types.beta.prompt_caching.PromptCachingBetaMessageParam"],
list["anthropic.types.beta.prompt_caching.PromptCachingBetaTextBlockParam"],
list["anthropic.types.beta.prompt_caching.PromptCachingBetaToolParam"] | None,
list["anthropic.types.MessageParam"],
list["anthropic.types.TextBlockParam"],
list["anthropic.types.ToolParam"] | None,
]:
"""Prepare messages for the Anthropic API.

Expand All @@ -312,13 +309,7 @@ def _prepare_messages_for_api(
- Tool dictionaries (if tools provided)
"""
# noreorder
from anthropic.types.beta.prompt_caching import ( # fmt: skip
PromptCachingBetaImageBlockParam,
PromptCachingBetaMessageParam,
PromptCachingBetaTextBlockParam,
PromptCachingBetaToolResultBlockParam,
PromptCachingBetaToolUseBlockParam,
)
import anthropic.types # fmt: skip

# Transform system messages
messages, system_messages = _transform_system_messages(messages)
Expand All @@ -327,13 +318,13 @@ def _prepare_messages_for_api(
messages_dicts = _handle_files(msgs2dicts(messages))

# Apply cache control to optimize performance
messages_dicts_new: list[PromptCachingBetaMessageParam] = []
messages_dicts_new: list[anthropic.types.MessageParam] = []
for msg in messages_dicts:
content_parts: list[
PromptCachingBetaTextBlockParam
| PromptCachingBetaImageBlockParam
| PromptCachingBetaToolUseBlockParam
| PromptCachingBetaToolResultBlockParam
anthropic.types.TextBlockParam
| anthropic.types.ImageBlockParam
| anthropic.types.ToolUseBlockParam
| anthropic.types.ToolResultBlockParam
] = []
raw_content = (
msg["content"]
Expand All @@ -359,7 +350,7 @@ def _prepare_messages_for_api(
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
for msgp in [msg for msg in messages_dicts_new if msg["role"] == "user"][-2:]:
assert isinstance(msgp["content"], list)
msgp["content"][-1]["cache_control"] = {"type": "ephemeral"}
msgp["content"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore

# Prepare tools
tools_dict = [_spec2tool(tool) for tool in tools] if tools else None
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ json-repair = "^0.30.2"

# providers
openai = "^1.0"
anthropic = "^0.40"
anthropic = "^0.42"

# tools
ipython = "^8.17.2"
Expand Down
Loading