Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make header reading compliant with RFC7230, section 3.2.2 #270

Merged
merged 11 commits into from
Sep 10, 2019
2 changes: 1 addition & 1 deletion flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def header_name(self):

@property
def header_type(self):
return current_app.config['JWT_HEADER_TYPE']
return current_app.config.get('JWT_HEADER_TYPE')
Croug marked this conversation as resolved.
Show resolved Hide resolved

@property
def query_string_name(self):
Expand Down
28 changes: 20 additions & 8 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from datetime import datetime
from calendar import timegm
from re import split

from werkzeug.exceptions import BadRequest

Expand Down Expand Up @@ -170,25 +171,36 @@ def _decode_jwt_from_headers():
header_type = config.header_type

# Verify we have the auth header
jwt_header = request.headers.get(header_name, None)
if not jwt_header:
auth_header = request.headers.get(header_name, None)
if not auth_header:
raise NoAuthorizationError("Missing {} Header".format(header_name))

# Make sure the header is in a valid format that we are expecting, ie
# <HeaderName>: <HeaderType(optional)> <JWT>
jwt_header = None

# Check if header is comma delimited, ie
# <HeaderName>: <field> <value>, <field> <value>, etc...
if header_type:
field_values = split(r',\s*', auth_header)
jwt_header = [s for s in field_values if s.startswith(header_type)]
Croug marked this conversation as resolved.
Show resolved Hide resolved
if len(jwt_header) < 1:
msg = "Bad {} header. Expected value '{} <JWT>'".format(
header_name,
header_type
)
raise InvalidHeaderError(msg)
jwt_header = jwt_header[0]
else:
jwt_header = auth_header

parts = jwt_header.split()
if not header_type:
if len(parts) != 1:
msg = "Bad {} header. Expected value '<JWT>'".format(header_name)
raise InvalidHeaderError(msg)
encoded_token = parts[0]
else:
if parts[0] != header_type or len(parts) != 2:
msg = "Bad {} header. Expected value '{} <JWT>'".format(
header_name,
header_type
)
raise InvalidHeaderError(msg)
encoded_token = parts[1]

return encoded_token, None
Expand Down
59 changes: 58 additions & 1 deletion tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,46 @@ def access_protected():
return app


def test_default_headers(app):
app.config
test_client = app.test_client()

with app.test_request_context():
access_token = create_access_token('username')

# Ensure other authorization types don't work
access_headers = {'Authorization': 'Basic basiccreds'}
response = test_client.get('/protected', headers=access_headers)
expected_json = {'msg': "Bad Authorization header. Expected value 'Bearer <JWT>'"}
assert response.status_code == 422
assert response.get_json() == expected_json

# Ensure default headers work
access_headers = {'Authorization': 'Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure default headers work with multiple field values
access_headers = {'Authorization': 'Bearer {}, Basic randomcredshere'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure default headers work with multiple field values in any position
access_headers = {'Authorization': 'Basic randomcredshere, Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}


def test_custom_header_name(app):
app.config['JWT_HEADER_NAME'] = 'Foo'
test_client = app.test_client()

with app.test_request_context():
access_token = create_access_token('username')

# Insure 'default' headers no longer work
access_headers = {'Authorization': 'Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
Expand All @@ -38,6 +71,18 @@ def test_custom_header_name(app):
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values
access_headers = {'Foo': 'Bearer {}, Basic randomcredshere'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values in any position
access_headers = {'Foo': 'Basic randomcredshere, Bearer {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}


def test_custom_header_type(app):
app.config['JWT_HEADER_TYPE'] = 'JWT'
Expand All @@ -58,6 +103,18 @@ def test_custom_header_type(app):
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values
access_headers = {'Authorization': 'JWT {}, Basic randomcredshere'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Ensure new headers work with multiple field values in any position
access_headers = {'Authorization': 'Basic randomcredshere, JWT {}'.format(access_token)}
response = test_client.get('/protected', headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {'foo': 'bar'}

# Insure new headers without a type also work
app.config['JWT_HEADER_TYPE'] = ''
Expand Down
2 changes: 1 addition & 1 deletion tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,4 @@ def test_different_token_algorightm(app):

response = test_client.get(url, headers=make_headers(token))
assert response.status_code == 422
assert response.get_json() == {'msg': 'The specified alg value is not allowed'}
assert response.get_json() == {'msg': 'The specified alg value is not allowed'}