diff --git a/src/waitress/receiver.py b/src/waitress/receiver.py index 87852808..6289d1a9 100644 --- a/src/waitress/receiver.py +++ b/src/waitress/receiver.py @@ -14,6 +14,7 @@ """Data Chunk Receiver """ +from waitress.rfc7230 import CHUNK_EXT_RE, ONLY_HEXDIG_RE from waitress.utilities import BadRequest, find_double_newline @@ -110,6 +111,7 @@ def received(self, s): s = b"" else: self.chunk_end = b"" + if pos == 0: # Chop off the terminating CR LF from the chunk s = s[2:] @@ -140,7 +142,14 @@ def received(self, s): semi = line.find(b";") if semi >= 0: - # discard extension info. + extinfo = line[semi:] + valid_ext_info = CHUNK_EXT_RE.match(extinfo) + + if not valid_ext_info: + self.error = BadRequest("Invalid chunk extension") + self.all_chunks_received = True + + break line = line[:semi] try: sz = int(line.strip(), 16) # hexadecimal diff --git a/tests/test_functional.py b/tests/test_functional.py index d1366caf..9e33fc0f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -364,6 +364,28 @@ def test_broken_chunked_encoding(self): self.send_check_error(to_send) self.assertRaises(ConnectionClosed, read_http, fp) + def test_broken_chunked_encoding_invalid_extension(self): + control_line = b"20;invalid=\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + to_send = b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + b"\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertIn(b"Invalid chunk extension", response_body) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["content-type"], "text/plain") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + def test_broken_chunked_encoding_missing_chunk_end(self): control_line = b"20\r\n" # 20 hex = 32 dec s = b"This string has 32 characters.\r\n" diff --git a/tests/test_receiver.py b/tests/test_receiver.py index f55aa68e..a3b6f994 100644 --- a/tests/test_receiver.py +++ b/tests/test_receiver.py @@ -1,5 +1,7 @@ import unittest +import pytest + class TestFixedStreamReceiver(unittest.TestCase): def _makeOne(self, cl, buf): @@ -226,6 +228,41 @@ def test_received_multiple_chunks_split(self): self.assertEqual(inst.error, None) +class TestChunkedReceiverParametrized: + def _makeOne(self, buf): + from waitress.receiver import ChunkedReceiver + + return ChunkedReceiver(buf) + + @pytest.mark.parametrize( + "invalid_extension", [b"\n", b"invalid=", b"\r", b"invalid = true"] + ) + def test_received_invalid_extensions(self, invalid_extension): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4;" + invalid_extension + b"\r\ntest\r\n" + result = inst.received(data) + assert result == len(data) + assert inst.error.__class__ == BadRequest + assert inst.error.body == "Invalid chunk extension" + + @pytest.mark.parametrize( + "valid_extension", [b"test", b"valid=true", b"valid=true;other=true"] + ) + def test_received_valid_extensions(self, valid_extension): + # While waitress may ignore extensions in Chunked Encoding, we do want + # to make sure that we don't fail when we do encounter one that is + # valid + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4;" + valid_extension + b"\r\ntest\r\n" + result = inst.received(data) + assert result == len(data) + assert inst.error == None + + class DummyBuffer: def __init__(self, data=None): if data is None: