From 2d7eb8c00816fad84b217d76c154696a6b5098a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Rubin?= Date: Mon, 18 Dec 2023 16:20:24 +0100 Subject: [PATCH] Use Protocol instead of Callable. --- starlette/applications.py | 6 +++--- starlette/middleware/__init__.py | 16 ++++++++++++---- tests/middleware/test_base.py | 6 +++--- tests/middleware/test_middleware.py | 13 ++++++++----- tests/test_applications.py | 12 ++++++------ 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index c8a5665ee..3e1086d98 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -3,10 +3,10 @@ import typing import warnings -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from starlette.datastructures import State, URLPath -from starlette.middleware import Middleware +from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware @@ -129,7 +129,7 @@ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: def add_middleware( self, - middleware_class: typing.Callable[Concatenate[ASGIApp, P], typing.Any], + middleware_class: typing.Type[_MiddlewareClass[P]], *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index b77f65367..880e301eb 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -1,16 +1,24 @@ -from typing import Any, Callable, Iterator +from typing import Any, Iterator, Protocol, Type -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Receive, Scope, Send P = ParamSpec("P") +class _MiddlewareClass(Protocol[P]): + def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: + ... # pragma: no cover + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ... # pragma: no cover + + class Middleware: def __init__( self, - cls: Callable[Concatenate[ASGIApp, P], Any], + cls: Type[_MiddlewareClass[P]], *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 650f4aee1..4d51f34bf 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,13 +1,13 @@ import contextvars from contextlib import AsyncExitStack -from typing import AsyncGenerator, Awaitable, Callable, List, Union +from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union import anyio import pytest from starlette.applications import Starlette from starlette.background import BackgroundTask -from starlette.middleware import Middleware +from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import PlainTextResponse, Response, StreamingResponse @@ -196,7 +196,7 @@ async def dispatch(self, request, call_next): ), ], ) -def test_contextvars(test_client_factory, middleware_cls: type): +def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]): # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) diff --git a/tests/middleware/test_middleware.py b/tests/middleware/test_middleware.py index 61442cda1..c6cf1fa1c 100644 --- a/tests/middleware/test_middleware.py +++ b/tests/middleware/test_middleware.py @@ -1,19 +1,22 @@ from starlette.middleware import Middleware -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Receive, Scope, Send -class CustomMiddleware: - def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None: # pragma: no cover +class CustomMiddleware: # pragma: no cover + def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None: self.app = app self.foo = foo self.bar = bar + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) -def test_middleware_repr(): + +def test_middleware_repr() -> None: middleware = Middleware(CustomMiddleware, "foo", bar=123) assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)" -def test_middleware_iter(): +def test_middleware_iter() -> None: cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123) assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123}) diff --git a/tests/test_applications.py b/tests/test_applications.py index e30ec9295..6d0118b53 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,6 +1,6 @@ import os from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Callable +from typing import AsyncIterator, Callable import anyio import httpx @@ -15,7 +15,7 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket @@ -499,8 +499,8 @@ class NoOpMiddleware: def __init__(self, app: ASGIApp): self.app = app - async def __call__(self, *args: Any): - await self.app(*args) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) class SimpleInitializableMiddleware: counter = 0 @@ -509,8 +509,8 @@ def __init__(self, app: ASGIApp): self.app = app SimpleInitializableMiddleware.counter += 1 - async def __call__(self, *args: Any): - await self.app(*args) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) def get_app() -> ASGIApp: app = Starlette()