Skip to content

Commit

Permalink
openai[patch]: add usage metadata details (#27080)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 3, 2024
1 parent 546dc44 commit c09da53
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 16 deletions.
55 changes: 40 additions & 15 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.ai import (
InputTokenDetails,
OutputTokenDetails,
UsageMetadata,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.openai_tools import (
Expand Down Expand Up @@ -286,16 +290,10 @@ def _convert_chunk_to_generation_chunk(
) -> Optional[ChatGenerationChunk]:
token_usage = chunk.get("usage")
choices = chunk.get("choices", [])

usage_metadata: Optional[UsageMetadata] = (
UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
if token_usage
else None
_create_usage_metadata(token_usage) if token_usage else None
)

if len(choices) == 0:
# logprobs is implicitly None
generation_chunk = ChatGenerationChunk(
Expand Down Expand Up @@ -721,15 +719,11 @@ def _create_chat_result(
if response_dict.get("error"):
raise ValueError(response_dict.get("error"))

token_usage = response_dict.get("usage", {})
token_usage = response_dict.get("usage")
for res in response_dict["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
message.usage_metadata = _create_usage_metadata(token_usage)
generation_info = generation_info or {}
generation_info["finish_reason"] = (
res.get("finish_reason")
Expand Down Expand Up @@ -2160,3 +2154,34 @@ class OpenAIRefusalError(Exception):
.. versionadded:: 0.1.21
"""


def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
input_tokens = oai_token_usage.get("prompt_tokens", 0)
output_tokens = oai_token_usage.get("completion_tokens", 0)
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
input_token_details: dict = {
"audio": oai_token_usage.get("prompt_tokens_details", {}).get("audio_tokens"),
"cache_read": oai_token_usage.get("prompt_tokens_details", {}).get(
"cached_tokens"
),
}
output_token_details: dict = {
"audio": oai_token_usage.get("completion_tokens_details", {}).get(
"audio_tokens"
),
"reasoning": oai_token_usage.get("completion_tokens_details", {}).get(
"reasoning_tokens"
),
}
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
input_token_details=InputTokenDetails(
**{k: v for k, v in input_token_details.items() if v is not None}
),
output_token_details=OutputTokenDetails(
**{k: v for k, v in output_token_details.items() if v is not None}
),
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Standard LangChain interface tests"""

from typing import Type
from pathlib import Path
from typing import List, Literal, Type, cast

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests

from langchain_openai import ChatOpenAI

REPO_ROOT_DIR = Path(__file__).parents[6]


class TestOpenAIStandard(ChatModelIntegrationTests):
@property
Expand All @@ -20,3 +24,48 @@ def chat_model_params(self) -> dict:
@property
def supports_image_inputs(self) -> bool:
return True

@property
def supported_usage_metadata_details(
self,
) -> List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
]:
return ["reasoning_output", "cache_read_input"]

def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
with open(REPO_ROOT_DIR / "README.md", "r") as f:
readme = f.read()

input_ = f"""What's langchain? Here's the langchain README:
{readme}
"""
llm = ChatOpenAI(model="gpt-4o-mini", stream_usage=True)
_invoke(llm, input_, stream)
# invoke twice so first invocation is cached
return _invoke(llm, input_, stream)

def invoke_with_reasoning_output(self, *, stream: bool = False) -> AIMessage:
llm = ChatOpenAI(model="o1-mini", stream_usage=True, temperature=1)
input_ = (
"explain the relationship between the 2008/9 economic crisis and the "
"startup ecosystem in the early 2010s"
)
return _invoke(llm, input_, stream)


def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
if stream:
full = None
for chunk in llm.stream(input_):
full = full + chunk if full else chunk # type: ignore[operator]
return cast(AIMessage, full)
else:
return cast(AIMessage, llm.invoke(input_))

0 comments on commit c09da53

Please sign in to comment.