From 4aa7294e7cc97aa7e82f58a972359ff651788f1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9A=D0=BE=D1=80=D0=B5=D0=BD=D0=B1=D0=B5=D1=80=D0=B3=20?= =?UTF-8?q?=E2=98=A2=EF=B8=8F=20=20=D0=9C=D0=B0=D1=80=D0=BA?= Date: Fri, 17 Jan 2020 19:26:20 +0500 Subject: [PATCH] Fix deflate compression (#4506) --- CHANGES/4506.bugfix | 1 + aiohttp/http_parser.py | 31 ++++++++------ aiohttp/http_writer.py | 2 +- aiohttp/web_response.py | 3 +- tests/test_http_parser.py | 61 +++++++++++++++++++++------ tests/test_http_writer.py | 14 +++--- tests/test_web_functional.py | 22 ++++++++++ tests/test_web_sendfile_functional.py | 2 +- 8 files changed, 101 insertions(+), 35 deletions(-) create mode 100644 CHANGES/4506.bugfix diff --git a/CHANGES/4506.bugfix b/CHANGES/4506.bugfix new file mode 100644 index 00000000000..eaf4bb88aac --- /dev/null +++ b/CHANGES/4506.bugfix @@ -0,0 +1 @@ +Fixed 'deflate' compressions. According to RFC 2616 now. diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index f2881c3bf11..e757f3fea9c 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -727,30 +727,37 @@ def flush(self) -> bytes: self.decompressor = BrotliDecoder() # type: Any else: zlib_mode = (16 + zlib.MAX_WBITS - if encoding == 'gzip' else -zlib.MAX_WBITS) + if encoding == 'gzip' else zlib.MAX_WBITS) self.decompressor = zlib.decompressobj(wbits=zlib_mode) def set_exception(self, exc: BaseException) -> None: self.out.set_exception(exc) def feed_data(self, chunk: bytes, size: int) -> None: + if not size: + return + self.size += size + + # RFC1950 + # bits 0..3 = CM = 0b1000 = 8 = "deflate" + # bits 4..7 = CINFO = 1..7 = windows size. + if not self._started_decoding \ + and self.encoding == 'deflate' \ + and chunk[0] & 0xf != 8: + # Change the decoder to decompress incorrectly compressed data + # Actually we should issue a warning about non-RFC-compilant data. + self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + try: chunk = self.decompressor.decompress(chunk) except Exception: - if not self._started_decoding and self.encoding == 'deflate': - self.decompressor = zlib.decompressobj() - try: - chunk = self.decompressor.decompress(chunk) - except Exception: - raise ContentEncodingError( - 'Can not decode content-encoding: %s' % self.encoding) - else: - raise ContentEncodingError( - 'Can not decode content-encoding: %s' % self.encoding) + raise ContentEncodingError( + 'Can not decode content-encoding: %s' % self.encoding) + + self._started_decoding = True if chunk: - self._started_decoding = True self.out.feed_data(chunk, len(chunk)) def feed_eof(self) -> None: diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index 7e27fbf6a43..102fb3ef2f4 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -55,7 +55,7 @@ def enable_chunking(self) -> None: def enable_compression(self, encoding: str='deflate') -> None: zlib_mode = (16 + zlib.MAX_WBITS - if encoding == 'gzip' else -zlib.MAX_WBITS) + if encoding == 'gzip' else zlib.MAX_WBITS) self._compress = zlib.compressobj(wbits=zlib_mode) def _write(self, chunk: bytes) -> None: diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 4dc64976839..fdb7aa4d359 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -669,6 +669,7 @@ async def _start(self, request: 'BaseRequest') -> AbstractStreamWriter: return await super()._start(request) def _compress_body(self, zlib_mode: int) -> None: + assert zlib_mode > 0 compressobj = zlib.compressobj(wbits=zlib_mode) body_in = self._body assert body_in is not None @@ -683,7 +684,7 @@ async def _do_start_compression(self, coding: ContentCoding) -> None: # Instead of using _payload_writer.enable_compression, # compress the whole body zlib_mode = (16 + zlib.MAX_WBITS - if coding == ContentCoding.gzip else -zlib.MAX_WBITS) + if coding == ContentCoding.gzip else zlib.MAX_WBITS) body_in = self._body assert body_in is not None if self._zlib_executor_size is not None and \ diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 19fe9be7a3c..cd3373767d9 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -1,7 +1,6 @@ # Tests for aiohttp/protocol.py import asyncio -import zlib from unittest import mock import pytest @@ -837,32 +836,66 @@ async def test_http_payload_parser_length(self, stream) -> None: assert b'12' == b''.join(d for d, _ in out._buffer) assert b'45' == tail - _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) - async def test_http_payload_parser_deflate(self, stream) -> None: - length = len(self._COMPRESSED) + # c=compressobj(wbits=15); b''.join([c.compress(b'data'), c.flush()]) + COMPRESSED = b'x\x9cKI,I\x04\x00\x04\x00\x01\x9b' + + length = len(COMPRESSED) out = aiohttp.FlowControlDataQueue(stream, loop=asyncio.get_event_loop()) - p = HttpPayloadParser( - out, length=length, compression='deflate') - p.feed_data(self._COMPRESSED) + p = HttpPayloadParser(out, length=length, compression='deflate') + p.feed_data(COMPRESSED) assert b'data' == b''.join(d for d, _ in out._buffer) assert out.is_eof() - async def test_http_payload_parser_deflate_no_wbits(self, stream) -> None: - comp = zlib.compressobj() - COMPRESSED = b''.join([comp.compress(b'data'), comp.flush()]) + async def test_http_payload_parser_deflate_no_hdrs(self, stream) -> None: + """Tests incorrectly formed data (no zlib headers) """ + + # c=compressobj(wbits=-15); b''.join([c.compress(b'data'), c.flush()]) + COMPRESSED = b'KI,I\x04\x00' length = len(COMPRESSED) out = aiohttp.FlowControlDataQueue(stream, loop=asyncio.get_event_loop()) - p = HttpPayloadParser( - out, length=length, compression='deflate') + p = HttpPayloadParser(out, length=length, compression='deflate') p.feed_data(COMPRESSED) assert b'data' == b''.join(d for d, _ in out._buffer) assert out.is_eof() + async def test_http_payload_parser_deflate_light(self, stream) -> None: + # c=compressobj(wbits=9); b''.join([c.compress(b'data'), c.flush()]) + COMPRESSED = b'\x18\x95KI,I\x04\x00\x04\x00\x01\x9b' + + length = len(COMPRESSED) + out = aiohttp.FlowControlDataQueue(stream, + loop=asyncio.get_event_loop()) + p = HttpPayloadParser(out, length=length, compression='deflate') + p.feed_data(COMPRESSED) + assert b'data' == b''.join(d for d, _ in out._buffer) + assert out.is_eof() + + async def test_http_payload_parser_deflate_split(self, stream) -> None: + out = aiohttp.FlowControlDataQueue(stream, + loop=asyncio.get_event_loop()) + p = HttpPayloadParser(out, compression='deflate', readall=True) + # Feeding one correct byte should be enough to choose exact + # deflate decompressor + p.feed_data(b'x', 1) + p.feed_data(b'\x9cKI,I\x04\x00\x04\x00\x01\x9b', 11) + p.feed_eof() + assert b'data' == b''.join(d for d, _ in out._buffer) + + async def test_http_payload_parser_deflate_split_err(self, stream) -> None: + out = aiohttp.FlowControlDataQueue(stream, + loop=asyncio.get_event_loop()) + p = HttpPayloadParser(out, compression='deflate', readall=True) + # Feeding one wrong byte should be enough to choose exact + # deflate decompressor + p.feed_data(b'K', 1) + p.feed_data(b'I,I\x04\x00', 5) + p.feed_eof() + assert b'data' == b''.join(d for d, _ in out._buffer) + async def test_http_payload_parser_length_zero(self, stream) -> None: out = aiohttp.FlowControlDataQueue(stream, loop=asyncio.get_event_loop()) @@ -905,7 +938,7 @@ async def test_feed_data_err(self, stream) -> None: dbuf.decompressor.decompress.side_effect = exc with pytest.raises(http_exceptions.ContentEncodingError): - dbuf.feed_data(b'data', 4) + dbuf.feed_data(b'somedata', 8) # Should be more than 4 bytes async def test_feed_eof(self, stream) -> None: buf = aiohttp.FlowControlDataQueue(stream, diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 2f8085f8a85..ae10fb08413 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -1,5 +1,4 @@ # Tests for aiohttp/http_writer.py -import zlib from unittest import mock import pytest @@ -117,12 +116,10 @@ async def test_write_payload_chunked_filter_mutiple_chunks( b'2\r\na2\r\n0\r\n\r\n') -compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS) -COMPRESSED = b''.join([compressor.compress(b'data'), compressor.flush()]) - - async def test_write_payload_deflate_compression(protocol, transport, loop) -> None: + + COMPRESSED = b'x\x9cKI,I\x04\x00\x04\x00\x01\x9b' write = transport.write = mock.Mock() msg = http.StreamWriter(protocol, loop) msg.enable_compression('deflate') @@ -148,7 +145,12 @@ async def test_write_payload_deflate_and_chunked( await msg.write(b'ta') await msg.write_eof() - assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf + thing = ( + b'2\r\nx\x9c\r\n' + b'a\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n' + b'0\r\n\r\n' + ) + assert thing == buf async def test_write_drain(protocol, transport, loop) -> None: diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 4a8a71370de..d261fe49d53 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -953,6 +953,28 @@ async def test_response_with_precompressed_body_deflate( async def handler(request): headers = {'Content-Encoding': 'deflate'} + zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS) + data = zcomp.compress(b'mydata') + zcomp.flush() + return web.Response(body=data, headers=headers) + + app = web.Application() + app.router.add_get('/', handler) + client = await aiohttp_client(app) + + resp = await client.get('/') + assert 200 == resp.status + data = await resp.read() + assert b'mydata' == data + assert resp.headers.get('Content-Encoding') == 'deflate' + + +async def test_response_with_precompressed_body_deflate_no_hdrs( + aiohttp_client) -> None: + + async def handler(request): + headers = {'Content-Encoding': 'deflate'} + # Actually, wrong compression format, but + # should be supported for some legacy cases. zcomp = zlib.compressobj(wbits=-zlib.MAX_WBITS) data = zcomp.compress(b'mydata') + zcomp.flush() return web.Response(body=data, headers=headers) diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index c2be5dbff0d..02aceb69f7b 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -742,7 +742,7 @@ async def handler(request): resp = await client.get('/') assert resp.status == 200 - zcomp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS) expected_body = zcomp.compress(b'file content\n') + zcomp.flush() assert expected_body == await resp.read() assert 'application/octet-stream' == resp.headers['Content-Type']