Skip to content

Commit

Permalink
chore(internal): support parsing Annotated types (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Mar 8, 2024
1 parent 9dc7d15 commit f44efd5
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 16 deletions.
15 changes: 12 additions & 3 deletions src/anthropic/_legacy_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pydantic

from ._types import NoneType
from ._utils import is_given
from ._utils import is_given, extract_type_arg, is_annotated_type
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
Expand Down Expand Up @@ -174,6 +174,10 @@ def elapsed(self) -> datetime.timedelta:
return self.http_response.elapsed

def _parse(self, *, to: type[_T] | None = None) -> R | _T:
# unwrap `Annotated[T, ...]` -> `T`
if to and is_annotated_type(to):
to = extract_type_arg(to, 0)

if self._stream:
if to:
if not is_stream_class_type(to):
Expand Down Expand Up @@ -215,6 +219,11 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
)

cast_to = to if to is not None else self._cast_to

# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)

if cast_to is NoneType:
return cast(R, None)

Expand Down Expand Up @@ -315,7 +324,7 @@ def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIRespon

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "true"

kwargs["extra_headers"] = extra_headers
Expand All @@ -332,7 +341,7 @@ def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P

@functools.wraps(func)
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "true"

kwargs["extra_headers"] = extra_headers
Expand Down
14 changes: 13 additions & 1 deletion src/anthropic/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@
AnyMapping,
HttpxRequestFiles,
)
from ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given
from ._utils import (
is_list,
is_given,
is_mapping,
parse_date,
parse_datetime,
strip_not_given,
extract_type_arg,
is_annotated_type,
)
from ._compat import (
PYDANTIC_V2,
ConfigDict,
Expand Down Expand Up @@ -275,6 +284,9 @@ def construct_type(*, value: object, type_: type) -> object:
If the given value does not match the expected type then it is returned as-is.
"""
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
type_ = extract_type_arg(type_, 0)

# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
Expand Down
27 changes: 18 additions & 9 deletions src/anthropic/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pydantic

from ._types import NoneType
from ._utils import is_given, extract_type_var_from_base
from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
Expand Down Expand Up @@ -121,6 +121,10 @@ def __repr__(self) -> str:
)

def _parse(self, *, to: type[_T] | None = None) -> R | _T:
# unwrap `Annotated[T, ...]` -> `T`
if to and is_annotated_type(to):
to = extract_type_arg(to, 0)

if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
Expand Down Expand Up @@ -162,6 +166,11 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
)

cast_to = to if to is not None else self._cast_to

# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)

if cast_to is NoneType:
return cast(R, None)

Expand Down Expand Up @@ -634,7 +643,7 @@ def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseCo

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"

kwargs["extra_headers"] = extra_headers
Expand All @@ -655,7 +664,7 @@ def async_to_streamed_response_wrapper(

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"

kwargs["extra_headers"] = extra_headers
Expand All @@ -679,7 +688,7 @@ def to_custom_streamed_response_wrapper(

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls

Expand All @@ -704,7 +713,7 @@ def async_to_custom_streamed_response_wrapper(

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls

Expand All @@ -724,7 +733,7 @@ def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"

kwargs["extra_headers"] = extra_headers
Expand All @@ -741,7 +750,7 @@ def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P

@functools.wraps(func)
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"

kwargs["extra_headers"] = extra_headers
Expand All @@ -763,7 +772,7 @@ def to_custom_raw_response_wrapper(

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls

Expand All @@ -786,7 +795,7 @@ def async_to_custom_raw_response_wrapper(

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]:
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls

Expand Down
19 changes: 19 additions & 0 deletions tests/test_legacy_response.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
from typing import cast
from typing_extensions import Annotated

import httpx
import pytest
Expand Down Expand Up @@ -63,3 +65,20 @@ def test_response_parse_custom_model(client: Anthropic) -> None:
obj = response.parse(to=CustomModel)
assert obj.foo == "hello!"
assert obj.bar == 2


def test_response_parse_annotated_type(client: Anthropic) -> None:
response = LegacyAPIResponse(
raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
client=client,
stream=False,
stream_cls=None,
cast_to=str,
options=FinalRequestOptions.construct(method="get", url="/foo"),
)

obj = response.parse(
to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
)
assert obj.foo == "hello!"
assert obj.bar == 2
16 changes: 14 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
from typing import Any, Dict, List, Union, Optional, cast
from datetime import datetime, timezone
from typing_extensions import Literal
from typing_extensions import Literal, Annotated

import pytest
import pydantic
from pydantic import Field

from anthropic._compat import PYDANTIC_V2, parse_obj, model_dump, model_json
from anthropic._models import BaseModel
from anthropic._models import BaseModel, construct_type


class BasicModel(BaseModel):
Expand Down Expand Up @@ -571,3 +571,15 @@ class OurModel(BaseModel):
foo: Optional[str] = None

takes_pydantic(OurModel())


def test_annotated_types() -> None:
class Model(BaseModel):
value: str

m = construct_type(
value={"value": "foo"},
type_=cast(Any, Annotated[Model, "random metadata"]),
)
assert isinstance(m, Model)
assert m.value == "foo"
37 changes: 36 additions & 1 deletion tests/test_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import List
from typing import List, cast
from typing_extensions import Annotated

import httpx
import pytest
Expand Down Expand Up @@ -157,3 +158,37 @@ async def test_async_response_parse_custom_model(async_client: AsyncAnthropic) -
obj = await response.parse(to=CustomModel)
assert obj.foo == "hello!"
assert obj.bar == 2


def test_response_parse_annotated_type(client: Anthropic) -> None:
response = APIResponse(
raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
client=client,
stream=False,
stream_cls=None,
cast_to=str,
options=FinalRequestOptions.construct(method="get", url="/foo"),
)

obj = response.parse(
to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
)
assert obj.foo == "hello!"
assert obj.bar == 2


async def test_async_response_parse_annotated_type(async_client: AsyncAnthropic) -> None:
response = AsyncAPIResponse(
raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
client=async_client,
stream=False,
stream_cls=None,
cast_to=str,
options=FinalRequestOptions.construct(method="get", url="/foo"),
)

obj = await response.parse(
to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
)
assert obj.foo == "hello!"
assert obj.bar == 2
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
is_list,
is_list_type,
is_union_type,
extract_type_arg,
is_annotated_type,
)
from anthropic._compat import PYDANTIC_V2, field_outer_type, get_model_fields
from anthropic._models import BaseModel
Expand Down Expand Up @@ -49,6 +51,10 @@ def assert_matches_type(
path: list[str],
allow_none: bool = False,
) -> None:
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
type_ = extract_type_arg(type_, 0)

if allow_none and value is None:
return

Expand Down

0 comments on commit f44efd5

Please sign in to comment.