diff --git a/websockets/client.py b/websockets/client.py index 9aea4f6b7..411cf37f5 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -5,11 +5,10 @@ import asyncio import collections.abc -import http.client from .exceptions import InvalidHandshake, InvalidMessage from .handshake import build_request, check_response -from .http import USER_AGENT, read_response +from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol from .uri import parse_uri @@ -35,8 +34,7 @@ def write_http_request(self, path, headers): """ self.path = path - self.request_headers = http.client.HTTPMessage() - self.request_headers._headers = headers # HACK + self.request_headers = build_headers(headers) self.raw_request_headers = headers # Since the path and headers only contain ASCII characters, @@ -62,8 +60,7 @@ def read_http_response(self): except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc - self.response_headers = http.client.HTTPMessage() - self.response_headers._headers = headers # HACK + self.response_headers = build_headers(headers) self.raw_response_headers = headers return status_code, self.response_headers diff --git a/websockets/http.py b/websockets/http.py index 48ec2f5e8..e71e8c78d 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -8,6 +8,7 @@ """ import asyncio +import http.client import re import sys @@ -191,3 +192,15 @@ def read_line(stream): if not line.endswith(b'\r\n'): raise ValueError("Line without CRLF") return line + + +def build_headers(raw_headers): + """ + Build a date structure for HTTP headers from a list of name - value pairs. + + See also https://github.com/aaugustin/websockets/issues/210. + + """ + headers = http.client.HTTPMessage() + headers._headers = raw_headers # HACK + return headers diff --git a/websockets/server.py b/websockets/server.py index f4da5c59f..43a9c682b 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -6,13 +6,13 @@ import asyncio import collections.abc -import http.client +import http import logging from .compatibility import asyncio_ensure_future from .exceptions import InvalidHandshake, InvalidMessage, InvalidOrigin from .handshake import build_response, check_request -from .http import USER_AGENT, read_request +from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -154,8 +154,7 @@ def read_http_request(self): raise InvalidMessage("Malformed HTTP message") from exc self.path = path - self.request_headers = http.client.HTTPMessage() - self.request_headers._headers = headers # HACK + self.request_headers = build_headers(headers) self.raw_request_headers = headers return path, self.request_headers @@ -166,8 +165,7 @@ def write_http_response(self, status, headers): Write status line and headers to the HTTP response. """ - self.response_headers = http.client.HTTPMessage() - self.response_headers._headers = headers # HACK + self.response_headers = build_headers(headers) self.raw_response_headers = headers # Since the status line and headers only contain ASCII characters, diff --git a/websockets/test_http.py b/websockets/test_http.py index 8035451c7..ff2fc0a76 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -2,10 +2,10 @@ import unittest from .http import * -from .http import read_headers # private API +from .http import build_headers, read_headers # private APIs -class HTTPTests(unittest.TestCase): +class HTTPAsyncTests(unittest.TestCase): def setUp(self): super().setUp() @@ -97,3 +97,31 @@ def test_line_ending(self): self.stream.feed_data(b'foo: bar\n\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) + + +class HTTPSyncTests(unittest.TestCase): + + def test_build_headers(self): + headers = build_headers([ + ('X-Foo', 'Bar'), + ('X-Baz', 'Quux Quux'), + ]) + + self.assertEqual(headers['X-Foo'], 'Bar') + self.assertEqual(headers['X-Bar'], None) + + self.assertEqual(headers.get('X-Bar', ''), '') + self.assertEqual(headers.get('X-Baz', ''), 'Quux Quux') + + def test_build_headers_multi_value(self): + headers = build_headers([ + ('X-Foo', 'Bar'), + ('X-Foo', 'Baz'), + ]) + + # Getting a single value is non-deterministic. + self.assertIn(headers['X-Foo'], ['Bar', 'Baz']) + self.assertIn(headers.get('X-Foo'), ['Bar', 'Baz']) + + # Ordering is deterministic when getting all values. + self.assertEqual(headers.get_all('X-Foo'), ['Bar', 'Baz'])