Skip to content

Commit

Permalink
🗜 Fix deflate compression (#4506)
Browse files Browse the repository at this point in the history
  • Loading branch information
socketpair committed Jan 27, 2020
1 parent 4d03dbb commit ae7395a
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGES/4506.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed 'deflate' compressions. According to RFC 2616 now.
30 changes: 18 additions & 12 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,30 +727,36 @@ def flush(self) -> bytes:
self.decompressor = BrotliDecoder() # type: Any
else:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
if encoding == 'gzip' else zlib.MAX_WBITS)
self.decompressor = zlib.decompressobj(wbits=zlib_mode)

def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)

def feed_data(self, chunk: bytes, size: int) -> None:
if not size:
return

self.size += size

# RFC1950
# bits 0..3 = CM = 0b1000 = 8 = "deflate"
# bits 4..7 = CINFO = 1..7 = windows size.
if not self._started_decoding and self.encoding == 'deflate' \
and chunk[0] & 0xf != 8:
# Change the decoder to decompress incorrectly compressed data
# Actually we should issue a warning about non-RFC-compilant data.
self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS)

try:
chunk = self.decompressor.decompress(chunk)
except Exception:
if not self._started_decoding and self.encoding == 'deflate':
self.decompressor = zlib.decompressobj()
try:
chunk = self.decompressor.decompress(chunk)
except Exception:
raise ContentEncodingError(
'Can not decode content-encoding: %s' % self.encoding)
else:
raise ContentEncodingError(
'Can not decode content-encoding: %s' % self.encoding)
raise ContentEncodingError(
'Can not decode content-encoding: %s' % self.encoding)

self._started_decoding = True

if chunk:
self._started_decoding = True
self.out.feed_data(chunk, len(chunk))

def feed_eof(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def enable_chunking(self) -> None:

def enable_compression(self, encoding: str='deflate') -> None:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
if encoding == 'gzip' else zlib.MAX_WBITS)
self._compress = zlib.compressobj(wbits=zlib_mode)

def _write(self, chunk: bytes) -> None:
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ async def _start(self, request: 'BaseRequest') -> AbstractStreamWriter:
return await super()._start(request)

def _compress_body(self, zlib_mode: int) -> None:
assert zlib_mode > 0
compressobj = zlib.compressobj(wbits=zlib_mode)
body_in = self._body
assert body_in is not None
Expand All @@ -683,7 +684,7 @@ async def _do_start_compression(self, coding: ContentCoding) -> None:
# Instead of using _payload_writer.enable_compression,
# compress the whole body
zlib_mode = (16 + zlib.MAX_WBITS
if coding == ContentCoding.gzip else -zlib.MAX_WBITS)
if coding == ContentCoding.gzip else zlib.MAX_WBITS)
body_in = self._body
assert body_in is not None
if self._zlib_executor_size is not None and \
Expand Down
66 changes: 51 additions & 15 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Tests for aiohttp/protocol.py

import asyncio
import zlib
from unittest import mock

import pytest
Expand Down Expand Up @@ -837,32 +836,66 @@ async def test_http_payload_parser_length(self, stream) -> None:
assert b'12' == b''.join(d for d, _ in out._buffer)
assert b'45' == tail

_comp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
_COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()])

async def test_http_payload_parser_deflate(self, stream) -> None:
length = len(self._COMPRESSED)
# c=compressobj(wbits=15); b''.join([c.compress(b'data'), c.flush()])
COMPRESSED = b'x\x9cKI,I\x04\x00\x04\x00\x01\x9b'

length = len(COMPRESSED)
out = aiohttp.FlowControlDataQueue(stream,
loop=asyncio.get_event_loop())
p = HttpPayloadParser(
out, length=length, compression='deflate')
p.feed_data(self._COMPRESSED)
p = HttpPayloadParser(out, length=length, compression='deflate')
p.feed_data(COMPRESSED)
assert b'data' == b''.join(d for d, _ in out._buffer)
assert out.is_eof()

async def test_http_payload_parser_deflate_no_wbits(self, stream) -> None:
comp = zlib.compressobj()
COMPRESSED = b''.join([comp.compress(b'data'), comp.flush()])
async def test_http_payload_parser_deflate_no_hdrs(self, stream) -> None:
"""Tests incorrectly formed data (no zlib headers) """

# c=compressobj(wbits=-15); b''.join([c.compress(b'data'), c.flush()])
COMPRESSED = b'KI,I\x04\x00'

length = len(COMPRESSED)
out = aiohttp.FlowControlDataQueue(stream,
loop=asyncio.get_event_loop())
p = HttpPayloadParser(
out, length=length, compression='deflate')
p = HttpPayloadParser(out, length=length, compression='deflate')
p.feed_data(COMPRESSED)
assert b'data' == b''.join(d for d, _ in out._buffer)
assert out.is_eof()

async def test_http_payload_parser_deflate_light(self, stream) -> None:
# c=compressobj(wbits=9); b''.join([c.compress(b'data'), c.flush()])
COMPRESSED = b'\x18\x95KI,I\x04\x00\x04\x00\x01\x9b'

