Skip to content

Commit

Permalink
Use mypy strict (#2180)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <[email protected]>
  • Loading branch information
Viicos and Kludex authored Jul 23, 2023
1 parent 1a71441 commit 7c8ca17
Show file tree
Hide file tree
Showing 31 changed files with 261 additions and 200 deletions.
17 changes: 10 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,21 @@ select = ["E", "F", "I"]
combine-as-imports = true

[tool.mypy]
disallow_untyped_defs = true
strict = true
ignore_missing_imports = true
show_error_codes = true
python_version = "3.8"

[[tool.mypy.overrides]]
module = "starlette.testclient.*"
no_implicit_optional = false
implicit_optional = true

[[tool.mypy.overrides]]
module = "tests.*"
disallow_untyped_defs = false
check_untyped_defs = true
# TODO: Uncomment the following configuration when
# https://github.com/python/mypy/issues/10045 is solved. In the meantime,
# we are calling `mypy tests` directly. Check `scripts/check` for more info.
# [[tool.mypy.overrides]]
# module = "tests.*"
# disallow_untyped_defs = false
# check_untyped_defs = true

[tool.pytest.ini_options]
addopts = "-rxXs --strict-config --strict-markers"
Expand Down
5 changes: 4 additions & 1 deletion scripts/check
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ set -x

./scripts/sync-version
${PREFIX}black --check --diff $SOURCE_FILES
${PREFIX}mypy $SOURCE_FILES
# TODO: Use `[[tool.mypy.overrides]]` on the `pyproject.toml` when the mypy issue is solved:
# github.com/python/mypy/issues/10045. Check github.com/encode/starlette/pull/2180 for more info.
${PREFIX}mypy starlette
${PREFIX}mypy tests --disable-error-code no-untyped-def --disable-error-code no-untyped-call
${PREFIX}ruff check $SOURCE_FILES
11 changes: 4 additions & 7 deletions starlette/_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

Handler = typing.Callable[..., typing.Any]
ExceptionHandlers = typing.Dict[typing.Any, Handler]
StatusHandlers = typing.Dict[int, Handler]
ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
StatusHandlers = typing.Dict[int, ExceptionHandler]


def _lookup_exception_handler(
exc_handlers: ExceptionHandlers, exc: Exception
) -> typing.Optional[Handler]:
) -> typing.Optional[ExceptionHandler]:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
Expand Down Expand Up @@ -61,7 +59,6 @@ async def sender(message: Message) -> None:
raise RuntimeError(msg) from exc

if scope["type"] == "http":
response: Response
if is_async_callable(handler):
response = await handler(conn, exc)
else:
Expand Down
21 changes: 20 additions & 1 deletion starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
import asyncio
import functools
import sys
import typing

if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
else: # pragma: no cover
from typing_extensions import TypeGuard

def is_async_callable(obj: typing.Any) -> bool:
T = typing.TypeVar("T")
AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]


@typing.overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]:
...


@typing.overload
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]:
...


def is_async_callable(obj: typing.Any) -> typing.Any:
while isinstance(obj, functools.partial):
obj = obj.func

Expand Down
76 changes: 35 additions & 41 deletions starlette/applications.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing
import warnings

Expand All @@ -9,7 +11,8 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket

AppType = typing.TypeVar("AppType", bound="Starlette")

Expand Down Expand Up @@ -47,19 +50,11 @@ class Starlette:
def __init__(
self: "AppType",
debug: bool = False,
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
exception_handlers: typing.Optional[
typing.Mapping[
typing.Any,
typing.Callable[
[Request, Exception],
typing.Union[Response, typing.Awaitable[Response]],
],
]
] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
routes: typing.Sequence[BaseRoute] | None = None,
middleware: typing.Sequence[Middleware] | None = None,
exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None,
on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
lifespan: typing.Optional[Lifespan["AppType"]] = None,
) -> None:
# The lifespan context function is a newer style that replaces
Expand Down Expand Up @@ -120,18 +115,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)

