From 7c8ca177730103bfda4b52dff97b16897956a2e2 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 23 Jul 2023 23:41:50 +0200 Subject: [PATCH] Use mypy `strict` (#2180) Co-authored-by: Marcelo Trylesinski --- pyproject.toml | 17 ++++--- scripts/check | 5 +- starlette/_exception_handler.py | 11 ++--- starlette/_utils.py | 21 ++++++++- starlette/applications.py | 76 ++++++++++++++---------------- starlette/authentication.py | 30 +++++++----- starlette/concurrency.py | 18 +++---- starlette/config.py | 24 +++++----- starlette/convertors.py | 14 +++--- starlette/datastructures.py | 17 ++++--- starlette/endpoints.py | 4 +- starlette/exceptions.py | 2 +- starlette/formparsers.py | 2 +- starlette/middleware/__init__.py | 2 +- starlette/middleware/errors.py | 4 +- starlette/middleware/exceptions.py | 2 +- starlette/middleware/wsgi.py | 12 +++-- starlette/requests.py | 4 +- starlette/responses.py | 2 +- starlette/routing.py | 74 ++++++++++++++++++----------- starlette/schemas.py | 18 ++++--- starlette/staticfiles.py | 2 +- starlette/templating.py | 29 +++++------- starlette/testclient.py | 34 ++++++------- starlette/types.py | 13 +++++ starlette/websockets.py | 4 +- tests/test_convertors.py | 2 +- tests/test_formparsers.py | 7 +-- tests/test_requests.py | 4 +- tests/test_responses.py | 5 +- tests/test_routing.py | 2 +- 31 files changed, 261 insertions(+), 200 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f17ffb09d..cb876b52e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,18 +58,21 @@ select = ["E", "F", "I"] combine-as-imports = true [tool.mypy] -disallow_untyped_defs = true +strict = true ignore_missing_imports = true -show_error_codes = true +python_version = "3.8" [[tool.mypy.overrides]] module = "starlette.testclient.*" -no_implicit_optional = false +implicit_optional = true -[[tool.mypy.overrides]] -module = "tests.*" -disallow_untyped_defs = false -check_untyped_defs = true +# TODO: Uncomment the following configuration when +# https://github.com/python/mypy/issues/10045 is solved. In the meantime, +# we are calling `mypy tests` directly. Check `scripts/check` for more info. +# [[tool.mypy.overrides]] +# module = "tests.*" +# disallow_untyped_defs = false +# check_untyped_defs = true [tool.pytest.ini_options] addopts = "-rxXs --strict-config --strict-markers" diff --git a/scripts/check b/scripts/check index 076ede9eb..cc515ddaf 100755 --- a/scripts/check +++ b/scripts/check @@ -10,5 +10,8 @@ set -x ./scripts/sync-version ${PREFIX}black --check --diff $SOURCE_FILES -${PREFIX}mypy $SOURCE_FILES +# TODO: Use `[[tool.mypy.overrides]]` on the `pyproject.toml` when the mypy issue is solved: +# github.com/python/mypy/issues/10045. Check github.com/encode/starlette/pull/2180 for more info. +${PREFIX}mypy starlette +${PREFIX}mypy tests --disable-error-code no-untyped-def --disable-error-code no-untyped-call ${PREFIX}ruff check $SOURCE_FILES diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 8a9beb3b2..ea9ffbe9d 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -4,18 +4,16 @@ from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.responses import Response -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send from starlette.websockets import WebSocket -Handler = typing.Callable[..., typing.Any] -ExceptionHandlers = typing.Dict[typing.Any, Handler] -StatusHandlers = typing.Dict[int, Handler] +ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler] +StatusHandlers = typing.Dict[int, ExceptionHandler] def _lookup_exception_handler( exc_handlers: ExceptionHandlers, exc: Exception -) -> typing.Optional[Handler]: +) -> typing.Optional[ExceptionHandler]: for cls in type(exc).__mro__: if cls in exc_handlers: return exc_handlers[cls] @@ -61,7 +59,6 @@ async def sender(message: Message) -> None: raise RuntimeError(msg) from exc if scope["type"] == "http": - response: Response if is_async_callable(handler): response = await handler(conn, exc) else: diff --git a/starlette/_utils.py b/starlette/_utils.py index 5a6e6965b..f06dd557c 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -1,9 +1,28 @@ import asyncio import functools +import sys import typing +if sys.version_info >= (3, 10): # pragma: no cover + from typing import TypeGuard +else: # pragma: no cover + from typing_extensions import TypeGuard -def is_async_callable(obj: typing.Any) -> bool: +T = typing.TypeVar("T") +AwaitableCallable = typing.Callable[..., typing.Awaitable[T]] + + +@typing.overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: + ... + + +@typing.overload +def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: + ... + + +def is_async_callable(obj: typing.Any) -> typing.Any: while isinstance(obj, functools.partial): obj = obj.func diff --git a/starlette/applications.py b/starlette/applications.py index 344a4a37f..cef4ace71 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing import warnings @@ -9,7 +11,8 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router -from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send +from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send +from starlette.websockets import WebSocket AppType = typing.TypeVar("AppType", bound="Starlette") @@ -47,19 +50,11 @@ class Starlette: def __init__( self: "AppType", debug: bool = False, - routes: typing.Optional[typing.Sequence[BaseRoute]] = None, - middleware: typing.Optional[typing.Sequence[Middleware]] = None, - exception_handlers: typing.Optional[ - typing.Mapping[ - typing.Any, - typing.Callable[ - [Request, Exception], - typing.Union[Response, typing.Awaitable[Response]], - ], - ] - ] = None, - on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, - on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, + routes: typing.Sequence[BaseRoute] | None = None, + middleware: typing.Sequence[Middleware] | None = None, + exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None, + on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, lifespan: typing.Optional[Lifespan["AppType"]] = None, ) -> None: # The lifespan context function is a newer style that replaces @@ -120,18 +115,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.middleware_stack = self.build_middleware_stack() await self.middleware_stack(scope, receive, send) - def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover - return self.router.on_event(event_type) + def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] + return self.router.on_event(event_type) # pragma: nocover - def mount( - self, path: str, app: ASGIApp, name: typing.Optional[str] = None - ) -> None: # pragma: nocover - self.router.mount(path, app=app, name=name) + def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: + self.router.mount(path, app=app, name=name) # pragma: no cover - def host( - self, host: str, app: ASGIApp, name: typing.Optional[str] = None - ) -> None: # pragma: no cover - self.router.host(host, app=app, name=name) + def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: + self.router.host(host, app=app, name=name) # pragma: no cover def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: if self.middleware_stack is not None: # pragma: no cover @@ -140,20 +131,20 @@ def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: def add_exception_handler( self, - exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], - handler: typing.Callable, + exc_class_or_status_code: int | typing.Type[Exception], + handler: ExceptionHandler, ) -> None: # pragma: no cover self.exception_handlers[exc_class_or_status_code] = handler def add_event_handler( - self, event_type: str, func: typing.Callable + self, event_type: str, func: typing.Callable # type: ignore[type-arg] ) -> None: # pragma: no cover self.router.add_event_handler(event_type, func) def add_route( self, path: str, - route: typing.Callable, + route: typing.Callable[[Request], typing.Awaitable[Response] | Response], methods: typing.Optional[typing.List[str]] = None, name: typing.Optional[str] = None, include_in_schema: bool = True, @@ -163,20 +154,23 @@ def add_route( ) def add_websocket_route( - self, path: str, route: typing.Callable, name: typing.Optional[str] = None + self, + path: str, + route: typing.Callable[[WebSocket], typing.Awaitable[None]], + name: str | None = None, ) -> None: # pragma: no cover self.router.add_websocket_route(path, route, name=name) def exception_handler( - self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]] - ) -> typing.Callable: + self, exc_class_or_status_code: int | typing.Type[Exception] + ) -> typing.Callable: # type: ignore[type-arg] warnings.warn( "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 "Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501 DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 self.add_exception_handler(exc_class_or_status_code, func) return func @@ -185,10 +179,10 @@ def decorator(func: typing.Callable) -> typing.Callable: def route( self, path: str, - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, + methods: typing.List[str] | None = None, + name: str | None = None, include_in_schema: bool = True, - ) -> typing.Callable: + ) -> typing.Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -202,7 +196,7 @@ def route( DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 self.router.add_route( path, func, @@ -215,8 +209,8 @@ def decorator(func: typing.Callable) -> typing.Callable: return decorator def websocket_route( - self, path: str, name: typing.Optional[str] = None - ) -> typing.Callable: + self, path: str, name: str | None = None + ) -> typing.Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -230,13 +224,13 @@ def websocket_route( DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 self.router.add_websocket_route(path, func, name=name) return func return decorator - def middleware(self, middleware_type: str) -> typing.Callable: + def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -253,7 +247,7 @@ def middleware(self, middleware_type: str) -> typing.Callable: middleware_type == "http" ), 'Currently only middleware("http") is supported.' - def decorator(func: typing.Callable) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 self.add_middleware(BaseHTTPMiddleware, dispatch=func) return func diff --git a/starlette/authentication.py b/starlette/authentication.py index 32713eb17..494c50a57 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -1,15 +1,21 @@ import functools import inspect +import sys import typing from urllib.parse import urlencode +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + from starlette._utils import is_async_callable from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request -from starlette.responses import RedirectResponse, Response +from starlette.responses import RedirectResponse from starlette.websockets import WebSocket -_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable) +_P = ParamSpec("_P") def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: @@ -23,10 +29,14 @@ def requires( scopes: typing.Union[str, typing.Sequence[str]], status_code: int = 403, redirect: typing.Optional[str] = None, -) -> typing.Callable[[_CallableType], _CallableType]: +) -> typing.Callable[ + [typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any] +]: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) - def decorator(func: typing.Callable) -> typing.Callable: + def decorator( + func: typing.Callable[_P, typing.Any] + ) -> typing.Callable[_P, typing.Any]: sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": @@ -40,9 +50,7 @@ def decorator(func: typing.Callable) -> typing.Callable: if type_ == "websocket": # Handle websocket functions. (Always async) @functools.wraps(func) - async def websocket_wrapper( - *args: typing.Any, **kwargs: typing.Any - ) -> None: + async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: websocket = kwargs.get( "websocket", args[idx] if idx < len(args) else None ) @@ -58,9 +66,7 @@ async def websocket_wrapper( elif is_async_callable(func): # Handle async request/response functions. @functools.wraps(func) - async def async_wrapper( - *args: typing.Any, **kwargs: typing.Any - ) -> Response: + async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) @@ -80,7 +86,7 @@ async def async_wrapper( else: # Handle sync request/response functions. @functools.wraps(func) - def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: + def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) @@ -97,7 +103,7 @@ def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: return sync_wrapper - return decorator # type: ignore[return-value] + return decorator class AuthenticationError(Exception): diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 5c76cb3df..ca6033c0f 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,21 +1,13 @@ import functools -import sys import typing import warnings -import anyio - -if sys.version_info >= (3, 10): # pragma: no cover - from typing import ParamSpec -else: # pragma: no cover - from typing_extensions import ParamSpec - +import anyio.to_thread T = typing.TypeVar("T") -P = ParamSpec("P") -async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: +async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501 warnings.warn( "run_until_first_complete is deprecated " "and will be removed in a future version.", @@ -24,7 +16,7 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) - async with anyio.create_task_group() as task_group: - async def run(func: typing.Callable[[], typing.Coroutine]) -> None: + async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] # noqa: E501 await func() task_group.cancel_scope.cancel() @@ -32,8 +24,10 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: task_group.start_soon(run, functools.partial(func, **kwargs)) +# TODO: We should use `ParamSpec` here, but mypy doesn't support it yet. +# Check https://github.com/python/mypy/issues/12278 for more details. async def run_in_threadpool( - func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs + func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any ) -> T: if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here diff --git a/starlette/config.py b/starlette/config.py index 795232cf6..173955006 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -1,6 +1,5 @@ import os import typing -from collections.abc import MutableMapping from pathlib import Path @@ -12,16 +11,16 @@ class EnvironError(Exception): pass -class Environ(MutableMapping): - def __init__(self, environ: typing.MutableMapping = os.environ): +class Environ(typing.MutableMapping[str, str]): + def __init__(self, environ: typing.MutableMapping[str, str] = os.environ): self._environ = environ - self._has_been_read: typing.Set[typing.Any] = set() + self._has_been_read: typing.Set[str] = set() - def __getitem__(self, key: typing.Any) -> typing.Any: + def __getitem__(self, key: str) -> str: self._has_been_read.add(key) return self._environ.__getitem__(key) - def __setitem__(self, key: typing.Any, value: typing.Any) -> None: + def __setitem__(self, key: str, value: str) -> None: if key in self._has_been_read: raise EnvironError( f"Attempting to set environ['{key}'], but the value has already been " @@ -29,7 +28,7 @@ def __setitem__(self, key: typing.Any, value: typing.Any) -> None: ) self._environ.__setitem__(key, value) - def __delitem__(self, key: typing.Any) -> None: + def __delitem__(self, key: str) -> None: if key in self._has_been_read: raise EnvironError( f"Attempting to delete environ['{key}'], but the value has already " @@ -37,7 +36,7 @@ def __delitem__(self, key: typing.Any) -> None: ) self._environ.__delitem__(key) - def __iter__(self) -> typing.Iterator: + def __iter__(self) -> typing.Iterator[str]: return iter(self._environ) def __len__(self) -> int: @@ -94,7 +93,7 @@ def __call__( def __call__( self, key: str, - cast: typing.Optional[typing.Callable] = None, + cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, default: typing.Any = undefined, ) -> typing.Any: return self.get(key, cast, default) @@ -102,7 +101,7 @@ def __call__( def get( self, key: str, - cast: typing.Optional[typing.Callable] = None, + cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, default: typing.Any = undefined, ) -> typing.Any: key = self.env_prefix + key @@ -129,7 +128,10 @@ def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str return file_values def _perform_cast( - self, key: str, value: typing.Any, cast: typing.Optional[typing.Callable] = None + self, + key: str, + value: typing.Any, + cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, ) -> typing.Any: if cast is None or value is None: return value diff --git a/starlette/convertors.py b/starlette/convertors.py index 3ade9f7af..3b12ac7a0 100644 --- a/starlette/convertors.py +++ b/starlette/convertors.py @@ -15,7 +15,7 @@ def to_string(self, value: T) -> str: raise NotImplementedError() # pragma: no cover -class StringConvertor(Convertor): +class StringConvertor(Convertor[str]): regex = "[^/]+" def convert(self, value: str) -> str: @@ -28,7 +28,7 @@ def to_string(self, value: str) -> str: return value -class PathConvertor(Convertor): +class PathConvertor(Convertor[str]): regex = ".*" def convert(self, value: str) -> str: @@ -38,7 +38,7 @@ def to_string(self, value: str) -> str: return str(value) -class IntegerConvertor(Convertor): +class IntegerConvertor(Convertor[int]): regex = "[0-9]+" def convert(self, value: str) -> int: @@ -50,7 +50,7 @@ def to_string(self, value: int) -> str: return str(value) -class FloatConvertor(Convertor): +class FloatConvertor(Convertor[float]): regex = r"[0-9]+(\.[0-9]+)?" def convert(self, value: str) -> float: @@ -64,7 +64,7 @@ def to_string(self, value: float) -> str: return ("%0.20f" % value).rstrip("0").rstrip(".") -class UUIDConvertor(Convertor): +class UUIDConvertor(Convertor[uuid.UUID]): regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" def convert(self, value: str) -> uuid.UUID: @@ -74,7 +74,7 @@ def to_string(self, value: uuid.UUID) -> str: return str(value) -CONVERTOR_TYPES = { +CONVERTOR_TYPES: typing.Dict[str, Convertor[typing.Any]] = { "str": StringConvertor(), "path": PathConvertor(), "int": IntegerConvertor(), @@ -83,5 +83,5 @@ def to_string(self, value: uuid.UUID) -> str: } -def register_url_convertor(key: str, convertor: Convertor) -> None: +def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None: CONVERTOR_TYPES[key] = convertor diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 236f9fa43..dc57c2e9f 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -1,5 +1,4 @@ import typing -from collections.abc import Sequence from shlex import shlex from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit @@ -223,7 +222,7 @@ def __bool__(self) -> bool: return bool(self._value) -class CommaSeparatedStrings(Sequence): +class CommaSeparatedStrings(typing.Sequence[str]): def __init__(self, value: typing.Union[str, typing.Sequence[str]]): if isinstance(value, str): splitter = shlex(value, posix=True) @@ -269,7 +268,7 @@ def __init__( if kwargs: value = ( ImmutableMultiDict(value).multi_items() - + ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator] + + ImmutableMultiDict(kwargs).multi_items() ) if not value: @@ -341,12 +340,12 @@ def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: self._list = [(k, v) for k, v in self._list if k != key] return self._dict.pop(key, default) - def popitem(self) -> typing.Tuple: + def popitem(self) -> typing.Tuple[typing.Any, typing.Any]: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value - def poplist(self, key: typing.Any) -> typing.List: + def poplist(self, key: typing.Any) -> typing.List[typing.Any]: values = [v for k, v in self._list if k == key] self.pop(key) return values @@ -362,7 +361,7 @@ def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: return self[key] - def setlist(self, key: typing.Any, values: typing.List) -> None: + def setlist(self, key: typing.Any, values: typing.List[typing.Any]) -> None: if not values: self.pop(key, None) else: @@ -378,7 +377,7 @@ def update( self, *args: typing.Union[ "MultiDict", - typing.Mapping, + typing.Mapping[typing.Any, typing.Any], typing.List[typing.Tuple[typing.Any, typing.Any]], ], **kwargs: typing.Any, @@ -397,8 +396,8 @@ class QueryParams(ImmutableMultiDict[str, str]): def __init__( self, *args: typing.Union[ - "ImmutableMultiDict", - typing.Mapping, + "ImmutableMultiDict[typing.Any, typing.Any]", + typing.Mapping[typing.Any, typing.Any], typing.List[typing.Tuple[typing.Any, typing.Any]], str, bytes, diff --git a/starlette/endpoints.py b/starlette/endpoints.py index 95cd7640d..c25dd9db2 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -23,7 +23,7 @@ def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: if getattr(self, method.lower(), None) is not None ] - def __await__(self) -> typing.Generator: + def __await__(self) -> typing.Generator[typing.Any, None, None]: return self.dispatch().__await__() async def dispatch(self) -> None: @@ -63,7 +63,7 @@ def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: self.receive = receive self.send = send - def __await__(self) -> typing.Generator: + def __await__(self) -> typing.Generator[typing.Any, None, None]: return self.dispatch().__await__() async def dispatch(self) -> None: diff --git a/starlette/exceptions.py b/starlette/exceptions.py index cc08ed909..a583d93a0 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -10,7 +10,7 @@ def __init__( self, status_code: int, detail: typing.Optional[str] = None, - headers: typing.Optional[dict] = None, + headers: typing.Optional[typing.Dict[str, str]] = None, ) -> None: if detail is None: detail = http.HTTPStatus(status_code).phrase diff --git a/starlette/formparsers.py b/starlette/formparsers.py index eb3cba5be..5ac2bcc1b 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -142,7 +142,7 @@ def __init__( self._charset = "" self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = [] self._file_parts_to_finish: typing.List[MultipartPart] = [] - self._files_to_close_on_error: typing.List[SpooledTemporaryFile] = [] + self._files_to_close_on_error: typing.List[SpooledTemporaryFile[bytes]] = [] def on_part_begin(self) -> None: self._current_part = MultipartPart() diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 5ac5b96c8..05bd57f04 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -6,7 +6,7 @@ def __init__(self, cls: type, **options: typing.Any) -> None: self.cls = cls self.options = options - def __iter__(self) -> typing.Iterator: + def __iter__(self) -> typing.Iterator[typing.Any]: as_tuple = (self.cls, self.options) return iter(as_tuple) diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index b9d9c6910..f4c3d6746 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -137,7 +137,9 @@ class ServerErrorMiddleware: def __init__( self, app: ASGIApp, - handler: typing.Optional[typing.Callable] = None, + handler: typing.Optional[ + typing.Callable[[Request, Exception], typing.Any] + ] = None, debug: bool = False, ) -> None: self.app = app diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index 59010c7e6..0124f5c8f 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -26,7 +26,7 @@ def __init__( self._status_handlers: StatusHandlers = {} self._exception_handlers: ExceptionHandlers = { HTTPException: self.http_exception, - WebSocketException: self.websocket_exception, # type: ignore[dict-item] + WebSocketException: self.websocket_exception, } if handlers is not None: for key, value in handlers.items(): diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index d4a117cac..95578c9d2 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -16,7 +16,7 @@ ) -def build_environ(scope: Scope, body: bytes) -> dict: +def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]: """ Builds a scope and request body into a WSGI environ object. """ @@ -63,7 +63,7 @@ def build_environ(scope: Scope, body: bytes) -> dict: class WSGIMiddleware: - def __init__(self, app: typing.Callable) -> None: + def __init__(self, app: typing.Callable[..., typing.Any]) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -76,7 +76,7 @@ class WSGIResponder: stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] - def __init__(self, app: typing.Callable, scope: Scope) -> None: + def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None: self.app = app self.scope = scope self.status = None @@ -132,7 +132,11 @@ def start_response( }, ) - def wsgi(self, environ: dict, start_response: typing.Callable) -> None: + def wsgi( + self, + environ: typing.Dict[str, typing.Any], + start_response: typing.Callable[..., typing.Any], + ) -> None: for chunk in self.app(environ, start_response): anyio.from_thread.run( self.stream_send.send, diff --git a/starlette/requests.py b/starlette/requests.py index fff451e23..5c7a4296c 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -147,7 +147,7 @@ def session(self) -> typing.Dict[str, typing.Any]: assert ( "session" in self.scope ), "SessionMiddleware must be installed to access request.session" - return self.scope["session"] + return self.scope["session"] # type: ignore[no-any-return] @property def auth(self) -> typing.Any: @@ -203,7 +203,7 @@ def __init__( @property def method(self) -> str: - return self.scope["method"] + return typing.cast(str, self.scope["method"]) @property def receive(self) -> Receive: diff --git a/starlette/responses.py b/starlette/responses.py index 575caf655..16380db06 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -37,7 +37,7 @@ def __init__( self.body = self.render(content) self.init_headers(headers) - def render(self, content: typing.Any) -> bytes: + def render(self, content: typing.Union[str, bytes, None]) -> bytes: if content is None: return b"" if isinstance(content, bytes): diff --git a/starlette/routing.py b/starlette/routing.py index b50d32a1f..9da6730ca 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.requests import Request -from starlette.responses import PlainTextResponse, RedirectResponse +from starlette.responses import PlainTextResponse, RedirectResponse, Response from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -54,18 +54,19 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover return inspect.iscoroutinefunction(obj) -def request_response(func: typing.Callable) -> ASGIApp: +def request_response( + func: typing.Callable[[Request], typing.Union[typing.Awaitable[Response], Response]] +) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - is_coroutine = is_async_callable(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive, send) async def app(scope: Scope, receive: Receive, send: Send) -> None: - if is_coroutine: + if is_async_callable(func): response = await func(request) else: response = await run_in_threadpool(func, request) @@ -76,7 +77,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: return app -def websocket_session(func: typing.Callable) -> ASGIApp: +def websocket_session( + func: typing.Callable[[WebSocket], typing.Awaitable[None]] +) -> ASGIApp: """ Takes a coroutine `func(session)`, and returns an ASGI application. """ @@ -93,7 +96,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: return app -def get_name(endpoint: typing.Callable) -> str: +def get_name(endpoint: typing.Callable[..., typing.Any]) -> str: if inspect.isroutine(endpoint) or inspect.isclass(endpoint): return endpoint.__name__ return endpoint.__class__.__name__ @@ -101,9 +104,9 @@ def get_name(endpoint: typing.Callable) -> str: def replace_params( path: str, - param_convertors: typing.Dict[str, Convertor], + param_convertors: typing.Dict[str, Convertor[typing.Any]], path_params: typing.Dict[str, str], -) -> typing.Tuple[str, dict]: +) -> typing.Tuple[str, typing.Dict[str, str]]: for key, value in list(path_params.items()): if "{" + key + "}" in path: convertor = param_convertors[key] @@ -119,7 +122,7 @@ def replace_params( def compile_path( path: str, -) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]: +) -> typing.Tuple[typing.Pattern[str], str, typing.Dict[str, Convertor[typing.Any]]]: """ Given a path string, like: "/{username:str}", or a host string, like: "{subdomain}.mydomain.org", return a three-tuple @@ -209,7 +212,7 @@ class Route(BaseRoute): def __init__( self, path: str, - endpoint: typing.Callable, + endpoint: typing.Callable[..., typing.Any], *, methods: typing.Optional[typing.List[str]] = None, name: typing.Optional[str] = None, @@ -301,7 +304,11 @@ def __repr__(self) -> str: class WebSocketRoute(BaseRoute): def __init__( - self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None + self, + path: str, + endpoint: typing.Callable[..., typing.Any], + *, + name: typing.Optional[str] = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -556,12 +563,14 @@ async def __aexit__( def _wrap_gen_lifespan_context( - lifespan_context: typing.Callable[[typing.Any], typing.Generator] -) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: + lifespan_context: typing.Callable[ + [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any] + ] +) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]: cmgr = contextlib.contextmanager(lifespan_context) @functools.wraps(cmgr) - def wrapper(app: typing.Any) -> _AsyncLiftContextManager: + def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]: return _AsyncLiftContextManager(cmgr(app)) return wrapper @@ -587,8 +596,12 @@ def __init__( routes: typing.Optional[typing.Sequence[BaseRoute]] = None, redirect_slashes: bool = True, default: typing.Optional[ASGIApp] = None, - on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, - on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, + on_startup: typing.Optional[ + typing.Sequence[typing.Callable[[], typing.Any]] + ] = None, + on_shutdown: typing.Optional[ + typing.Sequence[typing.Callable[[], typing.Any]] + ] = None, # the generic to Lifespan[AppType] is the type of the top level application # which the router cannot know statically, so we use typing.Any lifespan: typing.Optional[Lifespan[typing.Any]] = None, @@ -614,7 +627,7 @@ def __init__( ) if lifespan is None: - self.lifespan_context: Lifespan = _DefaultLifespan(self) + self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self) elif inspect.isasyncgenfunction(lifespan): warnings.warn( @@ -623,7 +636,7 @@ def __init__( DeprecationWarning, ) self.lifespan_context = asynccontextmanager( - lifespan, # type: ignore[arg-type] + lifespan, ) elif inspect.isgeneratorfunction(lifespan): warnings.warn( @@ -632,7 +645,7 @@ def __init__( DeprecationWarning, ) self.lifespan_context = _wrap_gen_lifespan_context( - lifespan, # type: ignore[arg-type] + lifespan, ) else: self.lifespan_context = lifespan @@ -779,7 +792,9 @@ def host( def add_route( self, path: str, - endpoint: typing.Callable, + endpoint: typing.Callable[ + [Request], typing.Union[typing.Awaitable[Response], Response] + ], methods: typing.Optional[typing.List[str]] = None, name: typing.Optional[str] = None, include_in_schema: bool = True, @@ -794,7 +809,10 @@ def add_route( self.routes.append(route) def add_websocket_route( - self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None + self, + path: str, + endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]], + name: typing.Optional[str] = None, ) -> None: # pragma: no cover route = WebSocketRoute(path, endpoint=endpoint, name=name) self.routes.append(route) @@ -805,7 +823,7 @@ def route( methods: typing.Optional[typing.List[str]] = None, name: typing.Optional[str] = None, include_in_schema: bool = True, - ) -> typing.Callable: + ) -> typing.Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -819,7 +837,7 @@ def route( DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 self.add_route( path, func, @@ -833,7 +851,7 @@ def decorator(func: typing.Callable) -> typing.Callable: def websocket_route( self, path: str, name: typing.Optional[str] = None - ) -> typing.Callable: + ) -> typing.Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -847,14 +865,14 @@ def websocket_route( DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 self.add_websocket_route(path, func, name=name) return func return decorator def add_event_handler( - self, event_type: str, func: typing.Callable + self, event_type: str, func: typing.Callable[[], typing.Any] ) -> None: # pragma: no cover assert event_type in ("startup", "shutdown") @@ -863,14 +881,14 @@ def add_event_handler( else: self.on_shutdown.append(func) - def on_event(self, event_type: str) -> typing.Callable: + def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] warnings.warn( "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 "Refer to https://www.starlette.io/lifespan/ for recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 self.add_event_handler(event_type, func) return func diff --git a/starlette/schemas.py b/starlette/schemas.py index 72d93e7d7..737f6b029 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -26,11 +26,13 @@ def render(self, content: typing.Any) -> bytes: class EndpointInfo(typing.NamedTuple): path: str http_method: str - func: typing.Callable + func: typing.Callable[..., typing.Any] class BaseSchemaGenerator: - def get_schema(self, routes: typing.List[BaseRoute]) -> dict: + def get_schema( + self, routes: typing.List[BaseRoute] + ) -> typing.Dict[str, typing.Any]: raise NotImplementedError() # pragma: no cover def get_endpoints( @@ -46,7 +48,7 @@ def get_endpoints( - func method ready to extract the docstring """ - endpoints_info: list = [] + endpoints_info: typing.List[EndpointInfo] = [] for route in routes: if isinstance(route, (Mount, Host)): @@ -95,7 +97,9 @@ def _remove_converter(self, path: str) -> str: """ return re.sub(r":\w+}", "}", path) - def parse_docstring(self, func_or_method: typing.Callable) -> dict: + def parse_docstring( + self, func_or_method: typing.Callable[..., typing.Any] + ) -> typing.Dict[str, typing.Any]: """ Given a function, parse the docstring as YAML and return a dictionary of info. """ @@ -126,10 +130,12 @@ def OpenAPIResponse(self, request: Request) -> Response: class SchemaGenerator(BaseSchemaGenerator): - def __init__(self, base_schema: dict) -> None: + def __init__(self, base_schema: typing.Dict[str, typing.Any]) -> None: self.base_schema = base_schema - def get_schema(self, routes: typing.List[BaseRoute]) -> dict: + def get_schema( + self, routes: typing.List[BaseRoute] + ) -> typing.Dict[str, typing.Any]: schema = dict(self.base_schema) schema.setdefault("paths", {}) endpoints_info = self.get_endpoints(routes) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 4c856063c..2f1f1ddab 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -108,7 +108,7 @@ def get_path(self, scope: Scope) -> str: Given the ASGI scope, return the `path` string to serve up, with OS specific path separators, and any '..', '.' components removed. """ - return os.path.normpath(os.path.join(*scope["path"].split("/"))) + return os.path.normpath(os.path.join(*scope["path"].split("/"))) # type: ignore[no-any-return] # noqa: E501 async def get_response(self, path: str, scope: Scope) -> Response: """ diff --git a/starlette/templating.py b/starlette/templating.py index ffa4133b8..071e8a4bb 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -29,7 +29,7 @@ class _TemplateResponse(Response): def __init__( self, template: typing.Any, - context: dict, + context: typing.Dict[str, typing.Any], status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, @@ -66,11 +66,7 @@ class Jinja2Templates: @typing.overload def __init__( self, - directory: typing.Union[ - str, - PathLike, - typing.Sequence[typing.Union[str, PathLike]], - ], + directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]]]", # noqa: E501 *, context_processors: typing.Optional[ typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]] @@ -92,9 +88,7 @@ def __init__( def __init__( self, - directory: typing.Union[ - str, PathLike, typing.Sequence[typing.Union[str, PathLike]], None - ] = None, + directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]], None]" = None, # noqa: E501 *, context_processors: typing.Optional[ typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]] @@ -117,14 +111,17 @@ def __init__( def _create_env( self, - directory: typing.Union[ - str, PathLike, typing.Sequence[typing.Union[str, PathLike]] - ], + directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]]]", # noqa: E501 **env_options: typing.Any, ) -> "jinja2.Environment": @pass_context - def url_for(context: dict, name: str, /, **path_params: typing.Any) -> URL: - request = context["request"] + def url_for( + context: typing.Dict[str, typing.Any], + name: str, + /, + **path_params: typing.Any, + ) -> URL: + request: Request = context["request"] return request.url_for(name, **path_params) loader = jinja2.FileSystemLoader(directory) @@ -143,7 +140,7 @@ def TemplateResponse( self, request: Request, name: str, - context: typing.Optional[dict] = None, + context: typing.Optional[typing.Dict[str, typing.Any]] = None, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, @@ -155,7 +152,7 @@ def TemplateResponse( def TemplateResponse( self, name: str, - context: typing.Optional[dict] = None, + context: typing.Optional[typing.Dict[str, typing.Any]] = None, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, diff --git a/starlette/testclient.py b/starlette/testclient.py index c9ae97a08..bfac4bb9f 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -79,8 +79,8 @@ def __init__( self.scope = scope self.accepted_subprotocol = None self.portal_factory = portal_factory - self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() - self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() + self._receive_queue: "queue.Queue[Message]" = queue.Queue() + self._send_queue: "queue.Queue[Message | BaseException]" = queue.Queue() self.extra_headers = None def __enter__(self) -> "WebSocketTestSession": @@ -165,12 +165,12 @@ def receive(self) -> Message: def receive_text(self) -> str: message = self.receive() self._raise_on_close(message) - return message["text"] + return typing.cast(str, message["text"]) def receive_bytes(self) -> bytes: message = self.receive() self._raise_on_close(message) - return message["bytes"] + return typing.cast(bytes, message["bytes"]) def receive_json(self, mode: str = "text") -> typing.Any: assert mode in ["text", "binary"] @@ -374,7 +374,7 @@ def __init__( root_path: str = "", backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, - cookies: httpx._client.CookieTypes = None, + cookies: httpx._types.CookieTypes = None, headers: typing.Dict[str, str] = None, follow_redirects: bool = True, ) -> None: @@ -459,7 +459,7 @@ def request( # type: ignore[override] follow_redirects: typing.Optional[bool] = None, allow_redirects: typing.Optional[bool] = None, timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault + httpx._types.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: @@ -469,7 +469,7 @@ def request( # type: ignore[override] method, url, content=content, - data=data, # type: ignore[arg-type] + data=data, files=files, json=json, params=params, @@ -494,7 +494,7 @@ def get( # type: ignore[override] follow_redirects: typing.Optional[bool] = None, allow_redirects: typing.Optional[bool] = None, timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault + httpx._types.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: @@ -523,7 +523,7 @@ def options( # type: ignore[override] follow_redirects: typing.Optional[bool] = None, allow_redirects: typing.Optional[bool] = None, timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault + httpx._types.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: @@ -552,7 +552,7 @@ def head( # type: ignore[override] follow_redirects: typing.Optional[bool] = None, allow_redirects: typing.Optional[bool] = None, timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault + httpx._types.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: @@ -585,7 +585,7 @@ def post( # type: ignore[override] follow_redirects: typing.Optional[bool] = None, allow_redirects: typing.Optional[bool] = None, timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault + httpx._types.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: @@ -593,7 +593,7 @@ def post( # type: ignore[override] return super().post( url, content=content, - data=data, # type: ignore[arg-type] + data=data, files=files, json=json, params=params, @@ -622,7 +622,7 @@ def put( # type: ignore[override] follow_redirects: typing.Optional[bool] = None, allow_redirects: typing.Optional[bool] = None, timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault + httpx._types.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: @@ -630,7 +630,7 @@ def put( # type: ignore[override] return super().put( url, content=content, - data=data, # type: ignore[arg-type] + data=data, files=files, json=json, params=params, @@ -659,7 +659,7 @@ def patch( # type: ignore[override] follow_redirects: typing.Optional[bool] = None, allow_redirects: typing.Optional[bool] = None, timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault + httpx._types.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: @@ -667,7 +667,7 @@ def patch( # type: ignore[override] return super().patch( url, content=content, - data=data, # type: ignore[arg-type] + data=data, files=files, json=json, params=params, @@ -692,7 +692,7 @@ def delete( # type: ignore[override] follow_redirects: typing.Optional[bool] = None, allow_redirects: typing.Optional[bool] = None, timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault + httpx._types.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: diff --git a/starlette/types.py b/starlette/types.py index 713d18a80..19484301e 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -1,5 +1,10 @@ import typing +if typing.TYPE_CHECKING: + from starlette.requests import Request + from starlette.responses import Response + from starlette.websockets import WebSocket + AppType = typing.TypeVar("AppType") Scope = typing.MutableMapping[str, typing.Any] @@ -15,3 +20,11 @@ [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]] ] Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] + +HTTPExceptionHandler = typing.Callable[ + ["Request", Exception], typing.Union["Response", typing.Awaitable["Response"]] +] +WebSocketExceptionHandler = typing.Callable[ + ["WebSocket", Exception], typing.Awaitable[None] +] +ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler] diff --git a/starlette/websockets.py b/starlette/websockets.py index 5aa411824..4704dff72 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -111,7 +111,7 @@ async def receive_text(self) -> str: ) message = await self.receive() self._raise_on_disconnect(message) - return message["text"] + return typing.cast(str, message["text"]) async def receive_bytes(self) -> bytes: if self.application_state != WebSocketState.CONNECTED: @@ -120,7 +120,7 @@ async def receive_bytes(self) -> bytes: ) message = await self.receive() self._raise_on_disconnect(message) - return message["bytes"] + return typing.cast(bytes, message["bytes"]) async def receive_json(self, mode: str = "text") -> typing.Any: if mode not in {"text", "binary"}: diff --git a/tests/test_convertors.py b/tests/test_convertors.py index 72ca9ba12..2a866309f 100644 --- a/tests/test_convertors.py +++ b/tests/test_convertors.py @@ -15,7 +15,7 @@ def refresh_convertor_types(): convertors.CONVERTOR_TYPES = convert_types -class DateTimeConvertor(Convertor): +class DateTimeConvertor(Convertor[datetime]): regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(.[0-9]+)?" def convert(self, value: str) -> datetime: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 502f7809f..77ed776ea 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -5,13 +5,14 @@ import pytest from starlette.applications import Starlette -from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode +from starlette.datastructures import UploadFile +from starlette.formparsers import MultiPartException, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount -class ForceMultipartDict(dict): +class ForceMultipartDict(typing.Dict[typing.Any, typing.Any]): def __bool__(self): return True @@ -43,7 +44,7 @@ async def app(scope, receive, send): async def multi_items_app(scope, receive, send): request = Request(scope, receive) data = await request.form() - output: typing.Dict[str, list] = {} + output: typing.Dict[str, typing.List[typing.Any]] = {} for key, value in data.multi_items(): if key not in output: output[key] = [] diff --git a/tests/test_requests.py b/tests/test_requests.py index a8f62b39e..caf110efe 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -4,8 +4,8 @@ import anyio import pytest -from starlette.datastructures import Address -from starlette.requests import ClientDisconnect, Request, State +from starlette.datastructures import Address, State +from starlette.requests import ClientDisconnect, Request from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.types import Message, Scope diff --git a/tests/test_responses.py b/tests/test_responses.py index 284bda1ef..7535fa641 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,6 +1,7 @@ import datetime as dt import os import time +import typing from http.cookies import SimpleCookie import anyio @@ -343,7 +344,9 @@ async def app(scope, receive, send): client = test_client_factory(app) response = client.get("/") - cookie: SimpleCookie = SimpleCookie(response.headers.get("set-cookie")) + cookie: "SimpleCookie[typing.Any]" = SimpleCookie( + response.headers.get("set-cookie") + ) assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT" diff --git a/tests/test_routing.py b/tests/test_routing.py index 24f2bf7d7..7159a4bfc 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -895,7 +895,7 @@ def __call__(self, request): pytest.param(lambda request: ..., "", id="lambda"), ], ) -def test_route_name(endpoint: typing.Callable, expected_name: str): +def test_route_name(endpoint: typing.Callable[..., typing.Any], expected_name: str): assert Route(path="/", endpoint=endpoint).name == expected_name