Skip to content

Commit

Permalink
Compatibility changes for upcoming flask version 2.3 (#493)
Browse files Browse the repository at this point in the history
* Switched from _request_ctx_stack.top to flask.g

* Handle JSONEncoder changes
  • Loading branch information
jrast authored Aug 15, 2022
1 parent 37634ed commit 88a628e
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 32 deletions.
5 changes: 3 additions & 2 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from json import JSONEncoder
from typing import Iterable
from typing import List
from typing import Optional
Expand All @@ -9,9 +10,9 @@
from typing import Union

from flask import current_app
from flask.json import JSONEncoder
from jwt.algorithms import requires_cryptography

from flask_jwt_extended.internal_utils import get_json_encoder
from flask_jwt_extended.typing import ExpiresDelta


Expand Down Expand Up @@ -284,7 +285,7 @@ def error_msg_key(self) -> str:

@property
def json_encoder(self) -> Type[JSONEncoder]:
return current_app.json_encoder
return get_json_encoder(current_app)

@property
def decode_audience(self) -> Union[str, Iterable[str]]:
Expand Down
45 changes: 45 additions & 0 deletions flask_jwt_extended/internal_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import json
from typing import Any
from typing import Type
from typing import TYPE_CHECKING

from flask import current_app
from flask import Flask

from flask_jwt_extended.exceptions import RevokedTokenError
from flask_jwt_extended.exceptions import UserClaimsVerificationError
from flask_jwt_extended.exceptions import WrongTokenError

try:
from flask.json.provider import DefaultJSONProvider

HAS_JSON_PROVIDER = True
except ModuleNotFoundError: # pragma: no cover
# The flask.json.provider module was added in Flask 2.2.
# Further details are handled in get_json_encoder.
HAS_JSON_PROVIDER = False


if TYPE_CHECKING: # pragma: no cover
from flask_jwt_extended import JWTManager

Expand Down Expand Up @@ -51,3 +64,35 @@ def custom_verification_for_token(jwt_header: dict, jwt_data: dict) -> None:
if not jwt_manager._token_verification_callback(jwt_header, jwt_data):
error_msg = "User claims verification failed"
raise UserClaimsVerificationError(error_msg, jwt_header, jwt_data)


class JSONEncoder(json.JSONEncoder):
"""A JSON encoder which uses the app.json_provider_class for the default"""

def default(self, o: Any) -> Any:
# If the registered JSON provider does not implement a default classmethod
# use the method defined by the DefaultJSONProvider
default = getattr(
current_app.json_provider_class, "default", DefaultJSONProvider.default
)
return default(o)


def get_json_encoder(app: Flask) -> Type[json.JSONEncoder]:
"""Get the JSON Encoder for the provided flask app
Starting with flask version 2.2 the flask application provides a
interface to register a custom JSON Encoder/Decoder under the json_provider_class.
As this interface is not compatible with the standard JSONEncoder, the `default`
method of the class is wrapped.
Lookup Order:
- app.json_encoder - For Flask < 2.2
- app.json_provider_class.default
- flask.json.provider.DefaultJSONProvider.default
"""
if not HAS_JSON_PROVIDER: # pragma: no cover
return app.json_encoder

return JSONEncoder
2 changes: 1 addition & 1 deletion flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from datetime import timedelta
from datetime import timezone
from hmac import compare_digest
from json import JSONEncoder
from typing import Any
from typing import Iterable
from typing import List
from typing import Type
from typing import Union

import jwt
from flask.json import JSONEncoder

from flask_jwt_extended.exceptions import CSRFError
from flask_jwt_extended.exceptions import JWTDecodeError
Expand Down
10 changes: 5 additions & 5 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

import jwt
from flask import _request_ctx_stack
from flask import g
from flask import Response
from werkzeug.local import LocalProxy

Expand All @@ -23,7 +23,7 @@ def get_jwt() -> dict:
:return:
The payload (claims) of the JWT in the current request
"""
decoded_jwt = getattr(_request_ctx_stack.top, "jwt", None)
decoded_jwt = g.get("_jwt_extended_jwt", None)
if decoded_jwt is None:
raise RuntimeError(
"You must call `@jwt_required()` or `verify_jwt_in_request()` "
Expand All @@ -41,7 +41,7 @@ def get_jwt_header() -> dict:
:return:
The headers of the JWT in the current request
"""
decoded_header = getattr(_request_ctx_stack.top, "jwt_header", None)
decoded_header = g.get("_jwt_extended_jwt_header", None)
if decoded_header is None:
raise RuntimeError(
"You must call `@jwt_required()` or `verify_jwt_in_request()` "
Expand Down Expand Up @@ -73,7 +73,7 @@ def get_jwt_request_location() -> Optional[str]:
The location of the JWT in the current request; e.g., "cookies",
"query-string", "headers", or "json"
"""
return getattr(_request_ctx_stack.top, "jwt_location", None)
return g.get("_jwt_extended_jwt_location", None)


def get_current_user() -> Any:
Expand All @@ -91,7 +91,7 @@ def get_current_user() -> Any:
The current user object for the JWT in the current request
"""
get_jwt() # Raise an error if not in a decorated context
jwt_user_dict = getattr(_request_ctx_stack.top, "jwt_user", None)
jwt_user_dict = g.get("_jwt_extended_jwt_user", None)
if jwt_user_dict is None:
raise RuntimeError(
"You must provide a `@jwt.user_lookup_loader` callback to use "
Expand Down
22 changes: 9 additions & 13 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from typing import Tuple
from typing import Union

from flask import _request_ctx_stack
from flask import current_app
from flask import g
from flask import request
from werkzeug.exceptions import BadRequest

Expand Down Expand Up @@ -85,10 +85,6 @@ def verify_jwt_in_request(
if request.method in config.exempt_methods:
return None

# Should be impossible to hit, this makes mypy checks happy
if not _request_ctx_stack.top: # pragma: no cover
raise RuntimeError("No _request_ctx_stack.top present, aborting")

try:
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh, refresh=refresh, verify_type=verify_type
Expand All @@ -97,18 +93,18 @@ def verify_jwt_in_request(
except NoAuthorizationError:
if not optional:
raise
_request_ctx_stack.top.jwt = {}
_request_ctx_stack.top.jwt_header = {}
_request_ctx_stack.top.jwt_user = {"loaded_user": None}
_request_ctx_stack.top.jwt_location = None
g._jwt_extended_jwt = {}
g._jwt_extended_jwt_header = {}
g._jwt_extended_jwt_user = {"loaded_user": None}
g._jwt_extended_jwt_location = None
return None

# Save these at the very end so that they are only saved in the requet
# context if the token is valid and all callbacks succeed
_request_ctx_stack.top.jwt_user = _load_user(jwt_header, jwt_data)
_request_ctx_stack.top.jwt_header = jwt_header
_request_ctx_stack.top.jwt = jwt_data
_request_ctx_stack.top.jwt_location = jwt_location
g._jwt_extended_jwt_user = _load_user(jwt_header, jwt_data)
g._jwt_extended_jwt_header = jwt_header
g._jwt_extended_jwt = jwt_data
g._jwt_extended_jwt_location = jwt_location

return jwt_header, jwt_data

Expand Down
37 changes: 27 additions & 10 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import json
from datetime import date
from datetime import timedelta

import pytest
from dateutil.relativedelta import relativedelta
from flask import __version__ as flask_version
from flask import Flask
from flask.json import JSONEncoder

from flask_jwt_extended import JWTManager
from flask_jwt_extended.config import config
from flask_jwt_extended.internal_utils import JSONEncoder


flask_version_tuple = tuple(map(int, flask_version.split(".")))


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -65,8 +71,6 @@ def test_default_configs(app):

assert config.identity_claim_key == "sub"

assert config.json_encoder is app.json_encoder

assert config.error_msg_key == "msg"


Expand Down Expand Up @@ -112,11 +116,6 @@ def test_override_configs(app, delta_func):

app.config["JWT_ERROR_MESSAGE_KEY"] = "message"

class CustomJSONEncoder(JSONEncoder):
pass

app.json_encoder = CustomJSONEncoder

with app.test_request_context():
assert config.token_location == ["cookies", "query_string", "json"]
assert config.jwt_in_query_string is True
Expand Down Expand Up @@ -162,11 +161,29 @@ class CustomJSONEncoder(JSONEncoder):

assert config.identity_claim_key == "foo"

assert config.json_encoder is CustomJSONEncoder

assert config.error_msg_key == "message"


@pytest.mark.skipif(
flask_version_tuple >= (2, 2, 0), reason="Only applies to Flask <= 2.2.0"
)
def test_config_json_encoder_flask21(app):
with app.test_request_context():
assert config.json_encoder == app.json_encoder
dump = json.dumps({"d": date(2022, 8, 12)}, cls=config.json_encoder)
assert dump == '{"d": "Fri, 12 Aug 2022 00:00:00 GMT"}'


@pytest.mark.skipif(
flask_version_tuple < (2, 2, 0), reason="Only applies to Flask > 2.2.0"
)
def test_config_json_encoder_flask(app):
with app.test_request_context():
assert config.json_encoder == JSONEncoder
dump = json.dumps({"d": date(2022, 8, 12)}, cls=config.json_encoder)
assert dump == '{"d": "Fri, 12 Aug 2022 00:00:00 GMT"}'


def test_tokens_never_expire(app):
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = False
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = False
Expand Down
4 changes: 3 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# and then run "tox" from this directory.

[tox]
envlist = py37,py38,py39,py310,pypy3.9,mypy,coverage,style,docs
envlist = py{37,38,39,310}-{flask21,flask},pypy3.9,mypy,coverage,style,docs

[testenv]
commands =
Expand All @@ -13,6 +13,8 @@ deps =
pytest
cryptography
python-dateutil
flask21: Flask>=2.1,<2.2
flask: Flask>=2.2

[testenv:mypy]
commands =
Expand Down

0 comments on commit 88a628e

Please sign in to comment.