Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support JWT types other than refresh and access #401

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions flask_jwt_extended/internal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ def user_lookup(*args, **kwargs):
return jwt_manager._user_lookup_callback(*args, **kwargs)


def verify_token_type(decoded_token, expected_type):
if decoded_token["type"] != expected_type:
raise WrongTokenError("Only {} tokens are allowed".format(expected_type))
def verify_token_type(decoded_token, refresh):
if not refresh and decoded_token["type"] == "refresh":
raise WrongTokenError("Only non-refresh tokens are allowed")
elif refresh and decoded_token["type"] != "refresh":
raise WrongTokenError("Only refresh tokens are allowed")


def verify_token_not_blocklisted(jwt_header, jwt_data, request_type):
def verify_token_not_blocklisted(jwt_header, jwt_data):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for cleaning this up 👍

jwt_manager = get_jwt_manager()
if jwt_manager._token_in_blocklist_callback(jwt_header, jwt_data):
raise RevokedTokenError(jwt_header, jwt_data)
Expand Down
3 changes: 0 additions & 3 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def _decode_jwt(
if "type" not in decoded_token:
decoded_token["type"] = "access"

if decoded_token["type"] not in ("access", "refresh"):
raise JWTDecodeError("Invalid token type: {}".format(decoded_token["type"]))

if "fresh" not in decoded_token:
decoded_token["fresh"] = False

Expand Down
41 changes: 20 additions & 21 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
Defaults to ``False``.

:param refresh:
If ``True``, require a refresh JWT to be verified. If ``False`` require an access
JWT to be verified. Defaults to ``False``.
If ``True``, require a refresh JWT to be verified.

:param locations:
A list of locations to look for the JWT in this request, for example:
Expand All @@ -61,9 +60,11 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=

try:
if refresh:
jwt_data, jwt_header = _decode_jwt_from_request("refresh", locations, fresh)
jwt_data, jwt_header = _decode_jwt_from_request(
locations, fresh, refresh=True
)
else:
jwt_data, jwt_header = _decode_jwt_from_request("access", locations, fresh)
jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh)
except (NoAuthorizationError, InvalidHeaderError):
if not optional:
raise
Expand Down Expand Up @@ -170,15 +171,15 @@ def _decode_jwt_from_headers():
return encoded_token, None


def _decode_jwt_from_cookies(token_type):
if token_type == "access":
cookie_key = config.access_cookie_name
csrf_header_key = config.access_csrf_header_name
csrf_field_key = config.access_csrf_field_name
else:
def _decode_jwt_from_cookies(refresh):
if refresh:
cookie_key = config.refresh_cookie_name
csrf_header_key = config.refresh_csrf_header_name
csrf_field_key = config.refresh_csrf_field_name
else:
cookie_key = config.access_cookie_name
csrf_header_key = config.access_csrf_header_name
csrf_field_key = config.access_csrf_field_name

encoded_token = request.cookies.get(cookie_key)
if not encoded_token:
Expand All @@ -205,15 +206,15 @@ def _decode_jwt_from_query_string():
return encoded_token, None


def _decode_jwt_from_json(token_type):
def _decode_jwt_from_json(refresh):
content_type = request.content_type or ""
if not content_type.startswith("application/json"):
raise NoAuthorizationError("Invalid content-type. Must be application/json.")

if token_type == "access":
token_key = config.json_key
else:
if refresh:
token_key = config.refresh_json_key
else:
token_key = config.json_key

try:
encoded_token = request.json.get(token_key, None)
Expand All @@ -225,7 +226,7 @@ def _decode_jwt_from_json(token_type):
return encoded_token, None


def _decode_jwt_from_request(token_type, locations, fresh):
def _decode_jwt_from_request(locations, fresh, refresh=False):
# All the places we can get a JWT from in this request
get_encoded_token_functions = []

Expand All @@ -238,16 +239,14 @@ def _decode_jwt_from_request(token_type, locations, fresh):
for location in locations:
if location == "cookies":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_cookies(token_type)
lambda: _decode_jwt_from_cookies(refresh)
)
if location == "query_string":
get_encoded_token_functions.append(_decode_jwt_from_query_string)
if location == "headers":
get_encoded_token_functions.append(_decode_jwt_from_headers)
if location == "json":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_json(token_type)
)
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))

# Try to find the token from one of these locations. It only needs to exist
# in one place to be valid (not every location).
Expand Down Expand Up @@ -277,10 +276,10 @@ def _decode_jwt_from_request(token_type, locations, fresh):
raise NoAuthorizationError(errors[0])

# Additional verifications provided by this extension
verify_token_type(decoded_token, expected_type=token_type)
verify_token_type(decoded_token, refresh)
if fresh:
_verify_token_is_fresh(jwt_header, decoded_token)
verify_token_not_blocklisted(jwt_header, decoded_token, token_type)
verify_token_not_blocklisted(jwt_header, decoded_token)
custom_verification_for_token(jwt_header, decoded_token)

return decoded_token, jwt_header
12 changes: 6 additions & 6 deletions tests/test_decode_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def test_default_decode_token_values(app, default_access_token):
assert decoded["fresh"] is False


def test_bad_token_type(app, default_access_token):
default_access_token["type"] = "banana"
bad_type_token = encode_token(app, default_access_token)
def test_supports_decoding_other_token_types(app, default_access_token):
default_access_token["type"] = "app"
other_token = encode_token(app, default_access_token)

with pytest.raises(JWTDecodeError):
with app.test_request_context():
decode_token(bad_type_token)
with app.test_request_context():
decoded = decode_token(other_token)
assert decoded["type"] == "app"


def test_encode_decode_audience(app):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_jwt_required(app):
# Test refresh token access to jwt_required
response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only access tokens are allowed"}
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}


def test_fresh_jwt_required(app):
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_fresh_jwt_required(app):

response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only access tokens are allowed"}
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}

# Test with custom response
@jwtM.needs_fresh_token_loader
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_jwt_optional(app, delta_func):

response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only access tokens are allowed"}
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}

response = test_client.get(url, headers=None)
assert response.status_code == 200
Expand Down