Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make header reading compliant with RFC7230, section 3.2.2 #270

Merged
merged 11 commits into from
Sep 10, 2019
28 changes: 20 additions & 8 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from datetime import datetime
from calendar import timegm
from re import split

from werkzeug.exceptions import BadRequest

Expand Down Expand Up @@ -170,25 +171,36 @@ def _decode_jwt_from_headers():
header_type = config.header_type

# Verify we have the auth header
jwt_header = request.headers.get(header_name, None)
if not jwt_header:
auth_header = request.headers.get(header_name, None)
if not auth_header:
raise NoAuthorizationError("Missing {} Header".format(header_name))

# Make sure the header is in a valid format that we are expecting, ie
# <HeaderName>: <HeaderType(optional)> <JWT>
jwt_header = None

# Check if header is comma delimited, ie
# <HeaderName>: <field> <value>, <field> <value>, etc...
if header_type:
field_values = split(r',\s*', auth_header)
jwt_header = [s for s in field_values if s.split()[0] == header_type]
if len(jwt_header) < 1:
msg = "Bad {} header. Expected value '{} <JWT>'".format(
header_name,
header_type
)
raise InvalidHeaderError(msg)
jwt_header = jwt_header[0]
else:
jwt_header = auth_header

parts = jwt_header.split()
if not header_type:
if len(parts) != 1:
msg = "Bad {} header. Expected value '<JWT>'".format(header_name)
raise InvalidHeaderError(msg)
encoded_token = parts[0]
else:
if parts[0] != header_type or len(parts) != 2:
msg = "Bad {} header. Expected value '{} <JWT>'".format(
header_name,
header_type
)
raise InvalidHeaderError(msg)
encoded_token = parts[1]

return encoded_token, None
Expand Down
57 changes: 57 additions & 0 deletions tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,39 @@ def access_protected():
return app


def test_default_headers(app):
app.config
test_client = app.test_client()

with app.test_request_context():
access_token = create_access_token('username')

# Ensure other authorization types don't work
access_headers = {'Authorization': 'Basic basiccreds'}
response = test_client.get('/protected', headers=access_headers)
expected_json = {'msg': "Bad Authorization header. Expected value 'Bearer <JWT>'"}
assert response.status_code == 422
assert response.get_json() == expected_json

# Ensure default headers work
access_headers = {'Authorization': 'Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure default headers work with multiple field values
access_headers = {'Authorization': 'Bearer {}, Basic creds'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure default headers work with multiple field values in any position
access_headers = {'Authorization': 'Basic creds, Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}


def test_custom_header_name(app):
app.config['JWT_HEADER_NAME'] = 'Foo'
test_client = app.test_client()
Expand All @@ -38,6 +71,18 @@ def test_custom_header_name(app):
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values
access_headers = {'Foo': 'Bearer {}, Basic randomcredshere'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values in any position
access_headers = {'Foo': 'Basic randomcredshere, Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}


def test_custom_header_type(app):
app.config['JWT_HEADER_TYPE'] = 'JWT'
Expand All @@ -59,6 +104,18 @@ def test_custom_header_type(app):
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values
access_headers = {'Authorization': 'JWT {}, Basic creds'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values in any position
access_headers = {'Authorization': 'Basic creds, JWT {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Insure new headers without a type also work
app.config['JWT_HEADER_TYPE'] = ''
access_headers = {'Authorization': access_token}
Expand Down