From 354489d2d4d2665253bb0b387d08d02dd5d3ad4f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Nov 2024 19:35:47 -0600 Subject: [PATCH] [PR #9839/a9a0d84 backport][3.11] Implement zero copy writes in `StreamWriter` (#9847) --- CHANGES/9839.misc.rst | 1 + aiohttp/http_writer.py | 68 ++++++++++++----- tests/test_client_request.py | 13 ++-- tests/test_http_writer.py | 143 ++++++++++++++++++++++++++++++++--- 4 files changed, 191 insertions(+), 34 deletions(-) create mode 100644 CHANGES/9839.misc.rst diff --git a/CHANGES/9839.misc.rst b/CHANGES/9839.misc.rst new file mode 100644 index 00000000000..8bdd50268a7 --- /dev/null +++ b/CHANGES/9839.misc.rst @@ -0,0 +1 @@ +Implemented zero copy writes for ``StreamWriter`` -- by :user:`bdraco`. diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index a1a9860b48d..c6c80edc3c4 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -2,7 +2,16 @@ import asyncio import zlib -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa +from typing import ( # noqa + Any, + Awaitable, + Callable, + Iterable, + List, + NamedTuple, + Optional, + Union, +) from multidict import CIMultiDict @@ -76,6 +85,17 @@ def _write(self, chunk: bytes) -> None: raise ClientConnectionResetError("Cannot write to closing transport") transport.write(chunk) + def _writelines(self, chunks: Iterable[bytes]) -> None: + size = 0 + for chunk in chunks: + size += len(chunk) + self.buffer_size += size + self.output_size += size + transport = self._protocol.transport + if transport is None or transport.is_closing(): + raise ClientConnectionResetError("Cannot write to closing transport") + transport.writelines(chunks) + async def write( self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 ) -> None: @@ -110,10 +130,11 @@ async def write( if chunk: if self.chunked: - chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii") - chunk = chunk_len_pre + chunk + b"\r\n" - - self._write(chunk) + self._writelines( + (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n") + ) + else: + self._write(chunk) if self.buffer_size > LIMIT and drain: self.buffer_size = 0 @@ -142,22 +163,31 @@ async def write_eof(self, chunk: bytes = b"") -> None: await self._on_chunk_sent(chunk) if self._compress: - if chunk: - chunk = await self._compress.compress(chunk) + chunks: List[bytes] = [] + chunks_len = 0 + if chunk and (compressed_chunk := await self._compress.compress(chunk)): + chunks_len = len(compressed_chunk) + chunks.append(compressed_chunk) - chunk += self._compress.flush() - if chunk and self.chunked: - chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") - chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" - else: - if self.chunked: - if chunk: - chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") - chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" - else: - chunk = b"0\r\n\r\n" + flush_chunk = self._compress.flush() + chunks_len += len(flush_chunk) + chunks.append(flush_chunk) + assert chunks_len - if chunk: + if self.chunked: + chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii") + self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")) + elif len(chunks) > 1: + self._writelines(chunks) + else: + self._write(chunks[0]) + elif self.chunked: + if chunk: + chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii") + self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n")) + else: + self._write(b"0\r\n\r\n") + elif chunk: self._write(chunk) await self.drain() diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 8947aa38944..870c9666f34 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -6,7 +6,7 @@ import urllib.parse import zlib from http.cookies import BaseCookie, Morsel, SimpleCookie -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Iterable, Optional from unittest import mock import pytest @@ -67,17 +67,18 @@ def protocol(loop, transport): @pytest.fixture -def transport(buf): - transport = mock.Mock() +def transport(buf: bytearray) -> mock.Mock: + transport = mock.create_autospec(asyncio.Transport, spec_set=True, instance=True) def write(chunk): buf.extend(chunk) - async def write_eof(): - pass + def writelines(chunks: Iterable[bytes]) -> None: + for chunk in chunks: + buf.extend(chunk) transport.write.side_effect = write - transport.write_eof.side_effect = write_eof + transport.writelines.side_effect = writelines transport.is_closing.return_value = False return transport diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index d330da48df7..e43b448bc0f 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -1,6 +1,8 @@ # Tests for aiohttp/http_writer.py import array import asyncio +import zlib +from typing import Iterable from unittest import mock import pytest @@ -23,7 +25,12 @@ def transport(buf): def write(chunk): buf.extend(chunk) + def writelines(chunks: Iterable[bytes]) -> None: + for chunk in chunks: + buf.extend(chunk) + transport.write.side_effect = write + transport.writelines.side_effect = writelines transport.is_closing.return_value = False return transport @@ -85,21 +92,53 @@ async def test_write_payload_length(protocol, transport, loop) -> None: assert b"da" == content.split(b"\r\n\r\n", 1)[-1] -async def test_write_payload_chunked_filter(protocol, transport, loop) -> None: - write = transport.write = mock.Mock() +async def test_write_large_payload_deflate_compression_data_in_eof( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + + await msg.write(b"data" * 4096) + assert transport.write.called # type: ignore[attr-defined] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + transport.write.reset_mock() # type: ignore[attr-defined] + assert not transport.writelines.called # type: ignore[attr-defined] + # This payload compresses to 20447 bytes + payload = b"".join( + [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] + ) + await msg.write_eof(payload) + assert not transport.write.called # type: ignore[attr-defined] + assert transport.writelines.called # type: ignore[attr-defined] + chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined] + content = b"".join(chunks) + assert zlib.decompress(content) == (b"data" * 4096) + payload + + +async def test_write_payload_chunked_filter( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b"da") await msg.write(b"ta") await msg.write_eof() - content = b"".join([c[1][0] for c in list(write.mock_calls)]) + content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined] + content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n") -async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport, loop): - write = transport.write = mock.Mock() +async def test_write_payload_chunked_filter_multiple_chunks( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b"da") @@ -108,14 +147,14 @@ async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport, await msg.write(b"at") await msg.write(b"a2") await msg.write_eof() - content = b"".join([c[1][0] for c in list(write.mock_calls)]) + content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined] + content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert content.endswith( b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n2\r\na2\r\n0\r\n\r\n" ) async def test_write_payload_deflate_compression(protocol, transport, loop) -> None: - COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" write = transport.write = mock.Mock() msg = http.StreamWriter(protocol, loop) @@ -129,7 +168,30 @@ async def test_write_payload_deflate_compression(protocol, transport, loop) -> N assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1] -async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop): +async def test_write_payload_deflate_compression_chunked( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n" + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof() + + chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert content == expected + + +async def test_write_payload_deflate_and_chunked( + buf: bytearray, + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() @@ -142,8 +204,71 @@ async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop) assert thing == buf -async def test_write_payload_bytes_memoryview(buf, protocol, transport, loop): +async def test_write_payload_deflate_compression_chunked_data_in_eof( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n" + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof(b"end") + + chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert content == expected + + +async def test_write_large_payload_deflate_compression_chunked_data_in_eof( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + + await msg.write(b"data" * 4096) + # This payload compresses to 1111 bytes + payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) + await msg.write_eof(payload) + assert not transport.write.called # type: ignore[attr-defined] + chunks = [] + for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined] + chunked_payload = list(write_lines_call[1][0])[1:] + chunked_payload.pop() + chunks.extend(chunked_payload) + + assert all(chunks) + content = b"".join(chunks) + assert zlib.decompress(content) == (b"data" * 4096) + payload + + +async def test_write_payload_deflate_compression_chunked_connection_lost( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + with pytest.raises( + ClientConnectionResetError, match="Cannot write to closing transport" + ), mock.patch.object(transport, "is_closing", return_value=True): + await msg.write_eof(b"end") + + +async def test_write_payload_bytes_memoryview( + buf: bytearray, + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: msg = http.StreamWriter(protocol, loop) mv = memoryview(b"abcd")