From 176258d4c7f34955b688a71e405860dc2a06bb50 Mon Sep 17 00:00:00 2001
From: Mosquito <me@mosquito.su>
Date: Sun, 11 Dec 2022 19:22:21 +0300
Subject: [PATCH] Added a configuration flag for enable request task handler
 cancelling when client connection closing.

---
 CHANGES/7056.feature     |  1 +
 aiohttp/web.py           |  4 +++
 aiohttp/web_protocol.py  |  6 ++++
 aiohttp/web_server.py    |  2 ++
 docs/web_reference.rst   | 11 +++++-
 tests/test_web_server.py | 78 ++++++++++++++++++++++++++++++++++++++++
 6 files changed, 101 insertions(+), 1 deletion(-)
 create mode 100644 CHANGES/7056.feature

diff --git a/CHANGES/7056.feature b/CHANGES/7056.feature
new file mode 100644
index 00000000000..102fb4d7938
--- /dev/null
+++ b/CHANGES/7056.feature
@@ -0,0 +1 @@
+Added ``handler_cancellation`` parameter to cancel web handler on client disconnection. -- by :user:`mosquito`
diff --git a/aiohttp/web.py b/aiohttp/web.py
index d2a4b39a0db..1fa8231dcab 100644
--- a/aiohttp/web.py
+++ b/aiohttp/web.py
@@ -307,6 +307,7 @@ async def _run_app(
     handle_signals: bool = True,
     reuse_address: Optional[bool] = None,
     reuse_port: Optional[bool] = None,
+    handler_cancellation: bool = False,
 ) -> None:
     # A internal functio to actually do all dirty job for application running
     if asyncio.iscoroutine(app):
@@ -321,6 +322,7 @@ async def _run_app(
         access_log_format=access_log_format,
         access_log=access_log,
         keepalive_timeout=keepalive_timeout,
+        handler_cancellation=handler_cancellation,
     )
 
     await runner.setup()
@@ -480,6 +482,7 @@ def run_app(
     handle_signals: bool = True,
     reuse_address: Optional[bool] = None,
     reuse_port: Optional[bool] = None,
+    handler_cancellation: bool = False,
     loop: Optional[asyncio.AbstractEventLoop] = None,
 ) -> None:
     """Run an app locally"""
@@ -511,6 +514,7 @@ def run_app(
             handle_signals=handle_signals,
             reuse_address=reuse_address,
             reuse_port=reuse_port,
+            handler_cancellation=handler_cancellation,
         )
     )
 
diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py
index 5d2f947738a..d0ed0591c17 100644
--- a/aiohttp/web_protocol.py
+++ b/aiohttp/web_protocol.py
@@ -297,6 +297,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
 
         super().connection_lost(exc)
 
+        # Grab value before setting _manager to None.
+        handler_cancellation = self._manager.handler_cancellation
+
         self._manager = None
         self._force_close = True
         self._request_factory = None
@@ -314,6 +317,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
         if self._waiter is not None:
             self._waiter.cancel()
 
+        if handler_cancellation and self._task_handler is not None:
+            self._task_handler.cancel()
+
         self._task_handler = None
 
         if self._payload_parser is not None:
diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py
index 3961157939e..a78d8d50c1c 100644
--- a/aiohttp/web_server.py
+++ b/aiohttp/web_server.py
@@ -18,6 +18,7 @@ def __init__(
         handler: _RequestHandler,
         *,
         request_factory: Optional[_RequestFactory] = None,
+        handler_cancellation: bool = False,
         loop: Optional[asyncio.AbstractEventLoop] = None,
         **kwargs: Any
     ) -> None:
