Skip to content

Commit

Permalink
Encapsulate creation of HTTP headers.
Browse files Browse the repository at this point in the history
Since this part is hacky and likely to change in the future (#210),
wrap it into a single function and add tests for the public API we
really care about.
  • Loading branch information
aaugustin committed Jul 17, 2017
1 parent dcd9b1a commit b57a410
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 14 deletions.
9 changes: 3 additions & 6 deletions websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import asyncio
import collections.abc
import http.client

from .exceptions import InvalidHandshake, InvalidMessage
from .handshake import build_request, check_response
from .http import USER_AGENT, read_response
from .http import USER_AGENT, build_headers, read_response
from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol
from .uri import parse_uri

Expand All @@ -35,8 +34,7 @@ def write_http_request(self, path, headers):
"""
self.path = path
self.request_headers = http.client.HTTPMessage()
self.request_headers._headers = headers # HACK
self.request_headers = build_headers(headers)
self.raw_request_headers = headers

# Since the path and headers only contain ASCII characters,
Expand All @@ -62,8 +60,7 @@ def read_http_response(self):
except ValueError as exc:
raise InvalidMessage("Malformed HTTP message") from exc

self.response_headers = http.client.HTTPMessage()
self.response_headers._headers = headers # HACK
self.response_headers = build_headers(headers)
self.raw_response_headers = headers

return status_code, self.response_headers
Expand Down
13 changes: 13 additions & 0 deletions websockets/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import asyncio
import http.client
import re
import sys

Expand Down Expand Up @@ -191,3 +192,15 @@ def read_line(stream):
if not line.endswith(b'\r\n'):
raise ValueError("Line without CRLF")
return line


def build_headers(raw_headers):
"""
Build a date structure for HTTP headers from a list of name - value pairs.
See also https://github.com/aaugustin/websockets/issues/210.
"""
headers = http.client.HTTPMessage()
headers._headers = raw_headers # HACK
return headers
10 changes: 4 additions & 6 deletions websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import asyncio
import collections.abc
import http.client
import http
import logging

from .compatibility import asyncio_ensure_future
from .exceptions import InvalidHandshake, InvalidMessage, InvalidOrigin
from .handshake import build_response, check_request
from .http import USER_AGENT, read_request
from .http import USER_AGENT, build_headers, read_request
from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol


Expand Down Expand Up @@ -154,8 +154,7 @@ def read_http_request(self):
raise InvalidMessage("Malformed HTTP message") from exc

self.path = path
self.request_headers = http.client.HTTPMessage()
self.request_headers._headers = headers # HACK
self.request_headers = build_headers(headers)
self.raw_request_headers = headers

return path, self.request_headers
Expand All @@ -166,8 +165,7 @@ def write_http_response(self, status, headers):
Write status line and headers to the HTTP response.
"""
self.response_headers = http.client.HTTPMessage()
self.response_headers._headers = headers # HACK
self.response_headers = build_headers(headers)
self.raw_response_headers = headers

# Since the status line and headers only contain ASCII characters,
Expand Down
32 changes: 30 additions & 2 deletions websockets/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import unittest

from .http import *
from .http import read_headers # private API
from .http import build_headers, read_headers # private APIs


class HTTPTests(unittest.TestCase):
class HTTPAsyncTests(unittest.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -97,3 +97,31 @@ def test_line_ending(self):
self.stream.feed_data(b'foo: bar\n\n')
with self.assertRaises(ValueError):
self.loop.run_until_complete(read_headers(self.stream))


class HTTPSyncTests(unittest.TestCase):

def test_build_headers(self):
headers = build_headers([
('X-Foo', 'Bar'),
('X-Baz', 'Quux Quux'),
])

self.assertEqual(headers['X-Foo'], 'Bar')
self.assertEqual(headers['X-Bar'], None)

self.assertEqual(headers.get('X-Bar', ''), '')
self.assertEqual(headers.get('X-Baz', ''), 'Quux Quux')

def test_build_headers_multi_value(self):
headers = build_headers([
('X-Foo', 'Bar'),
('X-Foo', 'Baz'),
])

# Getting a single value is non-deterministic.
self.assertIn(headers['X-Foo'], ['Bar', 'Baz'])
self.assertIn(headers.get('X-Foo'), ['Bar', 'Baz'])

# Ordering is deterministic when getting all values.
self.assertEqual(headers.get_all('X-Foo'), ['Bar', 'Baz'])

0 comments on commit b57a410

Please sign in to comment.