Skip to content

Commit

Permalink
Fix error messaging for edge case with current_user
Browse files Browse the repository at this point in the history
  • Loading branch information
Landon Gilbert-Bland committed May 2, 2021
1 parent ef3da3c commit a0f206e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
5 changes: 3 additions & 2 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_jwt():
decoded_jwt = getattr(_request_ctx_stack.top, "jwt", None)
if decoded_jwt is None:
raise RuntimeError(
"You must call `@jwt_required()` or `verify_jwt_in_request` "
"You must call `@jwt_required()` or `verify_jwt_in_request()` "
"before using this method"
)
return decoded_jwt
Expand All @@ -40,7 +40,7 @@ def get_jwt_header():
decoded_header = getattr(_request_ctx_stack.top, "jwt_header", None)
if decoded_header is None:
raise RuntimeError(
"You must call `@jwt_required()` or `verify_jwt_in_request` "
"You must call `@jwt_required()` or `verify_jwt_in_request()` "
"before using this method"
)
return decoded_header
Expand Down Expand Up @@ -87,6 +87,7 @@ def get_current_user():
:return:
The current user object for the JWT in the current request
"""
get_jwt() # Raise an error if not in a decorated context
jwt_user_dict = getattr(_request_ctx_stack.top, "jwt_user", None)
if jwt_user_dict is None:
raise RuntimeError(
Expand Down
20 changes: 18 additions & 2 deletions tests/test_user_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,32 @@ def app():
@app.route("/get_user1", methods=["GET"])
@jwt_required()
def get_user1():
return jsonify(foo=get_current_user()["username"])
try:
return jsonify(foo=get_current_user()["username"])
except RuntimeError as e:
return jsonify(error=str(e))

@app.route("/get_user2", methods=["GET"])
@jwt_required()
def get_user2():
return jsonify(foo=current_user["username"])
try:
return jsonify(foo=current_user["username"])
except RuntimeError as e:
return jsonify(error=str(e))

return app


@pytest.mark.parametrize("url", ["/get_user1", "/get_user2"])
def test_no_user_lookup_loader_specified(app, url):
test_client = app.test_client()
with app.test_request_context():
access_token = create_access_token("username")

response = test_client.get(url, headers=make_headers(access_token))
assert "@jwt.user_lookup_loader" in response.get_json()["error"]


@pytest.mark.parametrize("url", ["/get_user1", "/get_user2"])
def test_load_valid_user(app, url):
jwt = get_jwt_manager(app)
Expand Down

0 comments on commit a0f206e

Please sign in to comment.