Skip to content

Commit

Permalink
Fix deflate compression (#4506)
Browse files Browse the repository at this point in the history
  • Loading branch information
socketpair committed Jan 17, 2020
1 parent 6d2b136 commit 7dc5091
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGES/4506.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed 'deflate' compressions. According to RFC 2616 now.
6 changes: 4 additions & 2 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def flush(self) -> bytes:
self.decompressor = BrotliDecoder() # type: Any
else:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
if encoding == 'gzip' else zlib.MAX_WBITS)
self.decompressor = zlib.decompressobj(wbits=zlib_mode)

def set_exception(self, exc: BaseException) -> None:
Expand All @@ -739,7 +739,9 @@ def feed_data(self, chunk: bytes, size: int) -> None:
chunk = self.decompressor.decompress(chunk)
except Exception:
if not self._started_decoding and self.encoding == 'deflate':
self.decompressor = zlib.decompressobj()
# Try to change the decoder to decompress incorrectly
# compressed data
self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
try:
chunk = self.decompressor.decompress(chunk)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def enable_chunking(self) -> None:

def enable_compression(self, encoding: str='deflate') -> None:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
if encoding == 'gzip' else zlib.MAX_WBITS)
self._compress = zlib.compressobj(wbits=zlib_mode)

def _write(self, chunk: bytes) -> None:
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ async def _start(self, request: 'BaseRequest') -> AbstractStreamWriter:
return await super()._start(request)

def _compress_body(self, zlib_mode: int) -> None:
assert zlib_mode > 0
compressobj = zlib.compressobj(wbits=zlib_mode)
body_in = self._body
assert body_in is not None
Expand All @@ -683,7 +684,7 @@ async def _do_start_compression(self, coding: ContentCoding) -> None:
# Instead of using _payload_writer.enable_compression,
# compress the whole body
zlib_mode = (16 + zlib.MAX_WBITS
if coding == ContentCoding.gzip else -zlib.MAX_WBITS)
if coding == ContentCoding.gzip else zlib.MAX_WBITS)
body_in = self._body
assert body_in is not None
if self._zlib_executor_size is not None and \
Expand Down
17 changes: 8 additions & 9 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Tests for aiohttp/protocol.py

import asyncio
import zlib
from unittest import mock

import pytest
Expand Down Expand Up @@ -837,22 +836,22 @@ async def test_http_payload_parser_length(self, stream) -> None:
assert b'12' == b''.join(d for d, _ in out._buffer)
assert b'45' == tail

_comp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
_COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()])

async def test_http_payload_parser_deflate(self, stream) -> None:
length = len(self._COMPRESSED)
COMPRESSED = b'x\x9cKI,I\x04\x00\x04\x00\x01\x9b'

length = len(COMPRESSED)
out = aiohttp.FlowControlDataQueue(stream,
loop=asyncio.get_event_loop())
p = HttpPayloadParser(
out, length=length, compression='deflate')
p.feed_data(self._COMPRESSED)
p.feed_data(COMPRESSED)
assert b'data' == b''.join(d for d, _ in out._buffer)
assert out.is_eof()

async def test_http_payload_parser_deflate_no_wbits(self, stream) -> None:
comp = zlib.compressobj()
COMPRESSED = b''.join([comp.compress(b'data'), comp.flush()])
"""Tests incorrectly formed messages """

COMPRESSED = b'KI,I\x04\x00'

length = len(COMPRESSED)
out = aiohttp.FlowControlDataQueue(stream,
Expand Down Expand Up @@ -905,7 +904,7 @@ async def test_feed_data_err(self, stream) -> None:
dbuf.decompressor.decompress.side_effect = exc

with pytest.raises(http_exceptions.ContentEncodingError):
dbuf.feed_data(b'data', 4)
dbuf.feed_data(b'data1', 5)

async def test_feed_eof(self, stream) -> None:
buf = aiohttp.FlowControlDataQueue(stream,
Expand Down
9 changes: 7 additions & 2 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def test_write_payload_chunked_filter_mutiple_chunks(
b'2\r\na2\r\n0\r\n\r\n')


compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS)
compressor = zlib.compressobj(wbits=zlib.MAX_WBITS)
COMPRESSED = b''.join([compressor.compress(b'data'), compressor.flush()])


Expand Down Expand Up @@ -148,7 +148,12 @@ async def test_write_payload_deflate_and_chunked(
await msg.write(b'ta')
await msg.write_eof()

assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf
thing = (
b'2\r\nx\x9c\r\n'
b'a\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n'
b'0\r\n\r\n'
)
assert thing == buf


async def test_write_drain(protocol, transport, loop) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ async def test_response_with_precompressed_body_deflate(

async def handler(request):
headers = {'Content-Encoding': 'deflate'}
zcomp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS)
data = zcomp.compress(b'mydata') + zcomp.flush()
return web.Response(body=data, headers=headers)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_web_sendfile_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ async def handler(request):

resp = await client.get('/')
assert resp.status == 200
zcomp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS)
expected_body = zcomp.compress(b'file content\n') + zcomp.flush()
assert expected_body == await resp.read()
assert 'application/octet-stream' == resp.headers['Content-Type']
Expand Down

0 comments on commit 7dc5091

Please sign in to comment.