diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 5f8df774..71c72d37 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -1,6 +1,7 @@ from functools import wraps from datetime import datetime from calendar import timegm +from re import split from werkzeug.exceptions import BadRequest @@ -170,12 +171,29 @@ def _decode_jwt_from_headers(): header_type = config.header_type # Verify we have the auth header - jwt_header = request.headers.get(header_name, None) - if not jwt_header: + auth_header = request.headers.get(header_name, None) + if not auth_header: raise NoAuthorizationError("Missing {} Header".format(header_name)) # Make sure the header is in a valid format that we are expecting, ie # : + jwt_header = None + + # Check if header is comma delimited, ie + # : , , etc... + if header_type: + field_values = split(r',\s*', auth_header) + jwt_header = [s for s in field_values if s.split()[0] == header_type] + if len(jwt_header) < 1: + msg = "Bad {} header. Expected value '{} '".format( + header_name, + header_type + ) + raise InvalidHeaderError(msg) + jwt_header = jwt_header[0] + else: + jwt_header = auth_header + parts = jwt_header.split() if not header_type: if len(parts) != 1: @@ -183,12 +201,6 @@ def _decode_jwt_from_headers(): raise InvalidHeaderError(msg) encoded_token = parts[0] else: - if parts[0] != header_type or len(parts) != 2: - msg = "Bad {} header. Expected value '{} '".format( - header_name, - header_type - ) - raise InvalidHeaderError(msg) encoded_token = parts[1] return encoded_token, None diff --git a/tests/test_headers.py b/tests/test_headers.py index db3ffe46..1874d93b 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -19,6 +19,39 @@ def access_protected(): return app +def test_default_headers(app): + app.config + test_client = app.test_client() + + with app.test_request_context(): + access_token = create_access_token('username') + + # Ensure other authorization types don't work + access_headers = {'Authorization': 'Basic basiccreds'} + response = test_client.get('/protected', headers=access_headers) + expected_json = {'msg': "Bad Authorization header. Expected value 'Bearer '"} + assert response.status_code == 422 + assert response.get_json() == expected_json + + # Ensure default headers work + access_headers = {'Authorization': 'Bearer {}'.format(access_token)} + response = test_client.get('/protected', headers=access_headers) + assert response.status_code == 200 + assert response.get_json() == {'foo': 'bar'} + + # Ensure default headers work with multiple field values + access_headers = {'Authorization': 'Bearer {}, Basic creds'.format(access_token)} + response = test_client.get('/protected', headers=access_headers) + assert response.status_code == 200 + assert response.get_json() == {'foo': 'bar'} + + # Ensure default headers work with multiple field values in any position + access_headers = {'Authorization': 'Basic creds, Bearer {}'.format(access_token)} + response = test_client.get('/protected', headers=access_headers) + assert response.status_code == 200 + assert response.get_json() == {'foo': 'bar'} + + def test_custom_header_name(app): app.config['JWT_HEADER_NAME'] = 'Foo' test_client = app.test_client() @@ -38,6 +71,18 @@ def test_custom_header_name(app): assert response.status_code == 200 assert response.get_json() == {'foo': 'bar'} + # Ensure new headers work with multiple field values + access_headers = {'Foo': 'Bearer {}, Basic randomcredshere'.format(access_token)} + response = test_client.get('/protected', headers=access_headers) + assert response.status_code == 200 + assert response.get_json() == {'foo': 'bar'} + + # Ensure new headers work with multiple field values in any position + access_headers = {'Foo': 'Basic randomcredshere, Bearer {}'.format(access_token)} + response = test_client.get('/protected', headers=access_headers) + assert response.status_code == 200 + assert response.get_json() == {'foo': 'bar'} + def test_custom_header_type(app): app.config['JWT_HEADER_TYPE'] = 'JWT' @@ -59,6 +104,18 @@ def test_custom_header_type(app): assert response.status_code == 200 assert response.get_json() == {'foo': 'bar'} + # Ensure new headers work with multiple field values + access_headers = {'Authorization': 'JWT {}, Basic creds'.format(access_token)} + response = test_client.get('/protected', headers=access_headers) + assert response.status_code == 200 + assert response.get_json() == {'foo': 'bar'} + + # Ensure new headers work with multiple field values in any position + access_headers = {'Authorization': 'Basic creds, JWT {}'.format(access_token)} + response = test_client.get('/protected', headers=access_headers) + assert response.status_code == 200 + assert response.get_json() == {'foo': 'bar'} + # Insure new headers without a type also work app.config['JWT_HEADER_TYPE'] = '' access_headers = {'Authorization': access_token}