Skip to content

Commit

Permalink
fix: migrated to latest anthropic version v0.42 (prompt caching now s…
Browse files Browse the repository at this point in the history
…table) (#352)
  • Loading branch information
ErikBjare authored Dec 18, 2024
1 parent b4d0c41 commit 0ecf045
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 32 deletions.
43 changes: 17 additions & 26 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
if TYPE_CHECKING:
# noreorder
import anthropic.types # fmt: skip
import anthropic.types.beta.prompt_caching # fmt: skip
from anthropic import Anthropic # fmt: skip
from anthropic.types.beta.prompt_caching import PromptCachingBetaTextBlockParam

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +53,7 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
messages, tools
)

response = _anthropic.beta.prompt_caching.messages.create(
response = _anthropic.messages.create(
model=model,
messages=messages_dicts,
system=system_messages,
Expand All @@ -75,15 +73,14 @@ def stream(
messages: list[Message], model: str, tools: list[ToolSpec] | None
) -> Generator[str, None, None]:
import anthropic.types # fmt: skip
import anthropic.types.beta.prompt_caching # fmt: skip
from anthropic import NOT_GIVEN # fmt: skip

assert _anthropic, "LLM not initialized"
messages_dicts, system_messages, tools_dict = _prepare_messages_for_api(
messages, tools
)

with _anthropic.beta.prompt_caching.messages.stream(
with _anthropic.messages.stream(
model=model,
messages=messages_dicts,
system=system_messages,
Expand Down Expand Up @@ -121,7 +118,7 @@ def stream(
pass
case "message_start":
chunk = cast(
anthropic.types.beta.prompt_caching.RawPromptCachingBetaMessageStartEvent,
anthropic.types.MessageStartEvent,
chunk,
)
logger.debug(chunk.message.usage)
Expand Down Expand Up @@ -200,7 +197,7 @@ def _process_file(message_dict: dict) -> dict:

def _transform_system_messages(
messages: list[Message],
) -> tuple[list[Message], list["PromptCachingBetaTextBlockParam"]]:
) -> tuple[list[Message], list["anthropic.types.TextBlockParam"]]:
"""Transform system messages into Anthropic's expected format.
This function:
Expand All @@ -215,7 +212,7 @@ def _transform_system_messages(
2. Earlier messages in multi-turn conversations
Returns:
tuple[list[Message], list[PromptCachingBetaTextBlockParam]]: Transformed messages and system messages
tuple[list[Message], list[TextBlockParam]]: Transformed messages and system messages
"""
assert messages[0].role == "system"
system_prompt = messages[0].content
Expand Down Expand Up @@ -244,7 +241,7 @@ def _transform_system_messages(
else:
messages_new.append(message)
messages = messages_new
system_messages: list[PromptCachingBetaTextBlockParam] = [
system_messages: list[anthropic.types.TextBlockParam] = [
{
"type": "text",
"text": system_prompt,
Expand Down Expand Up @@ -273,7 +270,7 @@ def parameters2dict(parameters: list[Parameter]) -> dict[str, object]:

def _spec2tool(
spec: ToolSpec,
) -> "anthropic.types.beta.prompt_caching.PromptCachingBetaToolParam":
) -> "anthropic.types.ToolParam":
name = spec.name
if spec.block_types:
name = spec.block_types[0]
Expand All @@ -289,9 +286,9 @@ def _spec2tool(
def _prepare_messages_for_api(
messages: list[Message], tools: list[ToolSpec] | None
) -> tuple[
list["anthropic.types.beta.prompt_caching.PromptCachingBetaMessageParam"],
list["anthropic.types.beta.prompt_caching.PromptCachingBetaTextBlockParam"],
list["anthropic.types.beta.prompt_caching.PromptCachingBetaToolParam"] | None,
list["anthropic.types.MessageParam"],
list["anthropic.types.TextBlockParam"],
list["anthropic.types.ToolParam"] | None,
]:
"""Prepare messages for the Anthropic API.
Expand All @@ -312,13 +309,7 @@ def _prepare_messages_for_api(
- Tool dictionaries (if tools provided)
"""
# noreorder
from anthropic.types.beta.prompt_caching import ( # fmt: skip
PromptCachingBetaImageBlockParam,
PromptCachingBetaMessageParam,
PromptCachingBetaTextBlockParam,
PromptCachingBetaToolResultBlockParam,
PromptCachingBetaToolUseBlockParam,
)
import anthropic.types # fmt: skip

# Transform system messages
messages, system_messages = _transform_system_messages(messages)
Expand All @@ -327,13 +318,13 @@ def _prepare_messages_for_api(
messages_dicts = _handle_files(msgs2dicts(messages))

# Apply cache control to optimize performance
messages_dicts_new: list[PromptCachingBetaMessageParam] = []
messages_dicts_new: list[anthropic.types.MessageParam] = []
for msg in messages_dicts:
content_parts: list[
PromptCachingBetaTextBlockParam
| PromptCachingBetaImageBlockParam
| PromptCachingBetaToolUseBlockParam
| PromptCachingBetaToolResultBlockParam
anthropic.types.TextBlockParam
| anthropic.types.ImageBlockParam
| anthropic.types.ToolUseBlockParam
| anthropic.types.ToolResultBlockParam
] = []
raw_content = (
msg["content"]
Expand All @@ -359,7 +350,7 @@ def _prepare_messages_for_api(
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
for msgp in [msg for msg in messages_dicts_new if msg["role"] == "user"][-2:]:
assert isinstance(msgp["content"], list)
msgp["content"][-1]["cache_control"] = {"type": "ephemeral"}
msgp["content"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore

# Prepare tools
tools_dict = [_spec2tool(tool) for tool in tools] if tools else None
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ json-repair = "^0.30.2"

# providers
openai = "^1.0"
anthropic = "^0.40"
anthropic = "^0.42"

# tools
ipython = "^8.17.2"
Expand Down

0 comments on commit 0ecf045

Please sign in to comment.