diff --git a/CHANGES/5864.feature b/CHANGES/5864.feature new file mode 100644 index 00000000000..cc0c0ab48e8 --- /dev/null +++ b/CHANGES/5864.feature @@ -0,0 +1 @@ +Added static typing support using ``Application.state`` (replacing the previous dict-like behaviour of ``Application``). diff --git a/aiohttp/abc.py b/aiohttp/abc.py index 08a3fe02270..0f85168beec 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -9,32 +9,37 @@ Callable, Dict, Generator, + Generic, Iterable, List, Optional, Tuple, + TypeVar, ) from multidict import CIMultiDict from yarl import URL -from .typedefs import LooseCookies +from .typedefs import LooseCookies, _SafeApplication, _SafeRequest + +_T = TypeVar("_T") if TYPE_CHECKING: # pragma: no cover - from .web_app import Application from .web_exceptions import HTTPException from .web_request import BaseRequest, Request from .web_response import StreamResponse else: - BaseRequest = Request = Application = StreamResponse = None + BaseRequest = StreamResponse = None HTTPException = None + class Request(Generic[_T]): ... + class AbstractRouter(ABC): def __init__(self) -> None: self._frozen = False - def post_init(self, app: Application) -> None: + def post_init(self, app: _SafeApplication) -> None: """Post init stage. Not an abstract method for sake of backward compatibility, @@ -51,19 +56,19 @@ def freeze(self) -> None: self._frozen = True @abstractmethod - async def resolve(self, request: Request) -> "AbstractMatchInfo": + async def resolve(self, request: _SafeRequest) -> "AbstractMatchInfo": """Return MATCH_INFO for given request""" class AbstractMatchInfo(ABC): @property # pragma: no branch @abstractmethod - def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]: + def handler(self) -> Callable[[_SafeRequest], Awaitable[StreamResponse]]: """Execute matched request handler""" @property @abstractmethod - def expect_handler(self) -> Callable[[Request], Awaitable[None]]: + def expect_handler(self) -> Callable[[_SafeRequest], Awaitable[None]]: """Expect handler for 100-continue processing""" @property # pragma: no branch @@ -77,7 +82,7 @@ def get_info(self) -> Dict[str, Any]: @property # pragma: no branch @abstractmethod - def apps(self) -> Tuple[Application, ...]: + def apps(self) -> Tuple[_SafeApplication, ...]: """Stack of nested applications. Top level application is left-most element. @@ -85,7 +90,7 @@ def apps(self) -> Tuple[Application, ...]: """ @abstractmethod - def add_app(self, app: Application) -> None: + def add_app(self, app: _SafeApplication) -> None: """Add application to the nested apps stack.""" @abstractmethod @@ -99,14 +104,14 @@ def freeze(self) -> None: """ -class AbstractView(ABC): +class AbstractView(ABC, Generic[_T]): """Abstract class based view.""" - def __init__(self, request: Request) -> None: + def __init__(self, request: Request[_T]) -> None: self._request = request @property - def request(self) -> Request: + def request(self) -> Request[_T]: """Request instance.""" return self._request diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index ee877fd4def..fe8d974fb14 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -1,7 +1,6 @@ """WebSocket protocol versions 13 and 8.""" import asyncio -import collections import json import random import re @@ -9,7 +8,18 @@ import zlib from enum import IntEnum from struct import Struct -from typing import Any, Callable, List, Optional, Pattern, Set, Tuple, Union, cast +from typing import ( + Any, + Callable, + List, + NamedTuple, + Optional, + Pattern, + Set, + Tuple, + Union, + cast, +) from typing_extensions import Final @@ -78,7 +88,11 @@ class WSMsgType(IntEnum): DEFAULT_LIMIT: Final[int] = 2 ** 16 -_WSMessageBase = collections.namedtuple("_WSMessageBase", ["type", "data", "extra"]) +class _WSMessageBase(NamedTuple): + type: WSMsgType + # To type correctly, this would need some kind of tagged union for each type. + data: Any + extra: Optional[str] class WSMessage(_WSMessageBase): diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index 153db35276b..9ee7f681974 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -6,8 +6,6 @@ import pytest -from aiohttp.web import Application - from .test_utils import ( BaseTestServer, RawTestServer, @@ -18,6 +16,8 @@ teardown_test_loop, unused_port as _unused_port, ) +from .typedefs import _SafeApplication +from .web import Application try: import uvloop @@ -342,7 +342,7 @@ def aiohttp_client( clients = [] async def go( - __param: Union[Application, BaseTestServer], + __param: Union[_SafeApplication, BaseTestServer], *, server_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 06155bae19f..dbe578bc2fa 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -18,6 +18,7 @@ List, Optional, Type, + TypeVar, Union, cast, ) @@ -36,6 +37,7 @@ from .client_ws import ClientWebSocketResponse from .helpers import _SENTINEL, PY_38, sentinel from .http import HttpVersion, RawRequestMessage +from .typedefs import _SafeApplication, _SafeRequest from .web import ( Application, AppRunner, @@ -58,6 +60,8 @@ else: from asynctest import TestCase # type: ignore[no-redef] +_T = TypeVar("_T") + REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" @@ -205,7 +209,7 @@ async def __aexit__( class TestServer(BaseTestServer): def __init__( self, - app: Application, + app: _SafeApplication, *, scheme: Union[str, _SENTINEL] = sentinel, host: str = "127.0.0.1", @@ -286,8 +290,8 @@ def server(self) -> BaseTestServer: return self._server @property - def app(self) -> Optional[Application]: - return cast(Optional[Application], getattr(self._server, "app", None)) + def app(self) -> Optional[_SafeApplication]: + return cast(Optional[_SafeApplication], getattr(self._server, "app", None)) @property def session(self) -> ClientSession: @@ -410,7 +414,7 @@ class AioHTTPTestCase(TestCase): execute function on the test client using asynchronous methods. """ - async def get_application(self) -> Application: + async def get_application(self) -> _SafeApplication: """ This method should be overridden to return the aiohttp.web.Application @@ -419,7 +423,7 @@ async def get_application(self) -> Application: """ return self.get_app() - def get_app(self) -> Application: + def get_app(self) -> _SafeApplication: """Obsolete method used to constructing web application. Use .get_application() coroutine instead @@ -446,7 +450,7 @@ def tearDown(self) -> None: async def tearDownAsync(self) -> None: await self.client.close() - async def get_server(self, app: Application) -> TestServer: + async def get_server(self, app: _SafeApplication) -> TestServer: """Return a TestServer instance.""" return TestServer(app) @@ -524,16 +528,8 @@ def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> N def _create_app_mock() -> mock.MagicMock: - def get_dict(app: Any, key: str) -> Any: - return app.__app_dict[key] - - def set_dict(app: Any, key: str, value: Any) -> None: - app.__app_dict[key] = value - - app = mock.MagicMock() - app.__app_dict = {} - app.__getitem__ = get_dict - app.__setitem__ = set_dict + app = mock.MagicMock(spec_set=Application) + app.state = {} app.on_response_prepare = Signal(app) app.on_response_prepare.freeze() @@ -561,7 +557,7 @@ def make_mocked_request( match_info: Any = sentinel, version: HttpVersion = HttpVersion(1, 1), closing: bool = False, - app: Any = None, + app: Optional[Application[_T]] = None, writer: Any = sentinel, protocol: Any = sentinel, transport: Any = sentinel, @@ -569,7 +565,7 @@ def make_mocked_request( sslcontext: Optional[SSLContext] = None, client_max_size: int = 1024 ** 2, loop: Any = ..., -) -> Request: +) -> Request[_T]: """Creates mocked web.Request testing purposes. Useful in unit tests, when spinning full web server is overkill or @@ -632,7 +628,7 @@ def make_mocked_request( if payload is sentinel: payload = mock.Mock() - req = Request( + req: Request[_T] = Request( message, payload, protocol, writer, task, loop, client_max_size=client_max_size ) diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py index 1b13a4dbd0f..d3a12583031 100644 --- a/aiohttp/typedefs.py +++ b/aiohttp/typedefs.py @@ -5,6 +5,7 @@ Any, Awaitable, Callable, + Dict, Iterable, Mapping, Tuple, @@ -18,18 +19,23 @@ DEFAULT_JSON_DECODER = json.loads if TYPE_CHECKING: # pragma: no cover + from http.cookies import BaseCookie, Morsel + + from .web import Application, Request, StreamResponse + _CIMultiDict = CIMultiDict[str] _CIMultiDictProxy = CIMultiDictProxy[str] _MultiDict = MultiDict[str] _MultiDictProxy = MultiDictProxy[str] - from http.cookies import BaseCookie, Morsel - - from .web import Request, StreamResponse + _SafeApplication = Application[Dict[str, object]] + _SafeRequest = Request[Dict[str, object]] else: _CIMultiDict = CIMultiDict _CIMultiDictProxy = CIMultiDictProxy _MultiDict = MultiDict _MultiDictProxy = MultiDictProxy + _SafeApplication = "Application" + _SafeRequest = "Request" Byteish = Union[bytes, bytearray, memoryview] JSONEncoder = Callable[[Any], str] diff --git a/aiohttp/web.py b/aiohttp/web.py index 0062076268f..0055425b7b7 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -20,6 +20,7 @@ from .abc import AbstractAccessLogger from .log import access_logger +from .typedefs import _SafeApplication from .web_app import Application as Application, CleanupError as CleanupError from .web_exceptions import ( HTTPAccepted as HTTPAccepted, @@ -284,7 +285,7 @@ async def _run_app( - app: Union[Application, Awaitable[Application]], + app: Union[_SafeApplication, Awaitable[_SafeApplication]], *, host: Optional[Union[str, HostSequence]] = None, port: Optional[int] = None, @@ -306,7 +307,7 @@ async def _run_app( if asyncio.iscoroutine(app): app = await app # type: ignore[misc] - app = cast(Application, app) + app = cast(_SafeApplication, app) runner = AppRunner( app, @@ -459,7 +460,7 @@ def _cancel_tasks( def run_app( - app: Union[Application, Awaitable[Application]], + app: Union[Application[Any], Awaitable[Application[Any]]], *, debug: bool = False, host: Optional[Union[str, HostSequence]] = None, diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index d78d603d38b..f68f3f9a722 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -2,6 +2,7 @@ import logging import warnings from functools import partial, update_wrapper +from types import MappingProxyType from typing import ( # noqa TYPE_CHECKING, Any, @@ -9,15 +10,16 @@ Awaitable, Callable, Dict, + Generic, Iterable, Iterator, List, Mapping, - MutableMapping, Optional, Sequence, Tuple, Type, + TypeVar, Union, cast, ) @@ -28,6 +30,7 @@ from . import hdrs from .log import web_logger +from .typedefs import _SafeApplication, _SafeRequest from .web_middlewares import _fix_request_current_app from .web_request import Request from .web_response import StreamResponse @@ -44,16 +47,20 @@ __all__ = ("Application", "CleanupError") +_T = TypeVar("_T", covariant=True) + if TYPE_CHECKING: # pragma: no cover from .typedefs import Handler - _AppSignal = Signal[Callable[["Application"], Awaitable[None]]] - _RespPrepareSignal = Signal[Callable[[Request, StreamResponse], Awaitable[None]]] + _AppSignal = Signal[Callable[["Application[Any]"], Awaitable[None]]] + _RespPrepareSignal = Signal[ + Callable[[Request[Any], StreamResponse], Awaitable[None]] + ] _Middleware = Callable[[Request, Handler], Awaitable[StreamResponse]] _Middlewares = FrozenList[_Middleware] _MiddlewaresHandlers = Sequence[_Middleware] - _Subapps = List["Application"] + _Subapps = List[_SafeApplication] else: # No type checker mode, skip types _AppSignal = Signal @@ -66,7 +73,7 @@ @final -class Application(MutableMapping[str, Any]): +class Application(Generic[_T]): __slots__ = ( "logger", "_debug", @@ -115,7 +122,8 @@ def __init__( # initialized on freezing self._run_middlewares = None # type: Optional[bool] - self._state = {} # type: Dict[str, Any] + # We cheat here slightly, to allow users to type hint their apps more easily. + self._state: _T = {} # type: ignore[assignment] self._frozen = False self._pre_frozen = False self._subapps = [] # type: _Subapps @@ -129,40 +137,15 @@ def __init__( self._on_cleanup.append(self._cleanup_ctx._on_cleanup) self._client_max_size = client_max_size - def __init_subclass__(cls: Type["Application"]) -> None: + def __init_subclass__(cls: Type["Application[Any]"]) -> None: raise TypeError( "Inheritance class {} from web.Application " "is forbidden".format(cls.__name__) ) - # MutableMapping API - def __eq__(self, other: object) -> bool: return self is other - def __getitem__(self, key: str) -> Any: - return self._state[key] - - def _check_frozen(self) -> None: - if self._frozen: - raise RuntimeError( - "Changing state of started or joined " "application is forbidden" - ) - - def __setitem__(self, key: str, value: Any) -> None: - self._check_frozen() - self._state[key] = value - - def __delitem__(self, key: str) -> None: - self._check_frozen() - del self._state[key] - - def __len__(self) -> int: - return len(self._state) - - def __iter__(self) -> Iterator[str]: - return iter(self._state) - ######## def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None: warnings.warn( @@ -171,6 +154,10 @@ def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None: stacklevel=2, ) + @property + def state(self) -> _T: + return self._state + @property def pre_frozen(self) -> bool: return self._pre_frozen @@ -209,6 +196,7 @@ def freeze(self) -> None: return self.pre_freeze() + self._state = MappingProxyType(self._state) # type: ignore[arg-type,assignment] self._frozen = True for subapp in self._subapps: subapp.freeze() @@ -222,11 +210,11 @@ def debug(self) -> bool: ) return asyncio.get_event_loop().get_debug() - def _reg_subapp_signals(self, subapp: "Application") -> None: + def _reg_subapp_signals(self, subapp: _SafeApplication) -> None: def reg_handler(signame: str) -> None: subsig = getattr(subapp, signame) - async def handler(app: "Application") -> None: + async def handler(app: _SafeApplication) -> None: await subsig.send(subapp) appsig = getattr(self, signame) @@ -236,7 +224,7 @@ async def handler(app: "Application") -> None: reg_handler("on_shutdown") reg_handler("on_cleanup") - def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource: + def add_subapp(self, prefix: str, subapp: "Application[Any]") -> AbstractResource: if not isinstance(prefix, str): raise TypeError("Prefix must be str") prefix = prefix.rstrip("/") @@ -246,7 +234,7 @@ def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource: return self._add_subapp(factory, subapp) def _add_subapp( - self, resource_factory: Callable[[], AbstractResource], subapp: "Application" + self, resource_factory: Callable[[], AbstractResource], subapp: _SafeApplication ) -> AbstractResource: if self.frozen: raise RuntimeError("Cannot add sub application to frozen application") @@ -259,7 +247,7 @@ def _add_subapp( subapp.pre_freeze() return resource - def add_domain(self, domain: str, subapp: "Application") -> AbstractResource: + def add_domain(self, domain: str, subapp: _SafeApplication) -> AbstractResource: if not isinstance(domain, str): raise TypeError("Domain must be str") elif "*" in domain: @@ -323,15 +311,15 @@ async def cleanup(self) -> None: await self.on_cleanup.send(self) else: # If an exception occurs in startup, ensure cleanup contexts are completed. - await self._cleanup_ctx._on_cleanup(self) + await self._cleanup_ctx._on_cleanup(self) # type: ignore[arg-type] def _prepare_middleware(self) -> Iterator[_Middleware]: yield from reversed(self._middlewares) - yield _fix_request_current_app(self) + yield _fix_request_current_app(self) # type: ignore[arg-type] - async def _handle(self, request: Request) -> StreamResponse: + async def _handle(self, request: _SafeRequest) -> StreamResponse: match_info = await self._router.resolve(request) - match_info.add_app(self) + match_info.add_app(self) # type: ignore[arg-type] match_info.freeze() resp = None @@ -354,7 +342,7 @@ async def _handle(self, request: Request) -> StreamResponse: return resp - def __call__(self) -> "Application": + def __call__(self) -> "Application[_T]": """gunicorn compatibility""" return self @@ -372,7 +360,7 @@ def exceptions(self) -> List[BaseException]: if TYPE_CHECKING: # pragma: no cover - _CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]] + _CleanupContextBase = FrozenList[Callable[[Application[Any]], AsyncIterator[None]]] else: _CleanupContextBase = FrozenList @@ -382,13 +370,13 @@ def __init__(self) -> None: super().__init__() self._exits = [] # type: List[AsyncIterator[None]] - async def _on_startup(self, app: Application) -> None: + async def _on_startup(self, app: _SafeApplication) -> None: for cb in self: it = cb(app).__aiter__() await it.__anext__() self._exits.append(it) - async def _on_cleanup(self, app: Application) -> None: + async def _on_cleanup(self, app: _SafeApplication) -> None: errors = [] for it in reversed(self._exits): try: diff --git a/aiohttp/web_middlewares.py b/aiohttp/web_middlewares.py index 4d28ff76307..c858bf863c9 100644 --- a/aiohttp/web_middlewares.py +++ b/aiohttp/web_middlewares.py @@ -1,8 +1,8 @@ import re import warnings -from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, Type, TypeVar +from typing import Awaitable, Callable, Tuple, Type, TypeVar -from .typedefs import Handler +from .typedefs import Handler, _SafeApplication, _SafeRequest from .web_exceptions import HTTPMove, HTTPPermanentRedirect from .web_request import Request from .web_response import StreamResponse @@ -13,13 +13,12 @@ "normalize_path_middleware", ) -if TYPE_CHECKING: # pragma: no cover - from .web_app import Application - _Func = TypeVar("_Func") -async def _check_request_resolves(request: Request, path: str) -> Tuple[bool, Request]: +async def _check_request_resolves( + request: _SafeRequest, path: str +) -> Tuple[bool, _SafeRequest]: alt_request = request.clone(rel_url=path) match_info = await request.app.router.resolve(alt_request) @@ -85,7 +84,7 @@ def normalize_path_middleware( correct_configuration = not (append_slash and remove_slash) assert correct_configuration, "Cannot both remove and append slash" - async def impl(request: Request, handler: Handler) -> StreamResponse: + async def impl(request: _SafeRequest, handler: Handler) -> StreamResponse: if isinstance(request.match_info.route, SystemRoute): paths_to_check = [] if "?" in request.raw_path: @@ -118,8 +117,8 @@ async def impl(request: Request, handler: Handler) -> StreamResponse: return impl -def _fix_request_current_app(app: "Application") -> _Middleware: - async def impl(request: Request, handler: Handler) -> StreamResponse: +def _fix_request_current_app(app: "_SafeApplication") -> _Middleware: + async def impl(request: _SafeRequest, handler: Handler) -> StreamResponse: with request.match_info.set_current_app(app): return await handler(request) diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index d4d941d26e7..f8a396d908c 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -14,6 +14,7 @@ TYPE_CHECKING, Any, Dict, + Generic, Iterator, Mapping, MutableMapping, @@ -21,6 +22,7 @@ Pattern, Set, Tuple, + TypeVar, Union, cast, ) @@ -70,6 +72,8 @@ from .web_protocol import RequestHandler from .web_urldispatcher import UrlMappingMatchInfo +_T = TypeVar("_T", covariant=True) + @dataclasses.dataclass(frozen=True) class FileField: @@ -842,7 +846,7 @@ async def wait_for_disconnection(self) -> None: self._disconnection_waiters.remove(fut) -class Request(BaseRequest): +class Request(BaseRequest, Generic[_T]): __slots__ = ("_match_info",) @@ -864,7 +868,7 @@ def clone( scheme: Union[str, _SENTINEL] = sentinel, host: Union[str, _SENTINEL] = sentinel, remote: Union[str, _SENTINEL] = sentinel, - ) -> "Request": + ) -> "Request[_T]": ret = super().clone( method=method, rel_url=rel_url, @@ -873,7 +877,7 @@ def clone( host=host, remote=remote, ) - new_ret = cast(Request, ret) + new_ret = cast(Request[_T], ret) new_ret._match_info = self._match_info return new_ret @@ -885,19 +889,18 @@ def match_info(self) -> "UrlMappingMatchInfo": return match_info @property - def app(self) -> "Application": + def app(self) -> "Application[_T]": """Application instance.""" match_info = self._match_info assert match_info is not None - return match_info.current_app + return match_info.current_app # type: ignore[return-value] @property def config_dict(self) -> ChainMapProxy: match_info = self._match_info assert match_info is not None - lst = match_info.apps - app = self.app - idx = lst.index(app) + lst = [app.state for app in match_info.apps] + idx = lst.index(self.app.state) # type: ignore[arg-type] sublist = list(reversed(lst[: idx + 1])) return ChainMapProxy(sublist) diff --git a/aiohttp/web_routedef.py b/aiohttp/web_routedef.py index 787d9cbdeca..2d20b8a32d6 100644 --- a/aiohttp/web_routedef.py +++ b/aiohttp/web_routedef.py @@ -140,7 +140,7 @@ def delete(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return route(hdrs.METH_DELETE, path, handler, **kwargs) -def view(path: str, handler: Type[AbstractView], **kwargs: Any) -> RouteDef: +def view(path: str, handler: Type[AbstractView[Any]], **kwargs: Any) -> RouteDef: return route(hdrs.METH_ANY, path, handler, **kwargs) diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index c5294ffe295..0c31fa05aa1 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -2,13 +2,14 @@ import signal import socket from abc import ABC, abstractmethod -from typing import Any, List, Optional, Set, Type +from typing import Any, List, Optional, Set, Type, TypeVar from yarl import URL from .abc import AbstractAccessLogger, AbstractStreamWriter from .http_parser import RawRequestMessage from .streams import StreamReader +from .typedefs import _SafeRequest from .web_app import Application from .web_log import AccessLogger from .web_protocol import RequestHandler @@ -33,6 +34,8 @@ "GracefulExit", ) +_T = TypeVar("_T") + class GracefulExit(SystemExit): code = 1 @@ -358,7 +361,7 @@ class AppRunner(BaseRunner): def __init__( self, - app: Application, + app: Application[_T], *, handle_signals: bool = False, access_log_class: Type[AbstractAccessLogger] = AccessLogger, @@ -388,7 +391,7 @@ def __init__( self._app = app @property - def app(self) -> Application: + def app(self) -> Application[_T]: return self._app async def shutdown(self) -> None: @@ -412,8 +415,8 @@ def _make_request( protocol: RequestHandler, writer: AbstractStreamWriter, task: "asyncio.Task[None]", - _cls: Type[Request] = Request, - ) -> Request: + _cls: Type[_SafeRequest] = Request, + ) -> _SafeRequest: loop = asyncio.get_running_loop() return _cls( message, diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 3e9a2c22392..af35c5aaed1 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -25,6 +25,7 @@ Sized, Tuple, Type, + TypeVar, Union, cast, ) @@ -36,7 +37,7 @@ from .abc import AbstractMatchInfo, AbstractRouter, AbstractView from .helpers import DEBUG, iscoroutinefunction from .http import HttpVersion11 -from .typedefs import Handler, PathLike +from .typedefs import Handler, PathLike, _SafeApplication, _SafeRequest from .web_exceptions import ( HTTPException, HTTPExpectationFailed, @@ -81,6 +82,7 @@ PATH_SEP: Final[str] = re.escape("/") +_T = TypeVar("_T") _ExpectHandler = Callable[[Request], Awaitable[None]] _Resolve = Tuple[Optional[AbstractMatchInfo], Set[str]] @@ -95,7 +97,7 @@ class _InfoDict(TypedDict, total=False): prefix: str routes: Mapping[str, "AbstractRoute"] - app: "Application" + app: _SafeApplication domain: str @@ -126,7 +128,7 @@ def url_for(self, **kwargs: str) -> URL: """Construct url for resource with additional params.""" @abc.abstractmethod # pragma: no branch - async def resolve(self, request: Request) -> _Resolve: + async def resolve(self, request: _SafeRequest) -> _Resolve: """Resolve resource Return (UrlMappingMatchInfo, allowed_methods) pair.""" @@ -155,7 +157,7 @@ class AbstractRoute(abc.ABC): def __init__( self, method: str, - handler: Union[Handler, Type[AbstractView]], + handler: Union[Handler, Type[AbstractView[Any]]], *, expect_handler: Optional[_ExpectHandler] = None, resource: Optional[AbstractResource] = None, @@ -212,7 +214,7 @@ def get_info(self) -> _InfoDict: def url_for(self, *args: str, **kwargs: str) -> URL: """Construct url for route with additional params.""" - async def handle_expect_header(self, request: Request) -> None: + async def handle_expect_header(self, request: _SafeRequest) -> None: await self._expect_handler(request) @@ -220,8 +222,8 @@ class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo): def __init__(self, match_dict: Dict[str, str], route: AbstractRoute): super().__init__(match_dict) self._route = route - self._apps = [] # type: List[Application] - self._current_app = None # type: Optional[Application] + self._apps: List[_SafeApplication] = [] + self._current_app: Optional[_SafeApplication] = None self._frozen = False @property @@ -244,10 +246,10 @@ def get_info(self) -> _InfoDict: # type: ignore[override] return self._route.get_info() @property - def apps(self) -> Tuple["Application", ...]: + def apps(self) -> Tuple[_SafeApplication, ...]: return tuple(self._apps) - def add_app(self, app: "Application") -> None: + def add_app(self, app: "Application[Any]") -> None: if self._frozen: raise RuntimeError("Cannot change apps stack after .freeze() call") if self._current_app is None: @@ -255,13 +257,13 @@ def add_app(self, app: "Application") -> None: self._apps.insert(0, app) @property - def current_app(self) -> "Application": + def current_app(self) -> _SafeApplication: app = self._current_app assert app is not None return app @contextmanager - def set_current_app(self, app: "Application") -> Generator[None, None, None]: + def set_current_app(self, app: _SafeApplication) -> Generator[None, None, None]: if DEBUG: # pragma: no cover if app not in self._apps: raise RuntimeError( @@ -298,7 +300,7 @@ def __repr__(self) -> str: ) -async def _default_expect_handler(request: Request) -> None: +async def _default_expect_handler(request: _SafeRequest) -> None: """Default handler for Expect header. Just send "100 Continue" to client. @@ -320,7 +322,7 @@ def __init__(self, *, name: Optional[str] = None) -> None: def add_route( self, method: str, - handler: Union[Type[AbstractView], Handler], + handler: Union[Type[AbstractView[Any]], Handler], *, expect_handler: Optional[_ExpectHandler] = None, ) -> "ResourceRoute": @@ -343,7 +345,7 @@ def register_route(self, route: "ResourceRoute") -> None: ), f"Instance of Route class is required, got {route!r}" self._routes.append(route) - async def resolve(self, request: Request) -> _Resolve: + async def resolve(self, request: _SafeRequest) -> _Resolve: allowed_methods = set() # type: Set[str] match_dict = self._match(request.rel_url.raw_path) @@ -612,7 +614,7 @@ def set_options_route(self, handler: Handler) -> None: "OPTIONS", handler, self, expect_handler=self._expect_handler ) - async def resolve(self, request: Request) -> _Resolve: + async def resolve(self, request: _SafeRequest) -> _Resolve: path = request.rel_url.raw_path method = request.method allowed_methods = set(self._routes) @@ -631,7 +633,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[AbstractRoute]: return iter(self._routes.values()) - async def _handle(self, request: Request) -> StreamResponse: + async def _handle(self, request: _SafeRequest) -> StreamResponse: rel_url = request.match_info["filename"] try: filename = Path(rel_url) @@ -713,7 +715,7 @@ def __repr__(self) -> str: class PrefixedSubAppResource(PrefixResource): - def __init__(self, prefix: str, app: "Application") -> None: + def __init__(self, prefix: str, app: _SafeApplication) -> None: super().__init__(prefix) self._app = app for resource in app.router.resources(): @@ -730,7 +732,7 @@ def url_for(self, *args: str, **kwargs: str) -> URL: def get_info(self) -> _InfoDict: return {"app": self._app, "prefix": self._prefix} - async def resolve(self, request: Request) -> _Resolve: + async def resolve(self, request: _SafeRequest) -> _Resolve: if ( not request.url.raw_path.startswith(self._prefix + "/") and request.url.raw_path != self._prefix @@ -758,7 +760,7 @@ def __repr__(self) -> str: class AbstractRuleMatching(abc.ABC): @abc.abstractmethod # pragma: no branch - async def match(self, request: Request) -> bool: + async def match(self, request: _SafeRequest) -> bool: """Return bool if the request satisfies the criteria""" @abc.abstractmethod # pragma: no branch @@ -798,7 +800,7 @@ def validation(self, domain: str) -> str: return url.raw_host return f"{url.raw_host}:{url.port}" - async def match(self, request: Request) -> bool: + async def match(self, request: _SafeRequest) -> bool: host = request.headers.get(hdrs.HOST) if not host: return False @@ -828,7 +830,7 @@ def match_domain(self, host: str) -> bool: class MatchedSubAppResource(PrefixedSubAppResource): - def __init__(self, rule: AbstractRuleMatching, app: "Application") -> None: + def __init__(self, rule: AbstractRuleMatching, app: _SafeApplication) -> None: AbstractResource.__init__(self) self._prefix = "" self._app = app @@ -841,7 +843,7 @@ def canonical(self) -> str: def get_info(self) -> _InfoDict: return {"app": self._app, "rule": self._rule} - async def resolve(self, request: Request) -> _Resolve: + async def resolve(self, request: _SafeRequest) -> _Resolve: if not await self._rule.match(request): return None, set() match_info = await self._app.router.resolve(request) @@ -862,7 +864,7 @@ class ResourceRoute(AbstractRoute): def __init__( self, method: str, - handler: Union[Handler, Type[AbstractView]], + handler: Union[Handler, Type[AbstractView[Any]]], resource: AbstractResource, *, expect_handler: Optional[_ExpectHandler] = None, @@ -907,7 +909,7 @@ def name(self) -> Optional[str]: def get_info(self) -> _InfoDict: return {"http_exception": self._http_exception} - async def _handle(self, request: Request) -> StreamResponse: + async def _handle(self, request: _SafeRequest) -> StreamResponse: raise self._http_exception @property @@ -922,7 +924,7 @@ def __repr__(self) -> str: return "".format(self=self) -class View(AbstractView): +class View(AbstractView[_T]): async def _iter(self) -> StreamResponse: if self.request.method not in hdrs.METH_ALL: self._raise_allowed_methods() @@ -982,7 +984,7 @@ def __init__(self) -> None: self._resources = [] # type: List[AbstractResource] self._named_resources = {} # type: Dict[str, AbstractResource] - async def resolve(self, request: Request) -> AbstractMatchInfo: + async def resolve(self, request: _SafeRequest) -> AbstractMatchInfo: method = request.method allowed_methods = set() # type: Set[str] @@ -1072,7 +1074,7 @@ def add_route( self, method: str, path: str, - handler: Union[Handler, Type[AbstractView]], + handler: Union[Handler, Type[AbstractView[Any]]], *, name: Optional[str] = None, expect_handler: Optional[_ExpectHandler] = None, @@ -1169,7 +1171,7 @@ def add_delete(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRout return self.add_route(hdrs.METH_DELETE, path, handler, **kwargs) def add_view( - self, path: str, handler: Type[AbstractView], **kwargs: Any + self, path: str, handler: Type[AbstractView[Any]], **kwargs: Any ) -> AbstractRoute: """ Shortcut for add_route with ANY methods for a class-based view diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 40fd2fca728..92eba3caa30 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -575,12 +575,12 @@ as it results in more compact code:: app.cleanup_ctx.append(persistent_session) async def persistent_session(app): - app['PERSISTENT_SESSION'] = session = aiohttp.ClientSession() + app.state['PERSISTENT_SESSION'] = session = aiohttp.ClientSession() yield await session.close() async def my_request_handler(request): - session = request.app['PERSISTENT_SESSION'] + session = request.app.state['PERSISTENT_SESSION'] async with session.get("http://python.org") as resp: print(resp.status) @@ -593,9 +593,9 @@ can be safely shared between sessions if needed. In the end all you have to do is to close all sessions after `yield` statement:: async def multiple_sessions(app): - app['PERSISTENT_SESSION_1'] = session_1 = aiohttp.ClientSession() - app['PERSISTENT_SESSION_2'] = session_2 = aiohttp.ClientSession() - app['PERSISTENT_SESSION_3'] = session_3 = aiohttp.ClientSession() + app.state['PERSISTENT_SESSION_1'] = session_1 = aiohttp.ClientSession() + app.state['PERSISTENT_SESSION_2'] = session_2 = aiohttp.ClientSession() + app.state['PERSISTENT_SESSION_3'] = session_3 = aiohttp.ClientSession() yield diff --git a/docs/faq.rst b/docs/faq.rst index fbe7354cf49..98465033397 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -62,7 +62,7 @@ other resource you want to share between handlers. :: async def go(request): - db = request.app['db'] + db = request.app.state['db'] cursor = await db.cursor() await cursor.execute('SELECT 42') # ... @@ -72,7 +72,7 @@ other resource you want to share between handlers. async def init_app(): app = Application() db = await create_connection(user='user', password='123') - app['db'] = db + app.state['db'] = db app.router.add_get('/', go) return app @@ -123,7 +123,7 @@ peers. :: await ws.prepare(request) task = asyncio.create_task( read_subscription(ws, - request.app['redis'])) + request.app.state['redis'])) try: async for msg in ws: # handle incoming messages @@ -166,12 +166,12 @@ and call :meth:`aiohttp.web.WebSocketResponse.close` on all of them in ws = web.WebSocketResponse() user_id = authenticate_user(request) await ws.prepare(request) - request.app['websockets'][user_id].add(ws) + request.app.state['websockets'][user_id].add(ws) try: async for msg in ws: ws.send_str(msg.data) finally: - request.app['websockets'][user_id].remove(ws) + request.app.state['websockets'][user_id].remove(ws) return ws @@ -181,7 +181,7 @@ and call :meth:`aiohttp.web.WebSocketResponse.close` on all of them in user_id = authenticate_user(request) ws_closers = [ws.close() - for ws in request.app['websockets'][user_id] + for ws in request.app.state['websockets'][user_id] if not ws.closed] # Watch out, this will keep us from returing the response @@ -196,7 +196,7 @@ and call :meth:`aiohttp.web.WebSocketResponse.close` on all of them in app = web.Application() app.router.add_route('GET', '/echo', echo_handler) app.router.add_route('POST', '/logout', logout_handler) - app['websockets'] = defaultdict(set) + app.state['websockets'] = defaultdict(set) web.run_app(app, host='localhost', port=8080) @@ -278,7 +278,7 @@ deliberate choice. A subapplication is an isolated unit by design. If you need to share a database object, do it explicitly:: - subapp['db'] = mainapp['db'] + subapp.state['db'] = mainapp.state['db'] mainapp.add_subapp('/prefix', subapp) @@ -300,7 +300,7 @@ operations:: await resp.write_eof() # increase the pong count - APP['db'].inc_pong() + APP.state['db'].inc_pong() return resp diff --git a/docs/logging.rst b/docs/logging.rst index f27a77baa8a..82a2f5e2e9d 100644 --- a/docs/logging.rst +++ b/docs/logging.rst @@ -134,7 +134,7 @@ If your logging needs to perform IO you can instead inherit from class AccessLogger(AbstractAsyncAccessLogger): async def log(self, request, response, time): - logging_service = request.app['logging_service'] + logging_service = request.app.state['logging_service'] await logging_service.log(f'{request.remote} ' f'"{request.method} {request.path} ' f'done in {time}s: {response.status}') diff --git a/docs/testing.rst b/docs/testing.rst index 59c1cbe439d..49e928644ac 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -103,10 +103,10 @@ app test client:: async def previous(request): if request.method == 'POST': - request.app['value'] = (await request.post())['value'] + request.app.state['value'] = (await request.post())['value'] return web.Response(body=b'thanks for the data') return web.Response( - body='value: {}'.format(request.app['value']).encode('utf-8')) + body='value: {}'.format(request.app.state['value']).encode('utf-8')) @pytest.fixture def cli(loop, aiohttp_client): @@ -119,10 +119,10 @@ app test client:: resp = await cli.post('/', data={'value': 'foo'}) assert resp.status == 200 assert await resp.text() == 'thanks for the data' - assert cli.server.app['value'] == 'foo' + assert cli.server.app.state['value'] == 'foo' async def test_get_value(cli): - cli.server.app['value'] = 'bar' + cli.server.app.state['value'] = 'bar' resp = await cli.get('/') assert resp.status == 200 assert await resp.text() == 'value: bar' diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index 0476b5f1faf..3367c2dd9a1 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -250,14 +250,14 @@ Application's config ^^^^^^^^^^^^^^^^^^^^ For storing *global-like* variables, feel free to save them in an -:class:`Application` instance:: +:attr:`Application.state` instance:: - app['my_private_key'] = data + app.state["my_private_key"] = data and get it back in the :term:`web-handler`:: async def handler(request): - data = request.app['my_private_key'] + data = request.app.state["my_private_key"] In case of :ref:`nested applications ` the desired lookup strategy could @@ -269,8 +269,28 @@ be the following: For this please use :attr:`Request.config_dict` read-only property:: async def handler(request): - data = request.config_dict['my_private_key'] + data = request.config_dict["my_private_key"] +Type Annotations +"""""""""""""""" + +To utilize type checking, you should create a :class:`TypedDict` to represent +your config and use this for the application and request generics:: + + class MyState(TypedDict): + my_private_key: str + + Request = web.Request[MyState] + + app: web.Application[MyState] = web.Application() + app.state["my_private_key"] = data + + def handler(request: Request): + data = request.app.state["my_private_key"] + +.. note:: + + :attr:`Request.config_dict` does not support static typing. Request's storage ^^^^^^^^^^^^^^^^^ @@ -556,7 +576,7 @@ engine:: from aiopg.sa import create_engine async def create_aiopg(app): - app['pg_engine'] = await create_engine( + app.state['pg_engine'] = await create_engine( user='postgre', database='postgre', host='localhost', @@ -565,8 +585,8 @@ engine:: ) async def dispose_aiopg(app): - app['pg_engine'].close() - await app['pg_engine'].wait_closed() + app.state['pg_engine'].close() + await app.state['pg_engine'].wait_closed() app.on_startup.append(create_aiopg) app.on_cleanup.append(dispose_aiopg) @@ -598,7 +618,7 @@ knowledge about startup/cleanup pairs and their execution state. The solution is :attr:`Application.cleanup_ctx` usage:: async def pg_engine(app): - app['pg_engine'] = await create_engine( + app.state['pg_engine'] = await create_engine( user='postgre', database='postgre', host='localhost', @@ -606,8 +626,8 @@ The solution is :attr:`Application.cleanup_ctx` usage:: password='' ) yield - app['pg_engine'].close() - await app['pg_engine'].wait_closed() + app.state['pg_engine'].close() + await app.state['pg_engine'].wait_closed() app.cleanup_ctx.append(pg_engine) @@ -685,10 +705,10 @@ use the following explicit technique:: admin.add_routes([web.get('/resource', handler, name='name')]) app.add_subapp('/admin/', admin) - app['admin'] = admin + app.state['admin'] = admin async def handler(request): # main application's handler - admin = request.app['admin'] + admin = request.app.state['admin'] url = admin.router['name'].url_for() .. _aiohttp-web-expect-header: @@ -805,18 +825,18 @@ handler:: import weakref app = web.Application() - app['websockets'] = weakref.WeakSet() + app.state['websockets'] = weakref.WeakSet() async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) - request.app['websockets'].add(ws) + request.app.state['websockets'].add(ws) try: async for msg in ws: ... finally: - request.app['websockets'].discard(ws) + request.app.state['websockets'].discard(ws) return ws @@ -825,7 +845,7 @@ Signal handler may look like:: from aiohttp import WSCloseCode async def on_shutdown(app): - for ws in set(app['websockets']): + for ws in set(app.state['websockets']): await ws.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown') @@ -869,7 +889,7 @@ signal handlers as shown in the example below:: ch, *_ = await sub.subscribe('news') async for msg in ch.iter(encoding='utf-8'): # Forward message to all connected websockets: - for ws in app['websockets']: + for ws in app.state['websockets']: ws.send_str('{}: {}'.format(ch.name, msg)) except asyncio.CancelledError: pass @@ -879,12 +899,12 @@ signal handlers as shown in the example below:: async def start_background_tasks(app): - app['redis_listener'] = asyncio.create_task(listen_to_redis(app)) + app.state['redis_listener'] = asyncio.create_task(listen_to_redis(app)) async def cleanup_background_tasks(app): - app['redis_listener'].cancel() - await app['redis_listener'] + app.state['redis_listener'].cancel() + await app.state['redis_listener'] app = web.Application() diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 549e199507c..83e93438325 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -1310,20 +1310,29 @@ or :class:`aiohttp.web.AppRunner`. *Application* contains a *router* instance and a list of callbacks that will be called during application finishing. -:class:`Application` is a :obj:`dict`-like object, so you can use it for +:attr:`Application.state` is a :obj:`dict`-like object, so you can use it for :ref:`sharing data` globally by storing arbitrary properties for later access from a :ref:`handler` via the :attr:`Request.app` property:: app = Application() - app['database'] = await aiopg.create_engine(**db_config) + app.state['database'] = await aiopg.create_engine(**db_config) async def handler(request): - with (await request.app['database']) as conn: + with (await request.app.state['database']) as conn: conn.execute("DELETE * FROM table") -Although :class:`Application` is a :obj:`dict`-like object, it can't be -duplicated like one using :meth:`~aiohttp.web.Application.copy`. +Both :class:`Application` and :class:`Request` are generics and can be typed +with information about the :attr:`Application.state` dict:: + + class MyState(TypedDict): + database: DBType + + app: Application[MyState] = Application() + app.state['database'] = ... + + async def handler(request: Request[MyState]): + request.app.state['database'] .. class:: Application(*, logger=, middlewares=(), \ handler_args=None, client_max_size=1024**2, \ @@ -1354,6 +1363,12 @@ duplicated like one using :meth:`~aiohttp.web.Application.copy`. use asyncio :ref:`asyncio-debug-mode` instead. + .. attribute:: state + + A dict that can be used to store *global-like* variables. + The class is also generic over this variable, so you can define the + type with `Application[MyState]`, where `MyState` is a :class:`TypedDict`. + .. attribute:: router Read-only property that returns *router instance*. diff --git a/examples/background_tasks.py b/examples/background_tasks.py index dab7756ab86..787758d1381 100755 --- a/examples/background_tasks.py +++ b/examples/background_tasks.py @@ -1,38 +1,44 @@ #!/usr/bin/env python3 """Example of aiohttp.web.Application.on_startup signal handler""" import asyncio +from typing import List, TypedDict import aioredis # type: ignore from aiohttp import web -async def websocket_handler(request: web.Request) -> web.StreamResponse: +class StateDict(TypedDict): + redis_listener: asyncio.Task[None] + websockets: List[web.WebSocketResponse] + + +async def websocket_handler(request: web.Request[StateDict]) -> web.StreamResponse: ws = web.WebSocketResponse() await ws.prepare(request) - request.app["websockets"].append(ws) + request.app.state["websockets"].append(ws) try: async for msg in ws: print(msg) await asyncio.sleep(1) finally: - request.app["websockets"].remove(ws) + request.app.state["websockets"].remove(ws) return ws -async def on_shutdown(app: web.Application) -> None: - for ws in app["websockets"]: - await ws.close(code=999, message="Server shutdown") +async def on_shutdown(app: web.Application[StateDict]) -> None: + for ws in app.state["websockets"]: + await ws.close(code=999, message=b"Server shutdown") -async def listen_to_redis(app: web.Application) -> None: +async def listen_to_redis(app: web.Application[StateDict]) -> None: try: loop = asyncio.get_event_loop() sub = await aioredis.create_redis(("localhost", 6379), loop=loop) ch, *_ = await sub.subscribe("news") async for msg in ch.iter(encoding="utf-8"): # Forward message to all connected websockets: - for ws in app["websockets"]: + for ws in app.state["websockets"]: await ws.send_str(f"{ch.name}: {msg}") print(f"message in {ch.name}: {msg}") except asyncio.CancelledError: @@ -44,19 +50,19 @@ async def listen_to_redis(app: web.Application) -> None: print("Redis connection closed.") -async def start_background_tasks(app: web.Application) -> None: - app["redis_listener"] = asyncio.create_task(listen_to_redis(app)) +async def start_background_tasks(app: web.Application[StateDict]) -> None: + app.state["redis_listener"] = asyncio.create_task(listen_to_redis(app)) -async def cleanup_background_tasks(app: web.Application) -> None: +async def cleanup_background_tasks(app: web.Application[StateDict]) -> None: print("cleanup background tasks...") - app["redis_listener"].cancel() - await app["redis_listener"] + app.state["redis_listener"].cancel() + await app.state["redis_listener"] -def init() -> web.Application: - app = web.Application() - app["websockets"] = [] +def init() -> web.Application[StateDict]: + app: web.Application[StateDict] = web.Application() + app.state["websockets"] = [] app.router.add_get("/news", websocket_handler) app.on_startup.append(start_background_tasks) app.on_cleanup.append(cleanup_background_tasks) diff --git a/examples/cli_app.py b/examples/cli_app.py index 5357a0233f4..b5df6f3f3d3 100755 --- a/examples/cli_app.py +++ b/examples/cli_app.py @@ -13,19 +13,23 @@ arguments to the `cli_app:init` function for processing. """ -from argparse import ArgumentParser -from typing import Optional, Sequence +from argparse import ArgumentParser, Namespace +from typing import Optional, Sequence, TypedDict from aiohttp import web -async def display_message(req: web.Request) -> web.StreamResponse: - args = req.app["args"] +class StateDict(TypedDict): + args: Namespace + + +async def display_message(req: web.Request[StateDict]) -> web.StreamResponse: + args = req.app.state["args"] text = "\n".join([args.message] * args.repeat) return web.Response(text=text) -def init(argv: Optional[Sequence[str]]) -> web.Application: +def init(argv: Optional[Sequence[str]]) -> web.Application[StateDict]: arg_parser = ArgumentParser( prog="aiohttp.web ...", description="Application CLI", add_help=False ) @@ -45,8 +49,8 @@ def init(argv: Optional[Sequence[str]]) -> web.Application: args = arg_parser.parse_args(argv) - app = web.Application() - app["args"] = args + app: web.Application[StateDict] = web.Application() + app.state["args"] = args app.router.add_get("/", display_message) return app diff --git a/examples/fake_server.py b/examples/fake_server.py index 065d2d779eb..fb0ab10e402 100755 --- a/examples/fake_server.py +++ b/examples/fake_server.py @@ -3,12 +3,16 @@ import pathlib import socket import ssl -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, TypedDict, Union from aiohttp import ClientSession, TCPConnector, resolver, test_utils, web from aiohttp.abc import AbstractResolver +class EmptyDict(TypedDict): + pass + + class FakeResolver(AbstractResolver): _LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1", socket.AF_INET6: "::1"} @@ -44,7 +48,7 @@ async def close(self) -> None: class FakeFacebook: def __init__(self) -> None: - self.app = web.Application() + self.app: web.Application[EmptyDict] = web.Application() self.app.router.add_routes( [ web.get("/v2.7/me", self.on_me), @@ -68,10 +72,12 @@ async def start(self) -> Dict[str, int]: async def stop(self) -> None: await self.runner.cleanup() - async def on_me(self, request: web.Request) -> web.StreamResponse: + async def on_me(self, request: web.Request[EmptyDict]) -> web.StreamResponse: return web.json_response({"name": "John Doe", "id": "12345678901234567"}) - async def on_my_friends(self, request: web.Request) -> web.StreamResponse: + async def on_my_friends( + self, request: web.Request[EmptyDict] + ) -> web.StreamResponse: return web.json_response( { "data": [ diff --git a/examples/server_simple.py b/examples/server_simple.py index c68e141b33e..82fc1580590 100644 --- a/examples/server_simple.py +++ b/examples/server_simple.py @@ -1,14 +1,19 @@ -# server_simple.py +from typing import TypedDict + from aiohttp import web -async def handle(request: web.Request) -> web.StreamResponse: +class EmptyDict(TypedDict): + pass + + +async def handle(request: web.Request[EmptyDict]) -> web.StreamResponse: name = request.match_info.get("name", "Anonymous") text = "Hello, " + name return web.Response(text=text) -async def wshandle(request: web.Request) -> web.StreamResponse: +async def wshandle(request: web.Request[EmptyDict]) -> web.StreamResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -23,7 +28,7 @@ async def wshandle(request: web.Request) -> web.StreamResponse: return ws -app = web.Application() +app: web.Application[EmptyDict] = web.Application() app.add_routes( [web.get("/", handle), web.get("/echo", wshandle), web.get("/{name}", handle)] ) diff --git a/examples/static_files.py b/examples/static_files.py index 65f6bb9c764..347ded3aa5f 100755 --- a/examples/static_files.py +++ b/examples/static_files.py @@ -1,9 +1,15 @@ #!/usr/bin/env python3 import pathlib +from typing import TypedDict from aiohttp import web -app = web.Application() + +class EmptyDict(TypedDict): + pass + + +app: web.Application[EmptyDict] = web.Application() app.router.add_static("/", pathlib.Path(__file__).parent, show_index=True) web.run_app(app) diff --git a/examples/web_classview.py b/examples/web_classview.py index fc3fe67b851..9db5a035e83 100755 --- a/examples/web_classview.py +++ b/examples/web_classview.py @@ -4,10 +4,15 @@ import functools import json +from typing import TypedDict from aiohttp import web +class EmptyDict(TypedDict): + pass + + class MyView(web.View): async def get(self) -> web.StreamResponse: return web.json_response( @@ -31,7 +36,7 @@ async def post(self) -> web.StreamResponse: ) -async def index(request: web.Request) -> web.StreamResponse: +async def index(request: web.Request[EmptyDict]) -> web.StreamResponse: txt = """ @@ -50,8 +55,8 @@ async def index(request: web.Request) -> web.StreamResponse: return web.Response(text=txt, content_type="text/html") -def init() -> web.Application: - app = web.Application() +def init() -> web.Application[EmptyDict]: + app: web.Application[EmptyDict] = web.Application() app.router.add_get("/", index) app.router.add_get("/get", MyView) app.router.add_post("/post", MyView) diff --git a/examples/web_cookies.py b/examples/web_cookies.py index 6836569183f..4852e86d64d 100755 --- a/examples/web_cookies.py +++ b/examples/web_cookies.py @@ -3,10 +3,15 @@ """ from pprint import pformat -from typing import NoReturn +from typing import NoReturn, TypedDict from aiohttp import web + +class EmptyDict(TypedDict): + pass + + tmpl = """\ @@ -17,26 +22,26 @@ """ -async def root(request: web.Request) -> web.StreamResponse: +async def root(request: web.Request[EmptyDict]) -> web.StreamResponse: resp = web.Response(content_type="text/html") resp.text = tmpl.format(pformat(request.cookies)) return resp -async def login(request: web.Request) -> NoReturn: +async def login(request: web.Request[EmptyDict]) -> NoReturn: exc = web.HTTPFound(location="/") exc.set_cookie("AUTH", "secret") raise exc -async def logout(request: web.Request) -> NoReturn: +async def logout(request: web.Request[EmptyDict]) -> NoReturn: exc = web.HTTPFound(location="/") exc.del_cookie("AUTH") raise exc -def init() -> web.Application: - app = web.Application() +def init() -> web.Application[EmptyDict]: + app: web.Application[EmptyDict] = web.Application() app.router.add_get("/", root) app.router.add_get("/login", login) app.router.add_get("/logout", logout) diff --git a/examples/web_rewrite_headers_middleware.py b/examples/web_rewrite_headers_middleware.py index 149dc28285d..97936182676 100755 --- a/examples/web_rewrite_headers_middleware.py +++ b/examples/web_rewrite_headers_middleware.py @@ -2,15 +2,24 @@ """ Example for rewriting response headers by middleware. """ + +from typing import TypedDict + from aiohttp import web from aiohttp.typedefs import Handler -async def handler(request: web.Request) -> web.StreamResponse: +class EmptyDict(TypedDict): + pass + + +async def handler(request: web.Request[EmptyDict]) -> web.StreamResponse: return web.Response(text="Everything is fine") -async def middleware(request: web.Request, handler: Handler) -> web.StreamResponse: +async def middleware( + request: web.Request[EmptyDict], handler: Handler +) -> web.StreamResponse: try: response = await handler(request) except web.HTTPException as exc: @@ -20,8 +29,8 @@ async def middleware(request: web.Request, handler: Handler) -> web.StreamRespon return response -def init() -> web.Application: - app = web.Application(middlewares=[middleware]) +def init() -> web.Application[EmptyDict]: + app: web.Application[EmptyDict] = web.Application(middlewares=[middleware]) app.router.add_get("/", handler) return app diff --git a/examples/web_srv.py b/examples/web_srv.py index b87f6c43baf..7f40fc59070 100755 --- a/examples/web_srv.py +++ b/examples/web_srv.py @@ -3,11 +3,16 @@ """ import textwrap +from typing import TypedDict from aiohttp import web -async def intro(request: web.Request) -> web.StreamResponse: +class EmptyDict(TypedDict): + pass + + +async def intro(request: web.Request[EmptyDict]) -> web.StreamResponse: txt = textwrap.dedent( """\ Type {url}/hello/John {url}/simple or {url}/change_body @@ -23,18 +28,18 @@ async def intro(request: web.Request) -> web.StreamResponse: return resp -async def simple(request: web.Request) -> web.StreamResponse: +async def simple(request: web.Request[EmptyDict]) -> web.StreamResponse: return web.Response(text="Simple answer") -async def change_body(request: web.Request) -> web.StreamResponse: +async def change_body(request: web.Request[EmptyDict]) -> web.StreamResponse: resp = web.Response() resp.body = b"Body changed" resp.content_type = "text/plain" return resp -async def hello(request: web.Request) -> web.StreamResponse: +async def hello(request: web.Request[EmptyDict]) -> web.StreamResponse: resp = web.StreamResponse() name = request.match_info.get("name", "Anonymous") answer = ("Hello, " + name).encode("utf8") @@ -46,8 +51,8 @@ async def hello(request: web.Request) -> web.StreamResponse: return resp -def init() -> web.Application: - app = web.Application() +def init() -> web.Application[EmptyDict]: + app: web.Application[EmptyDict] = web.Application() app.router.add_get("/", intro) app.router.add_get("/simple", simple) app.router.add_get("/change_body", change_body) diff --git a/examples/web_srv_route_deco.py b/examples/web_srv_route_deco.py index 65a4f8618b2..29ebec1a2e2 100644 --- a/examples/web_srv_route_deco.py +++ b/examples/web_srv_route_deco.py @@ -4,14 +4,20 @@ """ import textwrap +from typing import TypedDict from aiohttp import web + +class EmptyDict(TypedDict): + pass + + routes = web.RouteTableDef() @routes.get("/") -async def intro(request: web.Request) -> web.StreamResponse: +async def intro(request: web.Request[EmptyDict]) -> web.StreamResponse: txt = textwrap.dedent( """\ Type {url}/hello/John {url}/simple or {url}/change_body @@ -28,12 +34,12 @@ async def intro(request: web.Request) -> web.StreamResponse: @routes.get("/simple") -async def simple(request: web.Request) -> web.StreamResponse: +async def simple(request: web.Request[EmptyDict]) -> web.StreamResponse: return web.Response(text="Simple answer") @routes.get("/change_body") -async def change_body(request: web.Request) -> web.StreamResponse: +async def change_body(request: web.Request[EmptyDict]) -> web.StreamResponse: resp = web.Response() resp.body = b"Body changed" resp.content_type = "text/plain" @@ -41,7 +47,7 @@ async def change_body(request: web.Request) -> web.StreamResponse: @routes.get("/hello") -async def hello(request: web.Request) -> web.StreamResponse: +async def hello(request: web.Request[EmptyDict]) -> web.StreamResponse: resp = web.StreamResponse() name = request.match_info.get("name", "Anonymous") answer = ("Hello, " + name).encode("utf8") @@ -53,8 +59,8 @@ async def hello(request: web.Request) -> web.StreamResponse: return resp -def init() -> web.Application: - app = web.Application() +def init() -> web.Application[EmptyDict]: + app: web.Application[EmptyDict] = web.Application() app.router.add_routes(routes) return app diff --git a/examples/web_srv_route_table.py b/examples/web_srv_route_table.py index 4d1acc43c57..755a2a7ddd1 100644 --- a/examples/web_srv_route_table.py +++ b/examples/web_srv_route_table.py @@ -4,11 +4,16 @@ """ import textwrap +from typing import TypedDict from aiohttp import web -async def intro(request: web.Request) -> web.StreamResponse: +class EmptyDict(TypedDict): + pass + + +async def intro(request: web.Request[EmptyDict]) -> web.StreamResponse: txt = textwrap.dedent( """\ Type {url}/hello/John {url}/simple or {url}/change_body @@ -24,18 +29,18 @@ async def intro(request: web.Request) -> web.StreamResponse: return resp -async def simple(request: web.Request) -> web.StreamResponse: +async def simple(request: web.Request[EmptyDict]) -> web.StreamResponse: return web.Response(text="Simple answer") -async def change_body(request: web.Request) -> web.StreamResponse: +async def change_body(request: web.Request[EmptyDict]) -> web.StreamResponse: resp = web.Response() resp.body = b"Body changed" resp.content_type = "text/plain" return resp -async def hello(request: web.Request) -> web.StreamResponse: +async def hello(request: web.Request[EmptyDict]) -> web.StreamResponse: resp = web.StreamResponse() name = request.match_info.get("name", "Anonymous") answer = ("Hello, " + name).encode("utf8") @@ -47,8 +52,8 @@ async def hello(request: web.Request) -> web.StreamResponse: return resp -def init() -> web.Application: - app = web.Application() +def init() -> web.Application[EmptyDict]: + app: web.Application[EmptyDict] = web.Application() app.router.add_routes( [ web.get("/", intro), diff --git a/examples/web_ws.py b/examples/web_ws.py index 24610f09bfc..2052703eec4 100755 --- a/examples/web_ws.py +++ b/examples/web_ws.py @@ -2,15 +2,25 @@ """Example for aiohttp.web websocket server """ +# The extra strict mypy settings are here to help test that `Application[T]` syntax +# is working correctly. A regression will cause mypy to raise an error. +# mypy: disallow-any-expr, disallow-any-unimported, disallow-subclassing-any + import os -from typing import Union +from typing import List, TypedDict, Union, cast from aiohttp import web WS_FILE = os.path.join(os.path.dirname(__file__), "websocket.html") -async def wshandler(request: web.Request) -> Union[web.WebSocketResponse, web.Response]: +class StateDict(TypedDict): + sockets: List[web.WebSocketResponse] + + +async def wshandler( + request: web.Request[StateDict], +) -> Union[web.WebSocketResponse, web.Response]: resp = web.WebSocketResponse() available = resp.can_prepare(request) if not available: @@ -23,34 +33,34 @@ async def wshandler(request: web.Request) -> Union[web.WebSocketResponse, web.Re try: print("Someone joined.") - for ws in request.app["sockets"]: + for ws in request.app.state["sockets"]: await ws.send_str("Someone joined") - request.app["sockets"].append(resp) + request.app.state["sockets"].append(resp) - async for msg in resp: - if msg.type == web.WSMsgType.TEXT: - for ws in request.app["sockets"]: + async for msg in resp: # type: ignore[misc] + if msg.type == web.WSMsgType.TEXT: # type: ignore[misc] + for ws in request.app.state["sockets"]: if ws is not resp: - await ws.send_str(msg.data) + await ws.send_str(cast(str, msg.data)) # type: ignore[misc] else: return resp return resp finally: - request.app["sockets"].remove(resp) + request.app.state["sockets"].remove(resp) print("Someone disconnected.") - for ws in request.app["sockets"]: + for ws in request.app.state["sockets"]: await ws.send_str("Someone disconnected.") -async def on_shutdown(app: web.Application) -> None: - for ws in app["sockets"]: +async def on_shutdown(app: web.Application[StateDict]) -> None: + for ws in app.state["sockets"]: await ws.close() -def init() -> web.Application: - app = web.Application() - app["sockets"] = [] +def init() -> web.Application[StateDict]: + app: web.Application[StateDict] = web.Application() + app.state["sockets"] = [] app.router.add_get("/", wshandler) app.on_shutdown.append(on_shutdown) return app diff --git a/tests/autobahn/server/server.py b/tests/autobahn/server/server.py index d4ca04b1d5f..f2dc2c449ec 100644 --- a/tests/autobahn/server/server.py +++ b/tests/autobahn/server/server.py @@ -1,11 +1,22 @@ #!/usr/bin/env python3 import logging +import sys +from typing import List from aiohttp import WSCloseCode, web +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict -async def wshandler(request: web.Request) -> web.WebSocketResponse: + +class StateDict(TypedDict): + websockets: List[web.WebSocketResponse] + + +async def wshandler(request: web.Request[StateDict]) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoclose=False) is_ws = ws.can_prepare(request) if not is_ws: @@ -29,9 +40,9 @@ async def wshandler(request: web.Request) -> web.WebSocketResponse: return ws -async def on_shutdown(app: web.Application) -> None: - for ws in set(app["websockets"]): - await ws.close(code=WSCloseCode.GOING_AWAY, message="Server shutdown") +async def on_shutdown(app: web.Application[StateDict]) -> None: + for ws in set(app.state["websockets"]): + await ws.close(code=WSCloseCode.GOING_AWAY, message=b"Server shutdown") if __name__ == "__main__": @@ -39,7 +50,7 @@ async def on_shutdown(app: web.Application) -> None: level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" ) - app = web.Application() + app: web.Application[StateDict] = web.Application() app.router.add_route("GET", "/", wshandler) app.on_shutdown.append(on_shutdown) try: diff --git a/tests/test_loop.py b/tests/test_loop.py index 50fe5a8ad69..2f4782811a1 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -26,8 +26,8 @@ async def test_subprocess_co(loop: Any) -> None: class TestCase(AioHTTPTestCase): on_startup_called: bool - async def get_application(self) -> web.Application: - app = web.Application() + async def get_application(self) -> web.Application[Any]: + app: web.Application[Any] = web.Application() app.on_startup.append(self.on_startup_hook) return app diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 2bff3e39a9c..170e15c7568 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -67,7 +67,7 @@ async def test_noop() -> None: async def previous(request): if request.method == 'POST': with pytest.warns(DeprecationWarning): - request.app['value'] = (await request.post())['value'] + request.app.state['value'] = (await request.post())['value'] return web.Response(body=b'thanks for the data') else: v = request.app.get('value', 'unknown') diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 2a540f104e3..49955c09a14 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -191,8 +191,14 @@ def test_make_mocked_request_app() -> None: def test_make_mocked_request_app_can_store_values() -> None: req = make_mocked_request("GET", "/") - req.app["a_field"] = "a_value" - assert req.app["a_field"] == "a_value" + req.app.state["a_field"] = "a_value" + assert req.app.state["a_field"] == "a_value" + + +def test_make_mocked_request_app_access_non_existing() -> None: + req = make_mocked_request("GET", "/") + with pytest.raises(AttributeError): + req.app.foo def test_make_mocked_request_match_info() -> None: diff --git a/tests/test_web_app.py b/tests/test_web_app.py index 12c40293793..6d145d1c383 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -1,6 +1,6 @@ -# type: ignore import asyncio -from typing import Any +import sys +from typing import Any, AsyncIterator, Callable, Dict, NoReturn from unittest import mock import pytest @@ -9,19 +9,31 @@ from aiohttp.test_utils import make_mocked_coro from aiohttp.typedefs import Handler +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class EmptyDict(TypedDict): + pass + + +_EmptyApplication = web.Application[EmptyDict] + async def test_app_ctor() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() assert app.logger is log.web_logger def test_app_call() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() assert app is app() async def test_app_register_on_finish() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() cb1 = make_mocked_coro(None) cb2 = make_mocked_coro(None) app.on_cleanup.append(cb1) @@ -33,10 +45,10 @@ async def test_app_register_on_finish() -> None: async def test_app_register_coro() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() fut = asyncio.get_event_loop().create_future() - async def cb(app): + async def cb(app: _EmptyApplication) -> None: await asyncio.sleep(0.001) fut.set_result(123) @@ -49,16 +61,16 @@ async def cb(app): def test_logging() -> None: logger = mock.Mock() - app = web.Application() + app: _EmptyApplication = web.Application() app.logger = logger assert app.logger is logger async def test_on_shutdown() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() called = False - async def on_shutdown(app_param): + async def on_shutdown(app_param: _EmptyApplication) -> None: nonlocal called assert app is app_param called = True @@ -70,27 +82,27 @@ async def on_shutdown(app_param): async def test_on_startup() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() long_running1_called = False long_running2_called = False all_long_running_called = False - async def long_running1(app_param): + async def long_running1(app_param: _EmptyApplication) -> None: nonlocal long_running1_called assert app is app_param long_running1_called = True - async def long_running2(app_param): + async def long_running2(app_param: _EmptyApplication) -> None: nonlocal long_running2_called assert app is app_param long_running2_called = True - async def on_startup_all_long_running(app_param): + async def on_startup_all_long_running(app_param: _EmptyApplication) -> None: nonlocal all_long_running_called assert app is app_param all_long_running_called = True - return await asyncio.gather(long_running1(app_param), long_running2(app_param)) + await asyncio.gather(long_running1(app_param), long_running2(app_param)) app.on_startup.append(on_startup_all_long_running) app.freeze() @@ -102,15 +114,15 @@ async def on_startup_all_long_running(app_param): def test_app_delitem() -> None: - app = web.Application() - app["key"] = "value" - assert len(app) == 1 - del app["key"] - assert len(app) == 0 + app: web.Application[Dict[str, str]] = web.Application() + app.state["key"] = "value" + assert len(app.state) == 1 + del app.state["key"] + assert len(app.state) == 0 def test_app_freeze() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() subapp = mock.Mock() subapp._middlewares = () app._subapps.append(subapp) @@ -123,8 +135,8 @@ def test_app_freeze() -> None: def test_equality() -> None: - app1 = web.Application() - app2 = web.Application() + app1: _EmptyApplication = web.Application() + app2: _EmptyApplication = web.Application() assert app1 == app1 assert app1 != app2 @@ -132,13 +144,15 @@ def test_equality() -> None: def test_app_run_middlewares() -> None: - root = web.Application() - sub = web.Application() + root: _EmptyApplication = web.Application() + sub: _EmptyApplication = web.Application() root.add_subapp("/sub", sub) root.freeze() assert root._run_middlewares is False - async def middleware(request, handler: Handler): + async def middleware( + request: web.Request[EmptyDict], handler: Handler + ) -> web.StreamResponse: return await handler(request) root = web.Application(middlewares=[middleware]) @@ -155,8 +169,8 @@ async def middleware(request, handler: Handler): def test_subapp_pre_frozen_after_adding() -> None: - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() app.add_subapp("/prefix", subapp) assert subapp.pre_frozen @@ -166,22 +180,22 @@ def test_subapp_pre_frozen_after_adding() -> None: def test_app_inheritance() -> None: with pytest.raises(TypeError): - class A(web.Application): + class A(web.Application[Any]): # type: ignore[misc] pass def test_app_custom_attr() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() with pytest.raises(AttributeError): - app.custom = None + app.custom = None # type: ignore[attr-defined] async def test_cleanup_ctx() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() out = [] - def f(num): - async def inner(app): + def f(num: int) -> Callable[[_EmptyApplication], AsyncIterator[None]]: + async def inner(app: _EmptyApplication) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) @@ -198,13 +212,15 @@ async def inner(app): async def test_cleanup_ctx_exception_on_startup() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() out = [] exc = Exception("fail") - def f(num, fail=False): - async def inner(app): + def f( + num: int, fail: bool = False + ) -> Callable[[_EmptyApplication], AsyncIterator[None]]: + async def inner(app: _EmptyApplication) -> AsyncIterator[None]: out.append("pre_" + str(num)) if fail: raise exc @@ -226,13 +242,15 @@ async def inner(app): async def test_cleanup_ctx_exception_on_cleanup() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() out = [] exc = Exception("fail") - def f(num, fail=False): - async def inner(app): + def f( + num: int, fail: bool = False + ) -> Callable[[_EmptyApplication], AsyncIterator[None]]: + async def inner(app: _EmptyApplication) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) @@ -254,16 +272,16 @@ async def inner(app): async def test_cleanup_ctx_cleanup_after_exception() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() ctx_state = None - async def success_ctx(app): + async def success_ctx(app: _EmptyApplication) -> AsyncIterator[None]: nonlocal ctx_state ctx_state = "START" yield ctx_state = "CLEAN" - async def fail_ctx(app): + async def fail_ctx(app: _EmptyApplication) -> AsyncIterator[NoReturn]: raise Exception() yield @@ -280,11 +298,13 @@ async def fail_ctx(app): async def test_cleanup_ctx_exception_on_cleanup_multiple() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() out = [] - def f(num, fail=False): - async def inner(app): + def f( + num: int, fail: bool = False + ) -> Callable[[_EmptyApplication], AsyncIterator[None]]: + async def inner(app: _EmptyApplication) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) @@ -309,11 +329,11 @@ async def inner(app): async def test_cleanup_ctx_multiple_yields() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() out = [] - def f(num): - async def inner(app): + def f(num: int) -> Callable[[_EmptyApplication], AsyncIterator[None]]: + async def inner(app: _EmptyApplication) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) @@ -332,22 +352,22 @@ async def inner(app): async def test_subapp_chained_config_dict_visibility(aiohttp_client: Any) -> None: - async def main_handler(request): + async def main_handler(request: web.Request[Dict[str, str]]) -> web.Response: assert request.config_dict["key1"] == "val1" assert "key2" not in request.config_dict return web.Response(status=200) - root = web.Application() - root["key1"] = "val1" + root: web.Application[Dict[str, str]] = web.Application() + root.state["key1"] = "val1" root.add_routes([web.get("/", main_handler)]) - async def sub_handler(request): + async def sub_handler(request: web.Request[Dict[str, str]]) -> web.Response: assert request.config_dict["key1"] == "val1" assert request.config_dict["key2"] == "val2" return web.Response(status=201) - sub = web.Application() - sub["key2"] = "val2" + sub: web.Application[Dict[str, str]] = web.Application() + sub.state["key2"] = "val2" sub.add_routes([web.get("/", sub_handler)]) root.add_subapp("/sub", sub) @@ -360,20 +380,20 @@ async def sub_handler(request): async def test_subapp_chained_config_dict_overriding(aiohttp_client: Any) -> None: - async def main_handler(request): + async def main_handler(request: web.Request[Dict[str, str]]) -> web.Response: assert request.config_dict["key"] == "val1" return web.Response(status=200) - root = web.Application() - root["key"] = "val1" + root: web.Application[Dict[str, str]] = web.Application() + root.state["key"] = "val1" root.add_routes([web.get("/", main_handler)]) - async def sub_handler(request): + async def sub_handler(request: web.Request[Dict[str, str]]) -> web.Response: assert request.config_dict["key"] == "val2" return web.Response(status=201) - sub = web.Application() - sub["key"] = "val2" + sub: web.Application[Dict[str, str]] = web.Application() + sub.state["key"] = "val2" sub.add_routes([web.get("/", sub_handler)]) root.add_subapp("/sub", sub) @@ -386,25 +406,24 @@ async def sub_handler(request): async def test_subapp_on_startup(aiohttp_client: Any) -> None: - - subapp = web.Application() + subapp: web.Application[Dict[str, bool]] = web.Application() startup_called = False - async def on_startup(app): + async def on_startup(app: web.Application[Dict[str, bool]]) -> None: nonlocal startup_called startup_called = True - app["startup"] = True + app.state["startup"] = True subapp.on_startup.append(on_startup) ctx_pre_called = False ctx_post_called = False - async def cleanup_ctx(app): + async def cleanup_ctx(app: web.Application[Dict[str, bool]]) -> AsyncIterator[None]: nonlocal ctx_pre_called, ctx_post_called ctx_pre_called = True - app["cleanup"] = True + app.state["cleanup"] = True yield None ctx_post_called = True @@ -412,7 +431,7 @@ async def cleanup_ctx(app): shutdown_called = False - async def on_shutdown(app): + async def on_shutdown(app: web.Application[Dict[str, bool]]) -> None: nonlocal shutdown_called shutdown_called = True @@ -420,13 +439,13 @@ async def on_shutdown(app): cleanup_called = False - async def on_cleanup(app): + async def on_cleanup(app: web.Application[Dict[str, bool]]) -> None: nonlocal cleanup_called cleanup_called = True subapp.on_cleanup.append(on_cleanup) - app = web.Application() + app: web.Application[Dict[str, bool]] = web.Application() app.add_subapp("/subapp", subapp) @@ -460,27 +479,29 @@ async def on_cleanup(app): def test_app_iter() -> None: - app = web.Application() - app["a"] = "1" - app["b"] = "2" - assert sorted(list(app)) == ["a", "b"] + app: web.Application[Dict[str, str]] = web.Application() + app.state["a"] = "1" + app.state["b"] = "2" + assert sorted(list(app.state)) == ["a", "b"] def test_app_forbid_nonslot_attr() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() with pytest.raises(AttributeError): - app.unknow_attr + app.unknow_attr # type: ignore[attr-defined] with pytest.raises(AttributeError): - app.unknow_attr = 1 + app.unknow_attr = 1 # type: ignore[attr-defined] def test_forbid_changing_frozen_app() -> None: - app = web.Application() + app: web.Application[Dict[str, str]] = web.Application() app.freeze() - with pytest.raises(RuntimeError): - app["key"] = "value" + with pytest.raises(TypeError): + app.state["key"] = "value" + with pytest.raises(AttributeError): + app.state = {} # type: ignore[misc] def test_app_boolean() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() assert app diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 6a3b7a3fee5..bb012928577 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1,11 +1,11 @@ -# type: ignore import asyncio import io import json import pathlib import socket +import sys import zlib -from typing import Any +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, NoReturn, cast from unittest import mock import brotli @@ -15,39 +15,48 @@ import aiohttp from aiohttp import FormData, HttpVersion10, HttpVersion11, TraceConfig, multipart, web +from aiohttp.abc import AbstractResolver from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING from aiohttp.test_utils import make_mocked_coro from aiohttp.typedefs import Handler -try: - import ssl -except ImportError: - ssl = None +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class _EmptyDict(TypedDict): + pass + + +_EmptyApplication = web.Application[_EmptyDict] +_EmptyRequest = web.Request[_EmptyDict] @pytest.fixture -def here(): +def here() -> pathlib.Path: return pathlib.Path(__file__).parent @pytest.fixture -def fname(here: Any): +def fname(here: pathlib.Path) -> pathlib.Path: return here / "conftest.py" -def new_dummy_form(): +def new_dummy_form() -> FormData: form = FormData() form.add_field("name", b"123", content_transfer_encoding="base64") return form async def test_simple_get(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"OK") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -58,12 +67,12 @@ async def handler(request): async def test_simple_get_with_text(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: body = await request.read() assert b"" == body return web.Response(text="OK", headers={"content-type": "text/plain"}) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -79,11 +88,11 @@ async def test_handler_returns_not_response( asyncio.get_event_loop().set_debug(True) logger = mock.Mock() - async def handler(request): + async def handler(request: _EmptyRequest) -> str: return "abc" - app = web.Application() - app.router.add_get("/", handler) + app: _EmptyApplication = web.Application() + app.router.add_get("/", handler) # type: ignore[arg-type] server = await aiohttp_server(app, logger=logger) client = await aiohttp_client(server) @@ -99,11 +108,11 @@ async def test_handler_returns_none(aiohttp_server: Any, aiohttp_client: Any) -> asyncio.get_event_loop().set_debug(True) logger = mock.Mock() - async def handler(request): + async def handler(request: _EmptyRequest) -> None: return None - app = web.Application() - app.router.add_get("/", handler) + app: _EmptyApplication = web.Application() + app.router.add_get("/", handler) # type: ignore[arg-type] server = await aiohttp_server(app, logger=logger) client = await aiohttp_client(server) @@ -117,10 +126,10 @@ async def handler(request): async def test_head_returns_empty_body(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(body=b"test") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_head("/", handler) client = await aiohttp_client(app, version=HttpVersion11) @@ -131,10 +140,10 @@ async def handler(request): async def test_response_before_complete(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(body=b"OK") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -147,12 +156,13 @@ async def handler(request): async def test_post_form(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() - assert {"a": "1", "b": "2", "c": ""} == data + # TODO: Fix comparison overlap. + assert {"a": "1", "b": "2", "c": ""} == data # type: ignore[comparison-overlap] return web.Response(body=b"OK") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -163,14 +173,14 @@ async def handler(request): async def test_post_text(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.text() assert "русский" == data data2 = await request.text() assert data == data2 return web.Response(text=data) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -184,7 +194,7 @@ async def test_post_json(aiohttp_client: Any) -> None: dct = {"key": "текст"} - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.json() assert dct == data data2 = await request.json(loads=json.loads) @@ -194,7 +204,7 @@ async def handler(request): resp.body = json.dumps(data).encode("utf8") return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -210,7 +220,7 @@ async def test_multipart(aiohttp_client: Any) -> None: writer.append("test") writer.append_json({"passed": True}) - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: reader = await request.multipart() assert isinstance(reader, multipart.MultipartReader) @@ -222,15 +232,15 @@ async def handler(request): part = await reader.next() assert isinstance(part, multipart.BodyPartReader) assert part.headers["Content-Type"] == "application/json" - thing = await part.json() - assert thing == {"passed": True} + json_thing = await part.json() + assert json_thing == {"passed": True} resp = web.Response() resp.content_type = "application/json" resp.body = b"" return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -243,14 +253,14 @@ async def test_multipart_empty(aiohttp_client: Any) -> None: with multipart.MultipartWriter() as writer: pass - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: reader = await request.multipart() assert isinstance(reader, multipart.MultipartReader) async for part in reader: assert False, f"Unexpected part found in reader: {part!r}" return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -262,9 +272,13 @@ async def handler(request): async def test_multipart_content_transfer_encoding(aiohttp_client: Any) -> None: # For issue #1168 with multipart.MultipartWriter() as writer: - writer.append(b"\x00" * 10, headers={"Content-Transfer-Encoding": "binary"}) + # TODO: Fix arg-type error. + writer.append( + b"\x00" * 10, + headers={"Content-Transfer-Encoding": "binary"}, # type: ignore[arg-type] + ) - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: reader = await request.multipart() assert isinstance(reader, multipart.MultipartReader) @@ -279,7 +293,7 @@ async def handler(request): resp.body = b"" return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -289,10 +303,10 @@ async def handler(request): async def test_render_redirect(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: raise web.HTTPMovedPermanently(location="/path") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -307,23 +321,24 @@ async def test_post_single_file(aiohttp_client: Any) -> None: here = pathlib.Path(__file__).parent - def check_file(fs): + def check_file(fs: aiohttp.web_request.FileField) -> None: fullname = here / fs.filename with fullname.open("rb") as f: test_data = f.read() data = fs.file.read() assert test_data == data - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() assert ["data.unknown_mime_type"] == list(data.keys()) for fs in data.values(): + fs = cast(aiohttp.web_request.FileField, fs) check_file(fs) fs.file.close() resp = web.Response(body=b"OK") return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -335,11 +350,12 @@ async def handler(request): async def test_files_upload_with_same_key(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() files = data.getall("file") file_names = set() for _file in files: + _file = cast(aiohttp.web_request.FileField, _file) assert not _file.file.closed if _file.filename == "test1.jpeg": assert _file.file.read() == b"binary data 1" @@ -351,7 +367,7 @@ async def handler(request): resp = web.Response(body=b"OK") return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -370,23 +386,24 @@ async def test_post_files(aiohttp_client: Any) -> None: here = pathlib.Path(__file__).parent - def check_file(fs): + def check_file(fs: aiohttp.web_request.FileField) -> None: fullname = here / fs.filename with fullname.open("rb") as f: test_data = f.read() data = fs.file.read() assert test_data == data - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() assert ["data.unknown_mime_type", "conftest.py"] == list(data.keys()) for fs in data.values(): + fs = cast(aiohttp.web_request.FileField, fs) check_file(fs) fs.file.close() resp = web.Response(body=b"OK") return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -397,13 +414,13 @@ async def handler(request): async def test_release_post_data(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.release() chunk = await request.content.readany() assert chunk == b"" return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -412,12 +429,12 @@ async def handler(request): async def test_POST_DATA_with_content_transfer_encoding(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() assert b"123" == data["name"] return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -429,13 +446,13 @@ async def handler(request): async def test_post_form_with_duplicate_keys(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() lst = list(data.items()) assert [("a", "1"), ("a", "2")] == lst return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -444,7 +461,7 @@ async def handler(request): def test_repr_for_application() -> None: - app = web.Application() + app: _EmptyApplication = web.Application() assert "".format(id(app)) == repr(app) @@ -459,14 +476,14 @@ async def test_expect_default_handler_unknown(aiohttp_client: Any) -> None: # status. # http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.20 - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.post() pytest.xfail( "Handler should not proceed to this point in case of " "unknown Expect header" ) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -475,7 +492,7 @@ async def handler(request): async def test_100_continue(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() assert b"123" == data["name"] return web.Response() @@ -483,7 +500,7 @@ async def handler(request): form = FormData() form.add_field("name", b"123", content_transfer_encoding="base64") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -495,18 +512,18 @@ async def test_100_continue_custom(aiohttp_client: Any) -> None: expect_received = False - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() assert b"123" == data["name"] return web.Response() - async def expect_handler(request): + async def expect_handler(request: _EmptyRequest) -> None: nonlocal expect_received expect_received = True if request.version == HttpVersion11: await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler, expect_handler=expect_handler) client = await aiohttp_client(app) @@ -516,19 +533,19 @@ async def expect_handler(request): async def test_100_continue_custom_response(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: data = await request.post() assert b"123", data["name"] return web.Response() - async def expect_handler(request): + async def expect_handler(request: _EmptyRequest) -> None: if request.version == HttpVersion11: if auth_err: raise web.HTTPForbidden() await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler, expect_handler=expect_handler) client = await aiohttp_client(app) @@ -543,7 +560,7 @@ async def expect_handler(request): async def test_100_continue_for_not_found(aiohttp_client: Any) -> None: - app = web.Application() + app: _EmptyApplication = web.Application() client = await aiohttp_client(app) resp = await client.post("/not_found", data="data", expect100=True) @@ -551,10 +568,10 @@ async def test_100_continue_for_not_found(aiohttp_client: Any) -> None: async def test_100_continue_for_not_allowed(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -563,10 +580,10 @@ async def handler(request): async def test_http11_keep_alive_default(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion11) @@ -578,10 +595,10 @@ async def handler(request): @pytest.mark.xfail async def test_http10_keep_alive_default(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion10) @@ -592,11 +609,11 @@ async def handler(request): async def test_http10_keep_alive_with_headers_close(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.read() return web.Response(body=b"OK") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion10) @@ -608,11 +625,11 @@ async def handler(request): async def test_http10_keep_alive_with_headers(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.read() return web.Response(body=b"OK") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion10) @@ -630,13 +647,13 @@ async def test_upload_file(aiohttp_client: Any) -> None: with fname.open("rb") as f: data = f.read() - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: form = await request.post() - raw_data = form["file"].file.read() + raw_data = cast(aiohttp.web_request.FileField, form["file"]).file.read() assert data == raw_data return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -650,13 +667,13 @@ async def test_upload_file_object(aiohttp_client: Any) -> None: with fname.open("rb") as f: data = f.read() - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: form = await request.post() - raw_data = form["file"].file.read() + raw_data = cast(aiohttp.web_request.FileField, form["file"]).file.read() assert data == raw_data return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -671,12 +688,12 @@ async def handler(request): async def test_empty_content_for_query_without_body( method: Any, aiohttp_client: Any ) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: assert not request.body_exists assert not request.can_read_body return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_route(method, "/", handler) client = await aiohttp_client(app) @@ -685,13 +702,13 @@ async def handler(request): async def test_empty_content_for_query_with_body(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: assert request.body_exists assert request.can_read_body body = await request.read() return web.Response(body=body) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -700,12 +717,12 @@ async def handler(request): async def test_get_with_empty_arg(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: assert "arg" in request.query assert "" == request.query["arg"] return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -714,10 +731,10 @@ async def handler(request): async def test_large_header(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -727,10 +744,10 @@ async def handler(request): async def test_large_header_allowed(aiohttp_client: Any, aiohttp_server: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) server = await aiohttp_server(app, max_field_size=81920) client = await aiohttp_client(server) @@ -741,12 +758,12 @@ async def handler(request): async def test_get_with_empty_arg_with_equal(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: assert "arg" in request.query assert "" == request.query["arg"] return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -761,18 +778,18 @@ async def test_response_with_async_gen(aiohttp_client: Any, fname: Any) -> None: data_size = len(data) - async def stream(f_name): + async def stream(f_name: pathlib.Path) -> AsyncIterator[bytes]: with f_name.open("rb") as f: data = f.read(100) while data: yield data data = f.read(100) - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: headers = {"Content-Length": str(data_size)} return web.Response(body=stream(fname), headers=headers) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -792,18 +809,18 @@ async def test_response_with_async_gen_no_params( data_size = len(data) - async def stream(): + async def stream() -> AsyncIterator[bytes]: with fname.open("rb") as f: data = f.read(100) while data: yield data data = f.read(100) - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: headers = {"Content-Length": str(data_size)} return web.Response(body=stream(), headers=headers) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -820,12 +837,12 @@ async def test_response_with_file(aiohttp_client: Any, fname: Any) -> None: with fname.open("rb") as f: data = f.read() - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: nonlocal outer_file_descriptor outer_file_descriptor = fname.open("rb") return web.Response(body=outer_file_descriptor) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -842,7 +859,8 @@ async def handler(request): assert resp.headers.get("Content-Length") == str(len(resp_data)) assert resp.headers.get("Content-Disposition") == expected_content_disposition - outer_file_descriptor.close() + if outer_file_descriptor: + outer_file_descriptor.close() async def test_response_with_file_ctype(aiohttp_client: Any, fname: Any) -> None: @@ -851,7 +869,7 @@ async def test_response_with_file_ctype(aiohttp_client: Any, fname: Any) -> None with fname.open("rb") as f: data = f.read() - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: nonlocal outer_file_descriptor outer_file_descriptor = fname.open("rb") @@ -859,7 +877,7 @@ async def handler(request): body=outer_file_descriptor, headers={"content-type": "text/binary"} ) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -872,7 +890,8 @@ async def handler(request): assert resp.headers.get("Content-Length") == str(len(resp_data)) assert resp.headers.get("Content-Disposition") == expected_content_disposition - outer_file_descriptor.close() + if outer_file_descriptor: + outer_file_descriptor.close() async def test_response_with_payload_disp(aiohttp_client: Any, fname: Any) -> None: @@ -881,14 +900,14 @@ async def test_response_with_payload_disp(aiohttp_client: Any, fname: Any) -> No with fname.open("rb") as f: data = f.read() - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: nonlocal outer_file_descriptor outer_file_descriptor = fname.open("rb") pl = aiohttp.get_payload(outer_file_descriptor) pl.set_content_disposition("inline", filename="test.txt") return web.Response(body=pl, headers={"content-type": "text/binary"}) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -900,14 +919,15 @@ async def handler(request): assert resp.headers.get("Content-Length") == str(len(resp_data)) assert resp.headers.get("Content-Disposition") == 'inline; filename="test.txt"' - outer_file_descriptor.close() + if outer_file_descriptor: + outer_file_descriptor.close() async def test_response_with_payload_stringio(aiohttp_client: Any, fname: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(body=io.StringIO("test")) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -930,12 +950,12 @@ async def handler(request): async def test_response_with_precompressed_body( aiohttp_client: Any, compressor: Any, encoding: Any ) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: headers = {"Content-Encoding": encoding} data = compressor.compress(b"mydata") + compressor.flush() return web.Response(body=data, headers=headers) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -947,11 +967,11 @@ async def handler(request): async def test_response_with_precompressed_body_brotli(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: headers = {"Content-Encoding": "br"} return web.Response(body=brotli.compress(b"mydata"), headers=headers) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -963,7 +983,7 @@ async def handler(request): async def test_bad_request_payload(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: assert request.method == "POST" with pytest.raises(aiohttp.web.RequestPayloadError): @@ -971,7 +991,7 @@ async def handler(request): return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -980,7 +1000,7 @@ async def handler(request): async def test_stream_response_multiple_chunks(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.StreamResponse: resp = web.StreamResponse() resp.enable_chunked_encoding() await resp.prepare(request) @@ -989,7 +1009,7 @@ async def handler(request): await resp.write(b"z") return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -1001,7 +1021,7 @@ async def handler(request): async def test_start_without_routes(aiohttp_client: Any) -> None: - app = web.Application() + app: _EmptyApplication = web.Application() client = await aiohttp_client(app) resp = await client.get("/") @@ -1009,10 +1029,10 @@ async def test_start_without_routes(aiohttp_client: Any) -> None: async def test_requests_count(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) assert client.server.handler.requests_count == 0 @@ -1031,13 +1051,13 @@ async def handler(request): async def test_redirect_url(aiohttp_client: Any) -> None: - async def redirector(request): + async def redirector(request: _EmptyRequest) -> NoReturn: raise web.HTTPFound(location=URL("/redirected")) - async def redirected(request): + async def redirected(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/redirector", redirector) app.router.add_get("/redirected", redirected) @@ -1047,11 +1067,11 @@ async def redirected(request): async def test_simple_subapp(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(text="OK") - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path", subapp) @@ -1063,14 +1083,14 @@ async def handler(request): async def test_subapp_reverse_url(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: raise web.HTTPMovedPermanently(location=subapp.router["name"].url_for()) - async def handler2(request): + async def handler2(request: _EmptyRequest) -> web.Response: return web.Response(text="OK") - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) subapp.router.add_get("/final", handler2, name="name") app.add_subapp("/path", subapp) @@ -1084,16 +1104,16 @@ async def handler2(request): async def test_subapp_reverse_variable_url(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: raise web.HTTPMovedPermanently( location=subapp.router["name"].url_for(part="final") ) - async def handler2(request): + async def handler2(request: _EmptyRequest) -> web.Response: return web.Response(text="OK") - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) subapp.router.add_get("/{part}", handler2, name="name") app.add_subapp("/path", subapp) @@ -1109,13 +1129,13 @@ async def handler2(request): async def test_subapp_reverse_static_url(aiohttp_client: Any) -> None: fname = "aiohttp.png" - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: raise web.HTTPMovedPermanently( location=subapp.router["name"].url_for(filename=fname) ) - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) here = pathlib.Path(__file__).parent subapp.router.add_static("/static", here, name="name") @@ -1131,12 +1151,12 @@ async def handler(request): async def test_subapp_app(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: assert request.app is subapp return web.Response(text="OK") - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) @@ -1148,11 +1168,11 @@ async def handler(request): async def test_subapp_not_found(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(text="OK") - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) @@ -1162,11 +1182,11 @@ async def handler(request): async def test_subapp_not_found2(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(text="OK") - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) @@ -1176,11 +1196,11 @@ async def handler(request): async def test_subapp_not_allowed(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(text="OK") - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) @@ -1191,12 +1211,12 @@ async def handler(request): async def test_subapp_cannot_add_app_in_handler(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: request.match_info.add_app(app) return web.Response(text="OK") - app = web.Application() - subapp = web.Application() + app: _EmptyApplication = web.Application() + subapp: _EmptyApplication = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) @@ -1208,25 +1228,27 @@ async def handler(request): async def test_old_style_subapp_middlewares(aiohttp_client: Any) -> None: order = [] - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(text="OK") with pytest.warns(DeprecationWarning, match="Middleware decorator is deprecated"): @web.middleware - async def middleware(request, handler: Handler): - order.append((1, request.app["name"])) + async def middleware( + request: web.Request[Dict[str, str]], handler: Handler + ) -> web.StreamResponse: + order.append((1, request.app.state["name"])) resp = await handler(request) assert 200 == resp.status - order.append((2, request.app["name"])) + order.append((2, request.app.state["name"])) return resp - app = web.Application(middlewares=[middleware]) - subapp1 = web.Application(middlewares=[middleware]) - subapp2 = web.Application(middlewares=[middleware]) - app["name"] = "app" - subapp1["name"] = "subapp1" - subapp2["name"] = "subapp2" + app: web.Application[Dict[str, str]] = web.Application(middlewares=[middleware]) + subapp1: web.Application[Dict[str, str]] = web.Application(middlewares=[middleware]) + subapp2: web.Application[Dict[str, str]] = web.Application(middlewares=[middleware]) + app.state["name"] = "app" + subapp1.state["name"] = "subapp1" + subapp2.state["name"] = "subapp2" subapp2.router.add_get("/to", handler) subapp1.add_subapp("/b/", subapp2) @@ -1248,20 +1270,24 @@ async def middleware(request, handler: Handler): async def test_subapp_on_response_prepare(aiohttp_client: Any) -> None: order = [] - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(text="OK") - def make_signal(app): - async def on_response(request, response): + def make_signal( + app: _EmptyApplication, + ) -> Callable[[_EmptyRequest, web.StreamResponse], Awaitable[None]]: + async def on_response( + request: _EmptyRequest, response: web.StreamResponse + ) -> None: order.append(app) return on_response - app = web.Application() + app: _EmptyApplication = web.Application() app.on_response_prepare.append(make_signal(app)) - subapp1 = web.Application() + subapp1: _EmptyApplication = web.Application() subapp1.on_response_prepare.append(make_signal(subapp1)) - subapp2 = web.Application() + subapp2: _EmptyApplication = web.Application() subapp2.on_response_prepare.append(make_signal(subapp2)) subapp2.router.add_get("/to", handler) subapp1.add_subapp("/b/", subapp2) @@ -1276,14 +1302,14 @@ async def on_response(request, response): async def test_subapp_on_startup(aiohttp_server: Any) -> None: order = [] - async def on_signal(app): + async def on_signal(app: _EmptyApplication) -> None: order.append(app) - app = web.Application() + app: _EmptyApplication = web.Application() app.on_startup.append(on_signal) - subapp1 = web.Application() + subapp1: _EmptyApplication = web.Application() subapp1.on_startup.append(on_signal) - subapp2 = web.Application() + subapp2: _EmptyApplication = web.Application() subapp2.on_startup.append(on_signal) subapp1.add_subapp("/b/", subapp2) app.add_subapp("/a/", subapp1) @@ -1296,14 +1322,14 @@ async def on_signal(app): async def test_subapp_on_shutdown(aiohttp_server: Any) -> None: order = [] - async def on_signal(app): + async def on_signal(app: _EmptyApplication) -> None: order.append(app) - app = web.Application() + app: _EmptyApplication = web.Application() app.on_shutdown.append(on_signal) - subapp1 = web.Application() + subapp1: _EmptyApplication = web.Application() subapp1.on_shutdown.append(on_signal) - subapp2 = web.Application() + subapp2: _EmptyApplication = web.Application() subapp2.on_shutdown.append(on_signal) subapp1.add_subapp("/b/", subapp2) app.add_subapp("/a/", subapp1) @@ -1317,14 +1343,14 @@ async def on_signal(app): async def test_subapp_on_cleanup(aiohttp_server: Any) -> None: order = [] - async def on_signal(app): + async def on_signal(app: _EmptyApplication) -> None: order.append(app) - app = web.Application() + app: _EmptyApplication = web.Application() app.on_cleanup.append(on_signal) - subapp1 = web.Application() + subapp1: _EmptyApplication = web.Application() subapp1.on_cleanup.append(on_signal) - subapp2 = web.Application() + subapp2: _EmptyApplication = web.Application() subapp2.on_cleanup.append(on_signal) subapp1.add_subapp("/b/", subapp2) app.add_subapp("/a/", subapp1) @@ -1350,31 +1376,37 @@ async def on_signal(app): ) async def test_subapp_middleware_context( aiohttp_client: Any, route: Any, expected: Any, middlewares: Any -): +) -> None: values = [] - - def show_app_context(appname): - async def middleware(request, handler: Handler): - values.append("{}: {}".format(appname, request.app["my_value"])) + AppType = web.Application[Dict[str, str]] + RequestType = web.Request[Dict[str, str]] + + def show_app_context( + appname: str, + ) -> Callable[[RequestType, Handler], Awaitable[web.StreamResponse]]: + async def middleware( + request: RequestType, handler: Handler + ) -> web.StreamResponse: + values.append("{}: {}".format(appname, request.app.state["my_value"])) return await handler(request) return middleware - def make_handler(appname): - async def handler(request): - values.append("{}: {}".format(appname, request.app["my_value"])) + def make_handler(appname: str) -> Callable[[RequestType], Awaitable[web.Response]]: + async def handler(request: RequestType) -> web.Response: + values.append("{}: {}".format(appname, request.app.state["my_value"])) return web.Response(text="Ok") return handler - app = web.Application() - app["my_value"] = "root" + app: AppType = web.Application() + app.state["my_value"] = "root" if "A" in middlewares: app.middlewares.append(show_app_context("A")) app.router.add_get("/", make_handler("B")) - subapp = web.Application() - subapp["my_value"] = "sub" + subapp: AppType = web.Application() + subapp.state["my_value"] = "sub" if "C" in middlewares: subapp.middlewares.append(show_app_context("C")) subapp.router.add_get("/", make_handler("D")) @@ -1388,10 +1420,10 @@ async def handler(request): async def test_custom_date_header(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(headers={"Date": "Sun, 30 Oct 2016 03:13:52 GMT"}) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -1401,13 +1433,13 @@ async def handler(request): async def test_response_prepared_with_clone(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.StreamResponse: cloned = request.clone() resp = web.StreamResponse() await resp.prepare(cloned) return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -1416,12 +1448,12 @@ async def handler(request): async def test_app_max_client_size(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.post() return web.Response(body=b"ok") max_size = 1024 ** 2 - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) data = {"long_string": max_size * "x" + "xxx"} @@ -1438,13 +1470,13 @@ async def handler(request): async def test_app_max_client_size_adjusted(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.post() return web.Response(body=b"ok") default_max_size = 1024 ** 2 custom_max_size = default_max_size * 2 - app = web.Application(client_max_size=custom_max_size) + app: _EmptyApplication = web.Application(client_max_size=custom_max_size) app.router.add_post("/", handler) client = await aiohttp_client(app) data = {"long_string": default_max_size * "x" + "xxx"} @@ -1467,13 +1499,14 @@ async def handler(request): async def test_app_max_client_size_none(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.post() return web.Response(body=b"ok") default_max_size = 1024 ** 2 - custom_max_size = None - app = web.Application(client_max_size=custom_max_size) + app: _EmptyApplication = web.Application( + client_max_size=None # type: ignore[arg-type] + ) app.router.add_post("/", handler) client = await aiohttp_client(app) data = {"long_string": default_max_size * "x" + "xxx"} @@ -1491,11 +1524,11 @@ async def handler(request): async def test_post_max_client_size(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.post() return web.Response() - app = web.Application(client_max_size=10) + app: _EmptyApplication = web.Application(client_max_size=10) app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -1507,15 +1540,15 @@ async def handler(request): assert ( "Maximum request body size 10 exceeded, " "actual body size 1024" in resp_text ) - data["file"].close() + cast(io.BytesIO, data["file"]).close() async def test_post_max_client_size_for_file(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: await request.post() return web.Response() - app = web.Application(client_max_size=2) + app: _EmptyApplication = web.Application(client_max_size=2) app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -1526,12 +1559,12 @@ async def handler(request): async def test_response_with_bodypart(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: reader = await request.multipart() part = await reader.next() return web.Response(body=part) - app = web.Application(client_max_size=2) + app: _EmptyApplication = web.Application(client_max_size=2) app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -1547,12 +1580,12 @@ async def handler(request): async def test_response_with_bodypart_named(aiohttp_client: Any, tmp_path: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: reader = await request.multipart() part = await reader.next() return web.Response(body=part) - app = web.Application(client_max_size=2) + app: _EmptyApplication = web.Application(client_max_size=2) app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -1571,12 +1604,12 @@ async def handler(request): async def test_response_with_bodypart_invalid_name(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: reader = await request.multipart() part = await reader.next() return web.Response(body=part) - app = web.Application(client_max_size=2) + app: _EmptyApplication = web.Application(client_max_size=2) app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -1592,13 +1625,13 @@ async def handler(request): async def test_request_clone(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: r2 = request.clone(method="POST") assert r2.method == "POST" assert r2.match_info is request.match_info return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -1607,7 +1640,7 @@ async def handler(request): async def test_await(aiohttp_server: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.StreamResponse: resp = web.StreamResponse(headers={"content-length": str(4)}) await resp.prepare(request) with pytest.warns(DeprecationWarning): @@ -1618,7 +1651,7 @@ async def handler(request): await resp.write_eof() return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) @@ -1632,10 +1665,10 @@ async def handler(request): async def test_response_context_manager(aiohttp_server: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) session = aiohttp.ClientSession() @@ -1649,10 +1682,10 @@ async def handler(request): async def test_response_context_manager_error(aiohttp_server: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(text="some text") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) session = aiohttp.ClientSession() @@ -1665,16 +1698,17 @@ async def handler(request): await resp.read() assert resp.closed + assert session._connector is not None assert len(session._connector._conns) == 1 await session.close() -async def aiohttp_client_api_context_manager(aiohttp_server: Any): - async def handler(request): +async def aiohttp_client_api_context_manager(aiohttp_server: Any) -> None: + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) @@ -1688,7 +1722,7 @@ async def handler(request): async def test_context_manager_close_on_release( aiohttp_server: Any, mocker: Any ) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.StreamResponse: resp = web.StreamResponse() await resp.prepare(request) with pytest.warns(DeprecationWarning): @@ -1696,12 +1730,13 @@ async def handler(request): await asyncio.sleep(10) return resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as session: resp = await session.get(server.make_url("/")) + assert resp.connection is not None proto = resp.connection._protocol mocker.spy(proto, "close") async with resp: @@ -1715,14 +1750,14 @@ async def test_iter_any(aiohttp_server: Any) -> None: data = b"0123456789" * 1024 - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: buf = [] async for raw in request.content.iter_any(): buf.append(raw) assert b"".join(buf) == data return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_route("POST", "/", handler) server = await aiohttp_server(app) @@ -1741,10 +1776,10 @@ async def test_request_tracing(aiohttp_server: Any) -> None: on_connection_create_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) on_connection_create_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - async def redirector(request): + async def redirector(request: _EmptyRequest) -> NoReturn: raise web.HTTPFound(location=URL("/redirected")) - async def redirected(request): + async def redirected(request: _EmptyRequest) -> web.Response: return web.Response() trace_config = TraceConfig() @@ -1757,20 +1792,25 @@ async def redirected(request): trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/redirector", redirector) app.router.add_get("/redirected", redirected) server = await aiohttp_server(app) - class FakeResolver: + class FakeResolver(AbstractResolver): _LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1"} - def __init__(self, fakes): + def __init__(self, fakes: Dict[str, int]): # fakes -- dns -> port dict self._fakes = fakes self._resolver = aiohttp.DefaultResolver() - async def resolve(self, host, port=0, family=socket.AF_INET): + async def close(self) -> None: + pass + + async def resolve( + self, host: str, port: int = 0, family: int = socket.AF_INET + ) -> List[Dict[str, object]]: fake_port = self._fakes.get(host) if fake_port is not None: return [ @@ -1803,10 +1843,10 @@ async def resolve(self, host, port=0, family=socket.AF_INET): async def test_raise_http_exception(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: raise web.HTTPForbidden() - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) @@ -1815,13 +1855,13 @@ async def handler(request): async def test_request_path(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: assert request.path_qs == "/path%20to?a=1" assert request.path == "/path to" assert request.raw_path == "/path%20to?a=1" return web.Response(body=b"OK") - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/path to", handler) client = await aiohttp_client(app) @@ -1832,10 +1872,10 @@ async def handler(request): async def test_app_add_routes(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.add_routes([web.get("/get", handler)]) client = await aiohttp_client(app) @@ -1844,11 +1884,11 @@ async def handler(request): async def test_request_headers_type(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: assert isinstance(request.headers, CIMultiDictProxy) return web.Response() - app = web.Application() + app: _EmptyApplication = web.Application() app.add_routes([web.get("/get", handler)]) client = await aiohttp_client(app) @@ -1857,10 +1897,10 @@ async def handler(request): async def test_signal_on_error_handler(aiohttp_client: Any) -> None: - async def on_prepare(request, response): + async def on_prepare(request: _EmptyRequest, response: web.StreamResponse) -> None: response.headers["X-Custom"] = "val" - app = web.Application() + app: _EmptyApplication = web.Application() app.on_response_prepare.append(on_prepare) client = await aiohttp_client(app) @@ -1874,7 +1914,7 @@ async def on_prepare(request, response): reason="C based HTTP parser not available", ) async def test_bad_method_for_c_http_parser_not_hangs(aiohttp_client: Any) -> None: - app = web.Application() + app: _EmptyApplication = web.Application() timeout = aiohttp.ClientTimeout(sock_read=0.2) client = await aiohttp_client(app, timeout=timeout) resp = await client.request("GET1", "/") @@ -1882,12 +1922,12 @@ async def test_bad_method_for_c_http_parser_not_hangs(aiohttp_client: Any) -> No async def test_read_bufsize(aiohttp_client: Any) -> None: - async def handler(request): + async def handler(request: _EmptyRequest) -> web.Response: ret = request.content.get_read_buffer_limits() data = await request.text() # read posted data return web.Response(text=f"{data} {ret!r}") - app = web.Application(handler_args={"read_bufsize": 2}) + app: _EmptyApplication = web.Application(handler_args={"read_bufsize": 2}) app.router.add_post("/", handler) client = await aiohttp_client(app) @@ -1903,10 +1943,10 @@ async def handler(request): async def test_response_101_204_no_content_length_http11( status: Any, aiohttp_client: Any ) -> None: - async def handler(_): + async def handler(request: _EmptyRequest) -> web.Response: return web.Response(status=status) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version="1.1") resp = await client.get("/") @@ -1914,11 +1954,11 @@ async def handler(_): assert TRANSFER_ENCODING not in resp.headers -async def test_stream_response_headers_204(aiohttp_client: Any): - async def handler(_): +async def test_stream_response_headers_204(aiohttp_client: Any) -> None: + async def handler(request: _EmptyRequest) -> web.StreamResponse: return web.StreamResponse(status=204) - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") @@ -1927,12 +1967,12 @@ async def handler(_): async def test_httpfound_cookies_302(aiohttp_client: Any) -> None: - async def handler(_): + async def handler(request: _EmptyRequest) -> NoReturn: resp = web.HTTPFound("/") resp.set_cookie("my-cookie", "cookie-value") raise resp - app = web.Application() + app: _EmptyApplication = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/", allow_redirects=False)