Skip to content

Commit

Permalink
Add *args to Middleware and improve its type hints (#2381)
Browse files Browse the repository at this point in the history
Co-authored-by: Paweł Rubin <[email protected]>
  • Loading branch information
pawelrubin and Paweł Rubin authored Dec 20, 2023
1 parent 23c81da commit 866a15f
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 36 deletions.
18 changes: 13 additions & 5 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import typing
import warnings

from typing_extensions import ParamSpec

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand All @@ -15,6 +17,7 @@
from starlette.websockets import WebSocket

AppType = typing.TypeVar("AppType", bound="Starlette")
P = ParamSpec("P")


class Starlette:
Expand Down Expand Up @@ -98,8 +101,8 @@ def build_middleware_stack(self) -> ASGIApp:
)

app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
for cls, args, kwargs in reversed(middleware):
app = cls(app=app, *args, **kwargs)
return app

@property
Expand All @@ -124,10 +127,15 @@ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover

def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
def add_middleware(
self,
middleware_class: typing.Type[_MiddlewareClass[P]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, **options))
self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))

def add_exception_handler(
self,
Expand Down
35 changes: 28 additions & 7 deletions starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
import typing
from typing import Any, Iterator, Protocol, Type

from typing_extensions import ParamSpec

from starlette.types import ASGIApp, Receive, Scope, Send

P = ParamSpec("P")


class _MiddlewareClass(Protocol[P]):
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
... # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
... # pragma: no cover


class Middleware:
def __init__(self, cls: type, **options: typing.Any) -> None:
def __init__(
self,
cls: Type[_MiddlewareClass[P]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
self.cls = cls
self.options = options
self.args = args
self.kwargs = kwargs

def __iter__(self) -> typing.Iterator[typing.Any]:
as_tuple = (self.cls, self.options)
def __iter__(self) -> Iterator[Any]:
as_tuple = (self.cls, self.args, self.kwargs)
return iter(as_tuple)

def __repr__(self) -> str:
class_name = self.__class__.__name__
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
args_repr = ", ".join([self.cls.__name__] + option_strings)
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
return f"{class_name}({args_repr})"
16 changes: 8 additions & 8 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def __init__(
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)

if methods is None:
self.methods = None
Expand Down Expand Up @@ -335,8 +335,8 @@ def __init__(
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)

self.path_regex, self.path_format, self.param_convertors = compile_path(path)

Expand Down Expand Up @@ -404,8 +404,8 @@ def __init__(
self._base_app = Router(routes=routes)
self.app = self._base_app
if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
Expand Down Expand Up @@ -672,8 +672,8 @@ def __init__(

self.middleware_stack = self.app
if middleware:
for cls, options in reversed(middleware):
self.middleware_stack = cls(self.middleware_stack, **options)
for cls, args, kwargs in reversed(middleware):
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)

async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
Expand Down
6 changes: 3 additions & 3 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import contextvars
from contextlib import AsyncExitStack
from typing import AsyncGenerator, Awaitable, Callable, List, Union
from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
Expand Down Expand Up @@ -196,7 +196,7 @@ async def dispatch(self, request, call_next):
),
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
Expand Down
22 changes: 17 additions & 5 deletions tests/middleware/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from starlette.middleware import Middleware
from starlette.types import ASGIApp, Receive, Scope, Send


class CustomMiddleware:
pass
class CustomMiddleware: # pragma: no cover
def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None:
self.app = app
self.foo = foo
self.bar = bar

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"

def test_middleware_repr() -> None:
middleware = Middleware(CustomMiddleware, "foo", bar=123)
assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)"


def test_middleware_iter() -> None:
cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123)
assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123})
12 changes: 6 additions & 6 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Callable
from typing import AsyncIterator, Callable

import anyio
import httpx
Expand All @@ -15,7 +15,7 @@
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket


Expand Down Expand Up @@ -499,8 +499,8 @@ class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, *args: Any):
await self.app(*args)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

class SimpleInitializableMiddleware:
counter = 0
Expand All @@ -509,8 +509,8 @@ def __init__(self, app: ASGIApp):
self.app = app
SimpleInitializableMiddleware.counter += 1

async def __call__(self, *args: Any):
await self.app(*args)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

def get_app() -> ASGIApp:
app = Starlette()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from starlette.endpoints import HTTPEndpoint
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.responses import JSONResponse
from starlette.routing import Route, WebSocketRoute
from starlette.websockets import WebSocketDisconnect
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_authentication_redirect(test_client_factory):
assert response.json() == {"authenticated": True, "user": "tomchristie"}


def on_auth_error(request: Request, exc: Exception):
def on_auth_error(request: HTTPConnection, exc: AuthenticationError):
return JSONResponse({"error": str(exc)}, status_code=401)


Expand Down

0 comments on commit 866a15f

Please sign in to comment.