diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 5b4b68e88..99cb6b64c 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing from starlette._utils import is_async_callable @@ -22,16 +24,14 @@ def _lookup_exception_handler( exc_handlers: ExceptionHandlers, exc: Exception -) -> typing.Optional[ExceptionHandler]: +) -> ExceptionHandler | None: for cls in type(exc).__mro__: if cls in exc_handlers: return exc_handlers[cls] return None -def wrap_app_handling_exceptions( - app: ASGIApp, conn: typing.Union[Request, WebSocket] -) -> ASGIApp: +def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp: exception_handlers: ExceptionHandlers status_handlers: StatusHandlers try: diff --git a/starlette/_utils.py b/starlette/_utils.py index 15ccd92a4..42777c58e 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import functools import re @@ -74,7 +76,7 @@ async def __aenter__(self) -> SupportsAsyncCloseType: self.entered = await self.aw return self.entered - async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]: + async def __aexit__(self, *args: typing.Any) -> None | bool: await self.entered.close() return None diff --git a/starlette/applications.py b/starlette/applications.py index c3afaf704..1a4e3d264 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -55,14 +55,14 @@ class Starlette: """ def __init__( - self: "AppType", + self: AppType, debug: bool = False, routes: typing.Sequence[BaseRoute] | None = None, middleware: typing.Sequence[Middleware] | None = None, exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None, on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, - lifespan: typing.Optional[Lifespan["AppType"]] = None, + lifespan: Lifespan[AppType] | None = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. @@ -84,7 +84,7 @@ def __init__( def build_middleware_stack(self) -> ASGIApp: debug = self.debug error_handler = None - exception_handlers: typing.Dict[ + exception_handlers: dict[ typing.Any, typing.Callable[[Request, Exception], Response] ] = {} @@ -110,7 +110,7 @@ def build_middleware_stack(self) -> ASGIApp: return app @property - def routes(self) -> typing.List[BaseRoute]: + def routes(self) -> list[BaseRoute]: return self.router.routes def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: @@ -193,7 +193,7 @@ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-ar def route( self, path: str, - methods: typing.List[str] | None = None, + methods: list[str] | None = None, name: str | None = None, include_in_schema: bool = True, ) -> typing.Callable: # type: ignore[type-arg] diff --git a/starlette/authentication.py b/starlette/authentication.py index 07f271c49..e26a8a388 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import inspect import sys @@ -26,9 +28,9 @@ def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bo def requires( - scopes: typing.Union[str, typing.Sequence[str]], + scopes: str | typing.Sequence[str], status_code: int = 403, - redirect: typing.Optional[str] = None, + redirect: str | None = None, ) -> typing.Callable[ [typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any] ]: @@ -113,12 +115,12 @@ class AuthenticationError(Exception): class AuthenticationBackend: async def authenticate( self, conn: HTTPConnection - ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: + ) -> tuple[AuthCredentials, BaseUser] | None: raise NotImplementedError() # pragma: no cover class AuthCredentials: - def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None): + def __init__(self, scopes: typing.Sequence[str] | None = None): self.scopes = [] if scopes is None else list(scopes) diff --git a/starlette/background.py b/starlette/background.py index 4aaf7ae3c..1cbed3b22 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import typing @@ -29,7 +31,7 @@ async def __call__(self) -> None: class BackgroundTasks(BackgroundTask): - def __init__(self, tasks: typing.Optional[typing.Sequence[BackgroundTask]] = None): + def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None): self.tasks = list(tasks) if tasks else [] def add_task( diff --git a/starlette/concurrency.py b/starlette/concurrency.py index d19020183..215e3a63b 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import sys import typing @@ -14,7 +16,7 @@ T = typing.TypeVar("T") -async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501 +async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501 warnings.warn( "run_until_first_complete is deprecated " "and will be removed in a future version.", diff --git a/starlette/config.py b/starlette/config.py index 1ac49ea85..75a097724 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import typing from pathlib import Path @@ -51,7 +53,7 @@ def __len__(self) -> int: class Config: def __init__( self, - env_file: typing.Optional[typing.Union[str, Path]] = None, + env_file: str | Path | None = None, environ: typing.Mapping[str, str] = environ, env_prefix: str = "", ) -> None: @@ -64,17 +66,15 @@ def __init__( self.file_values = self._read_file(env_file) @typing.overload - def __call__(self, key: str, *, default: None) -> typing.Optional[str]: + def __call__(self, key: str, *, default: None) -> str | None: ... @typing.overload - def __call__(self, key: str, cast: typing.Type[T], default: T = ...) -> T: + def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ... @typing.overload - def __call__( - self, key: str, cast: typing.Type[str] = ..., default: str = ... - ) -> str: + def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ... @typing.overload @@ -87,15 +87,13 @@ def __call__( ... @typing.overload - def __call__( - self, key: str, cast: typing.Type[str] = ..., default: T = ... - ) -> typing.Union[T, str]: + def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ... def __call__( self, key: str, - cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, default: typing.Any = undefined, ) -> typing.Any: return self.get(key, cast, default) @@ -103,7 +101,7 @@ def __call__( def get( self, key: str, - cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, default: typing.Any = undefined, ) -> typing.Any: key = self.env_prefix + key @@ -117,7 +115,7 @@ def get( return self._perform_cast(key, default, cast) raise KeyError(f"Config '{key}' is missing, and has no default.") - def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str]: + def _read_file(self, file_name: str | Path) -> dict[str, str]: file_values: typing.Dict[str, str] = {} with open(file_name) as input_file: for line in input_file.readlines(): @@ -133,7 +131,7 @@ def _perform_cast( self, key: str, value: typing.Any, - cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, ) -> typing.Any: if cast is None or value is None: return value diff --git a/starlette/datastructures.py b/starlette/datastructures.py index e12957f50..e430d09b6 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing from shlex import shlex from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit @@ -22,7 +24,7 @@ class URL: def __init__( self, url: str = "", - scope: typing.Optional[Scope] = None, + scope: Scope | None = None, **components: typing.Any, ) -> None: if scope is not None: @@ -86,26 +88,26 @@ def fragment(self) -> str: return self.components.fragment @property - def username(self) -> typing.Union[None, str]: + def username(self) -> None | str: return self.components.username @property - def password(self) -> typing.Union[None, str]: + def password(self) -> None | str: return self.components.password @property - def hostname(self) -> typing.Union[None, str]: + def hostname(self) -> None | str: return self.components.hostname @property - def port(self) -> typing.Optional[int]: + def port(self) -> int | None: return self.components.port @property def is_secure(self) -> bool: return self.scheme in ("https", "wss") - def replace(self, **kwargs: typing.Any) -> "URL": + def replace(self, **kwargs: typing.Any) -> URL: if ( "username" in kwargs or "password" in kwargs @@ -138,19 +140,17 @@ def replace(self, **kwargs: typing.Any) -> "URL": components = self.components._replace(**kwargs) return self.__class__(components.geturl()) - def include_query_params(self, **kwargs: typing.Any) -> "URL": + def include_query_params(self, **kwargs: typing.Any) -> URL: params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) params.update({str(key): str(value) for key, value in kwargs.items()}) query = urlencode(params.multi_items()) return self.replace(query=query) - def replace_query_params(self, **kwargs: typing.Any) -> "URL": + def replace_query_params(self, **kwargs: typing.Any) -> URL: query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) return self.replace(query=query) - def remove_query_params( - self, keys: typing.Union[str, typing.Sequence[str]] - ) -> "URL": + def remove_query_params(self, keys: str | typing.Sequence[str]) -> "URL": if isinstance(keys, str): keys = [keys] params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) @@ -186,7 +186,7 @@ def __init__(self, path: str, protocol: str = "", host: str = "") -> None: self.protocol = protocol self.host = host - def make_absolute_url(self, base_url: typing.Union[str, URL]) -> URL: + def make_absolute_url(self, base_url: str | URL) -> URL: if isinstance(base_url, str): base_url = URL(base_url) if self.protocol: @@ -223,7 +223,7 @@ def __bool__(self) -> bool: class CommaSeparatedStrings(typing.Sequence[str]): - def __init__(self, value: typing.Union[str, typing.Sequence[str]]): + def __init__(self, value: str | typing.Sequence[str]): if isinstance(value, str): splitter = shlex(value, posix=True) splitter.whitespace = "," @@ -235,7 +235,7 @@ def __init__(self, value: typing.Union[str, typing.Sequence[str]]): def __len__(self) -> int: return len(self._items) - def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any: + def __getitem__(self, index: int | slice) -> typing.Any: return self._items[index] def __iter__(self) -> typing.Iterator[str]: @@ -255,11 +255,9 @@ class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): def __init__( self, - *args: typing.Union[ - "ImmutableMultiDict[_KeyType, _CovariantValueType]", - typing.Mapping[_KeyType, _CovariantValueType], - typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]], - ], + *args: ImmutableMultiDict[_KeyType, _CovariantValueType] + | typing.Mapping[_KeyType, _CovariantValueType] + | typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]], **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." @@ -272,7 +270,7 @@ def __init__( ) if not value: - _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = [] + _items: list[tuple[typing.Any, typing.Any]] = [] elif hasattr(value, "multi_items"): value = typing.cast( ImmutableMultiDict[_KeyType, _CovariantValueType], value @@ -282,15 +280,13 @@ def __init__( value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) _items = list(value.items()) else: - value = typing.cast( - typing.List[typing.Tuple[typing.Any, typing.Any]], value - ) + value = typing.cast("list[tuple[typing.Any, typing.Any]]", value) _items = list(value) self._dict = {k: v for k, v in _items} self._list = _items - def getlist(self, key: typing.Any) -> typing.List[_CovariantValueType]: + def getlist(self, key: typing.Any) -> list[_CovariantValueType]: return [item_value for item_key, item_value in self._list if item_key == key] def keys(self) -> typing.KeysView[_KeyType]: @@ -302,7 +298,7 @@ def values(self) -> typing.ValuesView[_CovariantValueType]: def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: return self._dict.items() - def multi_items(self) -> typing.List[typing.Tuple[_KeyType, _CovariantValueType]]: + def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: return list(self._list) def __getitem__(self, key: _KeyType) -> _CovariantValueType: @@ -340,12 +336,12 @@ def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: self._list = [(k, v) for k, v in self._list if k != key] return self._dict.pop(key, default) - def popitem(self) -> typing.Tuple[typing.Any, typing.Any]: + def popitem(self) -> tuple[typing.Any, typing.Any]: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value - def poplist(self, key: typing.Any) -> typing.List[typing.Any]: + def poplist(self, key: typing.Any) -> list[typing.Any]: values = [v for k, v in self._list if k == key] self.pop(key) return values @@ -361,7 +357,7 @@ def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: return self[key] - def setlist(self, key: typing.Any, values: typing.List[typing.Any]) -> None: + def setlist(self, key: typing.Any, values: list[typing.Any]) -> None: if not values: self.pop(key, None) else: @@ -375,11 +371,9 @@ def append(self, key: typing.Any, value: typing.Any) -> None: def update( self, - *args: typing.Union[ - "MultiDict", - typing.Mapping[typing.Any, typing.Any], - typing.List[typing.Tuple[typing.Any, typing.Any]], - ], + *args: MultiDict + | typing.Mapping[typing.Any, typing.Any] + | list[tuple[typing.Any, typing.Any]], **kwargs: typing.Any, ) -> None: value = MultiDict(*args, **kwargs) @@ -395,13 +389,11 @@ class QueryParams(ImmutableMultiDict[str, str]): def __init__( self, - *args: typing.Union[ - "ImmutableMultiDict[typing.Any, typing.Any]", - typing.Mapping[typing.Any, typing.Any], - typing.List[typing.Tuple[typing.Any, typing.Any]], - str, - bytes, - ], + *args: ImmutableMultiDict[typing.Any, typing.Any] + | typing.Mapping[typing.Any, typing.Any] + | list[tuple[typing.Any, typing.Any]] + | str + | bytes, **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." @@ -437,9 +429,9 @@ def __init__( self, file: typing.BinaryIO, *, - size: typing.Optional[int] = None, - filename: typing.Optional[str] = None, - headers: "typing.Optional[Headers]" = None, + size: int | None = None, + filename: str | None = None, + headers: Headers | None = None, ) -> None: self.filename = filename self.file = file @@ -447,7 +439,7 @@ def __init__( self.headers = headers or Headers() @property - def content_type(self) -> typing.Optional[str]: + def content_type(self) -> str | None: return self.headers.get("content-type", None) @property @@ -498,12 +490,10 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): def __init__( self, - *args: typing.Union[ - "FormData", - typing.Mapping[str, typing.Union[str, UploadFile]], - typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], - ], - **kwargs: typing.Union[str, UploadFile], + *args: FormData + | typing.Mapping[str, str | UploadFile] + | list[tuple[str, str | UploadFile]], + **kwargs: str | UploadFile, ) -> None: super().__init__(*args, **kwargs) @@ -520,11 +510,11 @@ class Headers(typing.Mapping[str, str]): def __init__( self, - headers: typing.Optional[typing.Mapping[str, str]] = None, - raw: typing.Optional[typing.List[typing.Tuple[bytes, bytes]]] = None, - scope: typing.Optional[typing.MutableMapping[str, typing.Any]] = None, + headers: typing.Mapping[str, str] | None = None, + raw: list[tuple[bytes, bytes]] | None = None, + scope: typing.MutableMapping[str, typing.Any] | None = None, ) -> None: - self._list: typing.List[typing.Tuple[bytes, bytes]] = [] + self._list: list[tuple[bytes, bytes]] = [] if headers is not None: assert raw is None, 'Cannot set both "headers" and "raw".' assert scope is None, 'Cannot set both "headers" and "scope".' @@ -541,22 +531,22 @@ def __init__( self._list = scope["headers"] = list(scope["headers"]) @property - def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: + def raw(self) -> list[tuple[bytes, bytes]]: return list(self._list) - def keys(self) -> typing.List[str]: # type: ignore[override] + def keys(self) -> list[str]: # type: ignore[override] return [key.decode("latin-1") for key, value in self._list] - def values(self) -> typing.List[str]: # type: ignore[override] + def values(self) -> list[str]: # type: ignore[override] return [value.decode("latin-1") for key, value in self._list] - def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore[override] + def items(self) -> list[tuple[str, str]]: # type: ignore[override] return [ (key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list ] - def getlist(self, key: str) -> typing.List[str]: + def getlist(self, key: str) -> list[str]: get_header_key = key.lower().encode("latin-1") return [ item_value.decode("latin-1") @@ -564,7 +554,7 @@ def getlist(self, key: str) -> typing.List[str]: if item_key == get_header_key ] - def mutablecopy(self) -> "MutableHeaders": + def mutablecopy(self) -> MutableHeaders: return MutableHeaders(raw=self._list[:]) def __getitem__(self, key: str) -> str: @@ -637,13 +627,13 @@ def __delitem__(self, key: str) -> None: for idx in reversed(pop_indexes): del self._list[idx] - def __ior__(self, other: typing.Mapping[str, str]) -> "MutableHeaders": + def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders: if not isinstance(other, typing.Mapping): raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") self.update(other) return self - def __or__(self, other: typing.Mapping[str, str]) -> "MutableHeaders": + def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders: if not isinstance(other, typing.Mapping): raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") new = self.mutablecopy() @@ -651,7 +641,7 @@ def __or__(self, other: typing.Mapping[str, str]) -> "MutableHeaders": return new @property - def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: + def raw(self) -> list[tuple[bytes, bytes]]: return self._list def setdefault(self, key: str, value: str) -> str: @@ -694,9 +684,9 @@ class State: Used for `request.state` and `app.state`. """ - _state: typing.Dict[str, typing.Any] + _state: dict[str, typing.Any] - def __init__(self, state: typing.Optional[typing.Dict[str, typing.Any]] = None): + def __init__(self, state: dict[str, typing.Any] | None = None): if state is None: state = {} super().__setattr__("_state", state) diff --git a/starlette/endpoints.py b/starlette/endpoints.py index c25dd9db2..57f718824 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import typing @@ -55,7 +57,7 @@ async def method_not_allowed(self, request: Request) -> Response: class WebSocketEndpoint: - encoding: typing.Optional[str] = None # May be "text", "bytes", or "json". + encoding: str | None = None # May be "text", "bytes", or "json". def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == "websocket" diff --git a/starlette/exceptions.py b/starlette/exceptions.py index a583d93a0..bd3352eb0 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import http import typing import warnings @@ -9,8 +11,8 @@ class HTTPException(Exception): def __init__( self, status_code: int, - detail: typing.Optional[str] = None, - headers: typing.Optional[typing.Dict[str, str]] = None, + detail: str | None = None, + headers: dict[str, str] | None = None, ) -> None: if detail is None: detail = http.HTTPStatus(status_code).phrase @@ -27,7 +29,7 @@ def __repr__(self) -> str: class WebSocketException(Exception): - def __init__(self, code: int, reason: typing.Optional[str] = None) -> None: + def __init__(self, code: int, reason: str | None = None) -> None: self.code = code self.reason = reason or "" @@ -56,5 +58,5 @@ def __getattr__(name: str) -> typing.Any: # pragma: no cover raise AttributeError(f"module '{__name__}' has no attribute '{name}'") -def __dir__() -> typing.List[str]: +def __dir__() -> list[str]: return sorted(list(__all__) + [__deprecated__]) # pragma: no cover diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 905260b98..e2a95e53f 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing from dataclasses import dataclass, field from enum import Enum @@ -24,11 +26,11 @@ class FormMessage(Enum): @dataclass class MultipartPart: - content_disposition: typing.Optional[bytes] = None + content_disposition: bytes | None = None field_name: str = "" data: bytes = b"" - file: typing.Optional[UploadFile] = None - item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list) + file: UploadFile | None = None + item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) def _user_safe_decode(src: bytes, codec: str) -> str: @@ -52,7 +54,7 @@ def __init__( ), "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream - self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = [] + self.messages: list[tuple[FormMessage, bytes]] = [] def on_field_start(self) -> None: message = (FormMessage.FIELD_START, b"") @@ -89,7 +91,7 @@ async def parse(self) -> FormData: field_name = b"" field_value = b"" - items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] + items: list[tuple[str, typing.Union[str, UploadFile]]] = [] # Feed the parser with data from the request. async for chunk in self.stream: @@ -123,8 +125,8 @@ def __init__( headers: Headers, stream: typing.AsyncGenerator[bytes, None], *, - max_files: typing.Union[int, float] = 1000, - max_fields: typing.Union[int, float] = 1000, + max_files: int | float = 1000, + max_fields: int | float = 1000, ) -> None: assert ( multipart is not None @@ -133,16 +135,16 @@ def __init__( self.stream = stream self.max_files = max_files self.max_fields = max_fields - self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] + self.items: list[tuple[str, str | UploadFile]] = [] self._current_files = 0 self._current_fields = 0 self._current_partial_header_name: bytes = b"" self._current_partial_header_value: bytes = b"" self._current_part = MultipartPart() self._charset = "" - self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = [] - self._file_parts_to_finish: typing.List[MultipartPart] = [] - self._files_to_close_on_error: typing.List[SpooledTemporaryFile[bytes]] = [] + self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] + self._file_parts_to_finish: list[MultipartPart] = [] + self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] def on_part_begin(self) -> None: self._current_part = MultipartPart() diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index ca9752baa..3d0342dc3 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import sys -from typing import Any, Iterator, Protocol, Type +from typing import Any, Iterator, Protocol if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec @@ -22,7 +24,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class Middleware: def __init__( self, - cls: Type[_MiddlewareClass[P]], + cls: type[_MiddlewareClass[P]], *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/starlette/requests.py b/starlette/requests.py index e51223bab..4af63bfc1 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import typing from http import cookies as http_cookies @@ -29,7 +31,7 @@ } -def cookie_parser(cookie_string: str) -> typing.Dict[str, str]: +def cookie_parser(cookie_string: str) -> dict[str, str]: """ This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. @@ -66,7 +68,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]): any functionality that is common to both `Request` and `WebSocket`. """ - def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None: + def __init__(self, scope: Scope, receive: Receive | None = None) -> None: assert scope["type"] in ("http", "websocket") self.scope = scope @@ -127,11 +129,11 @@ def query_params(self) -> QueryParams: return self._query_params @property - def path_params(self) -> typing.Dict[str, typing.Any]: + def path_params(self) -> dict[str, typing.Any]: return self.scope.get("path_params", {}) @property - def cookies(self) -> typing.Dict[str, str]: + def cookies(self) -> dict[str, str]: if not hasattr(self, "_cookies"): cookies: typing.Dict[str, str] = {} cookie_header = self.headers.get("cookie") @@ -142,7 +144,7 @@ def cookies(self) -> typing.Dict[str, str]: return self._cookies @property - def client(self) -> typing.Optional[Address]: + def client(self) -> Address | None: # client is a 2 item tuple of (host, port), None or missing host_port = self.scope.get("client") if host_port is not None: @@ -150,7 +152,7 @@ def client(self) -> typing.Optional[Address]: return None @property - def session(self) -> typing.Dict[str, typing.Any]: + def session(self) -> dict[str, typing.Any]: assert ( "session" in self.scope ), "SessionMiddleware must be installed to access request.session" @@ -251,10 +253,7 @@ async def json(self) -> typing.Any: return self._json async def _get_form( - self, - *, - max_files: typing.Union[int, float] = 1000, - max_fields: typing.Union[int, float] = 1000, + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 ) -> FormData: if self._form is None: assert ( @@ -284,10 +283,7 @@ async def _get_form( return self._form def form( - self, - *, - max_files: typing.Union[int, float] = 1000, - max_fields: typing.Union[int, float] = 1000, + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 ) -> AwaitableOrContextManager[FormData]: return AwaitableOrContextManagerWrapper( self._get_form(max_files=max_files, max_fields=max_fields) diff --git a/starlette/responses.py b/starlette/responses.py index c99c64f58..419816b7b 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import http.cookies import json import os @@ -28,9 +30,9 @@ def __init__( self, content: typing.Any = None, status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, ) -> None: self.status_code = status_code if media_type is not None: @@ -46,11 +48,9 @@ def render(self, content: typing.Any) -> bytes: return content return content.encode(self.charset) # type: ignore - def init_headers( - self, headers: typing.Optional[typing.Mapping[str, str]] = None - ) -> None: + def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None: if headers is None: - raw_headers: typing.List[typing.Tuple[bytes, bytes]] = [] + raw_headers: list[tuple[bytes, bytes]] = [] populate_content_length = True populate_content_type = True else: @@ -89,15 +89,15 @@ def set_cookie( self, key: str, value: str = "", - max_age: typing.Optional[int] = None, - expires: typing.Optional[typing.Union[datetime, str, int]] = None, + max_age: int | None = None, + expires: datetime | str | int | None = None, path: str = "/", - domain: typing.Optional[str] = None, + domain: str | None = None, secure: bool = False, httponly: bool = False, - samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax", + samesite: typing.Literal["lax", "strict", "none"] | None = "lax", ) -> None: - cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie() + cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() cookie[key] = value if max_age is not None: cookie[key]["max-age"] = max_age @@ -128,10 +128,10 @@ def delete_cookie( self, key: str, path: str = "/", - domain: typing.Optional[str] = None, + domain: str | None = None, secure: bool = False, httponly: bool = False, - samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax", + samesite: typing.Literal["lax", "strict", "none"] | None = "lax", ) -> None: self.set_cookie( key, @@ -173,9 +173,9 @@ def __init__( self, content: typing.Any, status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, ) -> None: super().__init__(content, status_code, headers, media_type, background) @@ -192,10 +192,10 @@ def render(self, content: typing.Any) -> bytes: class RedirectResponse(Response): def __init__( self, - url: typing.Union[str, URL], + url: str | URL, status_code: int = 307, - headers: typing.Optional[typing.Mapping[str, str]] = None, - background: typing.Optional[BackgroundTask] = None, + headers: typing.Mapping[str, str] | None = None, + background: BackgroundTask | None = None, ) -> None: super().__init__( content=b"", status_code=status_code, headers=headers, background=background @@ -216,9 +216,9 @@ def __init__( self, content: ContentStream, status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, ) -> None: if isinstance(content, typing.AsyncIterable): self.body_iterator = content @@ -253,7 +253,7 @@ async def stream_response(self, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: - async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None: + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: await func() task_group.cancel_scope.cancel() @@ -269,14 +269,14 @@ class FileResponse(Response): def __init__( self, - path: typing.Union[str, "os.PathLike[str]"], + path: str | os.PathLike[str], status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, - filename: typing.Optional[str] = None, - stat_result: typing.Optional[os.stat_result] = None, - method: typing.Optional[str] = None, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + filename: str | None = None, + stat_result: os.stat_result | None = None, + method: str | None = None, content_disposition_type: str = "attachment", ) -> None: self.path = path diff --git a/starlette/routing.py b/starlette/routing.py index d718bb921..b5467bb05 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import functools import inspect @@ -28,7 +30,7 @@ class NoMatchFound(Exception): if no matching route exists. """ - def __init__(self, name: str, path_params: typing.Dict[str, typing.Any]) -> None: + def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None: params = ", ".join(list(path_params.keys())) super().__init__(f'No route exists for name "{name}" and params "{params}".') @@ -106,9 +108,9 @@ def get_name(endpoint: typing.Callable[..., typing.Any]) -> str: def replace_params( path: str, - param_convertors: typing.Dict[str, Convertor[typing.Any]], - path_params: typing.Dict[str, str], -) -> typing.Tuple[str, typing.Dict[str, str]]: + param_convertors: dict[str, Convertor[typing.Any]], + path_params: dict[str, str], +) -> tuple[str, dict[str, str]]: for key, value in list(path_params.items()): if "{" + key + "}" in path: convertor = param_convertors[key] @@ -124,7 +126,7 @@ def replace_params( def compile_path( path: str, -) -> typing.Tuple[typing.Pattern[str], str, typing.Dict[str, Convertor[typing.Any]]]: +) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]: """ Given a path string, like: "/{username:str}", or a host string, like: "{subdomain}.mydomain.org", return a three-tuple @@ -181,7 +183,7 @@ def compile_path( class BaseRoute: - def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + def matches(self, scope: Scope) -> tuple[Match, Scope]: raise NotImplementedError() # pragma: no cover def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: @@ -216,10 +218,10 @@ def __init__( path: str, endpoint: typing.Callable[..., typing.Any], *, - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, + methods: list[str] | None = None, + name: str | None = None, include_in_schema: bool = True, - middleware: typing.Optional[typing.Sequence[Middleware]] = None, + middleware: typing.Sequence[Middleware] | None = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -252,7 +254,7 @@ def __init__( self.path_regex, self.path_format, self.param_convertors = compile_path(path) - def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + def matches(self, scope: Scope) -> tuple[Match, Scope]: path_params: "typing.Dict[str, typing.Any]" if scope["type"] == "http": route_path = get_route_path(scope) @@ -317,8 +319,8 @@ def __init__( path: str, endpoint: typing.Callable[..., typing.Any], *, - name: typing.Optional[str] = None, - middleware: typing.Optional[typing.Sequence[Middleware]] = None, + name: str | None = None, + middleware: typing.Sequence[Middleware] | None = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -341,7 +343,7 @@ def __init__( self.path_regex, self.path_format, self.param_convertors = compile_path(path) - def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + def matches(self, scope: Scope) -> tuple[Match, Scope]: path_params: "typing.Dict[str, typing.Any]" if scope["type"] == "websocket": route_path = get_route_path(scope) @@ -387,11 +389,11 @@ class Mount(BaseRoute): def __init__( self, path: str, - app: typing.Optional[ASGIApp] = None, - routes: typing.Optional[typing.Sequence[BaseRoute]] = None, - name: typing.Optional[str] = None, + app: ASGIApp | None = None, + routes: typing.Sequence[BaseRoute] | None = None, + name: str | None = None, *, - middleware: typing.Optional[typing.Sequence[Middleware]] = None, + middleware: typing.Sequence[Middleware] | None = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert ( @@ -412,7 +414,7 @@ def __init__( ) @property - def routes(self) -> typing.List[BaseRoute]: + def routes(self) -> list[BaseRoute]: return getattr(self._base_app, "routes", []) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: @@ -498,9 +500,7 @@ def __repr__(self) -> str: class Host(BaseRoute): - def __init__( - self, host: str, app: ASGIApp, name: typing.Optional[str] = None - ) -> None: + def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None: assert not host.startswith("/"), "Host must not start with '/'" self.host = host self.app = app @@ -508,10 +508,10 @@ def __init__( self.host_regex, self.host_format, self.param_convertors = compile_path(host) @property - def routes(self) -> typing.List[BaseRoute]: + def routes(self) -> list[BaseRoute]: return getattr(self.app, "routes", []) - def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + def matches(self, scope: Scope) -> tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): headers = Headers(scope=scope) host = headers.get("host", "").split(":")[0] @@ -581,10 +581,10 @@ async def __aenter__(self) -> _T: async def __aexit__( self, - exc_type: typing.Optional[typing.Type[BaseException]], - exc_value: typing.Optional[BaseException], - traceback: typing.Optional[types.TracebackType], - ) -> typing.Optional[bool]: + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool | None: return self._cm.__exit__(exc_type, exc_value, traceback) @@ -603,7 +603,7 @@ def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]: class _DefaultLifespan: - def __init__(self, router: "Router"): + def __init__(self, router: Router): self._router = router async def __aenter__(self) -> None: @@ -619,20 +619,16 @@ def __call__(self: _T, app: object) -> _T: class Router: def __init__( self, - routes: typing.Optional[typing.Sequence[BaseRoute]] = None, + routes: typing.Sequence[BaseRoute] | None = None, redirect_slashes: bool = True, - default: typing.Optional[ASGIApp] = None, - on_startup: typing.Optional[ - typing.Sequence[typing.Callable[[], typing.Any]] - ] = None, - on_shutdown: typing.Optional[ - typing.Sequence[typing.Callable[[], typing.Any]] - ] = None, + default: ASGIApp | None = None, + on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, # the generic to Lifespan[AppType] is the type of the top level application # which the router cannot know statically, so we use typing.Any - lifespan: typing.Optional[Lifespan[typing.Any]] = None, + lifespan: Lifespan[typing.Any] | None = None, *, - middleware: typing.Optional[typing.Sequence[Middleware]] = None, + middleware: typing.Sequence[Middleware] | None = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -815,13 +811,13 @@ def __eq__(self, other: typing.Any) -> bool: return isinstance(other, Router) and self.routes == other.routes def mount( - self, path: str, app: ASGIApp, name: typing.Optional[str] = None + self, path: str, app: ASGIApp, name: str | None = None ) -> None: # pragma: nocover route = Mount(path, app=app, name=name) self.routes.append(route) def host( - self, host: str, app: ASGIApp, name: typing.Optional[str] = None + self, host: str, app: ASGIApp, name: str | None = None ) -> None: # pragma: no cover route = Host(host, app=app, name=name) self.routes.append(route) @@ -829,11 +825,9 @@ def host( def add_route( self, path: str, - endpoint: typing.Callable[ - [Request], typing.Union[typing.Awaitable[Response], Response] - ], - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, + endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response], + methods: list[str] | None = None, + name: str | None = None, include_in_schema: bool = True, ) -> None: # pragma: nocover route = Route( @@ -849,7 +843,7 @@ def add_websocket_route( self, path: str, endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]], - name: typing.Optional[str] = None, + name: str | None = None, ) -> None: # pragma: no cover route = WebSocketRoute(path, endpoint=endpoint, name=name) self.routes.append(route) @@ -857,8 +851,8 @@ def add_websocket_route( def route( self, path: str, - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, + methods: list[str] | None = None, + name: str | None = None, include_in_schema: bool = True, ) -> typing.Callable: # type: ignore[type-arg] """ @@ -886,9 +880,7 @@ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-ar return decorator - def websocket_route( - self, path: str, name: typing.Optional[str] = None - ) -> typing.Callable: # type: ignore[type-arg] + def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: diff --git a/starlette/schemas.py b/starlette/schemas.py index 737f6b029..89fa20b89 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect import re import typing @@ -30,14 +32,10 @@ class EndpointInfo(typing.NamedTuple): class BaseSchemaGenerator: - def get_schema( - self, routes: typing.List[BaseRoute] - ) -> typing.Dict[str, typing.Any]: + def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]: raise NotImplementedError() # pragma: no cover - def get_endpoints( - self, routes: typing.List[BaseRoute] - ) -> typing.List[EndpointInfo]: + def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: """ Given the routes, yields the following information: @@ -48,7 +46,7 @@ def get_endpoints( - func method ready to extract the docstring """ - endpoints_info: typing.List[EndpointInfo] = [] + endpoints_info: list[EndpointInfo] = [] for route in routes: if isinstance(route, (Mount, Host)): @@ -99,7 +97,7 @@ def _remove_converter(self, path: str) -> str: def parse_docstring( self, func_or_method: typing.Callable[..., typing.Any] - ) -> typing.Dict[str, typing.Any]: + ) -> dict[str, typing.Any]: """ Given a function, parse the docstring as YAML and return a dictionary of info. """ @@ -130,12 +128,10 @@ def OpenAPIResponse(self, request: Request) -> Response: class SchemaGenerator(BaseSchemaGenerator): - def __init__(self, base_schema: typing.Dict[str, typing.Any]) -> None: + def __init__(self, base_schema: dict[str, typing.Any]) -> None: self.base_schema = base_schema - def get_schema( - self, routes: typing.List[BaseRoute] - ) -> typing.Dict[str, typing.Any]: + def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]: schema = dict(self.base_schema) schema.setdefault("paths", {}) endpoints_info = self.get_endpoints(routes) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 52614400b..5d0856ccc 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib.util import os import stat @@ -41,10 +43,8 @@ class StaticFiles: def __init__( self, *, - directory: typing.Optional[PathLike] = None, - packages: typing.Optional[ - typing.List[typing.Union[str, typing.Tuple[str, str]]] - ] = None, + directory: PathLike | None = None, + packages: list[str | tuple[str, str]] | None = None, html: bool = False, check_dir: bool = True, follow_symlink: bool = False, @@ -60,11 +60,9 @@ def __init__( def get_directories( self, - directory: typing.Optional[PathLike] = None, - packages: typing.Optional[ - typing.List[typing.Union[str, typing.Tuple[str, str]]] - ] = None, - ) -> typing.List[PathLike]: + directory: PathLike | None = None, + packages: list[str | tuple[str, str]] | None = None, + ) -> list[PathLike]: """ Given `directory` and `packages` arguments, return a list of all the directories that should be used for serving static files from. @@ -157,9 +155,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: return FileResponse(full_path, stat_result=stat_result, status_code=404) raise HTTPException(status_code=404) - def lookup_path( - self, path: str - ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: + def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]: for directory in self.all_directories: joined_path = os.path.join(directory, path) if self.follow_symlink: diff --git a/starlette/status.py b/starlette/status.py index 1689328a4..2cd5db575 100644 --- a/starlette/status.py +++ b/starlette/status.py @@ -5,8 +5,9 @@ And RFC 2324 - https://tools.ietf.org/html/rfc2324 """ +from __future__ import annotations + import warnings -from typing import List __all__ = ( "HTTP_100_CONTINUE", @@ -195,5 +196,5 @@ def __getattr__(name: str) -> int: raise AttributeError(f"module '{__name__}' has no attribute '{name}'") -def __dir__() -> List[str]: +def __dir__() -> list[str]: return sorted(list(__all__) + list(__deprecated__.keys())) # pragma: no cover diff --git a/starlette/templating.py b/starlette/templating.py index c2078d22d..fe31ab5ee 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing import warnings from os import PathLike @@ -27,11 +29,11 @@ class _TemplateResponse(HTMLResponse): def __init__( self, template: typing.Any, - context: typing.Dict[str, typing.Any], + context: dict[str, typing.Any], status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, ): self.template = template self.context = context @@ -64,11 +66,12 @@ class Jinja2Templates: @typing.overload def __init__( self, - directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]]]", # noqa: E501 + directory: str + | PathLike[typing.AnyStr] + | typing.Sequence[str | PathLike[typing.AnyStr]], *, - context_processors: typing.Optional[ - typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]] - ] = None, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] + | None = None, **env_options: typing.Any, ) -> None: ... @@ -77,21 +80,22 @@ def __init__( def __init__( self, *, - env: "jinja2.Environment", - context_processors: typing.Optional[ - typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]] - ] = None, + env: jinja2.Environment, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] + | None = None, ) -> None: ... def __init__( self, - directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]], None]" = None, # noqa: E501 + directory: str + | PathLike[typing.AnyStr] + | typing.Sequence[str | PathLike[typing.AnyStr]] + | None = None, *, - context_processors: typing.Optional[ - typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]] - ] = None, - env: typing.Optional["jinja2.Environment"] = None, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] + | None = None, + env: jinja2.Environment | None = None, **env_options: typing.Any, ) -> None: if env_options: @@ -111,16 +115,18 @@ def __init__( def _create_env( self, - directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]]]", # noqa: E501 + directory: str + | PathLike[typing.AnyStr] + | typing.Sequence[str | PathLike[typing.AnyStr]], **env_options: typing.Any, - ) -> "jinja2.Environment": + ) -> jinja2.Environment: loader = jinja2.FileSystemLoader(directory) env_options.setdefault("loader", loader) env_options.setdefault("autoescape", True) return jinja2.Environment(**env_options) - def _setup_env_defaults(self, env: "jinja2.Environment") -> None: + def _setup_env_defaults(self, env: jinja2.Environment) -> None: @pass_context def url_for( context: typing.Dict[str, typing.Any], @@ -133,7 +139,7 @@ def url_for( env.globals.setdefault("url_for", url_for) - def get_template(self, name: str) -> "jinja2.Template": + def get_template(self, name: str) -> jinja2.Template: return self.env.get_template(name) @typing.overload @@ -141,11 +147,11 @@ def TemplateResponse( self, request: Request, name: str, - context: typing.Optional[typing.Dict[str, typing.Any]] = None, + context: dict[str, typing.Any] | None = None, status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, ) -> _TemplateResponse: ... @@ -153,11 +159,11 @@ def TemplateResponse( def TemplateResponse( self, name: str, - context: typing.Optional[typing.Dict[str, typing.Any]] = None, + context: dict[str, typing.Any] | None = None, status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, ) -> _TemplateResponse: # Deprecated usage ... diff --git a/starlette/testclient.py b/starlette/testclient.py index 2cccb15d1..5b557032e 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -48,7 +48,7 @@ _RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]] -def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> TypeGuard[ASGI3App]: +def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: if inspect.isclass(app): return hasattr(app, "__await__") return is_async_callable(app) @@ -69,7 +69,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class _AsyncBackend(typing.TypedDict): backend: str - backend_options: typing.Dict[str, typing.Any] + backend_options: dict[str, typing.Any] class _Upgrade(Exception): @@ -167,8 +167,9 @@ def send_text(self, data: str) -> None: def send_bytes(self, data: bytes) -> None: self.send({"type": "websocket.receive", "bytes": data}) - def send_json(self, data: typing.Any, mode: str = "text") -> None: - assert mode in ["text", "binary"] + def send_json( + self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text" + ) -> None: text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) if mode == "text": self.send({"type": "websocket.receive", "text": text}) @@ -396,9 +397,9 @@ def __init__( raise_server_exceptions: bool = True, root_path: str = "", backend: typing.Literal["asyncio", "trio"] = "asyncio", - backend_options: typing.Dict[str, typing.Any] | None = None, + backend_options: dict[str, typing.Any] | None = None, cookies: httpx._types.CookieTypes | None = None, - headers: typing.Dict[str, str] | None = None, + headers: dict[str, str] | None = None, follow_redirects: bool = True, ) -> None: self.async_backend = _AsyncBackend( @@ -410,7 +411,7 @@ def __init__( app = typing.cast(ASGI2App, app) # type: ignore[assignment] asgi_app = _WrapASGI2(app) # type: ignore[arg-type] self.app = asgi_app - self.app_state: typing.Dict[str, typing.Any] = {} + self.app_state: dict[str, typing.Any] = {} transport = _TestClientTransport( self.app, portal_factory=self._portal_factory, @@ -466,22 +467,20 @@ def request( # type: ignore[override] method: str, url: httpx._types.URLTypes, *, - content: typing.Optional[httpx._types.RequestContent] = None, - data: typing.Optional[_RequestData] = None, - files: typing.Optional[httpx._types.RequestFiles] = None, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, json: typing.Any = None, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._types.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | None = None, + allow_redirects: bool | None = None, + timeout: httpx._types.TimeoutTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: url = self._merge_url(url) redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) @@ -505,18 +504,16 @@ def get( # type: ignore[override] self, url: httpx._types.URLTypes, *, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._types.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | None = None, + allow_redirects: bool | None = None, + timeout: httpx._types.TimeoutTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().get( @@ -534,18 +531,16 @@ def options( # type: ignore[override] self, url: httpx._types.URLTypes, *, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._types.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | None = None, + allow_redirects: bool | None = None, + timeout: httpx._types.TimeoutTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().options( @@ -563,18 +558,16 @@ def head( # type: ignore[override] self, url: httpx._types.URLTypes, *, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._types.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | None = None, + allow_redirects: bool | None = None, + timeout: httpx._types.TimeoutTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().head( @@ -592,22 +585,20 @@ def post( # type: ignore[override] self, url: httpx._types.URLTypes, *, - content: typing.Optional[httpx._types.RequestContent] = None, - data: typing.Optional[_RequestData] = None, - files: typing.Optional[httpx._types.RequestFiles] = None, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, json: typing.Any = None, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._types.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | None = None, + allow_redirects: bool | None = None, + timeout: httpx._types.TimeoutTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().post( @@ -629,22 +620,20 @@ def put( # type: ignore[override] self, url: httpx._types.URLTypes, *, - content: typing.Optional[httpx._types.RequestContent] = None, - data: typing.Optional[_RequestData] = None, - files: typing.Optional[httpx._types.RequestFiles] = None, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, json: typing.Any = None, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._types.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | None = None, + allow_redirects: bool | None = None, + timeout: httpx._types.TimeoutTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().put( @@ -666,22 +655,20 @@ def patch( # type: ignore[override] self, url: httpx._types.URLTypes, *, - content: typing.Optional[httpx._types.RequestContent] = None, - data: typing.Optional[_RequestData] = None, - files: typing.Optional[httpx._types.RequestFiles] = None, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, json: typing.Any = None, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._types.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | None = None, + allow_redirects: bool | None = None, + timeout: httpx._types.TimeoutTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().patch( @@ -703,18 +690,16 @@ def delete( # type: ignore[override] self, url: httpx._types.URLTypes, *, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._types.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | None = None, + allow_redirects: bool | None = None, + timeout: httpx._types.TimeoutTypes + | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().delete( @@ -733,7 +718,7 @@ def websocket_connect( url: str, subprotocols: typing.Sequence[str] | None = None, **kwargs: typing.Any, - ) -> "WebSocketTestSession": + ) -> WebSocketTestSession: url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) headers.setdefault("connection", "upgrade") @@ -751,7 +736,7 @@ def websocket_connect( return session - def __enter__(self) -> "TestClient": + def __enter__(self) -> TestClient: with contextlib.ExitStack() as stack: self.portal = portal = stack.enter_context( anyio.from_thread.start_blocking_portal(**self.async_backend) @@ -761,12 +746,8 @@ def __enter__(self) -> "TestClient": def reset_portal() -> None: self.portal = None - send1: ObjectSendStream[ - typing.Optional[typing.MutableMapping[str, typing.Any]] - ] - receive1: ObjectReceiveStream[ - typing.Optional[typing.MutableMapping[str, typing.Any]] - ] + send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None] + receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None] send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] send1, receive1 = anyio.create_memory_object_stream(math.inf) diff --git a/starlette/types.py b/starlette/types.py index 19484301e..f78dd63ae 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -22,7 +22,7 @@ Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] HTTPExceptionHandler = typing.Callable[ - ["Request", Exception], typing.Union["Response", typing.Awaitable["Response"]] + ["Request", Exception], "Response | typing.Awaitable[Response]" ] WebSocketExceptionHandler = typing.Callable[ ["WebSocket", Exception], typing.Awaitable[None] diff --git a/starlette/websockets.py b/starlette/websockets.py index 084d93094..850fbf115 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum import json import typing @@ -92,8 +94,8 @@ async def send(self, message: Message) -> None: async def accept( self, - subprotocol: typing.Optional[str] = None, - headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None, + subprotocol: str | None = None, + headers: typing.Iterable[tuple[bytes, bytes]] | None = None, ) -> None: headers = headers or [] @@ -178,16 +180,14 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None: else: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) - async def close( - self, code: int = 1000, reason: typing.Optional[str] = None - ) -> None: + async def close(self, code: int = 1000, reason: str | None = None) -> None: await self.send( {"type": "websocket.close", "code": code, "reason": reason or ""} ) class WebSocketClose: - def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + def __init__(self, code: int = 1000, reason: str | None = None) -> None: self.code = code self.reason = reason or ""