Skip to content

Commit

Permalink
Fix cancellations being swallowed (#9030) (#9257)
Browse files Browse the repository at this point in the history
Co-authored-by: J. Nick Koston <[email protected]>
(cherry picked from commit 1a77ad9)
  • Loading branch information
Dreamsorcerer authored Sep 23, 2024
1 parent 7ecc9c9 commit 3f1a8b1
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGES/9030.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed (on Python 3.11+) some edge cases where a task cancellation may get incorrectly suppressed -- by :user:`Dreamsorcerer`.
37 changes: 29 additions & 8 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,8 @@ async def write_bytes(
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
try:
await writer.drain()
await self._continue
except asyncio.CancelledError:
return
await writer.drain()
await self._continue

protocol = conn.protocol
assert protocol is not None
Expand Down Expand Up @@ -658,6 +655,7 @@ async def write_bytes(
except asyncio.CancelledError:
# Body hasn't been fully sent, so connection can't be reused.
conn.close()
raise
except Exception as underlying_exc:
set_exception(
protocol,
Expand Down Expand Up @@ -764,8 +762,15 @@ async def send(self, conn: "Connection") -> "ClientResponse":

async def close(self) -> None:
if self._writer is not None:
with contextlib.suppress(asyncio.CancelledError):
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

def terminate(self) -> None:
if self._writer is not None:
Expand Down Expand Up @@ -1119,7 +1124,15 @@ def _release_connection(self) -> None:

async def _wait_released(self) -> None:
if self._writer is not None:
await self._writer
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
self._release_connection()

def _cleanup_writer(self) -> None:
Expand All @@ -1135,7 +1148,15 @@ def _notify_content(self) -> None:

async def wait_for_close(self) -> None:
if self._writer is not None:
await self._writer
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
self.release()

async def read(self) -> bytes:
Expand Down
38 changes: 29 additions & 9 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,17 +271,32 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
# down while the handler is still processing a request
# to avoid creating a future for every request.
self._handler_waiter = self._loop.create_future()
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
try:
async with ceil_timeout(timeout):
await self._handler_waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._handler_waiter = None
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
# Then cancel handler and wait
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
try:
async with ceil_timeout(timeout):
if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())

if self._task_handler is not None and not self._task_handler.done():
await self._task_handler
await asyncio.shield(self._task_handler)
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

# force-close non-idle handler
if self._task_handler is not None:
Expand Down Expand Up @@ -517,8 +532,6 @@ async def start(self) -> None:
# wait for next request
self._waiter = loop.create_future()
await self._waiter
except asyncio.CancelledError:
break
finally:
self._waiter = None

Expand All @@ -545,7 +558,7 @@ async def start(self) -> None:
task = loop.create_task(coro)
try:
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
except ConnectionError:
self.log_debug("Ignored premature client disconnection")
break

Expand All @@ -569,12 +582,19 @@ async def start(self) -> None:
now = loop.time()
end_t = now + lingering_time

with suppress(asyncio.TimeoutError, asyncio.CancelledError):
try:
while not payload.is_eof() and now < end_t:
async with ceil_timeout(end_t - now):
# read and ignore
await payload.readany()
now = loop.time()
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task())
and t.cancelling()
):
raise

# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
Expand All @@ -584,8 +604,8 @@ async def start(self) -> None:
payload.set_exception(_PAYLOAD_ACCESS_ERROR)

except asyncio.CancelledError:
self.log_debug("Ignored premature client disconnection ")
break
self.log_debug("Ignored premature client disconnection")
raise
except Exception as exc:
self.log_exception("Unhandled exception", exc_info=exc)
self.force_close()
Expand Down
19 changes: 18 additions & 1 deletion tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import io
import pathlib
import sys
import urllib.parse
import zlib
from http.cookies import BaseCookie, Morsel, SimpleCookie
Expand Down Expand Up @@ -1213,7 +1214,23 @@ async def test_oserror_on_write_bytes(loop, conn) -> None:
await req.close()


async def test_terminate(loop, conn) -> None:
@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()")
async def test_cancel_close(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
req._writer = asyncio.Future() # type: ignore[assignment]

t = asyncio.create_task(req.close())

# Start waiting on _writer
await asyncio.sleep(0)

t.cancel()
# Cancellation should not be suppressed.
with pytest.raises(asyncio.CancelledError):
await t


async def test_terminate(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)

async def _mock_write_bytes(*args, **kwargs):
Expand Down
40 changes: 38 additions & 2 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import pathlib
import socket
import sys
import zlib
from typing import Any, NoReturn, Optional
from unittest import mock
Expand All @@ -22,6 +23,7 @@
web,
)
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING
from aiohttp.pytest_plugin import AiohttpClient
from aiohttp.test_utils import make_mocked_coro
from aiohttp.typedefs import Handler
from aiohttp.web_protocol import RequestHandler
Expand Down Expand Up @@ -187,8 +189,42 @@ async def handler(request):
await resp.release()


async def test_post_form(aiohttp_client) -> None:
async def handler(request):
@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()")
async def test_cancel_shutdown(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
t = asyncio.create_task(request.protocol.shutdown())
# Ensure it's started waiting
await asyncio.sleep(0)

t.cancel()
# Cancellation should not be suppressed
with pytest.raises(asyncio.CancelledError):
await t

# Repeat for second waiter in shutdown()
with mock.patch.object(request.protocol, "_request_in_progress", False):
with mock.patch.object(request.protocol, "_current_request", None):
t = asyncio.create_task(request.protocol.shutdown())
await asyncio.sleep(0)

t.cancel()
with pytest.raises(asyncio.CancelledError):
await t

return web.Response(body=b"OK")

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

async with client.get("/") as resp:
assert resp.status == 200
txt = await resp.text()
assert txt == "OK"


async def test_post_form(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
data = await request.post()
assert {"a": "1", "b": "2", "c": ""} == data
return web.Response(body=b"OK")
Expand Down

0 comments on commit 3f1a8b1

Please sign in to comment.