diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 21754937d180f..27d4adf06817a 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -63,7 +63,11 @@ ToolMessage, ToolMessageChunk, ) -from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.ai import ( + InputTokenDetails, + OutputTokenDetails, + UsageMetadata, +) from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.openai_tools import ( @@ -286,16 +290,10 @@ def _convert_chunk_to_generation_chunk( ) -> Optional[ChatGenerationChunk]: token_usage = chunk.get("usage") choices = chunk.get("choices", []) + usage_metadata: Optional[UsageMetadata] = ( - UsageMetadata( - input_tokens=token_usage.get("prompt_tokens", 0), - output_tokens=token_usage.get("completion_tokens", 0), - total_tokens=token_usage.get("total_tokens", 0), - ) - if token_usage - else None + _create_usage_metadata(token_usage) if token_usage else None ) - if len(choices) == 0: # logprobs is implicitly None generation_chunk = ChatGenerationChunk( @@ -721,15 +719,11 @@ def _create_chat_result( if response_dict.get("error"): raise ValueError(response_dict.get("error")) - token_usage = response_dict.get("usage", {}) + token_usage = response_dict.get("usage") for res in response_dict["choices"]: message = _convert_dict_to_message(res["message"]) if token_usage and isinstance(message, AIMessage): - message.usage_metadata = { - "input_tokens": token_usage.get("prompt_tokens", 0), - "output_tokens": token_usage.get("completion_tokens", 0), - "total_tokens": token_usage.get("total_tokens", 0), - } + message.usage_metadata = _create_usage_metadata(token_usage) generation_info = generation_info or {} generation_info["finish_reason"] = ( res.get("finish_reason") @@ -2160,3 +2154,34 @@ class OpenAIRefusalError(Exception): .. versionadded:: 0.1.21 """ + + +def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata: + input_tokens = oai_token_usage.get("prompt_tokens", 0) + output_tokens = oai_token_usage.get("completion_tokens", 0) + total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens) + input_token_details: dict = { + "audio": oai_token_usage.get("prompt_tokens_details", {}).get("audio_tokens"), + "cache_read": oai_token_usage.get("prompt_tokens_details", {}).get( + "cached_tokens" + ), + } + output_token_details: dict = { + "audio": oai_token_usage.get("completion_tokens_details", {}).get( + "audio_tokens" + ), + "reasoning": oai_token_usage.get("completion_tokens_details", {}).get( + "reasoning_tokens" + ), + } + return UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + input_token_details=InputTokenDetails( + **{k: v for k, v in input_token_details.items() if v is not None} + ), + output_token_details=OutputTokenDetails( + **{k: v for k, v in output_token_details.items() if v is not None} + ), + ) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py index 22b305c9753d4..b91b590ad7cd9 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py @@ -1,12 +1,16 @@ """Standard LangChain interface tests""" -from typing import Type +from pathlib import Path +from typing import List, Literal, Type, cast from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_openai import ChatOpenAI +REPO_ROOT_DIR = Path(__file__).parents[6] + class TestOpenAIStandard(ChatModelIntegrationTests): @property @@ -20,3 +24,48 @@ def chat_model_params(self) -> dict: @property def supports_image_inputs(self) -> bool: return True + + @property + def supported_usage_metadata_details( + self, + ) -> List[ + Literal[ + "audio_input", + "audio_output", + "reasoning_output", + "cache_read_input", + "cache_creation_input", + ] + ]: + return ["reasoning_output", "cache_read_input"] + + def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage: + with open(REPO_ROOT_DIR / "README.md", "r") as f: + readme = f.read() + + input_ = f"""What's langchain? Here's the langchain README: + + {readme} + """ + llm = ChatOpenAI(model="gpt-4o-mini", stream_usage=True) + _invoke(llm, input_, stream) + # invoke twice so first invocation is cached + return _invoke(llm, input_, stream) + + def invoke_with_reasoning_output(self, *, stream: bool = False) -> AIMessage: + llm = ChatOpenAI(model="o1-mini", stream_usage=True, temperature=1) + input_ = ( + "explain the relationship between the 2008/9 economic crisis and the " + "startup ecosystem in the early 2010s" + ) + return _invoke(llm, input_, stream) + + +def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: + if stream: + full = None + for chunk in llm.stream(input_): + full = full + chunk if full else chunk # type: ignore[operator] + return cast(AIMessage, full) + else: + return cast(AIMessage, llm.invoke(input_))