Skip to content

Commit

Permalink
chore(client): fix parsing union responses when non-json is returned (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] authored Aug 19, 2024
1 parent 8cde9ad commit 0f2facf
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def is_basemodel(type_: type) -> bool:

def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
if not inspect.isclass(origin):
return False
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)


Expand Down
22 changes: 21 additions & 1 deletion tests/test_legacy_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import cast
from typing import Any, Union, cast
from typing_extensions import Annotated

import httpx
Expand Down Expand Up @@ -81,3 +81,23 @@ def test_response_parse_annotated_type(client: OpenAI) -> None:
)
assert obj.foo == "hello!"
assert obj.bar == 2


class OtherModel(pydantic.BaseModel):
a: str


@pytest.mark.parametrize("client", [False], indirect=True) # loose validation
def test_response_parse_expect_model_union_non_json_content(client: OpenAI) -> None:
response = LegacyAPIResponse(
raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}),
client=client,
stream=False,
stream_cls=None,
cast_to=str,
options=FinalRequestOptions.construct(method="get", url="/foo"),
)

obj = response.parse(to=cast(Any, Union[CustomModel, OtherModel]))
assert isinstance(obj, str)
assert obj == "foo"
39 changes: 38 additions & 1 deletion tests/test_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, cast
from typing import Any, List, Union, cast
from typing_extensions import Annotated

import httpx
Expand Down Expand Up @@ -188,3 +188,40 @@ async def test_async_response_parse_annotated_type(async_client: AsyncOpenAI) ->
)
assert obj.foo == "hello!"
assert obj.bar == 2


class OtherModel(BaseModel):
a: str


@pytest.mark.parametrize("client", [False], indirect=True) # loose validation
def test_response_parse_expect_model_union_non_json_content(client: OpenAI) -> None:
response = APIResponse(
raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}),
client=client,
stream=False,
stream_cls=None,
cast_to=str,
options=FinalRequestOptions.construct(method="get", url="/foo"),
)

obj = response.parse(to=cast(Any, Union[CustomModel, OtherModel]))
assert isinstance(obj, str)
assert obj == "foo"


@pytest.mark.asyncio
@pytest.mark.parametrize("async_client", [False], indirect=True) # loose validation
async def test_async_response_parse_expect_model_union_non_json_content(async_client: AsyncOpenAI) -> None:
response = AsyncAPIResponse(
raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}),
client=async_client,
stream=False,
stream_cls=None,
cast_to=str,
options=FinalRequestOptions.construct(method="get", url="/foo"),
)

obj = await response.parse(to=cast(Any, Union[CustomModel, OtherModel]))
assert isinstance(obj, str)
assert obj == "foo"

0 comments on commit 0f2facf

Please sign in to comment.