Skip to content

Commit

Permalink
Make header reading compliant with RFC7230, section 3.2.2 (#270)
Browse files Browse the repository at this point in the history
* Make header reading compliant with RFC7230, section 3.2.2

* Make header reading compliant with RFC7230, section 3.2.2

* Only attempt to parse comma delimited headers if header_type specified

* Add unit tests for multi field headers

* Remove redundant checks

* Fix ambiguity between header fields beginning with same characters

* Fix pep8 errors

* Undo unintentional change

* Add newline to end of file

* Fix more pep8 issues
  • Loading branch information
Croug authored and vimalloc committed Sep 10, 2019
1 parent 391de47 commit cb988e1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 8 deletions.
28 changes: 20 additions & 8 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from datetime import datetime
from calendar import timegm
from re import split

from werkzeug.exceptions import BadRequest

Expand Down Expand Up @@ -170,25 +171,36 @@ def _decode_jwt_from_headers():
header_type = config.header_type

# Verify we have the auth header
jwt_header = request.headers.get(header_name, None)
if not jwt_header:
auth_header = request.headers.get(header_name, None)
if not auth_header:
raise NoAuthorizationError("Missing {} Header".format(header_name))

# Make sure the header is in a valid format that we are expecting, ie
# <HeaderName>: <HeaderType(optional)> <JWT>
jwt_header = None

# Check if header is comma delimited, ie
# <HeaderName>: <field> <value>, <field> <value>, etc...
if header_type:
field_values = split(r',\s*', auth_header)
jwt_header = [s for s in field_values if s.split()[0] == header_type]
if len(jwt_header) < 1:
msg = "Bad {} header. Expected value '{} <JWT>'".format(
header_name,
header_type
)
raise InvalidHeaderError(msg)
jwt_header = jwt_header[0]
else:
jwt_header = auth_header

parts = jwt_header.split()
if not header_type:
if len(parts) != 1:
msg = "Bad {} header. Expected value '<JWT>'".format(header_name)
raise InvalidHeaderError(msg)
encoded_token = parts[0]
else:
if parts[0] != header_type or len(parts) != 2:
msg = "Bad {} header. Expected value '{} <JWT>'".format(
header_name,
header_type
)
raise InvalidHeaderError(msg)
encoded_token = parts[1]

return encoded_token, None
Expand Down
57 changes: 57 additions & 0 deletions tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,39 @@ def access_protected():
return app


def test_default_headers(app):
app.config
test_client = app.test_client()

with app.test_request_context():
access_token = create_access_token('username')

# Ensure other authorization types don't work
access_headers = {'Authorization': 'Basic basiccreds'}
response = test_client.get('/protected', headers=access_headers)
expected_json = {'msg': "Bad Authorization header. Expected value 'Bearer <JWT>'"}
assert response.status_code == 422
assert response.get_json() == expected_json

# Ensure default headers work
access_headers = {'Authorization': 'Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure default headers work with multiple field values
access_headers = {'Authorization': 'Bearer {}, Basic creds'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure default headers work with multiple field values in any position
access_headers = {'Authorization': 'Basic creds, Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}


def test_custom_header_name(app):
app.config['JWT_HEADER_NAME'] = 'Foo'
test_client = app.test_client()
Expand All @@ -38,6 +71,18 @@ def test_custom_header_name(app):
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values
access_headers = {'Foo': 'Bearer {}, Basic randomcredshere'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values in any position
access_headers = {'Foo': 'Basic randomcredshere, Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}


def test_custom_header_type(app):
app.config['JWT_HEADER_TYPE'] = 'JWT'
Expand All @@ -59,6 +104,18 @@ def test_custom_header_type(app):
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values
access_headers = {'Authorization': 'JWT {}, Basic creds'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values in any position
access_headers = {'Authorization': 'Basic creds, JWT {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Insure new headers without a type also work
app.config['JWT_HEADER_TYPE'] = ''
access_headers = {'Authorization': access_token}
Expand Down

0 comments on commit cb988e1

Please sign in to comment.