Skip to content

Commit

Permalink
Fix WebSocket ping tasks being prematurely garbage collected (#8641)
Browse files Browse the repository at this point in the history
(cherry picked from commit 0a88bab)
  • Loading branch information
bdraco committed Aug 8, 2024
1 parent 68e8496 commit ba50518
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGES/8641.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed WebSocket ping tasks being prematurely garbage collected -- by :user:`bdraco`.

There was a small risk that WebSocket ping tasks would be prematurely garbage collected because the event loop only holds a weak reference to the task. The garbage collection risk has been fixed by holding a strong reference to the task. Additionally, the task is now scheduled eagerly with Python 3.12+ to increase the chance it can be completed immediately and avoid having to hold any references to the task.
25 changes: 20 additions & 5 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: Optional[asyncio.Task[None]] = None

self._reset_heartbeat()

Expand All @@ -80,6 +81,9 @@ def _cancel_heartbeat(self) -> None:
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None

def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -118,11 +122,6 @@ def _send_heartbeat(self) -> None:
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
Expand All @@ -131,6 +130,22 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
else:
ping_task = loop.create_task(self._writer.ping())

if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
self._ping_task = None

def _pong_not_received(self) -> None:
if not self._closed:
self._set_closed()
Expand Down
25 changes: 20 additions & 5 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ def __init__(
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._compress = compress
self._max_msg_size = max_msg_size
self._ping_task: Optional[asyncio.Task[None]] = None

def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None

def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -141,11 +145,6 @@ def _send_heartbeat(self) -> None:
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
Expand All @@ -154,6 +153,22 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
else:
ping_task = loop.create_task(self._writer.ping())

if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
self._ping_task = None

def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._set_closed()
Expand Down
50 changes: 48 additions & 2 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import sys
from typing import Any, NoReturn
from unittest import mock

import pytest

Expand Down Expand Up @@ -727,8 +728,53 @@ async def handler(request):
assert isinstance(msg.data, ServerTimeoutError)


async def test_send_recv_compress(aiohttp_client: Any) -> None:
async def handler(request):
async def test_close_websocket_while_ping_inflight(
aiohttp_client: AiohttpClient,
) -> None:
"""Test closing the websocket while a ping is in-flight."""
ping_received = False

async def handler(request: web.Request) -> NoReturn:
nonlocal ping_received
ws = web.WebSocketResponse(autoping=False)
await ws.prepare(request)
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.BINARY
msg = await ws.receive()
ping_received = msg.type is aiohttp.WSMsgType.PING
await ws.receive()
assert False

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)
resp = await client.ws_connect("/", heartbeat=0.1)
await resp.send_bytes(b"ask")

cancelled = False
ping_stated = False

async def delayed_ping() -> None:
nonlocal cancelled, ping_stated
ping_stated = True
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
cancelled = True
raise

with mock.patch.object(resp._writer, "ping", delayed_ping):
await asyncio.sleep(0.1)

await resp.close()
await asyncio.sleep(0)
assert ping_stated is True
assert cancelled is True


async def test_send_recv_compress(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)

Expand Down

0 comments on commit ba50518

Please sign in to comment.