Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Extract zlib-related logic into a single module #7223

Merged
merged 24 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fa479be
refactor: Extract zlib-related logic into a single module
mykola-mokhnach Feb 24, 2023
1c80247
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2023
e9f6c38
make mypy happier
mykola-mokhnach Feb 24, 2023
18f0320
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2023
536b54b
Fix line length
mykola-mokhnach Feb 24, 2023
ce4e693
Update tests
mykola-mokhnach Feb 24, 2023
401c01b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2023
1dac3f7
Make mypy happy
mykola-mokhnach Feb 24, 2023
04dba2f
fix defaults
mykola-mokhnach Feb 25, 2023
573ee76
Remove extra types
mykola-mokhnach Mar 3, 2023
c7f7578
Delete the obsolete method
mykola-mokhnach Mar 3, 2023
beb6c90
Address comments
mykola-mokhnach Mar 3, 2023
b08e461
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2023
829bab6
address review comments
mykola-mokhnach Mar 5, 2023
4b274c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2023
a65547d
Tune tests
mykola-mokhnach Mar 5, 2023
e796fe9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2023
9257884
Make mypy happy
mykola-mokhnach Mar 5, 2023
4b85a90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2023
d9f3c49
Update compression_utils.py
Dreamsorcerer Mar 7, 2023
570ffce
Update compression_utils.py
Dreamsorcerer Mar 7, 2023
662dfcb
Update multipart.py
Dreamsorcerer Mar 7, 2023
96b35a6
Update compression_utils.py
Dreamsorcerer Mar 7, 2023
307eafe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import collections
import re
import string
import zlib
from contextlib import suppress
from enum import IntEnum
from typing import (
Any,
Generic,
List,
NamedTuple,
Expand All @@ -22,7 +20,7 @@
)

from multidict import CIMultiDict, CIMultiDictProxy, istr
from typing_extensions import Final
from typing_extensions import Final, Protocol
from yarl import URL

from . import hdrs
Expand All @@ -41,6 +39,7 @@
from .log import internal_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import RawHeaders
from .zlib_utils import ZLibDecompressor

try:
import brotli
Expand All @@ -50,6 +49,18 @@
HAS_BROTLI = False


class Decompressor(Protocol):
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
@property
def eof(self) -> bool:
...
Fixed Show fixed Hide fixed

def decompress_sync(self, data: bytes, max_length: int = ...) -> bytes:
...
Fixed Show fixed Hide fixed

def flush(self, length: int = ...) -> bytes:
...
Fixed Show fixed Hide fixed


__all__ = (
"HeadersParser",
"HttpParser",
Expand Down Expand Up @@ -859,6 +870,7 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
self.encoding = encoding
self._started_decoding = False

self.decompressor: Decompressor
if encoding == "br":
if not HAS_BROTLI: # pragma: no cover
raise ContentEncodingError(
Expand All @@ -872,21 +884,26 @@ class BrotliDecoder:
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self) -> None:
self._obj = brotli.Decompressor()
self._is_at_eof = False

def decompress(self, data: bytes) -> bytes:
def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))

def flush(self) -> bytes:
def flush(self, length: int = 0) -> bytes:
self._is_at_eof = True
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""

self.decompressor: Any = BrotliDecoder()
@property
def eof(self) -> bool:
return self._is_at_eof

self.decompressor = BrotliDecoder()
else:
zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS
self.decompressor = zlib.decompressobj(wbits=zlib_mode)
self.decompressor = ZLibDecompressor(encoding=encoding)

def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)
Expand All @@ -907,10 +924,12 @@ def feed_data(self, chunk: bytes, size: int) -> None:
):
# Change the decoder to decompress incorrectly compressed data
# Actually we should issue a warning about non-RFC-compliant data.
self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
self.decompressor = ZLibDecompressor(
encoding=self.encoding, suppress_deflate_header=True
)

