From b92861fa2d15c07c035bbeac254e18456c8e20a0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 7 Aug 2024 16:58:55 -0500 Subject: [PATCH] Fix timer handle churn in websocket heartbeat (#8608) Co-authored-by: Sam Bull (cherry picked from commit c4acabc836ab969e95199aa976e85c01df720a27) --- CHANGES/8608.misc.rst | 3 + aiohttp/client_ws.py | 115 ++++++++++++++++--------- aiohttp/helpers.py | 23 +++-- aiohttp/web_ws.py | 100 ++++++++++++--------- tests/test_client_ws.py | 48 +++++++++-- tests/test_client_ws_functional.py | 81 ++++++++++++++++- tests/test_web_websocket_functional.py | 59 ++++++++++++- 7 files changed, 331 insertions(+), 98 deletions(-) create mode 100644 CHANGES/8608.misc.rst diff --git a/CHANGES/8608.misc.rst b/CHANGES/8608.misc.rst new file mode 100644 index 00000000000..76e845bf997 --- /dev/null +++ b/CHANGES/8608.misc.rst @@ -0,0 +1,3 @@ +Improved websocket performance when messages are sent or received frequently -- by :user:`bdraco`. + +The WebSocket heartbeat scheduling algorithm was improved to reduce the ``asyncio`` scheduling overhead by decreasing the number of ``asyncio.TimerHandle`` creations and cancellations. diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index c1a2c4641ba..516ad586f70 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -6,7 +6,7 @@ from .client_exceptions import ClientError, ServerTimeoutError from .client_reqrep import ClientResponse -from .helpers import call_later, set_result +from .helpers import calculate_timeout_when, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -62,6 +62,7 @@ def __init__( self._autoping = autoping self._heartbeat = heartbeat self._heartbeat_cb: Optional[asyncio.TimerHandle] = None + self._heartbeat_when: float = 0.0 if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb: Optional[asyncio.TimerHandle] = None @@ -75,52 +76,64 @@ def __init__( self._reset_heartbeat() def _cancel_heartbeat(self) -> None: - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = None - + self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None - def _reset_heartbeat(self) -> None: - self._cancel_heartbeat() + def _cancel_pong_response_cb(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None - if self._heartbeat is not None: - self._heartbeat_cb = call_later( - self._send_heartbeat, - self._heartbeat, - self._loop, - timeout_ceil_threshold=( - self._conn._connector._timeout_ceil_threshold - if self._conn is not None - else 5 - ), - ) + def _reset_heartbeat(self) -> None: + if self._heartbeat is None: + return + self._cancel_pong_response_cb() + loop = self._loop + assert loop is not None + conn = self._conn + timeout_ceil_threshold = ( + conn._connector._timeout_ceil_threshold if conn is not None else 5 + ) + now = loop.time() + when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) + self._heartbeat_when = when + if self._heartbeat_cb is None: + # We do not cancel the previous heartbeat_cb here because + # it generates a significant amount of TimerHandle churn + # which causes asyncio to rebuild the heap frequently. + # Instead _send_heartbeat() will reschedule the next + # heartbeat if it fires too early. + self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) def _send_heartbeat(self) -> None: - if self._heartbeat is not None and not self._closed: - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - self._loop.create_task(self._writer.ping()) - - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = call_later( - self._pong_not_received, - self._pong_heartbeat, - self._loop, - timeout_ceil_threshold=( - self._conn._connector._timeout_ceil_threshold - if self._conn is not None - else 5 - ), + self._heartbeat_cb = None + loop = self._loop + now = loop.time() + if now < self._heartbeat_when: + # Heartbeat fired too early, reschedule + self._heartbeat_cb = loop.call_at( + self._heartbeat_when, self._send_heartbeat ) + return + + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] + + conn = self._conn + timeout_ceil_threshold = ( + conn._connector._timeout_ceil_threshold if conn is not None else 5 + ) + when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) + self._cancel_pong_response_cb() + self._pong_response_cb = loop.call_at(when, self._pong_not_received) def _pong_not_received(self) -> None: if not self._closed: - self._closed = True + self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE self._exception = ServerTimeoutError() self._response.close() @@ -129,6 +142,22 @@ def _pong_not_received(self) -> None: WSMessage(WSMsgType.ERROR, self._exception, None) ) + def _set_closed(self) -> None: + """Set the connection to closed. + + Cancel any heartbeat timers and set the closed flag. + """ + self._closed = True + self._cancel_heartbeat() + + def _set_closing(self) -> None: + """Set the connection to closing. + + Cancel any heartbeat timers and set the closing flag. + """ + self._closing = True + self._cancel_heartbeat() + @property def closed(self) -> bool: return self._closed @@ -193,13 +222,12 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo if self._waiting and not self._closing: assert self._loop is not None self._close_wait = self._loop.create_future() - self._closing = True + self._set_closing() self._reader.feed_data(WS_CLOSING_MESSAGE, 0) await self._close_wait if not self._closed: - self._cancel_heartbeat() - self._closed = True + self._set_closed() try: await self._writer.close(code, message) except asyncio.CancelledError: @@ -266,7 +294,8 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: await self.close() return WSMessage(WSMsgType.CLOSED, None, None) except ClientError: - self._closed = True + # Likely ServerDisconnectedError when connection is lost + self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE return WS_CLOSED_MESSAGE except WebSocketError as exc: @@ -275,18 +304,18 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return WSMessage(WSMsgType.ERROR, exc, None) except Exception as exc: self._exception = exc - self._closing = True + self._set_closing() self._close_code = WSCloseCode.ABNORMAL_CLOSURE await self.close() return WSMessage(WSMsgType.ERROR, exc, None) if msg.type is WSMsgType.CLOSE: - self._closing = True + self._set_closing() self._close_code = msg.data if not self._closed and self._autoclose: await self.close() elif msg.type is WSMsgType.CLOSING: - self._closing = True + self._set_closing() elif msg.type is WSMsgType.PING and self._autoping: await self.pong(msg.data) continue diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index b3cc1b6b6e6..437c871e8f7 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -586,12 +586,23 @@ def call_later( loop: asyncio.AbstractEventLoop, timeout_ceil_threshold: float = 5, ) -> Optional[asyncio.TimerHandle]: - if timeout is not None and timeout > 0: - when = loop.time() + timeout - if timeout > timeout_ceil_threshold: - when = ceil(when) - return loop.call_at(when, cb) - return None + if timeout is None or timeout <= 0: + return None + now = loop.time() + when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) + return loop.call_at(when, cb) + + +def calculate_timeout_when( + loop_time: float, + timeout: float, + timeout_ceiling_threshold: float, +) -> float: + """Calculate when to execute a timeout.""" + when = loop_time + timeout + if timeout > timeout_ceiling_threshold: + return ceil(when) + return when class TimeoutHandle: diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index b74bfd688c9..9f71d147997 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -11,7 +11,7 @@ from . import hdrs from .abc import AbstractStreamWriter -from .helpers import call_later, set_exception, set_result +from .helpers import calculate_timeout_when, set_exception, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -89,6 +89,7 @@ def __init__( self._autoclose = autoclose self._autoping = autoping self._heartbeat = heartbeat + self._heartbeat_when = 0.0 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 @@ -97,57 +98,76 @@ def __init__( self._max_msg_size = max_msg_size def _cancel_heartbeat(self) -> None: - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = None - + self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None - def _reset_heartbeat(self) -> None: - self._cancel_heartbeat() + def _cancel_pong_response_cb(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None - if self._heartbeat is not None: - assert self._loop is not None - self._heartbeat_cb = call_later( - self._send_heartbeat, - self._heartbeat, - self._loop, - timeout_ceil_threshold=( - self._req._protocol._timeout_ceil_threshold - if self._req is not None - else 5 - ), - ) + def _reset_heartbeat(self) -> None: + if self._heartbeat is None: + return + self._cancel_pong_response_cb() + req = self._req + timeout_ceil_threshold = ( + req._protocol._timeout_ceil_threshold if req is not None else 5 + ) + loop = self._loop + assert loop is not None + now = loop.time() + when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) + self._heartbeat_when = when + if self._heartbeat_cb is None: + # We do not cancel the previous heartbeat_cb here because + # it generates a significant amount of TimerHandle churn + # which causes asyncio to rebuild the heap frequently. + # Instead _send_heartbeat() will reschedule the next + # heartbeat if it fires too early. + self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) def _send_heartbeat(self) -> None: - if self._heartbeat is not None and not self._closed: - assert self._loop is not None - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - self._loop.create_task(self._writer.ping()) # type: ignore[union-attr] - - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = call_later( - self._pong_not_received, - self._pong_heartbeat, - self._loop, - timeout_ceil_threshold=( - self._req._protocol._timeout_ceil_threshold - if self._req is not None - else 5 - ), + self._heartbeat_cb = None + loop = self._loop + assert loop is not None and self._writer is not None + now = loop.time() + if now < self._heartbeat_when: + # Heartbeat fired too early, reschedule + self._heartbeat_cb = loop.call_at( + self._heartbeat_when, self._send_heartbeat ) + return + + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] + + req = self._req + timeout_ceil_threshold = ( + req._protocol._timeout_ceil_threshold if req is not None else 5 + ) + when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) + self._cancel_pong_response_cb() + self._pong_response_cb = loop.call_at(when, self._pong_not_received) def _pong_not_received(self) -> None: if self._req is not None and self._req.transport is not None: - self._closed = True + self._set_closed() self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) self._exception = asyncio.TimeoutError() + def _set_closed(self) -> None: + """Set the connection to closed. + + Cancel any heartbeat timers and set the closed flag. + """ + self._closed = True + self._cancel_heartbeat() + async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: # make pre-check to don't hide it by do_handshake() exceptions if self._payload_writer is not None: @@ -387,7 +407,7 @@ async def close( if self._closed: return False - self._closed = True + self._set_closed() try: await self._writer.close(code, message) writer = self._payload_writer @@ -431,6 +451,7 @@ def _set_closing(self, code: WSCloseCode) -> None: """Set the close code and mark the connection as closing.""" self._closing = True self._close_code = code + self._cancel_heartbeat() def _set_code_close_transport(self, code: WSCloseCode) -> None: """Set the close code and close the transport.""" @@ -543,5 +564,6 @@ def _cancel(self, exc: BaseException) -> None: # web_protocol calls this from connection_lost # or when the server is shutting down. self._closing = True + self._cancel_heartbeat() if self._reader is not None: set_exception(self._reader, exc) diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index ebc9d910c1a..a790fba43ec 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -9,6 +9,7 @@ import aiohttp from aiohttp import client, hdrs +from aiohttp.client_exceptions import ServerDisconnectedError from aiohttp.http import WS_KEY from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro @@ -404,21 +405,56 @@ async def test_close_eofstream(loop, ws_key, key_data) -> None: await session.close() -async def test_close_exc(loop, ws_key, key_data) -> None: - resp = mock.Mock() - resp.status = 101 - resp.headers = { +async def test_close_connection_lost( + loop: asyncio.AbstractEventLoop, ws_key: bytes, key_data: bytes +) -> None: + """Test the websocket client handles the connection being closed out from under it.""" + mresp = mock.Mock(spec_set=client.ClientResponse) + mresp.status = 101 + mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - resp.connection.protocol.read_timeout = None + mresp.connection.protocol.read_timeout = None + with mock.patch("aiohttp.client.WebSocketWriter"), mock.patch( + "aiohttp.client.os" + ) as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(mresp) + + session = aiohttp.ClientSession() + resp = await session.ws_connect("http://test.org") + assert not resp.closed + + exc = ServerDisconnectedError() + resp._reader.set_exception(exc) + + msg = await resp.receive() + assert msg.type is aiohttp.WSMsgType.CLOSED + assert resp.closed + + await session.close() + + +async def test_close_exc( + loop: asyncio.AbstractEventLoop, ws_key: bytes, key_data: bytes +) -> None: + mresp = mock.Mock() + mresp.status = 101 + mresp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() - m_req.return_value.set_result(resp) + m_req.return_value.set_result(mresp) writer = mock.Mock() WebSocketWriter.return_value = writer writer.close = make_mocked_coro() diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index dc474f96c39..5abaf0fefbf 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,6 +1,6 @@ import asyncio import sys -from typing import Any +from typing import Any, NoReturn import pytest @@ -599,7 +599,8 @@ async def handler(request): assert ping_received -async def test_heartbeat_no_pong(aiohttp_client) -> None: +async def test_heartbeat_no_pong(aiohttp_client: AiohttpClient) -> None: + """Test that the connection is closed if no pong is received without sending messages.""" ping_received = False async def handler(request): @@ -624,7 +625,81 @@ async def handler(request): assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE -async def test_heartbeat_no_pong_concurrent_receive(aiohttp_client: Any) -> None: +async def test_heartbeat_no_pong_after_receive_many_messages( + aiohttp_client: AiohttpClient, +) -> None: + """Test that the connection is closed if no pong is received after receiving many messages.""" + ping_received = False + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + for _ in range(5): + await ws.send_str("test") + await asyncio.sleep(0.05) + for _ in range(5): + await ws.send_str("test") + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + + for _ in range(10): + test_msg = await resp.receive() + assert test_msg.data == "test" + # Connection should be closed roughly after 1.5x heartbeat. + + await asyncio.sleep(0.2) + assert ping_received + assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE + + +async def test_heartbeat_no_pong_after_send_many_messages( + aiohttp_client: AiohttpClient, +) -> None: + """Test that the connection is closed if no pong is received after sending many messages.""" + ping_received = False + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + for _ in range(10): + msg = await ws.receive() + assert msg.data == "test" + assert msg.type is aiohttp.WSMsgType.TEXT + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + + for _ in range(5): + await resp.send_str("test") + await asyncio.sleep(0.05) + for _ in range(5): + await resp.send_str("test") + # Connection should be closed roughly after 1.5x heartbeat. + await asyncio.sleep(0.2) + assert ping_received + assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE + + +async def test_heartbeat_no_pong_concurrent_receive( + aiohttp_client: AiohttpClient, +) -> None: ping_received = False async def handler(request): diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index ce338cdf92d..15ef33e3648 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -722,7 +722,64 @@ async def handler(request): await ws.close() -async def test_server_ws_async_for(loop, aiohttp_server) -> None: +async def test_heartbeat_no_pong_send_many_messages( + loop: Any, aiohttp_client: Any +) -> None: + """Test no pong after sending many messages.""" + + async def handler(request): + ws = web.WebSocketResponse(heartbeat=0.05) + await ws.prepare(request) + for _ in range(10): + await ws.send_str("test") + + await ws.receive() + return ws + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + for _ in range(10): + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert msg.data == "test" + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.PING + await ws.close() + + +async def test_heartbeat_no_pong_receive_many_messages( + loop: Any, aiohttp_client: Any +) -> None: + """Test no pong after receiving many messages.""" + + async def handler(request): + ws = web.WebSocketResponse(heartbeat=0.05) + await ws.prepare(request) + for _ in range(10): + server_msg = await ws.receive() + assert server_msg.type is aiohttp.WSMsgType.TEXT + + await ws.receive() + return ws + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + for _ in range(10): + await ws.send_str("test") + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.PING + await ws.close() + + +async def test_server_ws_async_for(loop: Any, aiohttp_server: Any) -> None: closed = loop.create_future() async def handler(request):