Skip to content

Commit

Permalink
Patchback/backports/3.9/38b9ec51e52fdcab11bbe322cc66392c599ca183/pr 7…
Browse files Browse the repository at this point in the history
…056 (#7127)

Co-authored-by: Mosquito <[email protected]>
  • Loading branch information
Dreamsorcerer and mosquito authored Dec 11, 2022
1 parent c56cf10 commit 592c1bb
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES/7056.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added ``handler_cancellation`` parameter to cancel web handler on client disconnection. -- by :user:`mosquito`
4 changes: 4 additions & 0 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ async def _run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
# A internal functio to actually do all dirty job for application running
if asyncio.iscoroutine(app):
Expand All @@ -321,6 +322,7 @@ async def _run_app(
access_log_format=access_log_format,
access_log=access_log,
keepalive_timeout=keepalive_timeout,
handler_cancellation=handler_cancellation,
)

await runner.setup()
Expand Down Expand Up @@ -480,6 +482,7 @@ def run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
Expand Down Expand Up @@ -511,6 +514,7 @@ def run_app(
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
handler_cancellation=handler_cancellation,
)
)

Expand Down
6 changes: 6 additions & 0 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:

super().connection_lost(exc)

# Grab value before setting _manager to None.
handler_cancellation = self._manager.handler_cancellation

self._manager = None
self._force_close = True
self._request_factory = None
Expand All @@ -314,6 +317,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._waiter is not None:
self._waiter.cancel()

if handler_cancellation and self._task_handler is not None:
self._task_handler.cancel()

self._task_handler = None

if self._payload_parser is not None:
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
handler: _RequestHandler,
*,
request_factory: Optional[_RequestFactory] = None,
handler_cancellation: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Any
) -> None:
Expand All @@ -27,6 +28,7 @@ def __init__(
self.requests_count = 0
self.request_handler = handler
self.request_factory = request_factory or self._make_request
self.handler_cancellation = handler_cancellation

@property
def connections(self) -> List[RequestHandler]:
Expand Down
11 changes: 10 additions & 1 deletion docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2898,7 +2898,8 @@ Utilities
access_log=aiohttp.log.access_logger, \
handle_signals=True, \
reuse_address=None, \
reuse_port=None)
reuse_port=None, \
handler_cancellation=False)

A high-level function for running an application, serving it until
keyboard interrupt and performing a
Expand Down Expand Up @@ -2992,6 +2993,9 @@ Utilities
this flag when being created. This option is not
supported on Windows.

:param bool handler_cancellation: cancels the web handler task if the client
drops the connection.

.. versionadded:: 3.0

Support *access_log_class* parameter.
Expand All @@ -3002,6 +3006,11 @@ Utilities

Accept a coroutine as *app* parameter.

.. versionadded:: 3.9

Support handler_cancellation parameter (this was the default behaviour
in aiohttp <3.7).

Constants
---------

Expand Down
78 changes: 78 additions & 0 deletions tests/test_web_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from contextlib import suppress
from unittest import mock

import pytest
Expand Down Expand Up @@ -172,3 +173,80 @@ async def handler(request):
)

logger.exception.assert_called_with("Error handling request", exc_info=exc)


async def test_handler_cancellation(aiohttp_unused_port) -> None:
event = asyncio.Event()
port = aiohttp_unused_port()

async def on_request(_: web.Request) -> web.Response:
nonlocal event
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
event.set()
raise
else:
raise web.HTTPInternalServerError()

app = web.Application()
app.router.add_route("GET", "/", on_request)

runner = web.AppRunner(app, handler_cancellation=True)
await runner.setup()

site = web.TCPSite(runner, host="localhost", port=port)

await site.start()

try:
assert runner.server.handler_cancellation, "Flag was not propagated"

async with client.ClientSession(
timeout=client.ClientTimeout(total=0.1)
) as sess:
with pytest.raises(asyncio.TimeoutError):
await sess.get(f"http://localhost:{port}/")

with suppress(asyncio.TimeoutError):
await asyncio.wait_for(event.wait(), timeout=1)
assert event.is_set(), "Request handler hasn't been cancelled"
finally:
await asyncio.gather(runner.shutdown(), site.stop())


async def test_no_handler_cancellation(aiohttp_unused_port) -> None:
timeout_event = asyncio.Event()
done_event = asyncio.Event()
port = aiohttp_unused_port()

async def on_request(_: web.Request) -> web.Response:
nonlocal done_event, timeout_event
await asyncio.wait_for(timeout_event.wait(), timeout=5)
done_event.set()
return web.Response()

app = web.Application()
app.router.add_route("GET", "/", on_request)

runner = web.AppRunner(app)
await runner.setup()

site = web.TCPSite(runner, host="localhost", port=port)

await site.start()

try:
async with client.ClientSession(
timeout=client.ClientTimeout(total=0.1)
) as sess:
with pytest.raises(asyncio.TimeoutError):
await sess.get(f"http://localhost:{port}/")
await asyncio.sleep(0.1)
timeout_event.set()

with suppress(asyncio.TimeoutError):
await asyncio.wait_for(done_event.wait(), timeout=1)
assert done_event.is_set()
finally:
await asyncio.gather(runner.shutdown(), site.stop())

0 comments on commit 592c1bb

Please sign in to comment.