diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 90a022d3c03..b497b5a7f08 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -5,6 +5,7 @@ import re import string import tempfile +import types import warnings from email.utils import parsedate from types import MappingProxyType @@ -23,17 +24,17 @@ FileField = collections.namedtuple( 'Field', 'name filename file content_type headers') -_TCHAR = string.digits + string.ascii_letters + r"!#$%&'*+\-.^_`|~" -# notice the escape of '-' to prevent interpretation as range +_TCHAR = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-" +# '-' at the end to prevent interpretation as range in a char class _TOKEN = r'[{tchar}]*'.format(tchar=_TCHAR) _QDTEXT = r'[{}]'.format( - r''.join(chr(c) for c in (0x09, 0x20, 0x21, *range(0x23, 0x7F)))) + r''.join(chr(c) for c in (0x09, 0x20, 0x21) + tuple(range(0x23, 0x7F)))) # qdtext includes 0x5C to escape 0x5D ('\]') # qdtext excludes obs-text (because obsoleted, and encoding not specified) -_QUOTED_PAIR = r'\\[\t {tchar}]'.format(tchar=_TCHAR) +_QUOTED_PAIR = r'\\[\t !-~]' _QUOTED_STRING = r'"(?:{quoted_pair}|{qdtext})*"'.format( qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR) @@ -42,13 +43,12 @@ r'[bB][yY]|[fF][oO][rR]|[hH][oO][sS][tT]|[pP][rR][oO][tT][oO]') _FORWARDED_PAIR = ( - r'^ *({forwarded_params})=({token}|{quoted_string}) *$'.format( + r'^({forwarded_params})=({token}|{quoted_string})$'.format( forwarded_params=_FORWARDED_PARAMS, token=_TOKEN, quoted_string=_QUOTED_STRING)) -# allow whitespace as specified in RFC 7239 section 7.1 -_QUOTED_PAIR_REPLACE_RE = re.compile(r'\\([\t {tchar}])'.format(tchar=_TCHAR)) +_QUOTED_PAIR_REPLACE_RE = re.compile(r'\\([\t !-~])') # same pattern as _QUOTED_PAIR but contains a capture group _FORWARDED_PAIR_RE = re.compile(_FORWARDED_PAIR) @@ -183,55 +183,72 @@ def secure(self): @reify def forwarded(self): - """ A frozendict containing parsed Forwarded header(s). + """ A tuple containing all parsed Forwarded header(s). Makes an effort to parse Forwarded headers as specified by RFC 7239: - - It adds all parameters (by, for, host, proto) in the order it finds - them; starting at the topmost / first added 'Forwarded' header, at - the leftmost / first-added parwameter. - - It checks that the value has valid syntax in general as specified in - section 4: either a 'token' or a 'quoted-string'. + - It adds one (immutable) dictionary per Forwarded 'field-value', ie + per proxy. The element corresponds to the data in the Forwarded + field-value added by the first proxy encountered by the client. Each + subsequent item corresponds to those added by later proxies. + - It checks that every value has valid syntax in general as specified + in section 4: either a 'token' or a 'quoted-string'. - It un-escapes found escape sequences. - It does NOT validate 'by' and 'for' contents as specified in section 6. - It does NOT validate 'host' contents (Host ABNF). - It does NOT validate 'proto' contents for valid URI scheme names. - Returns a dict(by=tuple(...), for=tuple(...), host=tuple(...), - proto=tuple(...), ) + Returns a tuple containing one or more immutable dicts """ - params = MultiDict({'by': [], 'for': [], 'host': [], 'proto': []}) - for forwarded_elm in self._message.headers.getall(hdrs.FORWARDED, ()): - forwarded_pairs = (_FORWARDED_PAIR_RE.findall(pair) - for pair in forwarded_elm.split(';')) - for forwarded_pair in forwarded_pairs: - if len(forwarded_pair) != 1: - # non-compliant syntax, ignore - continue - param, value = forwarded_pair[0] - if value and value[0] == '"': - # quoted string: replace quotes and escape - # sequences - value = _QUOTED_PAIR_REPLACE_RE.sub( - r'\1', value[1:-1]) - params[param.lower()].append(value) - return MultiDictProxy(params) + def _parse_forwarded(forwarded_headers): + for field_value in forwarded_headers: + # by=...;for=..., For=..., BY=... + for forwarded_elm in field_value.split(','): + # by=...;for=... + fvparams = dict() + forwarded_pairs = ( + _FORWARDED_PAIR_RE.findall(pair) + for pair in forwarded_elm.strip().split(';')) + for forwarded_pair in forwarded_pairs: + # by=... + if len(forwarded_pair) != 1: + # non-compliant syntax + break + param, value = forwarded_pair[0] + if param.lower() in fvparams: + # duplicate param in field-value + break + if value and value[0] == '"': + # quoted string: replace quotes and escape + # sequences + value = _QUOTED_PAIR_REPLACE_RE.sub( + r'\1', value[1:-1]) + fvparams[param.lower()] = value + else: + yield types.MappingProxyType(fvparams) + continue + yield dict() + + return tuple( + _parse_forwarded(self._message.headers.getall(hdrs.FORWARDED, ()))) @reify def _scheme(self): - proto = 'http' + proto = None if self._transport.get_extra_info('sslcontext'): proto = 'https' elif self._secure_proxy_ssl_header is not None: header, value = self._secure_proxy_ssl_header if self.headers.get(header) == value: proto = 'https' - elif self.forwarded['proto']: - proto = self.forwarded['proto'][0] - elif hdrs.X_FORWARDED_PROTO in self._message.headers: - proto = self._message.headers[hdrs.X_FORWARDED_PROTO] - return proto + else: + proto = next( + (f['proto'] for f in self.forwarded if 'proto' in f), None + ) + if not proto and hdrs.X_FORWARDED_PROTO in self._message.headers: + proto = self._message.headers[hdrs.X_FORWARDED_PROTO] + return proto or 'http' @property def method(self): @@ -261,10 +278,10 @@ def host(self): Returns str, or None if no hostname is found in the headers. """ - host = None - if self.forwarded['host']: - host = self.forwarded['host'][0] - elif hdrs.X_FORWARDED_HOST in self._message.headers: + host = next( + (f['host'] for f in self.forwarded if 'host' in f), None + ) + if not host and hdrs.X_FORWARDED_HOST in self._message.headers: host = self._message.headers[hdrs.X_FORWARDED_HOST] elif hdrs.HOST in self._message.headers: host = self._message.headers[hdrs.HOST] diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 57a846a65aa..440cfd71e59 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -249,55 +249,63 @@ def test_https_scheme_by_secure_proxy_ssl_header_false_test(make_request): def test_single_forwarded_header(make_request): - header = 'by=identifier; for=identifier; host=identifier; proto=identifier' + header = 'by=identifier;for=identifier;host=identifier;proto=identifier' req = make_request('GET', '/', headers=CIMultiDict({'Forwarded': header})) - assert req.forwarded['by'] == ['identifier'] - assert req.forwarded['for'] == ['identifier'] - assert req.forwarded['host'] == ['identifier'] - assert req.forwarded['proto'] == ['identifier'] + assert req.forwarded[0]['by'] == 'identifier' + assert req.forwarded[0]['for'] == 'identifier' + assert req.forwarded[0]['host'] == 'identifier' + assert req.forwarded[0]['proto'] == 'identifier' def test_single_forwarded_header_camelcase(make_request): - header = 'bY=identifier; fOr=identifier; HOst=identifier; pRoTO=identifier' + header = 'bY=identifier;fOr=identifier;HOst=identifier;pRoTO=identifier' req = make_request('GET', '/', headers=CIMultiDict({'Forwarded': header})) - assert req.forwarded['by'] == ['identifier'] - assert req.forwarded['for'] == ['identifier'] - assert req.forwarded['host'] == ['identifier'] - assert req.forwarded['proto'] == ['identifier'] + assert req.forwarded[0]['by'] == 'identifier' + assert req.forwarded[0]['for'] == 'identifier' + assert req.forwarded[0]['host'] == 'identifier' + assert req.forwarded[0]['proto'] == 'identifier' def test_single_forwarded_header_single_param(make_request): header = 'BY=identifier' req = make_request('GET', '/', headers=CIMultiDict({'Forwarded': header})) - assert req.forwarded['by'] == ['identifier'] + assert req.forwarded[0]['by'] == 'identifier' def test_single_forwarded_header_multiple_param(make_request): - header = 'By=identifier1;BY=identifier2; By=identifier3; BY=identifier4' + header = 'By=identifier1,BY=identifier2, By=identifier3 , BY=identifier4' req = make_request('GET', '/', headers=CIMultiDict({'Forwarded': header})) - assert req.forwarded['by'] == ['identifier1', 'identifier2', 'identifier3', - 'identifier4'] + assert len(req.forwarded) == 4 + assert req.forwarded[0]['by'] == 'identifier1' + assert req.forwarded[1]['by'] == 'identifier2' + assert req.forwarded[2]['by'] == 'identifier3' + assert req.forwarded[3]['by'] == 'identifier4' def test_single_forwarded_header_quoted_escaped(make_request): - header = 'Proto=identifier; pROTO="\lala lan\d\~ 123\!&"' + header = 'BY=identifier;pROTO="\lala lan\d\~ 123\!&"' req = make_request('GET', '/', headers=CIMultiDict({'Forwarded': header})) - assert req.forwarded['proto'] == ['identifier', 'lala land~ 123!&'] + assert req.forwarded[0]['by'] == 'identifier' + assert req.forwarded[0]['proto'] == 'lala land~ 123!&' def test_multiple_forwarded_headers(make_request): headers = CIMultiDict() - headers.add('Forwarded', 'By=identifier1;BY=identifier2') - headers.add('Forwarded', 'By=identifier3; BY=identifier4') + headers.add('Forwarded', 'By=identifier1;for=identifier2, BY=identifier3') + headers.add('Forwarded', 'By=identifier4;fOr=identifier5') req = make_request('GET', '/', headers=headers) - assert req.forwarded['by'] == ['identifier1', 'identifier2', 'identifier3', - 'identifier4'] + assert len(req.forwarded) == 3 + assert req.forwarded[0]['by'] == 'identifier1' + assert req.forwarded[0]['for'] == 'identifier2' + assert req.forwarded[1]['by'] == 'identifier3' + assert req.forwarded[2]['by'] == 'identifier4' + assert req.forwarded[2]['for'] == 'identifier5' def test_https_scheme_by_forwarded_header(make_request): req = make_request('GET', '/', headers=CIMultiDict( - {'Forwarded': 'by=; for=; host=; proto=https'})) + {'Forwarded': 'by=;for=;host=;proto=https'})) assert "https" == req.scheme assert req.secure is True @@ -324,10 +332,10 @@ def test_https_scheme_by_x_forwarded_proto_header_no_tls(make_request): def test_host_by_forwarded_header(make_request): - req = make_request('GET', '/', - headers=CIMultiDict( - {'Forwarded': 'by=; for=; host' - '=example.com; proto=https'})) + headers = CIMultiDict() + headers.add('Forwarded', 'By=identifier1;for=identifier2, BY=identifier3') + headers.add('Forwarded', 'by=;for=;host=example.com') + req = make_request('GET', '/', headers=headers) assert req.host == 'example.com'