Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_jwt_request_location() function to find out where the request JWT was located #420

Merged
merged 3 commits into from
May 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flask_jwt_extended/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .utils import get_jwt
from .utils import get_jwt_header
from .utils import get_jwt_identity
from .utils import get_jwt_request_location
from .utils import get_unverified_jwt_headers
from .utils import set_access_cookies
from .utils import set_refresh_cookies
Expand Down
15 changes: 15 additions & 0 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def get_jwt_identity():
return get_jwt().get(config.identity_claim_key, None)


def get_jwt_request_location():
"""
In a protected endpoint, this will return the "location" at which the JWT
that is accessing the endpoint was found--e.g., "cookies", "query-string",
"headers", or "json". If no JWT is present due to ``jwt_required(optional=True)``,
None is returned.

:return:
The location of the JWT in the current request; e.g., cookies",
"query-string", "headers", or "json"
"""
location = getattr(_request_ctx_stack.top, "jwt_location", None)
return location


def get_current_user():
"""
In a protected endpoint, this will return the user object for the JWT that
Expand Down
27 changes: 19 additions & 8 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,28 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=

try:
if refresh:
jwt_data, jwt_header = _decode_jwt_from_request(
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh, refresh=True
)
else:
jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh)
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh
)
except (NoAuthorizationError, InvalidHeaderError):
if not optional:
raise
_request_ctx_stack.top.jwt = {}
_request_ctx_stack.top.jwt_header = {}
_request_ctx_stack.top.jwt_user = {"loaded_user": None}
_request_ctx_stack.top.jwt_location = None
return

# Save these at the very end so that they are only saved in the requet
# context if the token is valid and all callbacks succeed
_request_ctx_stack.top.jwt_user = _load_user(jwt_header, jwt_data)
_request_ctx_stack.top.jwt_header = jwt_header
_request_ctx_stack.top.jwt = jwt_data
_request_ctx_stack.top.jwt_location = jwt_location

return jwt_header, jwt_data

Expand Down Expand Up @@ -235,18 +239,23 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
locations = config.token_location

# Get the decode functions in the order specified by locations.
# Each entry in this list is a tuple (<location>, <encoded-token-function>)
get_encoded_token_functions = []
for location in locations:
if location == "cookies":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_cookies(refresh)
(location, lambda: _decode_jwt_from_cookies(refresh))
)
elif location == "query_string":
get_encoded_token_functions.append(_decode_jwt_from_query_string)
get_encoded_token_functions.append(
(location, _decode_jwt_from_query_string)
)
elif location == "headers":
get_encoded_token_functions.append(_decode_jwt_from_headers)
get_encoded_token_functions.append((location, _decode_jwt_from_headers))
elif location == "json":
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))
get_encoded_token_functions.append(
(location, lambda: _decode_jwt_from_json(refresh))
)
else:
raise RuntimeError(f"'{location}' is not a valid location")

Expand All @@ -255,10 +264,12 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
errors = []
decoded_token = None
jwt_header = None
for get_encoded_token_function in get_encoded_token_functions:
jwt_location = None
for location, get_encoded_token_function in get_encoded_token_functions:
try:
encoded_token, csrf_token = get_encoded_token_function()
decoded_token = decode_token(encoded_token, csrf_token)
jwt_location = location
jwt_header = get_unverified_jwt_headers(encoded_token)
break
except NoAuthorizationError as e:
Expand All @@ -284,4 +295,4 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
verify_token_not_blocklisted(jwt_header, decoded_token)
custom_verification_for_token(jwt_header, decoded_token)

return decoded_token, jwt_header
return decoded_token, jwt_header, jwt_location
19 changes: 10 additions & 9 deletions tests/test_multiple_token_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from flask import jsonify

from flask_jwt_extended import create_access_token
from flask_jwt_extended import get_jwt_request_location
from flask_jwt_extended import jwt_required
from flask_jwt_extended import JWTManager
from flask_jwt_extended import set_access_cookies
Expand All @@ -25,7 +26,7 @@ def cookie_login():
@app.route("/protected", methods=["GET", "POST"])
@jwt_required()
def access_protected():
return jsonify(foo="bar")
return jsonify(foo="bar", location=get_jwt_request_location())

return app

Expand All @@ -48,7 +49,7 @@ def cookie_login():
@app.route("/protected", methods=["GET", "POST"])
@jwt_required(locations=locations)
def access_protected():
return jsonify(foo="bar")
return jsonify(foo="bar", location=get_jwt_request_location())

return app

Expand All @@ -62,7 +63,7 @@ def test_header_access(app, app_with_locations):
access_headers = {"Authorization": "Bearer {}".format(access_token)}
response = test_client.get("/protected", headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
assert response.get_json() == {"foo": "bar", "location": "headers"}


def test_cookie_access(app, app_with_locations):
Expand All @@ -71,7 +72,7 @@ def test_cookie_access(app, app_with_locations):
test_client.get("/cookie_login")
response = test_client.get("/protected")
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
assert response.get_json() == {"foo": "bar", "location": "cookies"}


def test_query_string_access(app, app_with_locations):
Expand All @@ -83,7 +84,7 @@ def test_query_string_access(app, app_with_locations):
url = "/protected?jwt={}".format(access_token)
response = test_client.get(url)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
assert response.get_json() == {"foo": "bar", "location": "query_string"}


def test_json_access(app, app_with_locations):
Expand All @@ -94,7 +95,7 @@ def test_json_access(app, app_with_locations):
data = {"access_token": access_token}
response = test_client.post("/protected", json=data)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
assert response.get_json() == {"foo": "bar", "location": "json"}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -129,8 +130,8 @@ def test_no_jwt_in_request(app, options):
@pytest.mark.parametrize(
"options",
[
(["cookies", "headers"], 200, None, {"foo": "bar"}),
(["headers", "cookies"], 200, None, {"foo": "bar"}),
(["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}),
(["headers", "cookies"], 200, None, {"foo": "bar", "location": "cookies"}),
],
)
def test_order_of_jwt_locations_in_request(app, options):
Expand All @@ -151,7 +152,7 @@ def test_order_of_jwt_locations_in_request(app, options):
@pytest.mark.parametrize(
"options",
[
(["cookies", "headers"], 200, None, {"foo": "bar"}),
(["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}),
(["headers", "cookies"], 422, ("Invalid header padding"), None),
],
)
Expand Down