Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patchback/backports/3.9/38b9ec51e52fdcab11bbe322cc66392c599ca183/pr 7056 #7127

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/7056.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added ``handler_cancellation`` parameter to cancel web handler on client disconnection. -- by :user:`mosquito`
4 changes: 4 additions & 0 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ async def _run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
# A internal functio to actually do all dirty job for application running
if asyncio.iscoroutine(app):
Expand All @@ -321,6 +322,7 @@ async def _run_app(
access_log_format=access_log_format,
access_log=access_log,
keepalive_timeout=keepalive_timeout,
handler_cancellation=handler_cancellation,
)

await runner.setup()
Expand Down Expand Up @@ -480,6 +482,7 @@ def run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
Expand Down Expand Up @@ -511,6 +514,7 @@ def run_app(
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
handler_cancellation=handler_cancellation,
)
)

Expand Down
6 changes: 6 additions & 0 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:

super().connection_lost(exc)

# Grab value before setting _manager to None.
handler_cancellation = self._manager.handler_cancellation

self._manager = None
self._force_close = True
self._request_factory = None
Expand All @@ -314,6 +317,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._waiter is not None:
self._waiter.cancel()

if handler_cancellation and self._task_handler is not None:
self._task_handler.cancel()

self._task_handler = None

if self._payload_parser is not None:
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
handler: _RequestHandler,
*,
request_factory: Optional[_RequestFactory] = None,
handler_cancellation: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Any
) -> None:
Expand All @@ -27,6 +28,7 @@ def __init__(
self.requests_count = 0
self.request_handler = handler
self.request_factory = request_factory or self._make_request
self.handler_cancellation = handler_cancellation

@property
def connections(self) -> List[RequestHandler]:
Expand Down
11 changes: 10 additions & 1 deletion docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2898,7 +2898,8 @@ Utilities
access_log=aiohttp.log.access_logger, \
handle_signals=True, \
reuse_address=None, \
reuse_port=None)
reuse_port=None, \
handler_cancellation=False)

A high-level function for running an application, serving it until
keyboard interrupt and performing a
Expand Down Expand Up @@ -2992,6 +2993,9 @@ Utilities
this flag when being created. This option is not
supported on Windows.

:param bool handler_cancellation: cancels the web handler task if the client
drops the connection.

.. versionadded:: 3.0

Support *access_log_class* parameter.
Expand All @@ -3002,6 +3006,11 @@ Utilities

Accept a coroutine as *app* parameter.

.. versionadded:: 3.9

Support handler_cancellation parameter (this was the default behaviour
in aiohttp <3.7).

Constants
---------

Expand Down
78 changes: 78 additions & 0 deletions tests/test_web_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from contextlib import suppress
from unittest import mock

import pytest
Expand Down Expand Up @@ -172,3 +173,80 @@ async def handler(request):
)

logger.exception.assert_called_with("Error handling request", exc_info=exc)


async def test_handler_cancellation(aiohttp_unused_port) -> None:
event = asyncio.Event()
port = aiohttp_unused_port()

async def on_request(_: web.Request) -> web.Response:
nonlocal event
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
event.set()
raise
else:
raise web.HTTPInternalServerError()

app = web.Application()
app.router.add_route("GET", "/", on_request)

runner = web.AppRunner(app, handler_cancellation=True)
await runner.setup()

site = web.TCPSite(runner, host="localhost", port=port)

await site.start()

try:
assert runner.server.handler_cancellation, "Flag was not propagated"

async with client.ClientSession(
timeout=client.ClientTimeout(total=0.1)
) as sess:
with pytest.raises(asyncio.TimeoutError):
await sess.get(f"http://localhost:{port}/")

with suppress(asyncio.TimeoutError):
await asyncio.wait_for(event.wait(), timeout=1)
assert event.is_set(), "Request handler hasn't been cancelled"
finally:
await asyncio.gather(runner.shutdown(), site.stop())


async def test_no_handler_cancellation(aiohttp_unused_port) -> None:
timeout_event = asyncio.Event()
done_event = asyncio.Event()
port = aiohttp_unused_port()

async def on_request(_: web.Request) -> web.Response:
nonlocal done_event, timeout_event
await asyncio.wait_for(timeout_event.wait(), timeout=5)
done_event.set()
return web.Response()

app = web.Application()
app.router.add_route("GET", "/", on_request)

runner = web.AppRunner(app)
await runner.setup()

site = web.TCPSite(runner, host="localhost", port=port)

await site.start()

try:
async with client.ClientSession(
timeout=client.ClientTimeout(total=0.1)
) as sess:
with pytest.raises(asyncio.TimeoutError):
await sess.get(f"http://localhost:{port}/")
await asyncio.sleep(0.1)
timeout_event.set()

with suppress(asyncio.TimeoutError):
await asyncio.wait_for(done_event.wait(), timeout=1)
assert done_event.is_set()
finally:
await asyncio.gather(runner.shutdown(), site.stop())