Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing in a single string location to the locations kwarg #402

Merged
merged 3 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/add_custom_data_claims.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Storing Data in Access Tokens
=============================
Storing Additional Data in JWTs
===============================
You may want to store additional information in the access token which you could
later access in the protected views. This can be done using the ``additional_claims``
argument with the :func:`~flask_jwt_extended.create_access_token` or
Expand Down
2 changes: 1 addition & 1 deletion examples/additional_data_in_access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def login():
# In a protected view, get the claims you added to the jwt with the
# get_jwt() method
@app.route("/protected", methods=["GET"])
@jwt_required
@jwt_required()
def protected():
claims = get_jwt()
return jsonify(foo=claims["foo"])
Expand Down
2 changes: 1 addition & 1 deletion examples/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def login():


@app.route("/protected", methods=["GET"])
@jwt_required
@jwt_required()
def protected():
return jsonify(hello="world")

Expand Down
34 changes: 18 additions & 16 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
If ``True``, require a refresh JWT to be verified.

:param locations:
A list of locations to look for the JWT in this request, for example:
``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs
will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION``
configuration option.
A location or list of locations to look for the JWT in this request, for
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
which indicates that JWTs will be looked for in the locations defined by the
``JWT_TOKEN_LOCATION`` configuration option.
"""
if request.method in config.exempt_methods:
return
Expand Down Expand Up @@ -103,10 +103,10 @@ def jwt_required(optional=False, fresh=False, refresh=False, locations=None):
requires an access JWT to access this endpoint. Defaults to ``False``.

:param locations:
A list of locations to look for the JWT in this request, for example:
``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs
will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION``
configuration option.
A location or list of locations to look for the JWT in this request, for
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
which indicates that JWTs will be looked for in the locations defined by the
``JWT_TOKEN_LOCATION`` configuration option.
"""

def wrapper(fn):
Expand Down Expand Up @@ -227,26 +227,28 @@ def _decode_jwt_from_json(refresh):


def _decode_jwt_from_request(locations, fresh, refresh=False):
# All the places we can get a JWT from in this request
get_encoded_token_functions = []
# Figure out what locations to look for the JWT in this request
if isinstance(locations, str):
locations = [locations]

# Get locations in the order specified by the decorator or JWT_TOKEN_LOCATION
# configuration.
if not locations:
locations = config.token_location

# Add the functions in the order specified by locations.
# Get the decode functions in the order specified by locations.
get_encoded_token_functions = []
for location in locations:
if location == "cookies":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_cookies(refresh)
)
if location == "query_string":
elif location == "query_string":
get_encoded_token_functions.append(_decode_jwt_from_query_string)
if location == "headers":
elif location == "headers":
get_encoded_token_functions.append(_decode_jwt_from_headers)
if location == "json":
elif location == "json":
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))
else:
raise RuntimeError(f"'{location}' is not a valid location")

# Try to find the token from one of these locations. It only needs to exist
# in one place to be valid (not every location).
Expand Down
32 changes: 32 additions & 0 deletions tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,38 @@ def test_jwt_optional(app, delta_func):
assert response.get_json() == {"msg": "Token has expired"}


def test_override_jwt_location(app):
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]

@app.route("/protected_other")
@jwt_required(locations="headers")
def protected_other():
return jsonify(foo="bar")

@app.route("/protected_invalid")
@jwt_required(locations="INVALID_LOCATION")
def protected_invalid():
return jsonify(foo="bar")

test_client = app.test_client()
with app.test_request_context():
access_token = create_access_token("username")

url = "/protected_other"
response = test_client.get(url, headers=make_headers(access_token))
assert response.get_json() == {"foo": "bar"}
assert response.status_code == 200

url = "/protected"
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 401
assert response.get_json() == {"msg": 'Missing cookie "access_token_cookie"'}

url = "/protected_invalid"
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 500


def test_invalid_jwt(app):
url = "/protected"
jwtM = get_jwt_manager(app)
Expand Down