Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WebSocket ping tasks being prematurely garbage collected #8641

Merged
merged 20 commits into from
Aug 8, 2024
Merged
3 changes: 3 additions & 0 deletions CHANGES/8641.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed WebSocket ping tasks being prematurely garbage collected -- by :user:`bdraco`.

There was a small risk that WebSocket ping tasks would be prematurely garbage collected because the event loop only holds a weak reference to the task. The garbage collection risk has been fixed by holding a strong reference to the task. Additionally, the task is now scheduled eagerly with Python 3.12+ to increase the chance it can be completed immediately and avoid having to hold any references to the task.
14 changes: 5 additions & 9 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
HeadersMixin,
TimerNoop,
basicauth_from_netrc,
create_eager_task,
is_expected_content_type,
netrc_from_env,
parse_mimetype,
Expand Down Expand Up @@ -668,15 +669,10 @@ async def send(self, conn: "Connection") -> "ClientResponse":
await writer.write_headers(status_line, self.headers)
coro = self.write_bytes(writer, conn)

if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to write
# bytes immediately to avoid having to schedule
# the task on the event loop.
task = asyncio.Task(coro, loop=self.loop, eager_start=True)
else:
task = self.loop.create_task(coro)

self._writer = task
# Optimization for Python 3.12+, try to write
# bytes immediately to avoid having to schedule
# the task on the event loop.
self._writer = create_eager_task(coro, self.loop)
response_class = self.response_class
assert response_class is not None
self.response = response_class(
Expand Down
20 changes: 14 additions & 6 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .client_exceptions import ClientError, ServerTimeoutError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
from .helpers import calculate_timeout_when, create_eager_task, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: Optional[asyncio.Task[None]] = None

self._reset_heartbeat()

Expand All @@ -90,6 +91,9 @@ def _cancel_heartbeat(self) -> None:
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None

def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -128,11 +132,6 @@ def _send_heartbeat(self) -> None:
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
Expand All @@ -141,6 +140,15 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

ping_task = create_eager_task(self._writer.ping(), loop)
bdraco marked this conversation as resolved.
Show resolved Hide resolved
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
self._ping_task = None

def _pong_not_received(self) -> None:
if not self._closed:
self._set_closed()
Expand Down
15 changes: 15 additions & 0 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Any,
Callable,
ContextManager,
Coroutine,
Dict,
Generic,
Iterable,
Expand Down Expand Up @@ -592,6 +593,20 @@ def weakref_handle(
return None


def create_eager_task(
bdraco marked this conversation as resolved.
Show resolved Hide resolved
coro: Coroutine[Any, Any, None],
loop: asyncio.AbstractEventLoop,
) -> "asyncio.Task[None]":
"""Create a task that will be run immediately if possible."""
if sys.version_info >= (3, 12):
# Optimization for Python 3.12+, try start eagerly
# to avoid being scheduled on the event loop.
return asyncio.Task(coro, loop=loop, eager_start=True)
# For older python versions, we need to schedule the task
# on the event loop as eager_start is not available.
return loop.create_task(coro)


def call_later(
cb: Callable[[], Any],
timeout: Optional[float],
Expand Down
26 changes: 20 additions & 6 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import calculate_timeout_when, set_exception, set_result
from .helpers import (
calculate_timeout_when,
create_eager_task,
set_exception,
set_result,
)
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
Expand Down Expand Up @@ -80,6 +85,7 @@ class WebSocketResponse(StreamResponse):
"_pong_response_cb",
"_compress",
"_max_msg_size",
"_ping_task",
)

def __init__(
Expand Down Expand Up @@ -120,12 +126,16 @@ def __init__(
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._compress = compress
self._max_msg_size = max_msg_size
self._ping_task: Optional[asyncio.Task[None]] = None

def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None

def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -165,11 +175,6 @@ def _send_heartbeat(self) -> None:
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
Expand All @@ -178,6 +183,15 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

ping_task = create_eager_task(self._writer.ping(), loop)
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
self._ping_task = None

def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._set_closed()
Expand Down
45 changes: 45 additions & 0 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,51 @@ async def handler(request: web.Request) -> NoReturn:
assert isinstance(msg.data, ServerTimeoutError)


async def test_close_websocket_while_ping_inflight(
aiohttp_client: AiohttpClient,
) -> None:
"""Test closing the websocket while a ping is in-flight."""
ping_received = False

async def handler(request: web.Request) -> NoReturn:
nonlocal ping_received
ws = web.WebSocketResponse(autoping=False)
await ws.prepare(request)
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.BINARY
msg = await ws.receive()
ping_received = msg.type is aiohttp.WSMsgType.PING
await ws.receive()
assert False

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)
resp = await client.ws_connect("/", heartbeat=0.1)
await resp.send_bytes(b"ask")

cancelled = False
ping_stated = False

async def delayed_ping() -> None:
nonlocal cancelled, ping_stated
ping_stated = True
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
cancelled = True
raise

with mock.patch.object(resp._writer, "ping", delayed_ping):
await asyncio.sleep(0.1)

await resp.close()
await asyncio.sleep(0)
assert ping_stated is True
assert cancelled is True


async def test_send_recv_compress(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
Expand Down
Loading