try:
chunk = self.decompressor.decompress(chunk)
chunk = self.decompressor.decompress_sync(chunk)
except Exception:
raise ContentEncodingError(
"Can not decode content-encoding: %s" % self.encoding
Expand Down
16 changes: 9 additions & 7 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
from .streams import DataQueue
from .zlib_utils import ZLibCompressor, ZLibDecompressor

__all__ = (
"WS_CLOSED_MESSAGE",
Expand Down Expand Up @@ -270,7 +271,7 @@ def __init__(
self._payload_length = 0
self._payload_length_flag = 0
self._compressed: Optional[bool] = None
self._decompressobj: Any = None # zlib.decompressobj actually
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress

def feed_eof(self) -> None:
Expand All @@ -290,7 +291,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
for fin, opcode, payload, compressed in self.parse_frame(data):
if compressed and not self._decompressobj:
self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
if opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
Expand Down Expand Up @@ -375,8 +376,9 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
# Decompress process must to be done after all packets
# received.
if compressed:
assert self._decompressobj is not None
self._partial.extend(_WS_DEFLATE_TRAILING)
payload_merged = self._decompressobj.decompress(
payload_merged = self._decompressobj.decompress_sync(
self._partial, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
Expand Down Expand Up @@ -604,16 +606,16 @@ async def _send_frame(
if (compress or self.compress) and opcode < 8:
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = zlib.compressobj(level=zlib.Z_BEST_SPEED, wbits=-compress)
compressobj = ZLibCompressor(level=zlib.Z_BEST_SPEED, wbits=-compress)
else: # self.compress
if not self._compressobj:
self._compressobj = zlib.compressobj(
self._compressobj = ZLibCompressor(
level=zlib.Z_BEST_SPEED, wbits=-self.compress
)
compressobj = self._compressobj

message = compressobj.compress(message)
message = message + compressobj.flush(
message = await compressobj.compress(message)
message += compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
if message.endswith(_WS_DEFLATE_TRAILING):
Expand Down
12 changes: 6 additions & 6 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
from .zlib_utils import ZLibCompressor

__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")

Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(
self.output_size = 0

self._eof = False
self._compress: Any = None
self._compress: Optional[ZLibCompressor] = None
self._drain_waiter = None

self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
Expand All @@ -63,8 +64,7 @@ def enable_chunking(self) -> None:
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS
self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy)
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)

def _write(self, chunk: bytes) -> None:
size = len(chunk)
Expand Down Expand Up @@ -93,7 +93,7 @@ async def write(
chunk = chunk.cast("c")

if self._compress is not None:
chunk = self._compress.compress(chunk)
chunk = await self._compress.compress(chunk)
if not chunk:
return

Expand Down Expand Up @@ -138,9 +138,9 @@ async def write_eof(self, chunk: bytes = b"") -> None:

if self._compress:
if chunk:
chunk = self._compress.compress(chunk)
chunk = await self._compress.compress(chunk)

chunk = chunk + self._compress.flush()
chunk += self._compress.flush()
if chunk and self.chunked:
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
Expand Down
28 changes: 16 additions & 12 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
payload_type,
)
from .streams import StreamReader
from .zlib_utils import ZLibCompressor, ZLibDecompressor

__all__ = (
"MultipartReader",
Expand Down Expand Up @@ -491,15 +492,15 @@ def decode(self, data: bytes) -> bytes:

def _decode_content(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_ENCODING, "").lower()

if encoding == "deflate":
return zlib.decompress(data, -zlib.MAX_WBITS)
elif encoding == "gzip":
return zlib.decompress(data, 16 + zlib.MAX_WBITS)
elif encoding == "identity":
if encoding == "identity":
return data
else:
raise RuntimeError(f"unknown content encoding: {encoding}")
if encoding in ("deflate", "gzip"):
return ZLibDecompressor(
encoding=encoding,
suppress_deflate_header=True,
).decompress_sync(data)

raise RuntimeError(f"unknown content encoding: {encoding}")

def _decode_content_transfer(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
Expand Down Expand Up @@ -976,7 +977,7 @@ class MultipartPayloadWriter:
def __init__(self, writer: Any) -> None:
self._writer = writer
self._encoding: Optional[str] = None
self._compress: Any = None
self._compress: Optional[ZLibCompressor] = None
self._encoding_buffer: Optional[bytearray] = None

def enable_encoding(self, encoding: str) -> None:
Expand All @@ -989,8 +990,11 @@ def enable_encoding(self, encoding: str) -> None:
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else -zlib.MAX_WBITS
self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy)
self._compress = ZLibCompressor(
encoding=encoding,
suppress_deflate_header=True,
strategy=strategy,
)

async def write_eof(self) -> None:
if self._compress is not None:
Expand All @@ -1006,7 +1010,7 @@ async def write_eof(self) -> None:
async def write(self, chunk: bytes) -> None:
if self._compress is not None:
if chunk:
chunk = self._compress.compress(chunk)
chunk = await self._compress.compress(chunk)
if not chunk:
return

Expand Down
33 changes: 17 additions & 16 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .payload import Payload
from .typedefs import JSONEncoder, LooseHeaders
from .zlib_utils import ZLibCompressor

__all__ = ("ContentCoding", "StreamResponse", "Response", "json_response")

Expand Down Expand Up @@ -706,26 +707,26 @@ async def _do_start_compression(self, coding: ContentCoding) -> None:
if coding != ContentCoding.identity:
# Instead of using _payload_writer.enable_compression,
# compress the whole body
zlib_mode = (
16 + zlib.MAX_WBITS if coding == ContentCoding.gzip else zlib.MAX_WBITS
compressor = ZLibCompressor(
encoding=str(coding.value),
max_sync_chunk_size=self._zlib_executor_size,
executor=self._zlib_executor,
)
body_in = self._body
assert body_in is not None
if (
self._zlib_executor_size is not None
and len(body_in) > self._zlib_executor_size
):
await asyncio.get_event_loop().run_in_executor(
self._zlib_executor, self._compress_body, zlib_mode
assert self._body is not None
if self._zlib_executor_size is None and len(self._body) > 1024 * 1024:
warnings.warn(
"Synchronous compression of large response bodies "
f"({len(self._body)} bytes) might block the async event loop. "
"Consider providing a custom value to zlib_executor_size/"
"zlib_executor response properties or disabling compression on it."
)
else:
self._compress_body(zlib_mode)
mykola-mokhnach marked this conversation as resolved.
Show resolved Hide resolved

body_out = self._compressed_body
assert body_out is not None
self._compressed_body = (
await compressor.compress(self._body) + compressor.flush()
)
assert self._compressed_body is not None

self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._headers[hdrs.CONTENT_LENGTH] = str(len(body_out))
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._compressed_body))


def json_response(
Expand Down
Loading