diff --git a/docs/changelog.rst b/docs/changelog.rst index 26093b198..0bd7910c0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -14,6 +14,8 @@ Changelog * Made read and write buffer sizes configurable. +* Rewrote HTTP handling for simplicity and performance. + 3.3 ... diff --git a/websockets/client.py b/websockets/client.py index a90939e18..411cf37f5 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -5,11 +5,10 @@ import asyncio import collections.abc -import email.message from .exceptions import InvalidHandshake, InvalidMessage from .handshake import build_request, check_response -from .http import USER_AGENT, read_response +from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol from .uri import parse_uri @@ -35,9 +34,7 @@ def write_http_request(self, path, headers): """ self.path = path - self.request_headers = email.message.Message() - for name, value in headers: - self.request_headers[name] = value + self.request_headers = build_headers(headers) self.raw_request_headers = headers # Since the path and headers only contain ASCII characters, @@ -63,10 +60,10 @@ def read_http_response(self): except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc - self.response_headers = headers - self.raw_response_headers = list(headers.raw_items()) + self.response_headers = build_headers(headers) + self.raw_response_headers = headers - return status_code, headers + return status_code, self.response_headers def process_subprotocol(self, get_header, subprotocols=None): """ diff --git a/websockets/http.py b/websockets/http.py index 173a35691..e71e8c78d 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -8,8 +8,8 @@ """ import asyncio -import email.parser -import io +import http.client +import re import sys from .version import version as websockets_version @@ -26,6 +26,26 @@ )) +# See https://tools.ietf.org/html/rfc7230#appendix-B. + +# Regex for validating header names. + +_token_re = re.compile(rb'^[-!#$%&\'*+.^_`|~0-9a-zA-Z]+$') + +# Regex for validating header values. + +# We don't attempt to support obsolete line folding. + +# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). + +# The ABNF is complicated because it attempts to express that optional +# whitespace is ignored. We strip whitespace and don't revalidate that. + +# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + +_value_re = re.compile(rb'^[\x09\x20-\x7e\x80-\xff]*$') + + @asyncio.coroutine def read_request(stream): """ @@ -34,20 +54,38 @@ def read_request(stream): ``stream`` is an :class:`~asyncio.StreamReader`. Return ``(path, headers)`` where ``path`` is a :class:`str` and - ``headers`` is a :class:`~email.message.Message`. ``path`` isn't - URL-decoded. + ``headers`` is a list of ``(name, value)`` tuples. + + ``path`` isn't URL-decoded or validated in any way. + + Non-ASCII characters are represented with surrogate escapes. Raise an exception if the request isn't well formatted. The request is assumed not to contain a body. """ - request_line, headers = yield from read_message(stream) - method, path, version = request_line[:-2].decode().split(None, 2) - if method != 'GET': - raise ValueError("Unsupported method") - if version != 'HTTP/1.1': - raise ValueError("Unsupported HTTP version") + # https://tools.ietf.org/html/rfc7230#section-3.1.1 + + # Parsing is simple because fixed values are expected for method and + # version and because path isn't checked. Since WebSocket software tends + # to implement HTTP/1.1 strictly, there's little need for lenient parsing. + + # Given the implementation of read_line(), request_line ends with CRLF. + request_line = yield from read_line(stream) + + # This may raise "ValueError: not enough values to unpack" + method, path, version = request_line[:-2].split(b' ', 2) + + if method != b'GET': + raise ValueError("Unsupported HTTP method: %r" % method) + if version != b'HTTP/1.1': + raise ValueError("Unsupported HTTP version: %r" % version) + + path = path.decode('ascii', 'surrogateescape') + + headers = yield from read_headers(stream) + return path, headers @@ -59,45 +97,82 @@ def read_response(stream): ``stream`` is an :class:`~asyncio.StreamReader`. Return ``(status, headers)`` where ``status`` is a :class:`int` and - ``headers`` is a :class:`~email.message.Message`. + ``headers`` is a list of ``(name, value)`` tuples. + + Non-ASCII characters are represented with surrogate escapes. Raise an exception if the request isn't well formatted. The response is assumed not to contain a body. """ - status_line, headers = yield from read_message(stream) - version, status, reason = status_line[:-2].decode().split(" ", 2) - if version != 'HTTP/1.1': - raise ValueError("Unsupported HTTP version") - return int(status), headers + # https://tools.ietf.org/html/rfc7230#section-3.1.2 + + # As in read_request, parsing is simple because a fixed value is expected + # for version, status is a 3-digit number, and reason can be ignored. + + # Given the implementation of read_line(), status_line ends with CRLF. + status_line = yield from read_line(stream) + + # This may raise "ValueError: not enough values to unpack" + version, status, reason = status_line[:-2].split(b' ', 2) + + if version != b'HTTP/1.1': + raise ValueError("Unsupported HTTP version: %r" % version) + # This may raise "ValueError: invalid literal for int() with base 10" + status = int(status) + if not 100 <= status < 1000: + raise ValueError("Unsupported HTTP status code: %d" % status) + if not _value_re.match(reason): + raise ValueError("Invalid HTTP reason phrase: %r" % reason) + + headers = yield from read_headers(stream) + + return status, headers @asyncio.coroutine -def read_message(stream): +def read_headers(stream): """ Read an HTTP message from ``stream``. ``stream`` is an :class:`~asyncio.StreamReader`. Return ``(start_line, headers)`` where ``start_line`` is :class:`bytes` - and ``headers`` is a :class:`~email.message.Message`. + and ``headers`` is a list of ``(name, value)`` tuples. + + Non-ASCII characters are represented with surrogate escapes. The message is assumed not to contain a body. """ - start_line = yield from read_line(stream) - header_lines = io.BytesIO() - for num in range(MAX_HEADERS): - header_line = yield from read_line(stream) - header_lines.write(header_line) - if header_line == b'\r\n': + # https://tools.ietf.org/html/rfc7230#section-3.2 + + # We don't attempt to support obsolete line folding. + + headers = [] + for _ in range(MAX_HEADERS): + line = yield from read_line(stream) + if line == b'\r\n': break + + # This may raise "ValueError: not enough values to unpack" + name, value = line[:-2].split(b':', 1) + if not _token_re.match(name): + raise ValueError("Invalid HTTP header name: %r" % name) + value = value.strip(b' \t') + if not _value_re.match(value): + raise ValueError("Invalid HTTP header value: %r" % value) + + headers.append(( + name.decode('ascii'), # guaranteed to be ASCII at this point + value.decode('ascii', 'surrogateescape'), + )) + else: - raise ValueError("Too many headers") - header_lines.seek(0) - headers = email.parser.BytesHeaderParser().parse(header_lines) - return start_line, headers + raise ValueError("Too many HTTP headers") + + return headers @asyncio.coroutine @@ -108,9 +183,24 @@ def read_line(stream): ``stream`` is an :class:`~asyncio.StreamReader`. """ + # Security: this is bounded by the StreamReader's limit (default = 32kB). line = yield from stream.readline() + # Security: this guarantees header values are small (hardcoded = 4kB) if len(line) > MAX_LINE: raise ValueError("Line too long") + # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b'\r\n'): raise ValueError("Line without CRLF") return line + + +def build_headers(raw_headers): + """ + Build a date structure for HTTP headers from a list of name - value pairs. + + See also https://github.com/aaugustin/websockets/issues/210. + + """ + headers = http.client.HTTPMessage() + headers._headers = raw_headers # HACK + return headers diff --git a/websockets/protocol.py b/websockets/protocol.py index 530f998a0..b0fb7c893 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -92,11 +92,13 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): processed, the request path is available in the :attr:`path` attribute, and the request and response HTTP headers are available: - * as a MIME :class:`~email.message.Message` in the :attr:`request_headers` + * as a :class:`~http.client.HTTPMessage` in the :attr:`request_headers` and :attr:`response_headers` attributes * as an iterable of (name, value) pairs in the :attr:`raw_request_headers` and :attr:`raw_response_headers` attributes + These attributes must be treated as immutable. + If a subprotocol was negotiated, it's available in the :attr:`subprotocol` attribute. diff --git a/websockets/server.py b/websockets/server.py index 119e251e5..43a9c682b 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -6,14 +6,13 @@ import asyncio import collections.abc -import email.message import http import logging from .compatibility import asyncio_ensure_future from .exceptions import InvalidHandshake, InvalidMessage, InvalidOrigin from .handshake import build_response, check_request -from .http import USER_AGENT, read_request +from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -155,10 +154,10 @@ def read_http_request(self): raise InvalidMessage("Malformed HTTP message") from exc self.path = path - self.request_headers = headers - self.raw_request_headers = list(headers.raw_items()) + self.request_headers = build_headers(headers) + self.raw_request_headers = headers - return path, headers + return path, self.request_headers @asyncio.coroutine def write_http_response(self, status, headers): @@ -166,9 +165,7 @@ def write_http_response(self, status, headers): Write status line and headers to the HTTP response. """ - self.response_headers = email.message.Message() - for name, value in headers: - self.response_headers[name] = value + self.response_headers = build_headers(headers) self.raw_response_headers = headers # Since the status line and headers only contain ASCII characters, diff --git a/websockets/test_http.py b/websockets/test_http.py index b31bd84d0..28ad4a25e 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -2,10 +2,10 @@ import unittest from .http import * -from .http import read_message # private API +from .http import build_headers, read_headers -class HTTPTests(unittest.TestCase): +class HTTPAsyncTests(unittest.TestCase): def setUp(self): super().setUp() @@ -32,7 +32,7 @@ def test_read_request(self): ) path, hdrs = self.loop.run_until_complete(read_request(self.stream)) self.assertEqual(path, '/chat') - self.assertEqual(hdrs['Upgrade'], 'websocket') + self.assertEqual(dict(hdrs)['Upgrade'], 'websocket') def test_read_response(self): # Example from the protocol overview in RFC 6455 @@ -46,32 +46,82 @@ def test_read_response(self): ) status, hdrs = self.loop.run_until_complete(read_response(self.stream)) self.assertEqual(status, 101) - self.assertEqual(hdrs['Upgrade'], 'websocket') + self.assertEqual(dict(hdrs)['Upgrade'], 'websocket') - def test_method(self): + def test_request_method(self): self.stream.feed_data(b'OPTIONS * HTTP/1.1\r\n\r\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_request(self.stream)) - def test_version(self): + def test_request_version(self): self.stream.feed_data(b'GET /chat HTTP/1.0\r\n\r\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_request(self.stream)) + + def test_response_version(self): self.stream.feed_data(b'HTTP/1.0 400 Bad Request\r\n\r\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_response(self.stream)) + def test_response_status(self): + self.stream.feed_data(b'HTTP/1.1 007 My name is Bond\r\n\r\n') + with self.assertRaises(ValueError): + self.loop.run_until_complete(read_response(self.stream)) + + def test_response_reason(self): + self.stream.feed_data(b'HTTP/1.1 200 \x7f\r\n\r\n') + with self.assertRaises(ValueError): + self.loop.run_until_complete(read_response(self.stream)) + + def test_header_name(self): + self.stream.feed_data(b'foo bar: baz qux\r\n\r\n') + with self.assertRaises(ValueError): + self.loop.run_until_complete(read_headers(self.stream)) + + def test_header_value(self): + self.stream.feed_data(b'foo: \x00\x00\x0f\r\n\r\n') + with self.assertRaises(ValueError): + self.loop.run_until_complete(read_headers(self.stream)) + def test_headers_limit(self): self.stream.feed_data(b'foo: bar\r\n' * 500 + b'\r\n') with self.assertRaises(ValueError): - self.loop.run_until_complete(read_message(self.stream)) + self.loop.run_until_complete(read_headers(self.stream)) def test_line_limit(self): self.stream.feed_data(b'a' * 5000 + b'\r\n\r\n') with self.assertRaises(ValueError): - self.loop.run_until_complete(read_message(self.stream)) + self.loop.run_until_complete(read_headers(self.stream)) def test_line_ending(self): - self.stream.feed_data(b'GET / HTTP/1.1\n\n') + self.stream.feed_data(b'foo: bar\n\n') with self.assertRaises(ValueError): - self.loop.run_until_complete(read_message(self.stream)) + self.loop.run_until_complete(read_headers(self.stream)) + + +class HTTPSyncTests(unittest.TestCase): + + def test_build_headers(self): + headers = build_headers([ + ('X-Foo', 'Bar'), + ('X-Baz', 'Quux Quux'), + ]) + + self.assertEqual(headers['X-Foo'], 'Bar') + self.assertEqual(headers['X-Bar'], None) + + self.assertEqual(headers.get('X-Bar', ''), '') + self.assertEqual(headers.get('X-Baz', ''), 'Quux Quux') + + def test_build_headers_multi_value(self): + headers = build_headers([ + ('X-Foo', 'Bar'), + ('X-Foo', 'Baz'), + ]) + + # Getting a single value is non-deterministic. + self.assertIn(headers['X-Foo'], ['Bar', 'Baz']) + self.assertIn(headers.get('X-Foo'), ['Bar', 'Baz']) + + # Ordering is deterministic when getting all values. + self.assertEqual(headers.get_all('X-Foo'), ['Bar', 'Baz'])