Skip to content

Commit

Permalink
anthropic[patch]: fix input_tokens when cached (#27125)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 4, 2024
1 parent 64a16f2 commit 0b8416b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
7 changes: 6 additions & 1 deletion libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,12 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
}

input_tokens = getattr(anthropic_usage, "input_tokens", 0)
# Anthropic input_tokens exclude cached token counts.
input_tokens = (
getattr(anthropic_usage, "input_tokens", 0)
+ (input_token_details["cache_read"] or 0)
+ (input_token_details["cache_creation"] or 0)
)
output_tokens = getattr(anthropic_usage, "output_tokens", 0)
return UsageMetadata(
input_tokens=input_tokens,
Expand Down
4 changes: 2 additions & 2 deletions libs/partners/anthropic/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def test__format_output_cached() -> None:
expected = AIMessage( # type: ignore[misc]
"bar",
usage_metadata={
"input_tokens": 2,
"input_tokens": 9,
"output_tokens": 1,
"total_tokens": 3,
"total_tokens": 10,
"input_token_details": {"cache_creation": 3, "cache_read": 4},
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,28 +153,58 @@ def test_usage_metadata(self, model: BaseChatModel) -> None:

if "audio_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_input()
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index]
assert msg.usage_metadata is not None
assert msg.usage_metadata["input_token_details"] is not None
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int)
assert msg.usage_metadata["input_tokens"] >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["input_token_details"].values()
)
if "audio_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_output()
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int) # type: ignore[index]
assert msg.usage_metadata is not None
assert msg.usage_metadata["output_token_details"] is not None
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int)
assert int(msg.usage_metadata["output_tokens"]) >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["output_token_details"].values()
)
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_reasoning_output()
assert msg.usage_metadata is not None
assert msg.usage_metadata["output_token_details"] is not None
assert isinstance(
msg.usage_metadata["output_token_details"]["reasoning"], # type: ignore[index]
msg.usage_metadata["output_token_details"]["reasoning"],
int,
)
assert msg.usage_metadata["output_tokens"] >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["output_token_details"].values()
)
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_read_input()
assert msg.usage_metadata is not None
assert msg.usage_metadata["input_token_details"] is not None
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_read"], # type: ignore[index]
msg.usage_metadata["input_token_details"]["cache_read"],
int,
)
assert msg.usage_metadata["input_tokens"] >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["input_token_details"].values()
)
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_creation_input()
assert msg.usage_metadata is not None
assert msg.usage_metadata["input_token_details"] is not None
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_creation"], # type: ignore[index]
msg.usage_metadata["input_token_details"]["cache_creation"],
int,
)
assert msg.usage_metadata["input_tokens"] >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["input_token_details"].values()
)

def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
if not self.returns_usage_metadata:
Expand Down

0 comments on commit 0b8416b

Please sign in to comment.