From 53a8b27365c085227e449812779c5bafadada9fb Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 15 May 2024 02:19:36 +0000 Subject: [PATCH] feat(api): OpenAPI spec update via Stainless API (#39) --- .github/workflows/ci.yml | 16 +--- .stats.yml | 2 +- CONTRIBUTING.md | 2 +- api.md | 12 ++- requirements-dev.lock | 4 +- requirements.lock | 4 +- scripts/format | 2 +- scripts/lint | 4 + scripts/test | 1 - src/groq/_models.py | 20 ++++- src/groq/resources/audio/transcriptions.py | 46 ++++++---- src/groq/resources/audio/translations.py | 46 ++++++---- src/groq/types/audio/__init__.py | 2 + .../audio/transcription_create_response.py | 86 +++++++++++++++++++ .../audio/translation_create_response.py | 71 +++++++++++++++ .../audio/test_transcriptions.py | 18 ++-- .../api_resources/audio/test_translations.py | 18 ++-- tests/test_models.py | 8 +- tests/test_transform.py | 22 ++--- 19 files changed, 288 insertions(+), 96 deletions(-) create mode 100644 src/groq/types/audio/transcription_create_response.py create mode 100644 src/groq/types/audio/translation_create_response.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e9f15a..6fcd6ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,20 +25,10 @@ jobs: RYE_INSTALL_OPTION: '--yes' - name: Install dependencies - run: | - rye sync --all-features - - - name: Run ruff - run: | - rye run check:ruff + run: rye sync --all-features - - name: Run type checking - run: | - rye run typecheck - - - name: Ensure importable - run: | - rye run python -c 'import groq' + - name: Run lints + run: ./scripts/lint test: name: test runs-on: ubuntu-latest diff --git a/.stats.yml b/.stats.yml index 8c47545..fc5085b 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,2 +1,2 @@ configured_endpoints: 6 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/groqcloud%2Fgroqcloud-0a3089666368ff1ff668f2a73ea3b40d8b20420d8403a18579a1168dd67f2220.yml +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/groqcloud%2Fgroqcloud-c28de228634e737a173375583a09eef5e0d7fa81fcdf7090d14d194e6ef4fdc5.yml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 143dd07..abb6cea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -59,7 +59,7 @@ If you’d like to use the repository from source, you can either install from g To install via git: ```bash -pip install git+ssh://git@github.com/groq/groq-python.git +pip install git+ssh://git@github.com/groq/groq-python#main.git ``` Alternatively, you can build from source and install the wheel file: diff --git a/api.md b/api.md index c2d54fc..a5c9c3f 100644 --- a/api.md +++ b/api.md @@ -25,18 +25,24 @@ from groq.types import Translation Types: ```python -from groq.types.audio import Transcription +from groq.types.audio import Transcription, TranscriptionCreateResponse ``` Methods: -- client.audio.transcriptions.create(\*\*params) -> Transcription +- client.audio.transcriptions.create(\*\*params) -> TranscriptionCreateResponse ## Translations +Types: + +```python +from groq.types.audio import TranslationCreateResponse +``` + Methods: -- client.audio.translations.create(\*\*params) -> Translation +- client.audio.translations.create(\*\*params) -> TranslationCreateResponse # Models diff --git a/requirements-dev.lock b/requirements-dev.lock index bcf3065..0f4ae6c 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -59,9 +59,9 @@ pluggy==1.3.0 # via pytest py==1.11.0 # via pytest -pydantic==2.4.2 +pydantic==2.7.1 # via groq -pydantic-core==2.10.1 +pydantic-core==2.18.2 # via pydantic pyright==1.1.359 pytest==7.1.1 diff --git a/requirements.lock b/requirements.lock index 991544f..2bf6a21 100644 --- a/requirements.lock +++ b/requirements.lock @@ -29,9 +29,9 @@ httpx==0.25.2 idna==3.4 # via anyio # via httpx -pydantic==2.4.2 +pydantic==2.7.1 # via groq -pydantic-core==2.10.1 +pydantic-core==2.18.2 # via pydantic sniffio==1.3.0 # via anyio diff --git a/scripts/format b/scripts/format index 2a9ea46..667ec2d 100755 --- a/scripts/format +++ b/scripts/format @@ -4,5 +4,5 @@ set -e cd "$(dirname "$0")/.." +echo "==> Running formatters" rye run format - diff --git a/scripts/lint b/scripts/lint index 0cc68b5..a65cb43 100755 --- a/scripts/lint +++ b/scripts/lint @@ -4,5 +4,9 @@ set -e cd "$(dirname "$0")/.." +echo "==> Running lints" rye run lint +echo "==> Making sure it imports" +rye run python -c 'import groq' + diff --git a/scripts/test b/scripts/test index be01d04..b3ace90 100755 --- a/scripts/test +++ b/scripts/test @@ -52,6 +52,5 @@ else echo fi -# Run tests echo "==> Running tests" rye run pytest "$@" diff --git a/src/groq/_models.py b/src/groq/_models.py index ff3f54e..75c68cc 100644 --- a/src/groq/_models.py +++ b/src/groq/_models.py @@ -62,7 +62,7 @@ from ._constants import RAW_RESPONSE_HEADER if TYPE_CHECKING: - from pydantic_core.core_schema import ModelField, ModelFieldsSchema + from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema __all__ = ["BaseModel", "GenericModel"] @@ -251,7 +251,9 @@ def model_dump( exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool = True, + warnings: bool | Literal["none", "warn", "error"] = True, + context: dict[str, Any] | None = None, + serialize_as_any: bool = False, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -279,6 +281,10 @@ def model_dump( raise ValueError("round_trip is only supported in Pydantic v2") if warnings != True: raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") return super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, @@ -300,7 +306,9 @@ def model_dump_json( exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool = True, + warnings: bool | Literal["none", "warn", "error"] = True, + context: dict[str, Any] | None = None, + serialize_as_any: bool = False, ) -> str: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json @@ -324,6 +332,10 @@ def model_dump_json( raise ValueError("round_trip is only supported in Pydantic v2") if warnings != True: raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, @@ -550,7 +562,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, field_schema = field["schema"] if field_schema["type"] == "literal": - for entry in field_schema["expected"]: + for entry in cast("LiteralSchema", field_schema)["expected"]: if isinstance(entry, str): mapping[entry] = variant else: diff --git a/src/groq/resources/audio/transcriptions.py b/src/groq/resources/audio/transcriptions.py index 0ab55c5..c235030 100644 --- a/src/groq/resources/audio/transcriptions.py +++ b/src/groq/resources/audio/transcriptions.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Union, Mapping, cast +from typing import Any, List, Union, Mapping, cast from typing_extensions import Literal import httpx @@ -26,7 +26,7 @@ from ..._base_client import ( make_request_options, ) -from ...types.audio.transcription import Transcription +from ...types.audio.transcription_create_response import TranscriptionCreateResponse __all__ = ["TranscriptionsResource", "AsyncTranscriptionsResource"] @@ -56,7 +56,7 @@ def create( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Transcription: + ) -> TranscriptionCreateResponse: """ Transcribes audio into the input language. @@ -115,14 +115,19 @@ def create( # sent to the server will contain a `boundary` parameter, e.g. # multipart/form-data; boundary=---abc-- extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - return self._post( - "/openai/v1/audio/transcriptions", - body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams), - files=files, - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + return cast( + TranscriptionCreateResponse, + self._post( + "/openai/v1/audio/transcriptions", + body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams), + files=files, + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=cast( + Any, TranscriptionCreateResponse + ), # Union types cannot be passed in as arguments in the type system ), - cast_to=Transcription, ) @@ -151,7 +156,7 @@ async def create( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Transcription: + ) -> TranscriptionCreateResponse: """ Transcribes audio into the input language. @@ -210,14 +215,19 @@ async def create( # sent to the server will contain a `boundary` parameter, e.g. # multipart/form-data; boundary=---abc-- extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - return await self._post( - "/openai/v1/audio/transcriptions", - body=await async_maybe_transform(body, transcription_create_params.TranscriptionCreateParams), - files=files, - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + return cast( + TranscriptionCreateResponse, + await self._post( + "/openai/v1/audio/transcriptions", + body=await async_maybe_transform(body, transcription_create_params.TranscriptionCreateParams), + files=files, + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=cast( + Any, TranscriptionCreateResponse + ), # Union types cannot be passed in as arguments in the type system ), - cast_to=Transcription, ) diff --git a/src/groq/resources/audio/translations.py b/src/groq/resources/audio/translations.py index 6267909..da09ad0 100644 --- a/src/groq/resources/audio/translations.py +++ b/src/groq/resources/audio/translations.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Union, Mapping, cast +from typing import Any, Union, Mapping, cast from typing_extensions import Literal import httpx @@ -26,7 +26,7 @@ from ..._base_client import ( make_request_options, ) -from ...types.translation import Translation +from ...types.audio.translation_create_response import TranslationCreateResponse __all__ = ["TranslationsResource", "AsyncTranslationsResource"] @@ -54,7 +54,7 @@ def create( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Translation: + ) -> TranslationCreateResponse: """ Translates audio into English. @@ -100,14 +100,19 @@ def create( # sent to the server will contain a `boundary` parameter, e.g. # multipart/form-data; boundary=---abc-- extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - return self._post( - "/openai/v1/audio/translations", - body=maybe_transform(body, translation_create_params.TranslationCreateParams), - files=files, - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + return cast( + TranslationCreateResponse, + self._post( + "/openai/v1/audio/translations", + body=maybe_transform(body, translation_create_params.TranslationCreateParams), + files=files, + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=cast( + Any, TranslationCreateResponse + ), # Union types cannot be passed in as arguments in the type system ), - cast_to=Translation, ) @@ -134,7 +139,7 @@ async def create( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Translation: + ) -> TranslationCreateResponse: """ Translates audio into English. @@ -180,14 +185,19 @@ async def create( # sent to the server will contain a `boundary` parameter, e.g. # multipart/form-data; boundary=---abc-- extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - return await self._post( - "/openai/v1/audio/translations", - body=await async_maybe_transform(body, translation_create_params.TranslationCreateParams), - files=files, - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + return cast( + TranslationCreateResponse, + await self._post( + "/openai/v1/audio/translations", + body=await async_maybe_transform(body, translation_create_params.TranslationCreateParams), + files=files, + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=cast( + Any, TranslationCreateResponse + ), # Union types cannot be passed in as arguments in the type system ), - cast_to=Translation, ) diff --git a/src/groq/types/audio/__init__.py b/src/groq/types/audio/__init__.py index ae3a015..29a99fe 100644 --- a/src/groq/types/audio/__init__.py +++ b/src/groq/types/audio/__init__.py @@ -5,3 +5,5 @@ from .transcription import Transcription as Transcription from .translation_create_params import TranslationCreateParams as TranslationCreateParams from .transcription_create_params import TranscriptionCreateParams as TranscriptionCreateParams +from .translation_create_response import TranslationCreateResponse as TranslationCreateResponse +from .transcription_create_response import TranscriptionCreateResponse as TranscriptionCreateResponse diff --git a/src/groq/types/audio/transcription_create_response.py b/src/groq/types/audio/transcription_create_response.py new file mode 100644 index 0000000..55a9484 --- /dev/null +++ b/src/groq/types/audio/transcription_create_response.py @@ -0,0 +1,86 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Union, Optional + +from ..._models import BaseModel +from .transcription import Transcription + +__all__ = [ + "TranscriptionCreateResponse", + "CreateTranscriptionResponseVerboseJson", + "CreateTranscriptionResponseVerboseJsonSegment", + "CreateTranscriptionResponseVerboseJsonWord", +] + + +class CreateTranscriptionResponseVerboseJsonSegment(BaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this + segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: List[int] + """Array of token IDs for the text content.""" + + +class CreateTranscriptionResponseVerboseJsonWord(BaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class CreateTranscriptionResponseVerboseJson(BaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[List[CreateTranscriptionResponseVerboseJsonSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[List[CreateTranscriptionResponseVerboseJsonWord]] = None + """Extracted words and their corresponding timestamps.""" + + +TranscriptionCreateResponse = Union[Transcription, CreateTranscriptionResponseVerboseJson] diff --git a/src/groq/types/audio/translation_create_response.py b/src/groq/types/audio/translation_create_response.py new file mode 100644 index 0000000..3102d4d --- /dev/null +++ b/src/groq/types/audio/translation_create_response.py @@ -0,0 +1,71 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Union, Optional + +from ..._models import BaseModel +from ..translation import Translation + +__all__ = [ + "TranslationCreateResponse", + "CreateTranslationResponseVerboseJson", + "CreateTranslationResponseVerboseJsonSegment", +] + + +class CreateTranslationResponseVerboseJsonSegment(BaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this + segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: List[int] + """Array of token IDs for the text content.""" + + +class CreateTranslationResponseVerboseJson(BaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the output translation (always `english`).""" + + text: str + """The translated text.""" + + segments: Optional[List[CreateTranslationResponseVerboseJsonSegment]] = None + """Segments of the translated text and their corresponding details.""" + + +TranslationCreateResponse = Union[Translation, CreateTranslationResponseVerboseJson] diff --git a/tests/api_resources/audio/test_transcriptions.py b/tests/api_resources/audio/test_transcriptions.py index b54784b..bb1a427 100644 --- a/tests/api_resources/audio/test_transcriptions.py +++ b/tests/api_resources/audio/test_transcriptions.py @@ -9,7 +9,7 @@ from groq import Groq, AsyncGroq from tests.utils import assert_matches_type -from groq.types.audio import Transcription +from groq.types.audio import TranscriptionCreateResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -23,7 +23,7 @@ def test_method_create(self, client: Groq) -> None: file=b"raw file contents", model="whisper-large-v3", ) - assert_matches_type(Transcription, transcription, path=["response"]) + assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) @parametrize def test_method_create_with_all_params(self, client: Groq) -> None: @@ -36,7 +36,7 @@ def test_method_create_with_all_params(self, client: Groq) -> None: temperature=0, timestamp_granularities=["word", "segment"], ) - assert_matches_type(Transcription, transcription, path=["response"]) + assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) @parametrize def test_raw_response_create(self, client: Groq) -> None: @@ -48,7 +48,7 @@ def test_raw_response_create(self, client: Groq) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" transcription = response.parse() - assert_matches_type(Transcription, transcription, path=["response"]) + assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) @parametrize def test_streaming_response_create(self, client: Groq) -> None: @@ -60,7 +60,7 @@ def test_streaming_response_create(self, client: Groq) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" transcription = response.parse() - assert_matches_type(Transcription, transcription, path=["response"]) + assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) assert cast(Any, response.is_closed) is True @@ -74,7 +74,7 @@ async def test_method_create(self, async_client: AsyncGroq) -> None: file=b"raw file contents", model="whisper-large-v3", ) - assert_matches_type(Transcription, transcription, path=["response"]) + assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) @parametrize async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> None: @@ -87,7 +87,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N temperature=0, timestamp_granularities=["word", "segment"], ) - assert_matches_type(Transcription, transcription, path=["response"]) + assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) @parametrize async def test_raw_response_create(self, async_client: AsyncGroq) -> None: @@ -99,7 +99,7 @@ async def test_raw_response_create(self, async_client: AsyncGroq) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" transcription = await response.parse() - assert_matches_type(Transcription, transcription, path=["response"]) + assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) @parametrize async def test_streaming_response_create(self, async_client: AsyncGroq) -> None: @@ -111,6 +111,6 @@ async def test_streaming_response_create(self, async_client: AsyncGroq) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" transcription = await response.parse() - assert_matches_type(Transcription, transcription, path=["response"]) + assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/audio/test_translations.py b/tests/api_resources/audio/test_translations.py index eae2a01..b6c540b 100644 --- a/tests/api_resources/audio/test_translations.py +++ b/tests/api_resources/audio/test_translations.py @@ -8,8 +8,8 @@ import pytest from groq import Groq, AsyncGroq -from groq.types import Translation from tests.utils import assert_matches_type +from groq.types.audio import TranslationCreateResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -23,7 +23,7 @@ def test_method_create(self, client: Groq) -> None: file=b"raw file contents", model="whisper-1", ) - assert_matches_type(Translation, translation, path=["response"]) + assert_matches_type(TranslationCreateResponse, translation, path=["response"]) @parametrize def test_method_create_with_all_params(self, client: Groq) -> None: @@ -34,7 +34,7 @@ def test_method_create_with_all_params(self, client: Groq) -> None: response_format="string", temperature=0, ) - assert_matches_type(Translation, translation, path=["response"]) + assert_matches_type(TranslationCreateResponse, translation, path=["response"]) @parametrize def test_raw_response_create(self, client: Groq) -> None: @@ -46,7 +46,7 @@ def test_raw_response_create(self, client: Groq) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" translation = response.parse() - assert_matches_type(Translation, translation, path=["response"]) + assert_matches_type(TranslationCreateResponse, translation, path=["response"]) @parametrize def test_streaming_response_create(self, client: Groq) -> None: @@ -58,7 +58,7 @@ def test_streaming_response_create(self, client: Groq) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" translation = response.parse() - assert_matches_type(Translation, translation, path=["response"]) + assert_matches_type(TranslationCreateResponse, translation, path=["response"]) assert cast(Any, response.is_closed) is True @@ -72,7 +72,7 @@ async def test_method_create(self, async_client: AsyncGroq) -> None: file=b"raw file contents", model="whisper-1", ) - assert_matches_type(Translation, translation, path=["response"]) + assert_matches_type(TranslationCreateResponse, translation, path=["response"]) @parametrize async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> None: @@ -83,7 +83,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N response_format="string", temperature=0, ) - assert_matches_type(Translation, translation, path=["response"]) + assert_matches_type(TranslationCreateResponse, translation, path=["response"]) @parametrize async def test_raw_response_create(self, async_client: AsyncGroq) -> None: @@ -95,7 +95,7 @@ async def test_raw_response_create(self, async_client: AsyncGroq) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" translation = await response.parse() - assert_matches_type(Translation, translation, path=["response"]) + assert_matches_type(TranslationCreateResponse, translation, path=["response"]) @parametrize async def test_streaming_response_create(self, async_client: AsyncGroq) -> None: @@ -107,6 +107,6 @@ async def test_streaming_response_create(self, async_client: AsyncGroq) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" translation = await response.parse() - assert_matches_type(Translation, translation, path=["response"]) + assert_matches_type(TranslationCreateResponse, translation, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/test_models.py b/tests/test_models.py index af5307e..bd5c305 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,7 +31,7 @@ class NestedModel(BaseModel): # mismatched types m = NestedModel.construct(nested="hello!") - assert m.nested == "hello!" + assert cast(Any, m.nested) == "hello!" def test_optional_nested_model() -> None: @@ -48,7 +48,7 @@ class NestedModel(BaseModel): # mismatched types m3 = NestedModel.construct(nested={"foo"}) assert isinstance(cast(Any, m3.nested), set) - assert m3.nested == {"foo"} + assert cast(Any, m3.nested) == {"foo"} def test_list_nested_model() -> None: @@ -323,7 +323,7 @@ class Model(BaseModel): assert len(m.items) == 2 assert isinstance(m.items[0], Submodel1) assert m.items[0].level == -1 - assert m.items[1] == 156 + assert cast(Any, m.items[1]) == 156 def test_union_of_lists() -> None: @@ -355,7 +355,7 @@ class Model(BaseModel): assert len(m.items) == 2 assert isinstance(m.items[0], SubModel1) assert m.items[0].level == -1 - assert m.items[1] == 156 + assert cast(Any, m.items[1]) == 156 def test_dict_of_union() -> None: diff --git a/tests/test_transform.py b/tests/test_transform.py index 52cdd4e..b1d6cbc 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -260,20 +260,22 @@ class MyModel(BaseModel): @parametrize @pytest.mark.asyncio async def test_pydantic_model_to_dictionary(use_async: bool) -> None: - assert await transform(MyModel(foo="hi!"), Any, use_async) == {"foo": "hi!"} - assert await transform(MyModel.construct(foo="hi!"), Any, use_async) == {"foo": "hi!"} + assert cast(Any, await transform(MyModel(foo="hi!"), Any, use_async)) == {"foo": "hi!"} + assert cast(Any, await transform(MyModel.construct(foo="hi!"), Any, use_async)) == {"foo": "hi!"} @parametrize @pytest.mark.asyncio async def test_pydantic_empty_model(use_async: bool) -> None: - assert await transform(MyModel.construct(), Any, use_async) == {} + assert cast(Any, await transform(MyModel.construct(), Any, use_async)) == {} @parametrize @pytest.mark.asyncio async def test_pydantic_unknown_field(use_async: bool) -> None: - assert await transform(MyModel.construct(my_untyped_field=True), Any, use_async) == {"my_untyped_field": True} + assert cast(Any, await transform(MyModel.construct(my_untyped_field=True), Any, use_async)) == { + "my_untyped_field": True + } @parametrize @@ -285,7 +287,7 @@ async def test_pydantic_mismatched_types(use_async: bool) -> None: params = await transform(model, Any, use_async) else: params = await transform(model, Any, use_async) - assert params == {"foo": True} + assert cast(Any, params) == {"foo": True} @parametrize @@ -297,7 +299,7 @@ async def test_pydantic_mismatched_object_type(use_async: bool) -> None: params = await transform(model, Any, use_async) else: params = await transform(model, Any, use_async) - assert params == {"foo": {"hello": "world"}} + assert cast(Any, params) == {"foo": {"hello": "world"}} class ModelNestedObjects(BaseModel): @@ -309,7 +311,7 @@ class ModelNestedObjects(BaseModel): async def test_pydantic_nested_objects(use_async: bool) -> None: model = ModelNestedObjects.construct(nested={"foo": "stainless"}) assert isinstance(model.nested, MyModel) - assert await transform(model, Any, use_async) == {"nested": {"foo": "stainless"}} + assert cast(Any, await transform(model, Any, use_async)) == {"nested": {"foo": "stainless"}} class ModelWithDefaultField(BaseModel): @@ -325,19 +327,19 @@ async def test_pydantic_default_field(use_async: bool) -> None: model = ModelWithDefaultField.construct() assert model.with_none_default is None assert model.with_str_default == "foo" - assert await transform(model, Any, use_async) == {} + assert cast(Any, await transform(model, Any, use_async)) == {} # should be included when the default value is explicitly given model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo") assert model.with_none_default is None assert model.with_str_default == "foo" - assert await transform(model, Any, use_async) == {"with_none_default": None, "with_str_default": "foo"} + assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": None, "with_str_default": "foo"} # should be included when a non-default value is explicitly given model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz") assert model.with_none_default == "bar" assert model.with_str_default == "baz" - assert await transform(model, Any, use_async) == {"with_none_default": "bar", "with_str_default": "baz"} + assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": "bar", "with_str_default": "baz"} class TypedDictIterableUnion(TypedDict):