diff --git a/aiohttp/web.py b/aiohttp/web.py index 0b465375659..d39ecb79622 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -306,6 +306,7 @@ async def _run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + cancel_handler_on_connection_lost: bool = False ) -> None: # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): @@ -320,6 +321,7 @@ async def _run_app( access_log_format=access_log_format, access_log=access_log, keepalive_timeout=keepalive_timeout, + cancel_handler_on_connection_lost=cancel_handler_on_connection_lost, ) await runner.setup() @@ -480,6 +482,7 @@ def run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + cancel_handler_on_connection_lost: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Run an app locally""" @@ -512,6 +515,7 @@ def run_app( handle_signals=handle_signals, reuse_address=reuse_address, reuse_port=reuse_port, + cancel_handler_on_connection_lost=cancel_handler_on_connection_lost, ) ) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 92e2052e08f..3885186dccb 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -313,7 +313,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: super().connection_lost(exc) - should_cancel_task_handler = self._manager.cancel_when_connection_lost + should_cancel_task_handler = self._manager.cancel_handler_on_connection_lost self._manager = None self._force_close = True diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 0c5ef4cf589..3352d7cb9c2 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -19,7 +19,7 @@ def __init__( *, request_factory: Optional[_RequestFactory] = None, debug: Optional[bool] = None, - cancel_when_connection_lost: bool = False, + cancel_handler_on_connection_lost: bool = False, **kwargs: Any, ) -> None: if debug is not None: @@ -34,7 +34,7 @@ def __init__( self.requests_count = 0 self.request_handler = handler self.request_factory = request_factory or self._make_request - self.cancel_when_connection_lost = cancel_when_connection_lost + self.cancel_handler_on_connection_lost = cancel_handler_on_connection_lost @property def connections(self) -> List[RequestHandler]: diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 8478ac9ca40..113f99709a2 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -2809,7 +2809,8 @@ Utilities access_log=aiohttp.log.access_logger, \ handle_signals=True, \ reuse_address=None, \ - reuse_port=None) + reuse_port=None, \ + cancel_handler_on_connection_lost=False) A high-level function for running an application, serving it until keyboard interrupt and performing a @@ -2904,6 +2905,11 @@ Utilities this flag when being created. This option is not supported on Windows. + :param bool cancel_handler_on_connection_lost: tells the runner whether to + cancel the execution of the + handler task if the client + connection has been closed. + .. versionadded:: 3.0 Support *access_log_class* parameter. diff --git a/tests/test_web_server.py b/tests/test_web_server.py index b97e0fa7b64..aa5119675b1 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -5,6 +5,7 @@ import pytest +import aiohttp from aiohttp import client, helpers, web @@ -207,3 +208,57 @@ async def handler(request): ) logger.exception.assert_called_with("Error handling request", exc_info=exc) + + +async def test_cancel_handler_on_connection_lost(aiohttp_unused_port) -> None: + event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal event + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + event.set() + raise + else: + raise web.HTTPInternalServerError() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app, cancel_handler_on_connection_lost=True) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + async def client_request_maker(): + async with aiohttp.ClientSession( + base_url=f"http://localhost:{port}" + ) as session: + request = session.get("/") + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(request, timeout=0.1) + + try: + assert runner.server.cancel_handler_on_connection_lost, "Flag was not propagated" + await client_request_maker() + + await asyncio.wait_for(event.wait(), timeout=1) + assert event.is_set(), "Request handler hasn't been cancelled" + finally: + await asyncio.gather(runner.shutdown(), site.stop()) + + +async def test_cancel_handler_on_connection_lost_flag_on_runner() -> None: + runner = web.AppRunner(web.Application(), cancel_handler_on_connection_lost=True) + await runner.setup() + assert runner.server.cancel_handler_on_connection_lost, "Flag was not propagated" + await runner.shutdown() + + +async def test_cancel_handler_on_connection_lost_flag_on_site() -> None: + web.run_app