diff --git a/aiohttp/protocol.py b/aiohttp/protocol.py index b5fac701764..daaefafc87a 100644 --- a/aiohttp/protocol.py +++ b/aiohttp/protocol.py @@ -11,6 +11,7 @@ import http.server import itertools import re +import string import sys import zlib from wsgiref.handlers import format_date_time @@ -20,6 +21,7 @@ from aiohttp import multidict from aiohttp.log import internal_log +ASCIISET = set(string.printable) METHRE = re.compile('[A-Z0-9$-_.]+') VERSRE = re.compile('HTTP/(\d+).(\d+)') HDRRE = re.compile('[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') @@ -572,8 +574,12 @@ def add_header(self, name, value): """Analyze headers. Calculate content length, removes hop headers, etc.""" assert not self.headers_sent, 'headers have been sent already' - assert isinstance(name, str), '{!r} is not a string'.format(name) - assert isinstance(value, str), '{!r} is not a string'.format(value) + assert isinstance(name, str), \ + 'Header name should be a string, got {!r}'.format(name) + assert set(name).issubset(ASCIISET), \ + 'Header name should contain ASCII chars, got {!r}'.format(name) + assert isinstance(value, str), \ + 'Header {!r} should have string value, got {!r}'.format(name, value) name = name.strip().upper() value = value.strip() diff --git a/tests/test_http_protocol.py b/tests/test_http_protocol.py index 27f172f9657..7d7775c3a3d 100644 --- a/tests/test_http_protocol.py +++ b/tests/test_http_protocol.py @@ -76,12 +76,22 @@ def test_add_header_with_spaces(self): self.assertEqual( [('CONTENT-TYPE', 'plain/html')], list(msg.headers.items())) + def test_add_header_non_ascii(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + with self.assertRaises(AssertionError): + msg.add_header('тип-контента', 'текст/плейн') + def test_add_header_invalid_value_type(self): msg = protocol.Response(self.transport, 200) self.assertEqual([], list(msg.headers)) with self.assertRaises(AssertionError): - msg.add_header('content-type', b'value') + msg.add_header('content-type', {'test': 'plain'}) + + with self.assertRaises(AssertionError): + msg.add_header(list('content-type'), 'text/plain') def test_add_headers(self): msg = protocol.Response(self.transport, 200)