diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 92c486c7..b907420d 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -1,14 +1,15 @@ from .jwt_manager import JWTManager -from .view_decorators import ( - fresh_jwt_required, jwt_optional, jwt_refresh_token_required, jwt_required, - verify_fresh_jwt_in_request, verify_jwt_in_request, - verify_jwt_in_request_optional, verify_jwt_refresh_token_in_request -) from .utils import ( create_access_token, create_refresh_token, current_user, decode_token, get_csrf_token, get_current_user, get_jti, get_jwt_claims, get_jwt_identity, get_raw_jwt, set_access_cookies, set_refresh_cookies, unset_access_cookies, - unset_jwt_cookies, unset_refresh_cookies + unset_jwt_cookies, unset_refresh_cookies, get_unverified_jwt_headers, + get_raw_jwt_header +) +from .view_decorators import ( + fresh_jwt_required, jwt_optional, jwt_refresh_token_required, jwt_required, + verify_fresh_jwt_in_request, verify_jwt_in_request, + verify_jwt_in_request_optional, verify_jwt_refresh_token_in_request ) __version__ = '3.23.0' diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index fc031769..c7785c17 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -22,6 +22,17 @@ def default_user_claims_callback(userdata): return {} +def default_jwt_headers_callback(default_headers): + """ + By default header typically consists of two parts: the type of the token, + which is JWT, and the signing algorithm being used, such as HMAC SHA256 + or RSA. But we don't set the default header here we set it as empty which + further by default set while encoding the token + :return: default we set None here + """ + return None + + def default_user_identity_callback(userdata): """ By default, we use the passed in object directly as the jwt identity. diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index ab488a0f..04f3df8f 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -5,6 +5,7 @@ ExpiredSignatureError, InvalidTokenError, InvalidAudienceError, InvalidIssuerError, DecodeError ) + try: from flask import _app_ctx_stack as ctx_stack except ImportError: # pragma: no cover @@ -22,8 +23,8 @@ default_unauthorized_callback, default_needs_fresh_token_callback, default_revoked_token_callback, default_user_loader_error_callback, default_claims_verification_callback, default_verify_claims_failed_callback, - default_decode_key_callback, default_encode_key_callback -) + default_decode_key_callback, default_encode_key_callback, + default_jwt_headers_callback) from flask_jwt_extended.tokens import ( encode_refresh_token, encode_access_token ) @@ -64,6 +65,7 @@ def __init__(self, app=None): self._verify_claims_failed_callback = default_verify_claims_failed_callback self._decode_key_callback = default_decode_key_callback self._encode_key_callback = default_encode_key_callback + self._jwt_additional_header_callback = default_jwt_headers_callback # Register this extension with the flask app now (if it is provided) if app is not None: @@ -454,13 +456,33 @@ def encode_key_loader(self, callback): self._encode_key_callback = callback return callback - def _create_refresh_token(self, identity, expires_delta=None, user_claims=None): + def additional_headers_loader(self, callback): + """ + This decorator sets the callback function for adding custom headers to an + access token when :func:`~flask_jwt_extended.create_access_token` is + called. By default, two headers will be added the type of the token, which is JWT, + and the signing algorithm being used, such as HMAC SHA256 or RSA. + + *HINT*: The callback function must be a function that takes **no** argument, + which is the object passed into + :func:`~flask_jwt_extended.create_access_token`, and returns the custom + claims you want included in the access tokens. This returned claims + must be *JSON serializable*. + """ + self._jwt_additional_header_callback = callback + return callback + + def _create_refresh_token(self, identity, expires_delta=None, user_claims=None, + headers=None): if expires_delta is None: expires_delta = config.refresh_expires if user_claims is None and config.user_claims_in_refresh_token: user_claims = self._user_claims_callback(identity) + if headers is None: + headers = self._jwt_additional_header_callback(identity) + refresh_token = encode_refresh_token( identity=self._user_identity_callback(identity), secret=self._encode_key_callback(identity), @@ -470,17 +492,22 @@ def _create_refresh_token(self, identity, expires_delta=None, user_claims=None): csrf=config.csrf_protect, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, - json_encoder=config.json_encoder + json_encoder=config.json_encoder, + headers=headers ) return refresh_token - def _create_access_token(self, identity, fresh=False, expires_delta=None, user_claims=None): + def _create_access_token(self, identity, fresh=False, expires_delta=None, + user_claims=None, headers=None): if expires_delta is None: expires_delta = config.access_expires if user_claims is None: user_claims = self._user_claims_callback(identity) + if headers is None: + headers = self._jwt_additional_header_callback(identity) + access_token = encode_access_token( identity=self._user_identity_callback(identity), secret=self._encode_key_callback(identity), @@ -491,6 +518,7 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None, user_c csrf=config.csrf_protect, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, - json_encoder=config.json_encoder + json_encoder=config.json_encoder, + headers=headers ) return access_token diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 17aae323..f6fa7f41 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -1,6 +1,5 @@ import datetime import uuid - from calendar import timegm import jwt @@ -14,7 +13,7 @@ def _create_csrf_token(): def _encode_jwt(additional_token_data, expires_delta, secret, algorithm, - json_encoder=None): + json_encoder=None, headers=None): uid = _create_csrf_token() now = datetime.datetime.utcnow() token_data = { @@ -28,13 +27,13 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm, token_data['exp'] = now + expires_delta token_data.update(additional_token_data) encoded_token = jwt.encode(token_data, secret, algorithm, - json_encoder=json_encoder).decode('utf-8') + json_encoder=json_encoder, headers=headers).decode('utf-8') return encoded_token def encode_access_token(identity, secret, algorithm, expires_delta, fresh, user_claims, csrf, identity_claim_key, user_claims_key, - json_encoder=None): + json_encoder=None, headers=None): """ Creates a new encoded (utf-8) access token. @@ -54,6 +53,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, (boolean) :param identity_claim_key: Which key should be used to store the identity :param user_claims_key: Which key should be used to store the user claims + :param headers: valid dict for specifying additional headers in JWT header section :return: Encoded access token """ @@ -74,12 +74,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, if csrf: token_data['csrf'] = _create_csrf_token() return _encode_jwt(token_data, expires_delta, secret, algorithm, - json_encoder=json_encoder) + json_encoder=json_encoder, headers=headers) def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims, csrf, identity_claim_key, user_claims_key, - json_encoder=None): + json_encoder=None, headers=None): """ Creates a new encoded (utf-8) refresh token. @@ -95,6 +95,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims (boolean) :param identity_claim_key: Which key should be used to store the identity :param user_claims_key: Which key should be used to store the user claims + :param headers: valid dict for specifying additional headers in JWT header section :return: Encoded refresh token """ token_data = { @@ -109,7 +110,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims if csrf: token_data['csrf'] = _create_csrf_token() return _encode_jwt(token_data, expires_delta, secret, algorithm, - json_encoder=json_encoder) + json_encoder=json_encoder, headers=headers) def decode_jwt(encoded_token, secret, algorithms, identity_claim_key, diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index b70c52f8..b04d49d2 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -1,7 +1,8 @@ +from warnings import warn + from flask import current_app -from werkzeug.local import LocalProxy from jwt import ExpiredSignatureError -from warnings import warn +from werkzeug.local import LocalProxy try: from flask import _app_ctx_stack as ctx_stack @@ -29,6 +30,15 @@ def get_raw_jwt(): return getattr(ctx_stack.top, 'jwt', {}) +def get_raw_jwt_header(): + """ + In a protected endpoint, this will return the python dictionary which has + the JWT headers values. If no + JWT is currently present, an empty dict is returned instead. + """ + return getattr(ctx_stack.top, 'jwt_header', {}) + + def get_jwt_identity(): """ In a protected endpoint, this will return the identity of the JWT that is @@ -132,7 +142,8 @@ def _get_jwt_manager(): "application before using this method") -def create_access_token(identity, fresh=False, expires_delta=None, user_claims=None): +def create_access_token(identity, fresh=False, expires_delta=None, user_claims=None, + headers=None): """ Create a new access token. @@ -153,13 +164,17 @@ def create_access_token(identity, fresh=False, expires_delta=None, user_claims=N 'JWT_ACCESS_TOKEN_EXPIRES` config value (see :ref:`Configuration Options`) :param user_claims: Optional JSON serializable to override user claims. + :param headers: Optional, valid dict for specifying additional headers in JWT + header section :return: An encoded access token """ jwt_manager = _get_jwt_manager() - return jwt_manager._create_access_token(identity, fresh, expires_delta, user_claims) + return jwt_manager._create_access_token(identity, fresh, expires_delta, user_claims, + headers=headers) -def create_refresh_token(identity, expires_delta=None, user_claims=None): +def create_refresh_token(identity, expires_delta=None, user_claims=None, + headers=None): """ Creates a new refresh token. @@ -175,10 +190,13 @@ def create_refresh_token(identity, expires_delta=None, user_claims=None): 'JWT_REFRESH_TOKEN_EXPIRES` config value (see :ref:`Configuration Options`) :param user_claims: Optional JSON serializable to override user claims. + :param headers: Optional, valid dict for specifying additional headers in JWT + header section :return: An encoded refresh token """ jwt_manager = _get_jwt_manager() - return jwt_manager._create_refresh_token(identity, expires_delta, user_claims) + return jwt_manager._create_refresh_token(identity, expires_delta, user_claims, + headers=headers) def has_user_loader(): @@ -396,3 +414,15 @@ def unset_refresh_cookies(response): domain=config.cookie_domain, path=config.refresh_csrf_cookie_path, samesite=config.cookie_samesite) + + +def get_unverified_jwt_headers(encoded_token): + """ + Returns the Headers of an encoded JWT without verifying the actual signature of JWT. + Note: The signature is not verified so the header parameters + should not be fully trusted until signature verification is complete + + :param encoded_token: The encoded JWT to get the Header from. + :return: JWT header parameters as python dict() + """ + return jwt.get_unverified_header(encoded_token) diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 71c72d37..7b319597 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -18,7 +18,7 @@ ) from flask_jwt_extended.utils import ( decode_token, has_user_loader, user_loader, verify_token_claims, - verify_token_not_blacklisted, verify_token_type + verify_token_not_blacklisted, verify_token_type, get_unverified_jwt_headers ) @@ -29,8 +29,9 @@ def verify_jwt_in_request(): no token or if the token is invalid. """ if request.method not in config.exempt_methods: - jwt_data = _decode_jwt_from_request(request_type='access') + jwt_data, jwt_header = _decode_jwt_from_request(request_type='access') ctx_stack.top.jwt = jwt_data + ctx_stack.top.jwt_header = jwt_header verify_token_claims(jwt_data) _load_user(jwt_data[config.identity_claim_key]) @@ -48,8 +49,9 @@ def verify_jwt_in_request_optional(): """ try: if request.method not in config.exempt_methods: - jwt_data = _decode_jwt_from_request(request_type='access') + jwt_data, jwt_header = _decode_jwt_from_request(request_type='access') ctx_stack.top.jwt = jwt_data + ctx_stack.top.jwt_header = jwt_header verify_token_claims(jwt_data) _load_user(jwt_data[config.identity_claim_key]) except (NoAuthorizationError, InvalidHeaderError): @@ -63,8 +65,9 @@ def verify_fresh_jwt_in_request(): token is not marked as fresh. """ if request.method not in config.exempt_methods: - jwt_data = _decode_jwt_from_request(request_type='access') + jwt_data, jwt_header = _decode_jwt_from_request(request_type='access') ctx_stack.top.jwt = jwt_data + ctx_stack.top.jwt_header = jwt_header fresh = jwt_data['fresh'] if isinstance(fresh, bool): if not fresh: @@ -83,8 +86,9 @@ def verify_jwt_refresh_token_in_request(): exception if there is no token or the token is invalid. """ if request.method not in config.exempt_methods: - jwt_data = _decode_jwt_from_request(request_type='refresh') + jwt_data, jwt_header = _decode_jwt_from_request(request_type='refresh') ctx_stack.top.jwt = jwt_data + ctx_stack.top.jwt_header = jwt_header _load_user(jwt_data[config.identity_claim_key]) @@ -283,10 +287,12 @@ def _decode_jwt_from_request(request_type): # in one place to be valid (not every location). errors = [] decoded_token = None + jwt_header = None for get_encoded_token_function in get_encoded_token_functions: try: encoded_token, csrf_token = get_encoded_token_function() decoded_token = decode_token(encoded_token, csrf_token) + jwt_header = get_unverified_jwt_headers(encoded_token) break except NoAuthorizationError as e: errors.append(str(e)) @@ -309,4 +315,4 @@ def _decode_jwt_from_request(request_type): verify_token_type(decoded_token, expected_type=request_type) verify_token_not_blacklisted(decoded_token, request_type) - return decoded_token + return decoded_token, jwt_header diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index f9cf6710..f1e98d76 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -13,7 +13,7 @@ from flask_jwt_extended import ( JWTManager, create_access_token, decode_token, create_refresh_token, - get_jti + get_jti, get_unverified_jwt_headers ) from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import JWTDecodeError @@ -286,3 +286,12 @@ def test_malformed_token(app): with pytest.raises(DecodeError): with app.test_request_context(): decode_token(invalid_token) + + +def test_jwt_headers(app): + jwt_header = {"foo": "bar"} + with app.test_request_context(): + access_token = create_access_token('username', headers=jwt_header) + refresh_token = create_refresh_token('username', headers=jwt_header) + assert get_unverified_jwt_headers(access_token)["foo"] == "bar" + assert get_unverified_jwt_headers(refresh_token)["foo"] == "bar" diff --git a/tests/test_jwt_header_loader.py b/tests/test_jwt_header_loader.py new file mode 100644 index 00000000..d09a1168 --- /dev/null +++ b/tests/test_jwt_header_loader.py @@ -0,0 +1,123 @@ +import pytest +from flask import Flask, jsonify + +from flask_jwt_extended import ( + JWTManager, create_access_token, jwt_required, + jwt_refresh_token_required, create_refresh_token, get_raw_jwt_header +) +from tests.utils import get_jwt_manager, make_headers + + +@pytest.fixture(scope='function') +def app(): + app = Flask(__name__) + app.config['JWT_SECRET_KEY'] = 'foobarbaz' + JWTManager(app) + + @app.route('/protected', methods=['GET']) + @jwt_required + def get_claims(): + return jsonify(get_raw_jwt_header()) + + @app.route('/protected2', methods=['GET']) + @jwt_refresh_token_required + def get_refresh_claims(): + return jsonify(get_raw_jwt_header()) + + return app + + +def test_jwt_headers_in_access_token(app): + jwt = get_jwt_manager(app) + + @jwt.additional_headers_loader + def add_jwt_headers(identity): + return {'foo': 'bar'} + + with app.test_request_context(): + access_token = create_access_token('username') + + test_client = app.test_client() + response = test_client.get('/protected', headers=make_headers(access_token)) + assert response.get_json().get("foo") == "bar" + assert response.status_code == 200 + + +def test_non_serializable_user_claims(app): + jwt = get_jwt_manager(app) + + @jwt.additional_headers_loader + def add_jwt_headers(identity): + return app + + with pytest.raises(TypeError): + with app.test_request_context(): + create_access_token('username') + + +def test_jwt_headers_in_refresh_token(app): + jwt = get_jwt_manager(app) + + @jwt.additional_headers_loader + def add_jwt_headers(identity): + return {'foo': 'bar'} + + with app.test_request_context(): + refresh_token = create_refresh_token('username') + + test_client = app.test_client() + response = test_client.get('/protected2', headers=make_headers(refresh_token)) + assert response.get_json().get("foo") == "bar" + assert response.status_code == 200 + + +def test_jwt_header_in_refresh_token_specified_at_creation(app): + with app.test_request_context(): + refresh_token = create_refresh_token('username', headers={'foo': 'bar'}) + + test_client = app.test_client() + response = test_client.get('/protected2', headers=make_headers(refresh_token)) + assert response.get_json().get("foo") == "bar" + assert response.status_code == 200 + + +def test_jwt_header_in_access_token_specified_at_creation(app): + with app.test_request_context(): + access_token = create_access_token('username', headers={'foo': 'bar'}) + + test_client = app.test_client() + response = test_client.get('/protected', headers=make_headers(access_token)) + assert response.get_json().get("foo") == "bar" + assert response.status_code == 200 + + +def test_jwt_header_in_access_token_specified_at_creation_override(app): + jwt = get_jwt_manager(app) + + @jwt.additional_headers_loader + def add_jwt_headers(identity): + return {'ping': 'pong'} + + with app.test_request_context(): + access_token = create_access_token('username', headers={'foo': 'bar'}) + + test_client = app.test_client() + response = test_client.get('/protected', headers=make_headers(access_token)) + assert response.get_json().get("foo") == "bar" + assert response.status_code == 200 + + +def test_jwt_header_in_refresh_token_specified_at_creation_override(app): + jwt = get_jwt_manager(app) + + @jwt.additional_headers_loader + def add_jwt_headers(identity): + return {'ping': 'pong'} + + with app.test_request_context(): + access_token = create_refresh_token('username', headers={'foo': 'bar'}) + + test_client = app.test_client() + response = test_client.get('/protected2', headers=make_headers(access_token)) + assert response.get_json().get("foo") == "bar" + assert response.status_code == 200 diff --git a/tests/utils.py b/tests/utils.py index f211915f..ebe05b6f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,13 +3,14 @@ from flask_jwt_extended.config import config -def encode_token(app, token_data): +def encode_token(app, token_data, headers=None): with app.test_request_context(): token = jwt.encode( token_data, config.decode_key, algorithm=config.algorithm, - json_encoder=config.json_encoder + json_encoder=config.json_encoder, + headers=headers ) return token.decode('utf-8')