Skip to content

Commit

Permalink
Add get_jwt_request_location() function to find out where the request…
Browse files Browse the repository at this point in the history
… JWT was located (#420)

* Add get_jwt_request_location() function to find out where the request JWT was located

Sometimes it is desirable to change behavior of a view based on where the JWT was
located. For example, if the same route is used for cookie-based access or
header-based access, and you want to implicitly refresh cookie-based access tokens.

With this change, a protected view can determine which location (e.g., "cookies",
"headers", "query_string", or "json") was selected as the source of the current
request's JWT.

* Fix linting errors

* fix linting issues
  • Loading branch information
sammck authored May 2, 2021
1 parent fdc0602 commit ef3da3c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 17 deletions.
1 change: 1 addition & 0 deletions flask_jwt_extended/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .utils import get_jwt
from .utils import get_jwt_header
from .utils import get_jwt_identity
from .utils import get_jwt_request_location
from .utils import get_unverified_jwt_headers
from .utils import set_access_cookies
from .utils import set_refresh_cookies
Expand Down
15 changes: 15 additions & 0 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def get_jwt_identity():
return get_jwt().get(config.identity_claim_key, None)


def get_jwt_request_location():
"""
In a protected endpoint, this will return the "location" at which the JWT
that is accessing the endpoint was found--e.g., "cookies", "query-string",
"headers", or "json". If no JWT is present due to ``jwt_required(optional=True)``,
None is returned.
:return:
The location of the JWT in the current request; e.g., cookies",
"query-string", "headers", or "json"
"""
location = getattr(_request_ctx_stack.top, "jwt_location", None)
return location


def get_current_user():
"""
In a protected endpoint, this will return the user object for the JWT that
Expand Down
27 changes: 19 additions & 8 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,28 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=

try:
if refresh:
jwt_data, jwt_header = _decode_jwt_from_request(
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh, refresh=True
)
else:
jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh)
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh
)
except (NoAuthorizationError, InvalidHeaderError):
if not optional:
raise
_request_ctx_stack.top.jwt = {}
_request_ctx_stack.top.jwt_header = {}
_request_ctx_stack.top.jwt_user = {"loaded_user": None}
_request_ctx_stack.top.jwt_location = None
return

# Save these at the very end so that they are only saved in the requet
# context if the token is valid and all callbacks succeed
_request_ctx_stack.top.jwt_user = _load_user(jwt_header, jwt_data)
_request_ctx_stack.top.jwt_header = jwt_header
_request_ctx_stack.top.jwt = jwt_data
_request_ctx_stack.top.jwt_location = jwt_location

return jwt_header, jwt_data

Expand Down Expand Up @@ -235,18 +239,23 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
locations = config.token_location

# Get the decode functions in the order specified by locations.
# Each entry in this list is a tuple (<location>, <encoded-token-function>)
get_encoded_token_functions = []
for location in locations:
if location == "cookies":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_cookies(refresh)
(location, lambda: _decode_jwt_from_cookies(refresh))
)
elif location == "query_string":
get_encoded_token_functions.append(_decode_jwt_from_query_string)
get_encoded_token_functions.append(
(location, _decode_jwt_from_query_string)
)
elif location == "headers":
get_encoded_token_functions.append(_decode_jwt_from_headers)
get_encoded_token_functions.append((location, _decode_jwt_from_headers))
elif location == "json":
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))
get_encoded_token_functions.append(
(location, lambda: _decode_jwt_from_json(refresh))
)
else:
raise RuntimeError(f"'{location}' is not a valid location")

Expand All @@ -255,10 +264,12 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
errors = []
decoded_token = None
jwt_header = None
for get_encoded_token_function in get_encoded_token_functions:
jwt_location = None
for location, get_encoded_token_function in get_encoded_token_functions:
try:
encoded_token, csrf_token = get_encoded_token_function()
decoded_token = decode_token(encoded_token, csrf_token)
jwt_location = location
jwt_header = get_unverified_jwt_headers(encoded_token)
break
except NoAuthorizationError as e:
Expand All @@ -284,4 +295,4 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
verify_token_not_blocklisted(jwt_header, decoded_token)
custom_verification_for_token(jwt_header, decoded_token)

return decoded_token, jwt_header
return decoded_token, jwt_header, jwt_location
19 changes: 10 additions & 9 deletions tests/test_multiple_token_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from flask import jsonify

from flask_jwt_extended import create_access_token
from flask_jwt_extended import get_jwt_request_location
from flask_jwt_extended import jwt_required
from flask_jwt_extended import JWTManager
from flask_jwt_extended import set_access_cookies
Expand All @@ -25,7 +26,7 @@ def cookie_login():
@app.route("/protected", methods=["GET", "POST"])
@jwt_required()
def access_protected():
return jsonify(foo="bar")
return jsonify(foo="bar", location=get_jwt_request_location())

return app

Expand All @@ -48,7 +49,7 @@ def cookie_login():
@app.route("/protected", methods=["GET", "POST"])
@jwt_required(locations=locations)
def access_protected():
return jsonify(foo="bar")
return jsonify(foo="bar", location=get_jwt_request_location())

return app

Expand All @@ -62,7 +63,7 @@ def test_header_access(app, app_with_locations):
access_headers = {"Authorization": "Bearer {}".format(access_token)}
response = test_client.get("/protected", headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
assert response.get_json() == {"foo": "bar", "location": "headers"}


def test_cookie_access(app, app_with_locations):
Expand All @@ -71,7 +72,7 @@ def test_cookie_access(app, app_with_locations):
test_client.get("/cookie_login")
response = test_client.get("/protected")
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
assert response.get_json() == {"foo": "bar", "location": "cookies"}


def test_query_string_access(app, app_with_locations):
Expand All @@ -83,7 +84,7 @@ def test_query_string_access(app, app_with_locations):
url = "/protected?jwt={}".format(access_token)
response = test_client.get(url)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
assert response.get_json() == {"foo": "bar", "location": "query_string"}


def test_json_access(app, app_with_locations):
Expand All @@ -94,7 +95,7 @@ def test_json_access(app, app_with_locations):
data = {"access_token": access_token}
response = test_client.post("/protected", json=data)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
assert response.get_json() == {"foo": "bar", "location": "json"}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -129,8 +130,8 @@ def test_no_jwt_in_request(app, options):
@pytest.mark.parametrize(
"options",
[
(["cookies", "headers"], 200, None, {"foo": "bar"}),
(["headers", "cookies"], 200, None, {"foo": "bar"}),
(["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}),
(["headers", "cookies"], 200, None, {"foo": "bar", "location": "cookies"}),
],
)
def test_order_of_jwt_locations_in_request(app, options):
Expand All @@ -151,7 +152,7 @@ def test_order_of_jwt_locations_in_request(app, options):
@pytest.mark.parametrize(
"options",
[
(["cookies", "headers"], 200, None, {"foo": "bar"}),
(["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}),
(["headers", "cookies"], 422, ("Invalid header padding"), None),
],
)
Expand Down

0 comments on commit ef3da3c

Please sign in to comment.