From 176258d4c7f34955b688a71e405860dc2a06bb50 Mon Sep 17 00:00:00 2001 From: Mosquito <me@mosquito.su> Date: Sun, 11 Dec 2022 19:22:21 +0300 Subject: [PATCH] Added a configuration flag for enable request task handler cancelling when client connection closing. --- CHANGES/7056.feature | 1 + aiohttp/web.py | 4 +++ aiohttp/web_protocol.py | 6 ++++ aiohttp/web_server.py | 2 ++ docs/web_reference.rst | 11 +++++- tests/test_web_server.py | 78 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 CHANGES/7056.feature diff --git a/CHANGES/7056.feature b/CHANGES/7056.feature new file mode 100644 index 00000000000..102fb4d7938 --- /dev/null +++ b/CHANGES/7056.feature @@ -0,0 +1 @@ +Added ``handler_cancellation`` parameter to cancel web handler on client disconnection. -- by :user:`mosquito` diff --git a/aiohttp/web.py b/aiohttp/web.py index d2a4b39a0db..1fa8231dcab 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -307,6 +307,7 @@ async def _run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, ) -> None: # A internal functio to actually do all dirty job for application running if asyncio.iscoroutine(app): @@ -321,6 +322,7 @@ async def _run_app( access_log_format=access_log_format, access_log=access_log, keepalive_timeout=keepalive_timeout, + handler_cancellation=handler_cancellation, ) await runner.setup() @@ -480,6 +482,7 @@ def run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Run an app locally""" @@ -511,6 +514,7 @@ def run_app( handle_signals=handle_signals, reuse_address=reuse_address, reuse_port=reuse_port, + handler_cancellation=handler_cancellation, ) ) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 5d2f947738a..d0ed0591c17 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -297,6 +297,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: super().connection_lost(exc) + # Grab value before setting _manager to None. + handler_cancellation = self._manager.handler_cancellation + self._manager = None self._force_close = True self._request_factory = None @@ -314,6 +317,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: if self._waiter is not None: self._waiter.cancel() + if handler_cancellation and self._task_handler is not None: + self._task_handler.cancel() + self._task_handler = None if self._payload_parser is not None: diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 3961157939e..a78d8d50c1c 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -18,6 +18,7 @@ def __init__( handler: _RequestHandler, *, request_factory: Optional[_RequestFactory] = None, + handler_cancellation: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any ) -> None: @@ -27,6 +28,7 @@ def __init__( self.requests_count = 0 self.request_handler = handler self.request_factory = request_factory or self._make_request + self.handler_cancellation = handler_cancellation @property def connections(self) -> List[RequestHandler]: diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 82bc9e8ebcc..487db4b135a 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -2898,7 +2898,8 @@ Utilities access_log=aiohttp.log.access_logger, \ handle_signals=True, \ reuse_address=None, \ - reuse_port=None) + reuse_port=None, \ + handler_cancellation=False) A high-level function for running an application, serving it until keyboard interrupt and performing a @@ -2992,6 +2993,9 @@ Utilities this flag when being created. This option is not supported on Windows. + :param bool handler_cancellation: cancels the web handler task if the client + drops the connection. + .. versionadded:: 3.0 Support *access_log_class* parameter. @@ -3002,6 +3006,11 @@ Utilities Accept a coroutine as *app* parameter. + .. versionadded:: 3.9 + + Support handler_cancellation parameter (this was the default behaviour + in aiohttp <3.7). + Constants --------- diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 627f0f9f774..73e69831991 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import suppress from unittest import mock import pytest @@ -172,3 +173,80 @@ async def handler(request): ) logger.exception.assert_called_with("Error handling request", exc_info=exc) + + +async def test_handler_cancellation(aiohttp_unused_port) -> None: + event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal event + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + event.set() + raise + else: + raise web.HTTPInternalServerError() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app, handler_cancellation=True) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + try: + assert runner.server.handler_cancellation, "Flag was not propagated" + + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://localhost:{port}/") + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(event.wait(), timeout=1) + assert event.is_set(), "Request handler hasn't been cancelled" + finally: + await asyncio.gather(runner.shutdown(), site.stop()) + + +async def test_no_handler_cancellation(aiohttp_unused_port) -> None: + timeout_event = asyncio.Event() + done_event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal done_event, timeout_event + await asyncio.wait_for(timeout_event.wait(), timeout=5) + done_event.set() + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + try: + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://localhost:{port}/") + await asyncio.sleep(0.1) + timeout_event.set() + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(done_event.wait(), timeout=1) + assert done_event.is_set() + finally: + await asyncio.gather(runner.shutdown(), site.stop())