From b65d4733cac25c3f0ef4928ee61aa0318155e8d1 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 23 Jul 2023 08:45:06 +0200 Subject: [PATCH] Fix some inconsistencies --- starlette/_compat.py | 4 +++- starlette/_exception_handler.py | 9 +++++---- starlette/_utils.py | 10 +++++++++- starlette/applications.py | 5 +---- starlette/concurrency.py | 14 ++++---------- starlette/middleware/exceptions.py | 2 +- starlette/types.py | 11 ++++++----- 7 files changed, 29 insertions(+), 26 deletions(-) diff --git a/starlette/_compat.py b/starlette/_compat.py index 2e49f9e42..9087a7645 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -18,7 +18,9 @@ def md5_hexdigest( data: bytes, *, usedforsecurity: bool = True ) -> str: # pragma: no cover - return hashlib.md5(data, usedforsecurity=usedforsecurity).hexdigest() # type: ignore[call-arg] # noqa: E501 + return hashlib.md5( # type: ignore[call-arg] + data, usedforsecurity=usedforsecurity + ).hexdigest() except TypeError: # pragma: no cover diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 3dab0256e..1574e53ce 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -8,14 +8,13 @@ from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send from starlette.websockets import WebSocket -AnyExceptionHandler = typing.Callable[..., typing.Any] ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler] -StatusHandlers = typing.Dict[int, AnyExceptionHandler] +StatusHandlers = typing.Dict[int, ExceptionHandler] def _lookup_exception_handler( exc_handlers: ExceptionHandlers, exc: Exception -) -> typing.Optional[AnyExceptionHandler]: +) -> typing.Optional[ExceptionHandler]: for cls in type(exc).__mro__: if cls in exc_handlers: return exc_handlers[cls] @@ -65,7 +64,9 @@ async def sender(message: Message) -> None: if is_async_callable(handler): response = await handler(conn, exc) else: - response = await run_in_threadpool(handler, conn, exc) + response = await run_in_threadpool( + handler, conn, exc # type: ignore[arg-type] + ) await response(scope, receive, sender) elif scope["type"] == "websocket": if is_async_callable(handler): diff --git a/starlette/_utils.py b/starlette/_utils.py index 5a6e6965b..d07312fa6 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -1,9 +1,17 @@ import asyncio import functools +import sys import typing +if sys.version_info >= (3, 10): # pragma: no cover + from typing import TypeGuard +else: + from typing_extensions import TypeGuard -def is_async_callable(obj: typing.Any) -> bool: +AnyAwaitableCallable = typing.Callable[..., typing.Awaitable[typing.Any]] + + +def is_async_callable(obj: typing.Any) -> TypeGuard[AnyAwaitableCallable]: while isinstance(obj, functools.partial): obj = obj.func diff --git a/starlette/applications.py b/starlette/applications.py index 1c315b42a..41db4ac80 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -50,10 +50,7 @@ def __init__( routes: typing.Optional[typing.Sequence[BaseRoute]] = None, middleware: typing.Optional[typing.Sequence[Middleware]] = None, exception_handlers: typing.Optional[ - typing.Mapping[ - typing.Any, - ExceptionHandler, - ] + typing.Mapping[typing.Any, ExceptionHandler] ] = None, on_startup: typing.Optional[ typing.Sequence[typing.Callable[[], typing.Any]] diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 4480247a2..cf5d926da 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,18 +1,10 @@ import functools -import sys import typing import warnings -import anyio - -if sys.version_info >= (3, 10): # pragma: no cover - from typing import ParamSpec -else: # pragma: no cover - from typing_extensions import ParamSpec - +import anyio.to_thread T = typing.TypeVar("T") -P = ParamSpec("P") async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501 @@ -32,8 +24,10 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ign task_group.start_soon(run, functools.partial(func, **kwargs)) +# TODO: We should use `ParamSpec`, but mypy doesn't support it yet. +# Check https://github.com/python/mypy/issues/12278 for more details. async def run_in_threadpool( - func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs + func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any ) -> T: if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index 59010c7e6..0124f5c8f 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -26,7 +26,7 @@ def __init__( self._status_handlers: StatusHandlers = {} self._exception_handlers: ExceptionHandlers = { HTTPException: self.http_exception, - WebSocketException: self.websocket_exception, # type: ignore[dict-item] + WebSocketException: self.websocket_exception, } if handlers is not None: for key, value in handlers.items(): diff --git a/starlette/types.py b/starlette/types.py index a8b3e798a..19484301e 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -21,9 +21,10 @@ ] Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] -ExceptionHandler = typing.Union[ - typing.Callable[ - ["Request", Exception], typing.Union["Response", typing.Awaitable["Response"]] - ], - typing.Callable[["WebSocket", Exception], None], +HTTPExceptionHandler = typing.Callable[ + ["Request", Exception], typing.Union["Response", typing.Awaitable["Response"]] ] +WebSocketExceptionHandler = typing.Callable[ + ["WebSocket", Exception], typing.Awaitable[None] +] +ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]