Skip to content

Commit

Permalink
Allow locations kwarg for jwt_required() to be a string
Browse files Browse the repository at this point in the history
Fixes #394
  • Loading branch information
Landon Gilbert-Bland committed Mar 9, 2021
1 parent ead8f60 commit c12eefc
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
38 changes: 22 additions & 16 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
If ``True``, require a refresh JWT to be verified.
:param locations:
A list of locations to look for the JWT in this request, for example:
``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs
will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION``
configuration option.
A location or list of locations to look for the JWT in this request, for
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
which indicates that JWTs will be looked for in the locations defined by the
``JWT_TOKEN_LOCATION`` configuration option.
"""
if request.method in config.exempt_methods:
return
Expand Down Expand Up @@ -103,10 +103,10 @@ def jwt_required(optional=False, fresh=False, refresh=False, locations=None):
requires an access JWT to access this endpoint. Defaults to ``False``.
:param locations:
A list of locations to look for the JWT in this request, for example:
``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs
will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION``
configuration option.
A location or list of locations to look for the JWT in this request, for
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
which indicates that JWTs will be looked for in the locations defined by the
``JWT_TOKEN_LOCATION`` configuration option.
"""

def wrapper(fn):
Expand Down Expand Up @@ -226,27 +226,33 @@ def _decode_jwt_from_json(refresh):
return encoded_token, None


def _invalid_location(location):
raise NoAuthorizationError(f"'{location}' is not a valid location")


def _decode_jwt_from_request(locations, fresh, refresh=False):
# All the places we can get a JWT from in this request
get_encoded_token_functions = []
# Figure out what locations to look for the JWT in this request
if isinstance(locations, str):
locations = [locations]

# Get locations in the order specified by the decorator or JWT_TOKEN_LOCATION
# configuration.
if not locations:
locations = config.token_location

# Add the functions in the order specified by locations.
# Get the decode functions in the order specified by locations.
get_encoded_token_functions = []
for location in locations:
if location == "cookies":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_cookies(refresh)
)
if location == "query_string":
elif location == "query_string":
get_encoded_token_functions.append(_decode_jwt_from_query_string)
if location == "headers":
elif location == "headers":
get_encoded_token_functions.append(_decode_jwt_from_headers)
if location == "json":
elif location == "json":
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))
else:
get_encoded_token_functions.append(lambda: _invalid_location(location))

# Try to find the token from one of these locations. It only needs to exist
# in one place to be valid (not every location).
Expand Down
33 changes: 33 additions & 0 deletions tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,39 @@ def test_jwt_optional(app, delta_func):
assert response.get_json() == {"msg": "Token has expired"}


def test_override_jwt_location(app):
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]

@app.route("/protected_other")
@jwt_required(locations="headers")
def protected_other():
return jsonify(foo="bar")

@app.route("/protected_invalid")
@jwt_required(locations="INVALID_LOCATION")
def protected_invalid():
return jsonify(foo="bar")

test_client = app.test_client()
with app.test_request_context():
access_token = create_access_token("username")

url = "/protected_other"
response = test_client.get(url, headers=make_headers(access_token))
assert response.get_json() == {"foo": "bar"}
assert response.status_code == 200

url = "/protected"
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 401
assert response.get_json() == {"msg": 'Missing cookie "access_token_cookie"'}

url = "/protected_invalid"
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 401
assert response.get_json() == {"msg": "'INVALID_LOCATION' is not a valid location"}


def test_invalid_jwt(app):
url = "/protected"
jwtM = get_jwt_manager(app)
Expand Down

0 comments on commit c12eefc

Please sign in to comment.