diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 3db8b6fa35..a168301f75 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -74,7 +74,12 @@ RAW_RESPONSE_HEADER, ) from ._streaming import Stream, AsyncStream -from ._exceptions import APIStatusError, APITimeoutError, APIConnectionError +from ._exceptions import ( + APIStatusError, + APITimeoutError, + APIConnectionError, + APIResponseValidationError, +) log: logging.Logger = logging.getLogger(__name__) @@ -518,13 +523,16 @@ def _process_response_data( if cast_to is UnknownResponse: return cast(ResponseT, data) - if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol): - return cast(ResponseT, cast_to.build(response=response, data=data)) + try: + if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol): + return cast(ResponseT, cast_to.build(response=response, data=data)) - if self._strict_response_validation: - return cast(ResponseT, validate_type(type_=cast_to, value=data)) + if self._strict_response_validation: + return cast(ResponseT, validate_type(type_=cast_to, value=data)) - return cast(ResponseT, construct_type(type_=cast_to, value=data)) + return cast(ResponseT, construct_type(type_=cast_to, value=data)) + except pydantic.ValidationError as err: + raise APIResponseValidationError(response=response, body=data) from err @property def qs(self) -> Querystring: diff --git a/src/openai/_models.py b/src/openai/_models.py index 6d5aad5963..5b8c96010f 100644 --- a/src/openai/_models.py +++ b/src/openai/_models.py @@ -263,6 +263,19 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: return construct_type(value=value, type_=type_) +def is_basemodel(type_: type) -> bool: + """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" + origin = get_origin(type_) or type_ + if is_union(type_): + for variant in get_args(type_): + if is_basemodel(variant): + return True + + return False + + return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) + + def construct_type(*, value: object, type_: type) -> object: """Loose coercion to the expected type with construction of nested values. diff --git a/src/openai/_response.py b/src/openai/_response.py index 3cc8fd8cc1..933c37525e 100644 --- a/src/openai/_response.py +++ b/src/openai/_response.py @@ -1,17 +1,17 @@ from __future__ import annotations import inspect +import logging import datetime import functools from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin import httpx -import pydantic from ._types import NoneType, UnknownResponse, BinaryResponseContent from ._utils import is_given -from ._models import BaseModel +from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER from ._exceptions import APIResponseValidationError @@ -23,6 +23,8 @@ P = ParamSpec("P") R = TypeVar("R") +log: logging.Logger = logging.getLogger(__name__) + class APIResponse(Generic[R]): _cast_to: type[R] @@ -174,6 +176,18 @@ def _parse(self) -> R: # in the response, e.g. application/json; charset=utf-8 content_type, *_ = response.headers.get("content-type").split(";") if content_type != "application/json": + if is_basemodel(cast_to): + try: + data = response.json() + except Exception as exc: + log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) + else: + return self._client._process_response_data( + data=data, + cast_to=cast_to, # type: ignore + response=response, + ) + if self._client._strict_response_validation: raise APIResponseValidationError( response=response, @@ -188,14 +202,11 @@ def _parse(self) -> R: data = response.json() - try: - return self._client._process_response_data( - data=data, - cast_to=cast_to, # type: ignore - response=response, - ) - except pydantic.ValidationError as err: - raise APIResponseValidationError(response=response, body=data) from err + return self._client._process_response_data( + data=data, + cast_to=cast_to, # type: ignore + response=response, + ) @override def __repr__(self) -> str: diff --git a/tests/test_client.py b/tests/test_client.py index e295d193e8..c5dbfe4bfe 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -401,6 +401,27 @@ class Model2(BaseModel): assert isinstance(response, Model1) assert response.foo == 1 + @pytest.mark.respx(base_url=base_url) + def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + """ + Response that sets Content-Type to something other than application/json but returns json data + """ + + class Model(BaseModel): + foo: int + + respx_mock.get("/foo").mock( + return_value=httpx.Response( + 200, + content=json.dumps({"foo": 2}), + headers={"Content-Type": "application/text"}, + ) + ) + + response = self.client.get("/foo", cast_to=Model) + assert isinstance(response, Model) + assert response.foo == 2 + def test_base_url_env(self) -> None: with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"): client = OpenAI(api_key=api_key, _strict_response_validation=True) @@ -939,6 +960,27 @@ class Model2(BaseModel): assert isinstance(response, Model1) assert response.foo == 1 + @pytest.mark.respx(base_url=base_url) + async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + """ + Response that sets Content-Type to something other than application/json but returns json data + """ + + class Model(BaseModel): + foo: int + + respx_mock.get("/foo").mock( + return_value=httpx.Response( + 200, + content=json.dumps({"foo": 2}), + headers={"Content-Type": "application/text"}, + ) + ) + + response = await self.client.get("/foo", cast_to=Model) + assert isinstance(response, Model) + assert response.foo == 2 + def test_base_url_env(self) -> None: with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"): client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)