length = len(COMPRESSED)
out = aiohttp.FlowControlDataQueue(stream,
loop=asyncio.get_event_loop())
p = HttpPayloadParser(out, length=length, compression='deflate')
p.feed_data(COMPRESSED)
assert b'data' == b''.join(d for d, _ in out._buffer)
assert out.is_eof()

async def test_http_payload_parser_deflate_split(self, stream) -> None:
out = aiohttp.FlowControlDataQueue(stream,
loop=asyncio.get_event_loop())
p = HttpPayloadParser(out, compression='deflate', readall=True)
# Feeding one correct byte should be enough to choose exact
# deflate decompressor
p.feed_data(b'x', 1)
p.feed_data(b'\x9cKI,I\x04\x00\x04\x00\x01\x9b', 11)
p.feed_eof()
assert b'data' == b''.join(d for d, _ in out._buffer)

async def test_http_payload_parser_deflate_split_err(self, stream) -> None:
out = aiohttp.FlowControlDataQueue(stream,
loop=asyncio.get_event_loop())
p = HttpPayloadParser(out, compression='deflate', readall=True)
# Feeding one wrong byte should be enough to choose exact
# deflate decompressor
p.feed_data(b'K', 1)
p.feed_data(b'I,I\x04\x00', 5)
p.feed_eof()
assert b'data' == b''.join(d for d, _ in out._buffer)

async def test_http_payload_parser_length_zero(self, stream) -> None:
out = aiohttp.FlowControlDataQueue(stream,
loop=asyncio.get_event_loop())
Expand Down Expand Up @@ -892,7 +925,8 @@ async def test_feed_data(self, stream) -> None:
dbuf.decompressor = mock.Mock()
dbuf.decompressor.decompress.return_value = b'line'

dbuf.feed_data(b'data', 4)
# First byte should be b'x' in order code not to change the decoder.
dbuf.feed_data(b'xxxx', 4)
assert [b'line'] == list(d for d, _ in buf._buffer)

async def test_feed_data_err(self, stream) -> None:
Expand All @@ -905,7 +939,9 @@ async def test_feed_data_err(self, stream) -> None:
dbuf.decompressor.decompress.side_effect = exc

with pytest.raises(http_exceptions.ContentEncodingError):
dbuf.feed_data(b'data', 4)
# Should be more than 4 bytes to trigger deflate FSM error.
# Should start with b'x', otherwise code switch mocked decoder.
dbuf.feed_data(b'xsomedata', 9)

async def test_feed_eof(self, stream) -> None:
buf = aiohttp.FlowControlDataQueue(stream,
Expand Down
14 changes: 8 additions & 6 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Tests for aiohttp/http_writer.py
import zlib
from unittest import mock

import pytest
Expand Down Expand Up @@ -117,12 +116,10 @@ async def test_write_payload_chunked_filter_mutiple_chunks(
b'2\r\na2\r\n0\r\n\r\n')


compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS)
COMPRESSED = b''.join([compressor.compress(b'data'), compressor.flush()])


async def test_write_payload_deflate_compression(protocol,
transport, loop) -> None:

COMPRESSED = b'x\x9cKI,I\x04\x00\x04\x00\x01\x9b'
write = transport.write = mock.Mock()
msg = http.StreamWriter(protocol, loop)
msg.enable_compression('deflate')
Expand All @@ -148,7 +145,12 @@ async def test_write_payload_deflate_and_chunked(
await msg.write(b'ta')
await msg.write_eof()

assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf
thing = (
b'2\r\nx\x9c\r\n'
b'a\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n'
b'0\r\n\r\n'
)
assert thing == buf


async def test_write_drain(protocol, transport, loop) -> None:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,28 @@ async def test_response_with_precompressed_body_deflate(

async def handler(request):
headers = {'Content-Encoding': 'deflate'}
zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS)
data = zcomp.compress(b'mydata') + zcomp.flush()
return web.Response(body=data, headers=headers)

app = web.Application()
app.router.add_get('/', handler)
client = await aiohttp_client(app)

resp = await client.get('/')
assert 200 == resp.status
data = await resp.read()
assert b'mydata' == data
assert resp.headers.get('Content-Encoding') == 'deflate'


async def test_response_with_precompressed_body_deflate_no_hdrs(
aiohttp_client) -> None:

async def handler(request):
headers = {'Content-Encoding': 'deflate'}
# Actually, wrong compression format, but
# should be supported for some legacy cases.
zcomp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
data = zcomp.compress(b'mydata') + zcomp.flush()
return web.Response(body=data, headers=headers)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_web_sendfile_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ async def handler(request):

resp = await client.get('/')
assert resp.status == 200
zcomp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS)
expected_body = zcomp.compress(b'file content\n') + zcomp.flush()
assert expected_body == await resp.read()
assert 'application/octet-stream' == resp.headers['Content-Type']
Expand Down

1 comment on commit ae7395a

@socketpair
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asvetlov I've fixed tests and linters

Please sign in to comment.