diff --git a/pyproject.toml b/pyproject.toml index 71a02dc344..8200e84a7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,12 @@ authors = [ { name = "OpenAI", email = "support@openai.com" }, ] dependencies = [ - "httpx>=0.23.0, <1", - "pydantic>=1.9.0, <3", - "typing-extensions>=4.5, <5", - "anyio>=3.5.0, <4", - "distro>=1.7.0, <2", - "tqdm > 4" + "httpx>=0.23.0, <1", + "pydantic>=1.9.0, <3", + "typing-extensions>=4.5, <5", + "anyio>=3.5.0, <4", + "distro>=1.7.0, <2", + "tqdm > 4" ] requires-python = ">= 3.7.1" @@ -30,17 +30,17 @@ openai = "openai.cli:main" [tool.rye] managed = true dev-dependencies = [ - "pyright==1.1.326", - "mypy==1.4.1", - "black==23.3.0", - "respx==0.19.2", - "pytest==7.1.1", - "pytest-asyncio==0.21.1", - "ruff==0.0.282", - "isort==5.10.1", - "time-machine==2.9.0", - "nox==2023.4.22", - "types-tqdm > 4" + "pyright==1.1.332", + "mypy==1.4.1", + "black==23.3.0", + "respx==0.19.2", + "pytest==7.1.1", + "pytest-asyncio==0.21.1", + "ruff==0.0.282", + "isort==5.10.1", + "time-machine==2.9.0", + "nox==2023.4.22", + "types-tqdm > 4" ] [tool.rye.scripts] diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 962ee0e2d7..c7fb0889b2 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -315,8 +315,11 @@ async def get_next_page(self: AsyncPageT) -> AsyncPageT: return await self._client._request_api_list(self._model, page=self.__class__, options=options) -class BaseClient: - _client: httpx.Client | httpx.AsyncClient +_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient]) + + +class BaseClient(Generic[_HttpxClientT]): + _client: _HttpxClientT _version: str _base_url: URL max_retries: int @@ -730,7 +733,7 @@ def _idempotency_key(self) -> str: return f"stainless-python-retry-{uuid.uuid4()}" -class SyncAPIClient(BaseClient): +class SyncAPIClient(BaseClient[httpx.Client]): _client: httpx.Client _has_custom_http_client: bool _default_stream_cls: type[Stream[Any]] | None = None @@ -1136,7 +1139,7 @@ def get_api_list( return self._request_api_list(model, page, opts) -class AsyncAPIClient(BaseClient): +class AsyncAPIClient(BaseClient[httpx.AsyncClient]): _client: httpx.AsyncClient _has_custom_http_client: bool _default_stream_cls: type[AsyncStream[Any]] | None = None diff --git a/src/openai/_exceptions.py b/src/openai/_exceptions.py index 79ddd6607d..b79ac5fd64 100644 --- a/src/openai/_exceptions.py +++ b/src/openai/_exceptions.py @@ -92,31 +92,31 @@ def __init__(self, request: httpx.Request) -> None: class BadRequestError(APIStatusError): - status_code: Literal[400] = 400 + status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride] class AuthenticationError(APIStatusError): - status_code: Literal[401] = 401 + status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride] class PermissionDeniedError(APIStatusError): - status_code: Literal[403] = 403 + status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride] class NotFoundError(APIStatusError): - status_code: Literal[404] = 404 + status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride] class ConflictError(APIStatusError): - status_code: Literal[409] = 409 + status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride] class UnprocessableEntityError(APIStatusError): - status_code: Literal[422] = 422 + status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride] class RateLimitError(APIStatusError): - status_code: Literal[429] = 429 + status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride] class InternalServerError(APIStatusError): diff --git a/tests/test_client.py b/tests/test_client.py index 122845fe3e..553b54e5db 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -29,7 +29,7 @@ api_key = "My API Key" -def _get_params(client: BaseClient) -> dict[str, str]: +def _get_params(client: BaseClient[Any]) -> dict[str, str]: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) url = httpx.URL(request.url) return dict(url.params)