@@ -27,6 +28,7 @@ def __init__(
         self.requests_count = 0
         self.request_handler = handler
         self.request_factory = request_factory or self._make_request
+        self.handler_cancellation = handler_cancellation
 
     @property
     def connections(self) -> List[RequestHandler]:
diff --git a/docs/web_reference.rst b/docs/web_reference.rst
index 82bc9e8ebcc..487db4b135a 100644
--- a/docs/web_reference.rst
+++ b/docs/web_reference.rst
@@ -2898,7 +2898,8 @@ Utilities
                       access_log=aiohttp.log.access_logger, \
                       handle_signals=True, \
                       reuse_address=None, \
-                      reuse_port=None)
+                      reuse_port=None, \
+                      handler_cancellation=False)
 
    A high-level function for running an application, serving it until
    keyboard interrupt and performing a
@@ -2992,6 +2993,9 @@ Utilities
                            this flag when being created. This option is not
                            supported on Windows.
 
+   :param bool handler_cancellation: cancels the web handler task if the client
+                                     drops the connection.
+
    .. versionadded:: 3.0
 
       Support *access_log_class* parameter.
@@ -3002,6 +3006,11 @@ Utilities
 
       Accept a coroutine as *app* parameter.
 
+   .. versionadded:: 3.9
+
+      Support handler_cancellation parameter (this was the default behaviour
+      in aiohttp <3.7).
+
 Constants
 ---------
 
diff --git a/tests/test_web_server.py b/tests/test_web_server.py
index 627f0f9f774..73e69831991 100644
--- a/tests/test_web_server.py
+++ b/tests/test_web_server.py
@@ -1,4 +1,5 @@
 import asyncio
+from contextlib import suppress
 from unittest import mock
 
 import pytest
@@ -172,3 +173,80 @@ async def handler(request):
     )
 
     logger.exception.assert_called_with("Error handling request", exc_info=exc)
+
+
+async def test_handler_cancellation(aiohttp_unused_port) -> None:
+    event = asyncio.Event()
+    port = aiohttp_unused_port()
+
+    async def on_request(_: web.Request) -> web.Response:
+        nonlocal event
+        try:
+            await asyncio.sleep(10)
+        except asyncio.CancelledError:
+            event.set()
+            raise
+        else:
+            raise web.HTTPInternalServerError()
+
+    app = web.Application()
+    app.router.add_route("GET", "/", on_request)
+
+    runner = web.AppRunner(app, handler_cancellation=True)
+    await runner.setup()
+
+    site = web.TCPSite(runner, host="localhost", port=port)
+
+    await site.start()
+
+    try:
+        assert runner.server.handler_cancellation, "Flag was not propagated"
+
+        async with client.ClientSession(
+            timeout=client.ClientTimeout(total=0.1)
+        ) as sess:
+            with pytest.raises(asyncio.TimeoutError):
+                await sess.get(f"http://localhost:{port}/")
+
+        with suppress(asyncio.TimeoutError):
+            await asyncio.wait_for(event.wait(), timeout=1)
+        assert event.is_set(), "Request handler hasn't been cancelled"
+    finally:
+        await asyncio.gather(runner.shutdown(), site.stop())
+
+
+async def test_no_handler_cancellation(aiohttp_unused_port) -> None:
+    timeout_event = asyncio.Event()
+    done_event = asyncio.Event()
+    port = aiohttp_unused_port()
+
+    async def on_request(_: web.Request) -> web.Response:
+        nonlocal done_event, timeout_event
+        await asyncio.wait_for(timeout_event.wait(), timeout=5)
+        done_event.set()
+        return web.Response()
+
+    app = web.Application()
+    app.router.add_route("GET", "/", on_request)
+
+    runner = web.AppRunner(app)
+    await runner.setup()
+
+    site = web.TCPSite(runner, host="localhost", port=port)
+
+    await site.start()
+
+    try:
+        async with client.ClientSession(
+            timeout=client.ClientTimeout(total=0.1)
+        ) as sess:
+            with pytest.raises(asyncio.TimeoutError):
+                await sess.get(f"http://localhost:{port}/")
+        await asyncio.sleep(0.1)
+        timeout_event.set()
+
+        with suppress(asyncio.TimeoutError):
+            await asyncio.wait_for(done_event.wait(), timeout=1)
+        assert done_event.is_set()
+    finally:
+        await asyncio.gather(runner.shutdown(), site.stop())