Skip to content

Commit

Permalink
Fix WebSocket ping tasks being prematurely garbage collected
Browse files Browse the repository at this point in the history
The event loop only keeps weak references to tasks, we need to
hold a strong reference to ensure that the ping task is not
prematurely garbage collected.
https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task

In almost all cases the ping can be done synchronously if
the task is created eagerly which avoids scheduling the ping
task on the event loop.

fixes #8614
  • Loading branch information
bdraco committed Aug 8, 2024
1 parent c4acabc commit 62a99de
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 21 deletions.
14 changes: 5 additions & 9 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
HeadersMixin,
TimerNoop,
basicauth_from_netrc,
create_eager_task,
is_expected_content_type,
netrc_from_env,
parse_mimetype,
Expand Down Expand Up @@ -668,15 +669,10 @@ async def send(self, conn: "Connection") -> "ClientResponse":
await writer.write_headers(status_line, self.headers)
coro = self.write_bytes(writer, conn)

if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to write
# bytes immediately to avoid having to schedule
# the task on the event loop.
task = asyncio.Task(coro, loop=self.loop, eager_start=True)
else:
task = self.loop.create_task(coro)

self._writer = task
# Optimization for Python 3.12+, try to write
# bytes immediately to avoid having to schedule
# the task on the event loop.
self._writer = create_eager_task(coro, self.loop)
response_class = self.response_class
assert response_class is not None
self.response = response_class(
Expand Down
21 changes: 15 additions & 6 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .client_exceptions import ClientError, ServerTimeoutError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
from .helpers import calculate_timeout_when, create_eager_task, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: Optional[asyncio.Task[None]] = None

self._reset_heartbeat()

Expand All @@ -90,6 +91,9 @@ def _cancel_heartbeat(self) -> None:
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None

def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -128,11 +132,6 @@ def _send_heartbeat(self) -> None:
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
Expand All @@ -141,6 +140,16 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

self._ping_task = create_eager_task(self._writer.ping(), loop)
if self._ping_task.done():
self._ping_task = None
else:
self._ping_task.add_done_callback(self._ping_task_done)

def _ping_task_done(self, task: asyncio.Task[None]) -> None:
"""Callback for when the ping task completes."""
self._ping_task = None

def _pong_not_received(self) -> None:
if not self._closed:
self._set_closed()
Expand Down
15 changes: 15 additions & 0 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,21 @@ def weakref_handle(
return None


def create_eager_task(
coro: Callable[..., Any],
loop: asyncio.AbstractEventLoop,
) -> asyncio.Task:
"""Create a task that will be scheduled immediately if possible."""
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to write
# bytes immediately to avoid having to schedule
# the task on the event loop.
return asyncio.Task(coro, loop=loop, eager_start=True)
# For older python versions, we need to schedule the task
# on the event loop as eager_start is not available.
return loop.create_task(coro)


def call_later(
cb: Callable[[], Any],
timeout: Optional[float],
Expand Down
27 changes: 21 additions & 6 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import calculate_timeout_when, set_exception, set_result
from .helpers import (
calculate_timeout_when,
create_eager_task,
set_exception,
set_result,
)
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
Expand Down Expand Up @@ -80,6 +85,7 @@ class WebSocketResponse(StreamResponse):
"_pong_response_cb",
"_compress",
"_max_msg_size",
"_ping_task",
)

def __init__(
Expand Down Expand Up @@ -120,12 +126,16 @@ def __init__(
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._compress = compress
self._max_msg_size = max_msg_size
self._ping_task: Optional[asyncio.Task[None]] = None

def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None

def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -165,11 +175,6 @@ def _send_heartbeat(self) -> None:
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
Expand All @@ -178,6 +183,16 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

self._ping_task = create_eager_task(self._writer.ping(), loop)
if self._ping_task.done():
self._ping_task = None
else:
self._ping_task.add_done_callback(self._ping_task_done)

def _ping_task_done(self, task: asyncio.Task[None]) -> None:
"""Callback for when the ping task completes."""
self._ping_task = None

def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._set_closed()
Expand Down

0 comments on commit 62a99de

Please sign in to comment.