def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
return self.router.on_event(event_type)
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
return self.router.on_event(event_type) # pragma: nocover

def mount(
self, path: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None: # pragma: nocover
self.router.mount(path, app=app, name=name)
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
self.router.mount(path, app=app, name=name) # pragma: no cover

def host(
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None: # pragma: no cover
self.router.host(host, app=app, name=name)
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover

def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
if self.middleware_stack is not None: # pragma: no cover
Expand All @@ -140,20 +131,20 @@ def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable,
exc_class_or_status_code: int | typing.Type[Exception],
handler: ExceptionHandler,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler

def add_event_handler(
self, event_type: str, func: typing.Callable
self, event_type: str, func: typing.Callable # type: ignore[type-arg]
) -> None: # pragma: no cover
self.router.add_event_handler(event_type, func)

def add_route(
self,
path: str,
route: typing.Callable,
route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
Expand All @@ -163,20 +154,23 @@ def add_route(
)

def add_websocket_route(
self, path: str, route: typing.Callable, name: typing.Optional[str] = None
self,
path: str,
route: typing.Callable[[WebSocket], typing.Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
self.router.add_websocket_route(path, route, name=name)

def exception_handler(
self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]]
) -> typing.Callable:
self, exc_class_or_status_code: int | typing.Type[Exception]
) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501
DeprecationWarning,
)

def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.add_exception_handler(exc_class_or_status_code, func)
return func

Expand All @@ -185,10 +179,10 @@ def decorator(func: typing.Callable) -> typing.Callable:
def route(
self,
path: str,
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
methods: typing.List[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> typing.Callable:
) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
Expand All @@ -202,7 +196,7 @@ def route(
DeprecationWarning,
)

def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.router.add_route(
path,
func,
Expand All @@ -215,8 +209,8 @@ def decorator(func: typing.Callable) -> typing.Callable:
return decorator

def websocket_route(
self, path: str, name: typing.Optional[str] = None
) -> typing.Callable:
self, path: str, name: str | None = None
) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
Expand All @@ -230,13 +224,13 @@ def websocket_route(
DeprecationWarning,
)

def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.router.add_websocket_route(path, func, name=name)
return func

return decorator

def middleware(self, middleware_type: str) -> typing.Callable:
def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
Expand All @@ -253,7 +247,7 @@ def middleware(self, middleware_type: str) -> typing.Callable:
middleware_type == "http"
), 'Currently only middleware("http") is supported.'

def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
return func

Expand Down
30 changes: 18 additions & 12 deletions starlette/authentication.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import functools
import inspect
import sys
import typing
from urllib.parse import urlencode

if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec

from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket

_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
_P = ParamSpec("_P")


def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
Expand All @@ -23,10 +29,14 @@ def requires(
scopes: typing.Union[str, typing.Sequence[str]],
status_code: int = 403,
redirect: typing.Optional[str] = None,
) -> typing.Callable[[_CallableType], _CallableType]:
) -> typing.Callable[
[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)

def decorator(func: typing.Callable) -> typing.Callable:
def decorator(
func: typing.Callable[_P, typing.Any]
) -> typing.Callable[_P, typing.Any]:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
Expand All @@ -40,9 +50,7 @@ def decorator(func: typing.Callable) -> typing.Callable:
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(
*args: typing.Any, **kwargs: typing.Any
) -> None:
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
websocket = kwargs.get(
"websocket", args[idx] if idx < len(args) else None
)
Expand All @@ -58,9 +66,7 @@ async def websocket_wrapper(
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(
*args: typing.Any, **kwargs: typing.Any
) -> Response:
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)

Expand All @@ -80,7 +86,7 @@ async def async_wrapper(
else:
# Handle sync request/response functions.
@functools.wraps(func)
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)

Expand All @@ -97,7 +103,7 @@ def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:

return sync_wrapper

return decorator # type: ignore[return-value]
return decorator


class AuthenticationError(Exception):
Expand Down
Loading

0 comments on commit 7c8ca17

Please sign in to comment.