From bb111012706d3ef9edc525be3d8d4df410ad847f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 24 Nov 2023 15:11:06 -0600 Subject: [PATCH] Restore async concurrency safety to websocket compressor (#7865) (#7889) Fixes #7859 (cherry picked from commit 86a23961531103ccc34853f67321c7d0f63797f5) --- CHANGES/7865.bugfix | 1 + aiohttp/compression_utils.py | 22 +++++++---- aiohttp/http_websocket.py | 26 ++++++++----- tests/test_websocket_writer.py | 67 +++++++++++++++++++++++++++++++++- 4 files changed, 97 insertions(+), 19 deletions(-) create mode 100644 CHANGES/7865.bugfix diff --git a/CHANGES/7865.bugfix b/CHANGES/7865.bugfix new file mode 100644 index 00000000000..9a46e124486 --- /dev/null +++ b/CHANGES/7865.bugfix @@ -0,0 +1 @@ +Restore async concurrency safety to websocket compressor diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 52791fe5015..9631d377e9a 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -62,19 +62,25 @@ def __init__( self._compressor = zlib.compressobj( wbits=self._mode, strategy=strategy, level=level ) + self._compress_lock = asyncio.Lock() def compress_sync(self, data: bytes) -> bytes: return self._compressor.compress(data) async def compress(self, data: bytes) -> bytes: - if ( - self._max_sync_chunk_size is not None - and len(data) > self._max_sync_chunk_size - ): - return await asyncio.get_event_loop().run_in_executor( - self._executor, self.compress_sync, data - ) - return self.compress_sync(data) + async with self._compress_lock: + # To ensure the stream is consistent in the event + # there are multiple writers, we need to lock + # the compressor so that only one writer can + # compress at a time. + if ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ): + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.compress_sync, data + ) + return self.compress_sync(data) def flush(self, mode: int = zlib.Z_FINISH) -> bytes: return self._compressor.flush(mode) diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index a94ac2a73dd..f395a27614a 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -635,21 +635,17 @@ async def _send_frame( if (compress or self.compress) and opcode < 8: if compress: # Do not set self._compress if compressing is for this frame - compressobj = ZLibCompressor( - level=zlib.Z_BEST_SPEED, - wbits=-compress, - max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, - ) + compressobj = self._make_compress_obj(compress) else: # self.compress if not self._compressobj: - self._compressobj = ZLibCompressor( - level=zlib.Z_BEST_SPEED, - wbits=-self.compress, - max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, - ) + self._compressobj = self._make_compress_obj(self.compress) compressobj = self._compressobj message = await compressobj.compress(message) + # Its critical that we do not return control to the event + # loop until we have finished sending all the compressed + # data. Otherwise we could end up mixing compressed frames + # if there are multiple coroutines compressing data. message += compressobj.flush( zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH ) @@ -687,10 +683,20 @@ async def _send_frame( self._output_size += len(header) + len(message) + # It is safe to return control to the event loop when using compression + # after this point as we have already sent or buffered all the data. + if self._output_size > self._limit: self._output_size = 0 await self.protocol._drain_helper() + def _make_compress_obj(self, compress: int) -> ZLibCompressor: + return ZLibCompressor( + level=zlib.Z_BEST_SPEED, + wbits=-compress, + max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + ) + def _write(self, data: bytes) -> None: if self.transport is None or self.transport.is_closing(): raise ConnectionResetError("Cannot write to closing transport") diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index fce3c330d27..8dbbc815fb7 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -1,9 +1,12 @@ +import asyncio import random +from typing import Any, Callable from unittest import mock import pytest -from aiohttp.http import WebSocketWriter +from aiohttp import DataQueue, WSMessage +from aiohttp.http import WebSocketReader, WebSocketWriter from aiohttp.test_utils import make_mocked_coro @@ -104,3 +107,65 @@ async def test_send_compress_text_per_message(protocol, transport) -> None: writer.transport.write.assert_called_with(b"\x81\x04text") await writer.send(b"text", compress=15) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") + + +@pytest.mark.parametrize( + ("max_sync_chunk_size", "payload_point_generator"), + ( + (16, lambda count: count), + (4096, lambda count: count), + (32, lambda count: 64 + count if count % 2 else count), + ), +) +async def test_concurrent_messages( + protocol: Any, + transport: Any, + max_sync_chunk_size: int, + payload_point_generator: Callable[[int], int], +) -> None: + """Ensure messages are compressed correctly when there are multiple concurrent writers. + + This test generates is parametrized to + + - Generate messages that are larger than patch + WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 16 + where compression will run in the executor + + - Generate messages that are smaller than patch + WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 4096 + where compression will run in the event loop + + - Interleave generated messages with a + WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 32 + where compression will run in the event loop + and in the executor + """ + with mock.patch( + "aiohttp.http_websocket.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", max_sync_chunk_size + ): + writer = WebSocketWriter(protocol, transport, compress=15) + queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop()) + reader = WebSocketReader(queue, 50000) + writers = [] + payloads = [] + for count in range(1, 64 + 1): + point = payload_point_generator(count) + payload = bytes((point,)) * point + payloads.append(payload) + writers.append(writer.send(payload, binary=True)) + await asyncio.gather(*writers) + + for call in writer.transport.write.call_args_list: + call_bytes = call[0][0] + result, _ = reader.feed_data(call_bytes) + assert result is False + msg = await queue.read() + bytes_data: bytes = msg.data + first_char = bytes_data[0:1] + char_val = ord(first_char) + assert len(bytes_data) == char_val + # If we have a concurrency problem, the data + # tends to get mixed up between messages so + # we want to validate that all the bytes are + # the same value + assert bytes_data == bytes_data[0:1] * char_val