Skip to content

Commit

Permalink
community[patch]: gather token usage info in BedrockChat during gener…
Browse files Browse the repository at this point in the history
…ation (#19127)

This PR allows to calculate token usage for prompts and completion
directly in the generation method of BedrockChat. The token usage
details are then returned together with the generations, so that other
downstream tasks can access them easily.

This allows to define a callback for tokens tracking and cost
calculation, similarly to what happens with OpenAI (see
[OpenAICallbackHandler](https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/openai_info.html#OpenAICallbackHandler).
I plan on adding a BedrockCallbackHandler later.
Right now keeping track of tokens in the callback is already possible,
but it requires passing the llm, as done here:
https://how.wtf/how-to-count-amazon-bedrock-anthropic-tokens-with-langchain.html.
However, I find the approach of this PR cleaner.

Thanks for your reviews. FYI @baskaryan, @hwchase17

---------

Co-authored-by: taamedag <[email protected]>
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
3 people authored and hinthornw committed Apr 26, 2024
1 parent 8d7cc00 commit 50cd31b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 15 deletions.
31 changes: 23 additions & 8 deletions libs/community/langchain_community/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from langchain_core.callbacks import (
Expand Down Expand Up @@ -234,10 +235,9 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
provider = self._get_provider()
system = None
formatted_messages = None
prompt, system, formatted_messages = None, None, None

if provider == "anthropic":
prompt = None
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
Expand Down Expand Up @@ -265,17 +265,17 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {"model_id": self.model_id}

if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
provider = self._get_provider()
system = None
formatted_messages = None
prompt, system, formatted_messages = None, None, None
params: Dict[str, Any] = {**kwargs}

if provider == "anthropic":
prompt = None
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
Expand All @@ -287,7 +287,7 @@ def _generate(
if stop:
params["stop_sequences"] = stop

completion = self._prepare_input_and_invoke(
completion, usage_info = self._prepare_input_and_invoke(
prompt=prompt,
stop=stop,
run_manager=run_manager,
Expand All @@ -296,10 +296,25 @@ def _generate(
**params,
)

llm_output["usage"] = usage_info

return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=completion))]
generations=[ChatGeneration(message=AIMessage(content=completion))],
llm_output=llm_output,
)

def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
final_usage: Dict[str, int] = defaultdict(int)
final_output = {}
for output in llm_outputs:
output = output or {}
usage = output.pop("usage", {})
for token_type, token_count in usage.items():
final_usage[token_type] += token_count
final_output.update(output)
final_output["usage"] = final_usage
return final_output

def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
Expand Down
19 changes: 15 additions & 4 deletions libs/community/langchain_community/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Mapping,
Optional,
Tuple,
)

from langchain_core.callbacks import (
Expand Down Expand Up @@ -141,6 +142,7 @@ def prepare_input(

@classmethod
def prepare_output(cls, provider: str, response: Any) -> dict:
text = ""
if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
if "completion" in response_body:
Expand All @@ -162,9 +164,17 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
else:
text = response_body.get("results")[0].get("outputText")

headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
return {
"text": text,
"body": response_body,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}

@classmethod
Expand Down Expand Up @@ -498,7 +508,7 @@ def _prepare_input_and_invoke(
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
) -> Tuple[str, Dict[str, Any]]:
_model_kwargs = self.model_kwargs or {}

provider = self._get_provider()
Expand Down Expand Up @@ -531,7 +541,7 @@ def _prepare_input_and_invoke(
try:
response = self.client.invoke_model(**request_options)

text, body = LLMInputOutputAdapter.prepare_output(
text, body, usage_info = LLMInputOutputAdapter.prepare_output(
provider, response
).values()

Expand All @@ -554,7 +564,7 @@ def _prepare_input_and_invoke(
**services_trace,
)

return text
return text, usage_info

def _get_bedrock_services_signal(self, body: dict) -> dict:
"""
Expand Down Expand Up @@ -824,9 +834,10 @@ def _call(
completion += chunk.text
return completion

return self._prepare_input_and_invoke(
text, _ = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
return text

async def _astream(
self,
Expand Down
30 changes: 27 additions & 3 deletions libs/community/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Test Bedrock chat model."""
from typing import Any
from typing import Any, cast

import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult

from langchain_community.chat_models import BedrockChat
Expand Down Expand Up @@ -39,6 +44,20 @@ def test_chat_bedrock_generate(chat: BedrockChat) -> None:
assert generation.text == generation.message.content


@pytest.mark.scheduled
def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None:
"""Test BedrockChat wrapper with generate."""
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert isinstance(response.llm_output, dict)

usage = response.llm_output["usage"]
assert usage["prompt_tokens"] == 20
assert usage["completion_tokens"] > 0
assert usage["total_tokens"] > 0


@pytest.mark.scheduled
def test_chat_bedrock_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
Expand Down Expand Up @@ -80,15 +99,18 @@ def on_llm_end(
list(chat.stream("hi"))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned
assert generation.generations[0][0].text == " Hello!"
assert generation.generations[0][0].text == "Hello!"


@pytest.mark.scheduled
def test_bedrock_streaming(chat: BedrockChat) -> None:
"""Test streaming tokens from OpenAI."""

full = None
for token in chat.stream("I'm Pickle Rick"):
full = token if full is None else full + token
assert isinstance(token.content, str)
assert isinstance(cast(AIMessageChunk, full).content, str)


@pytest.mark.scheduled
Expand Down Expand Up @@ -137,3 +159,5 @@ def test_bedrock_invoke(chat: BedrockChat) -> None:
"""Test invoke tokens from BedrockChat."""
result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
assert all([k in result.response_metadata for k in ("usage", "model_id")])
assert result.response_metadata["usage"]["prompt_tokens"] == 13

0 comments on commit 50cd31b

Please sign in to comment.