diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 3386ebcb..6e9069bc 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -9,6 +9,7 @@ from .utils import get_jwt from .utils import get_jwt_header from .utils import get_jwt_identity +from .utils import get_jwt_request_location from .utils import get_unverified_jwt_headers from .utils import set_access_cookies from .utils import set_refresh_cookies diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 28753e00..6eb6289f 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -58,6 +58,21 @@ def get_jwt_identity(): return get_jwt().get(config.identity_claim_key, None) +def get_jwt_request_location(): + """ + In a protected endpoint, this will return the "location" at which the JWT + that is accessing the endpoint was found--e.g., "cookies", "query-string", + "headers", or "json". If no JWT is present due to ``jwt_required(optional=True)``, + None is returned. + + :return: + The location of the JWT in the current request; e.g., cookies", + "query-string", "headers", or "json" + """ + location = getattr(_request_ctx_stack.top, "jwt_location", None) + return location + + def get_current_user(): """ In a protected endpoint, this will return the user object for the JWT that diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 9419fe80..35c1bf01 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -60,17 +60,20 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= try: if refresh: - jwt_data, jwt_header = _decode_jwt_from_request( + jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( locations, fresh, refresh=True ) else: - jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh) + jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( + locations, fresh + ) except (NoAuthorizationError, InvalidHeaderError): if not optional: raise _request_ctx_stack.top.jwt = {} _request_ctx_stack.top.jwt_header = {} _request_ctx_stack.top.jwt_user = {"loaded_user": None} + _request_ctx_stack.top.jwt_location = None return # Save these at the very end so that they are only saved in the requet @@ -78,6 +81,7 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= _request_ctx_stack.top.jwt_user = _load_user(jwt_header, jwt_data) _request_ctx_stack.top.jwt_header = jwt_header _request_ctx_stack.top.jwt = jwt_data + _request_ctx_stack.top.jwt_location = jwt_location return jwt_header, jwt_data @@ -235,18 +239,23 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): locations = config.token_location # Get the decode functions in the order specified by locations. + # Each entry in this list is a tuple (, ) get_encoded_token_functions = [] for location in locations: if location == "cookies": get_encoded_token_functions.append( - lambda: _decode_jwt_from_cookies(refresh) + (location, lambda: _decode_jwt_from_cookies(refresh)) ) elif location == "query_string": - get_encoded_token_functions.append(_decode_jwt_from_query_string) + get_encoded_token_functions.append( + (location, _decode_jwt_from_query_string) + ) elif location == "headers": - get_encoded_token_functions.append(_decode_jwt_from_headers) + get_encoded_token_functions.append((location, _decode_jwt_from_headers)) elif location == "json": - get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh)) + get_encoded_token_functions.append( + (location, lambda: _decode_jwt_from_json(refresh)) + ) else: raise RuntimeError(f"'{location}' is not a valid location") @@ -255,10 +264,12 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): errors = [] decoded_token = None jwt_header = None - for get_encoded_token_function in get_encoded_token_functions: + jwt_location = None + for location, get_encoded_token_function in get_encoded_token_functions: try: encoded_token, csrf_token = get_encoded_token_function() decoded_token = decode_token(encoded_token, csrf_token) + jwt_location = location jwt_header = get_unverified_jwt_headers(encoded_token) break except NoAuthorizationError as e: @@ -284,4 +295,4 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): verify_token_not_blocklisted(jwt_header, decoded_token) custom_verification_for_token(jwt_header, decoded_token) - return decoded_token, jwt_header + return decoded_token, jwt_header, jwt_location diff --git a/tests/test_multiple_token_locations.py b/tests/test_multiple_token_locations.py index 1016a173..ac551d29 100644 --- a/tests/test_multiple_token_locations.py +++ b/tests/test_multiple_token_locations.py @@ -3,6 +3,7 @@ from flask import jsonify from flask_jwt_extended import create_access_token +from flask_jwt_extended import get_jwt_request_location from flask_jwt_extended import jwt_required from flask_jwt_extended import JWTManager from flask_jwt_extended import set_access_cookies @@ -25,7 +26,7 @@ def cookie_login(): @app.route("/protected", methods=["GET", "POST"]) @jwt_required() def access_protected(): - return jsonify(foo="bar") + return jsonify(foo="bar", location=get_jwt_request_location()) return app @@ -48,7 +49,7 @@ def cookie_login(): @app.route("/protected", methods=["GET", "POST"]) @jwt_required(locations=locations) def access_protected(): - return jsonify(foo="bar") + return jsonify(foo="bar", location=get_jwt_request_location()) return app @@ -62,7 +63,7 @@ def test_header_access(app, app_with_locations): access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 - assert response.get_json() == {"foo": "bar"} + assert response.get_json() == {"foo": "bar", "location": "headers"} def test_cookie_access(app, app_with_locations): @@ -71,7 +72,7 @@ def test_cookie_access(app, app_with_locations): test_client.get("/cookie_login") response = test_client.get("/protected") assert response.status_code == 200 - assert response.get_json() == {"foo": "bar"} + assert response.get_json() == {"foo": "bar", "location": "cookies"} def test_query_string_access(app, app_with_locations): @@ -83,7 +84,7 @@ def test_query_string_access(app, app_with_locations): url = "/protected?jwt={}".format(access_token) response = test_client.get(url) assert response.status_code == 200 - assert response.get_json() == {"foo": "bar"} + assert response.get_json() == {"foo": "bar", "location": "query_string"} def test_json_access(app, app_with_locations): @@ -94,7 +95,7 @@ def test_json_access(app, app_with_locations): data = {"access_token": access_token} response = test_client.post("/protected", json=data) assert response.status_code == 200 - assert response.get_json() == {"foo": "bar"} + assert response.get_json() == {"foo": "bar", "location": "json"} @pytest.mark.parametrize( @@ -129,8 +130,8 @@ def test_no_jwt_in_request(app, options): @pytest.mark.parametrize( "options", [ - (["cookies", "headers"], 200, None, {"foo": "bar"}), - (["headers", "cookies"], 200, None, {"foo": "bar"}), + (["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}), + (["headers", "cookies"], 200, None, {"foo": "bar", "location": "cookies"}), ], ) def test_order_of_jwt_locations_in_request(app, options): @@ -151,7 +152,7 @@ def test_order_of_jwt_locations_in_request(app, options): @pytest.mark.parametrize( "options", [ - (["cookies", "headers"], 200, None, {"foo": "bar"}), + (["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}), (["headers", "cookies"], 422, ("Invalid header padding"), None), ], )