Skip to content

Commit

Permalink
Fix some inconsistencies
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jul 23, 2023
1 parent 6ec2072 commit b65d473
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 26 deletions.
4 changes: 3 additions & 1 deletion starlette/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
def md5_hexdigest(
data: bytes, *, usedforsecurity: bool = True
) -> str: # pragma: no cover
return hashlib.md5(data, usedforsecurity=usedforsecurity).hexdigest() # type: ignore[call-arg] # noqa: E501
return hashlib.md5( # type: ignore[call-arg]
data, usedforsecurity=usedforsecurity
).hexdigest()

except TypeError: # pragma: no cover

Expand Down
9 changes: 5 additions & 4 deletions starlette/_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

AnyExceptionHandler = typing.Callable[..., typing.Any]
ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
StatusHandlers = typing.Dict[int, AnyExceptionHandler]
StatusHandlers = typing.Dict[int, ExceptionHandler]


def _lookup_exception_handler(
exc_handlers: ExceptionHandlers, exc: Exception
) -> typing.Optional[AnyExceptionHandler]:
) -> typing.Optional[ExceptionHandler]:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
Expand Down Expand Up @@ -65,7 +64,9 @@ async def sender(message: Message) -> None:
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
response = await run_in_threadpool(
handler, conn, exc # type: ignore[arg-type]
)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
if is_async_callable(handler):
Expand Down
10 changes: 9 additions & 1 deletion starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import asyncio
import functools
import sys
import typing

if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
else:
from typing_extensions import TypeGuard

def is_async_callable(obj: typing.Any) -> bool:
AnyAwaitableCallable = typing.Callable[..., typing.Awaitable[typing.Any]]


def is_async_callable(obj: typing.Any) -> TypeGuard[AnyAwaitableCallable]:
while isinstance(obj, functools.partial):
obj = obj.func

Expand Down
5 changes: 1 addition & 4 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def __init__(
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
exception_handlers: typing.Optional[
typing.Mapping[
typing.Any,
ExceptionHandler,
]
typing.Mapping[typing.Any, ExceptionHandler]
] = None,
on_startup: typing.Optional[
typing.Sequence[typing.Callable[[], typing.Any]]
Expand Down
14 changes: 4 additions & 10 deletions starlette/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import functools
import sys
import typing
import warnings

import anyio

if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec

import anyio.to_thread

T = typing.TypeVar("T")
P = ParamSpec("P")


async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501
Expand All @@ -32,8 +24,10 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ign
task_group.start_soon(run, functools.partial(func, **kwargs))


# TODO: We should use `ParamSpec`, but mypy doesn't support it yet.
# Check https://github.com/python/mypy/issues/12278 for more details.
async def run_in_threadpool(
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
) -> T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
Expand Down
2 changes: 1 addition & 1 deletion starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception, # type: ignore[dict-item]
WebSocketException: self.websocket_exception,
}
if handlers is not None:
for key, value in handlers.items():
Expand Down
11 changes: 6 additions & 5 deletions starlette/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
]
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]

ExceptionHandler = typing.Union[
typing.Callable[
["Request", Exception], typing.Union["Response", typing.Awaitable["Response"]]
],
typing.Callable[["WebSocket", Exception], None],
HTTPExceptionHandler = typing.Callable[
["Request", Exception], typing.Union["Response", typing.Awaitable["Response"]]
]
WebSocketExceptionHandler = typing.Callable[
["WebSocket", Exception], typing.Awaitable[None]
]
ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]

0 comments on commit b65d473

Please sign in to comment.