Skip to content

Commit

Permalink
feat: Improve Middleware type annotations.
Browse files Browse the repository at this point in the history
Use ParamSpec to provide concrete type annotations for middleware's parameters.
  • Loading branch information
Paweł Rubin committed Dec 18, 2023
1 parent 23c81da commit 0f4c8c5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
17 changes: 15 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import typing
import warnings

Expand All @@ -14,7 +15,14 @@
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket

if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec


AppType = typing.TypeVar("AppType", bound="Starlette")
P = ParamSpec("P")


class Starlette:
Expand Down Expand Up @@ -124,10 +132,15 @@ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover

def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
def add_middleware(
self,
middleware_class: typing.Callable[typing.Concatenate[ASGIApp, P], typing.Any],
*args: P.args,
**options: P.kwargs,
) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, **options))
self.user_middleware.insert(0, Middleware(middleware_class, *args, **options))

def add_exception_handler(
self,
Expand Down
22 changes: 19 additions & 3 deletions starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import typing
import sys
from typing import Any, Callable, Concatenate, Iterator

from starlette.types import ASGIApp

if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec


P = ParamSpec("P")


class Middleware:
def __init__(self, cls: type, **options: typing.Any) -> None:
def __init__(
self,
cls: Callable[Concatenate[ASGIApp, P], Any],
*args: P.args,
**options: P.kwargs,
) -> None:
self.cls = cls
self.options = options

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> Iterator[Any]:
as_tuple = (self.cls, self.options)
return iter(as_tuple)

Expand Down

0 comments on commit 0f4c8c5

Please sign in to comment.