diff --git a/starlette/applications.py b/starlette/applications.py index 554a25e65..3e1086d98 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -3,8 +3,10 @@ import typing import warnings +from typing_extensions import ParamSpec + from starlette.datastructures import State, URLPath -from starlette.middleware import Middleware +from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware @@ -15,6 +17,7 @@ from starlette.websockets import WebSocket AppType = typing.TypeVar("AppType", bound="Starlette") +P = ParamSpec("P") class Starlette: @@ -98,8 +101,8 @@ def build_middleware_stack(self) -> ASGIApp: ) app = self.router - for cls, options in reversed(middleware): - app = cls(app=app, **options) + for cls, args, kwargs in reversed(middleware): + app = cls(app=app, *args, **kwargs) return app @property @@ -124,10 +127,15 @@ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: self.router.host(host, app=app, name=name) # pragma: no cover - def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: + def add_middleware( + self, + middleware_class: typing.Type[_MiddlewareClass[P]], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: if self.middleware_stack is not None: # pragma: no cover raise RuntimeError("Cannot add middleware after an application has started") - self.user_middleware.insert(0, Middleware(middleware_class, **options)) + self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs)) def add_exception_handler( self, diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 05bd57f04..880e301eb 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -1,17 +1,38 @@ -import typing +from typing import Any, Iterator, Protocol, Type + +from typing_extensions import ParamSpec + +from starlette.types import ASGIApp, Receive, Scope, Send + +P = ParamSpec("P") + + +class _MiddlewareClass(Protocol[P]): + def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: + ... # pragma: no cover + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ... # pragma: no cover class Middleware: - def __init__(self, cls: type, **options: typing.Any) -> None: + def __init__( + self, + cls: Type[_MiddlewareClass[P]], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: self.cls = cls - self.options = options + self.args = args + self.kwargs = kwargs - def __iter__(self) -> typing.Iterator[typing.Any]: - as_tuple = (self.cls, self.options) + def __iter__(self) -> Iterator[Any]: + as_tuple = (self.cls, self.args, self.kwargs) return iter(as_tuple) def __repr__(self) -> str: class_name = self.__class__.__name__ - option_strings = [f"{key}={value!r}" for key, value in self.options.items()] - args_repr = ", ".join([self.cls.__name__] + option_strings) + args_strings = [f"{value!r}" for value in self.args] + option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()] + args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings) return f"{class_name}({args_repr})" diff --git a/starlette/routing.py b/starlette/routing.py index 9a2134957..c8c854d2c 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -238,8 +238,8 @@ def __init__( self.app = endpoint if middleware is not None: - for cls, options in reversed(middleware): - self.app = cls(app=self.app, **options) + for cls, args, kwargs in reversed(middleware): + self.app = cls(app=self.app, *args, **kwargs) if methods is None: self.methods = None @@ -335,8 +335,8 @@ def __init__( self.app = endpoint if middleware is not None: - for cls, options in reversed(middleware): - self.app = cls(app=self.app, **options) + for cls, args, kwargs in reversed(middleware): + self.app = cls(app=self.app, *args, **kwargs) self.path_regex, self.path_format, self.param_convertors = compile_path(path) @@ -404,8 +404,8 @@ def __init__( self._base_app = Router(routes=routes) self.app = self._base_app if middleware is not None: - for cls, options in reversed(middleware): - self.app = cls(app=self.app, **options) + for cls, args, kwargs in reversed(middleware): + self.app = cls(app=self.app, *args, **kwargs) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path( self.path + "/{path:path}" @@ -672,8 +672,8 @@ def __init__( self.middleware_stack = self.app if middleware: - for cls, options in reversed(middleware): - self.middleware_stack = cls(self.middleware_stack, **options) + for cls, args, kwargs in reversed(middleware): + self.middleware_stack = cls(self.middleware_stack, *args, **kwargs) async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 650f4aee1..4d51f34bf 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,13 +1,13 @@ import contextvars from contextlib import AsyncExitStack -from typing import AsyncGenerator, Awaitable, Callable, List, Union +from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union import anyio import pytest from starlette.applications import Starlette from starlette.background import BackgroundTask -from starlette.middleware import Middleware +from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import PlainTextResponse, Response, StreamingResponse @@ -196,7 +196,7 @@ async def dispatch(self, request, call_next): ), ], ) -def test_contextvars(test_client_factory, middleware_cls: type): +def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]): # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) diff --git a/tests/middleware/test_middleware.py b/tests/middleware/test_middleware.py index f4d7a32f0..c6cf1fa1c 100644 --- a/tests/middleware/test_middleware.py +++ b/tests/middleware/test_middleware.py @@ -1,10 +1,22 @@ from starlette.middleware import Middleware +from starlette.types import ASGIApp, Receive, Scope, Send -class CustomMiddleware: - pass +class CustomMiddleware: # pragma: no cover + def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None: + self.app = app + self.foo = foo + self.bar = bar + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) -def test_middleware_repr(): - middleware = Middleware(CustomMiddleware) - assert repr(middleware) == "Middleware(CustomMiddleware)" + +def test_middleware_repr() -> None: + middleware = Middleware(CustomMiddleware, "foo", bar=123) + assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)" + + +def test_middleware_iter() -> None: + cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123) + assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123}) diff --git a/tests/test_applications.py b/tests/test_applications.py index e30ec9295..6d0118b53 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,6 +1,6 @@ import os from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Callable +from typing import AsyncIterator, Callable import anyio import httpx @@ -15,7 +15,7 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket @@ -499,8 +499,8 @@ class NoOpMiddleware: def __init__(self, app: ASGIApp): self.app = app - async def __call__(self, *args: Any): - await self.app(*args) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) class SimpleInitializableMiddleware: counter = 0 @@ -509,8 +509,8 @@ def __init__(self, app: ASGIApp): self.app = app SimpleInitializableMiddleware.counter += 1 - async def __call__(self, *args: Any): - await self.app(*args) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) def get_app() -> ASGIApp: app = Starlette() diff --git a/tests/test_authentication.py b/tests/test_authentication.py index af0beafd0..150482a1b 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -15,7 +15,7 @@ from starlette.endpoints import HTTPEndpoint from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.responses import JSONResponse from starlette.routing import Route, WebSocketRoute from starlette.websockets import WebSocketDisconnect @@ -327,7 +327,7 @@ def test_authentication_redirect(test_client_factory): assert response.json() == {"authenticated": True, "user": "tomchristie"} -def on_auth_error(request: Request, exc: Exception): +def on_auth_error(request: HTTPConnection, exc: AuthenticationError): return JSONResponse({"error": str(exc)}, status_code=401)