diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index 05322d8d..9089b91f 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -1,6 +1,7 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +from json import JSONEncoder from typing import Iterable from typing import List from typing import Optional @@ -9,9 +10,9 @@ from typing import Union from flask import current_app -from flask.json import JSONEncoder from jwt.algorithms import requires_cryptography +from flask_jwt_extended.internal_utils import get_json_encoder from flask_jwt_extended.typing import ExpiresDelta @@ -284,7 +285,7 @@ def error_msg_key(self) -> str: @property def json_encoder(self) -> Type[JSONEncoder]: - return current_app.json_encoder + return get_json_encoder(current_app) @property def decode_audience(self) -> Union[str, Iterable[str]]: diff --git a/flask_jwt_extended/internal_utils.py b/flask_jwt_extended/internal_utils.py index 3220ff1c..d3b7fb05 100644 --- a/flask_jwt_extended/internal_utils.py +++ b/flask_jwt_extended/internal_utils.py @@ -1,12 +1,25 @@ +import json from typing import Any +from typing import Type from typing import TYPE_CHECKING from flask import current_app +from flask import Flask from flask_jwt_extended.exceptions import RevokedTokenError from flask_jwt_extended.exceptions import UserClaimsVerificationError from flask_jwt_extended.exceptions import WrongTokenError +try: + from flask.json.provider import DefaultJSONProvider + + HAS_JSON_PROVIDER = True +except ModuleNotFoundError: # pragma: no cover + # The flask.json.provider module was added in Flask 2.2. + # Further details are handled in get_json_encoder. + HAS_JSON_PROVIDER = False + + if TYPE_CHECKING: # pragma: no cover from flask_jwt_extended import JWTManager @@ -51,3 +64,35 @@ def custom_verification_for_token(jwt_header: dict, jwt_data: dict) -> None: if not jwt_manager._token_verification_callback(jwt_header, jwt_data): error_msg = "User claims verification failed" raise UserClaimsVerificationError(error_msg, jwt_header, jwt_data) + + +class JSONEncoder(json.JSONEncoder): + """A JSON encoder which uses the app.json_provider_class for the default""" + + def default(self, o: Any) -> Any: + # If the registered JSON provider does not implement a default classmethod + # use the method defined by the DefaultJSONProvider + default = getattr( + current_app.json_provider_class, "default", DefaultJSONProvider.default + ) + return default(o) + + +def get_json_encoder(app: Flask) -> Type[json.JSONEncoder]: + """Get the JSON Encoder for the provided flask app + + Starting with flask version 2.2 the flask application provides a + interface to register a custom JSON Encoder/Decoder under the json_provider_class. + As this interface is not compatible with the standard JSONEncoder, the `default` + method of the class is wrapped. + + Lookup Order: + - app.json_encoder - For Flask < 2.2 + - app.json_provider_class.default + - flask.json.provider.DefaultJSONProvider.default + + """ + if not HAS_JSON_PROVIDER: # pragma: no cover + return app.json_encoder + + return JSONEncoder diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 260d665b..9a600a8e 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -3,6 +3,7 @@ from datetime import timedelta from datetime import timezone from hmac import compare_digest +from json import JSONEncoder from typing import Any from typing import Iterable from typing import List @@ -10,7 +11,6 @@ from typing import Union import jwt -from flask.json import JSONEncoder from flask_jwt_extended.exceptions import CSRFError from flask_jwt_extended.exceptions import JWTDecodeError diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 4f1004a2..da01f5b2 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -3,7 +3,7 @@ from typing import Optional import jwt -from flask import _request_ctx_stack +from flask import g from flask import Response from werkzeug.local import LocalProxy @@ -23,7 +23,7 @@ def get_jwt() -> dict: :return: The payload (claims) of the JWT in the current request """ - decoded_jwt = getattr(_request_ctx_stack.top, "jwt", None) + decoded_jwt = g.get("_jwt_extended_jwt", None) if decoded_jwt is None: raise RuntimeError( "You must call `@jwt_required()` or `verify_jwt_in_request()` " @@ -41,7 +41,7 @@ def get_jwt_header() -> dict: :return: The headers of the JWT in the current request """ - decoded_header = getattr(_request_ctx_stack.top, "jwt_header", None) + decoded_header = g.get("_jwt_extended_jwt_header", None) if decoded_header is None: raise RuntimeError( "You must call `@jwt_required()` or `verify_jwt_in_request()` " @@ -73,7 +73,7 @@ def get_jwt_request_location() -> Optional[str]: The location of the JWT in the current request; e.g., "cookies", "query-string", "headers", or "json" """ - return getattr(_request_ctx_stack.top, "jwt_location", None) + return g.get("_jwt_extended_jwt_location", None) def get_current_user() -> Any: @@ -91,7 +91,7 @@ def get_current_user() -> Any: The current user object for the JWT in the current request """ get_jwt() # Raise an error if not in a decorated context - jwt_user_dict = getattr(_request_ctx_stack.top, "jwt_user", None) + jwt_user_dict = g.get("_jwt_extended_jwt_user", None) if jwt_user_dict is None: raise RuntimeError( "You must provide a `@jwt.user_lookup_loader` callback to use " diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 2c2bf826..c87b01a9 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -8,8 +8,8 @@ from typing import Tuple from typing import Union -from flask import _request_ctx_stack from flask import current_app +from flask import g from flask import request from werkzeug.exceptions import BadRequest @@ -85,10 +85,6 @@ def verify_jwt_in_request( if request.method in config.exempt_methods: return None - # Should be impossible to hit, this makes mypy checks happy - if not _request_ctx_stack.top: # pragma: no cover - raise RuntimeError("No _request_ctx_stack.top present, aborting") - try: jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( locations, fresh, refresh=refresh, verify_type=verify_type @@ -97,18 +93,18 @@ def verify_jwt_in_request( except NoAuthorizationError: if not optional: raise - _request_ctx_stack.top.jwt = {} - _request_ctx_stack.top.jwt_header = {} - _request_ctx_stack.top.jwt_user = {"loaded_user": None} - _request_ctx_stack.top.jwt_location = None + g._jwt_extended_jwt = {} + g._jwt_extended_jwt_header = {} + g._jwt_extended_jwt_user = {"loaded_user": None} + g._jwt_extended_jwt_location = None return None # Save these at the very end so that they are only saved in the requet # context if the token is valid and all callbacks succeed - _request_ctx_stack.top.jwt_user = _load_user(jwt_header, jwt_data) - _request_ctx_stack.top.jwt_header = jwt_header - _request_ctx_stack.top.jwt = jwt_data - _request_ctx_stack.top.jwt_location = jwt_location + g._jwt_extended_jwt_user = _load_user(jwt_header, jwt_data) + g._jwt_extended_jwt_header = jwt_header + g._jwt_extended_jwt = jwt_data + g._jwt_extended_jwt_location = jwt_location return jwt_header, jwt_data diff --git a/tests/test_config.py b/tests/test_config.py index 343d1e66..514d619e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,12 +1,18 @@ +import json +from datetime import date from datetime import timedelta import pytest from dateutil.relativedelta import relativedelta +from flask import __version__ as flask_version from flask import Flask -from flask.json import JSONEncoder from flask_jwt_extended import JWTManager from flask_jwt_extended.config import config +from flask_jwt_extended.internal_utils import JSONEncoder + + +flask_version_tuple = tuple(map(int, flask_version.split("."))) @pytest.fixture(scope="function") @@ -65,8 +71,6 @@ def test_default_configs(app): assert config.identity_claim_key == "sub" - assert config.json_encoder is app.json_encoder - assert config.error_msg_key == "msg" @@ -112,11 +116,6 @@ def test_override_configs(app, delta_func): app.config["JWT_ERROR_MESSAGE_KEY"] = "message" - class CustomJSONEncoder(JSONEncoder): - pass - - app.json_encoder = CustomJSONEncoder - with app.test_request_context(): assert config.token_location == ["cookies", "query_string", "json"] assert config.jwt_in_query_string is True @@ -162,11 +161,29 @@ class CustomJSONEncoder(JSONEncoder): assert config.identity_claim_key == "foo" - assert config.json_encoder is CustomJSONEncoder - assert config.error_msg_key == "message" +@pytest.mark.skipif( + flask_version_tuple >= (2, 2, 0), reason="Only applies to Flask <= 2.2.0" +) +def test_config_json_encoder_flask21(app): + with app.test_request_context(): + assert config.json_encoder == app.json_encoder + dump = json.dumps({"d": date(2022, 8, 12)}, cls=config.json_encoder) + assert dump == '{"d": "Fri, 12 Aug 2022 00:00:00 GMT"}' + + +@pytest.mark.skipif( + flask_version_tuple < (2, 2, 0), reason="Only applies to Flask > 2.2.0" +) +def test_config_json_encoder_flask(app): + with app.test_request_context(): + assert config.json_encoder == JSONEncoder + dump = json.dumps({"d": date(2022, 8, 12)}, cls=config.json_encoder) + assert dump == '{"d": "Fri, 12 Aug 2022 00:00:00 GMT"}' + + def test_tokens_never_expire(app): app.config["JWT_ACCESS_TOKEN_EXPIRES"] = False app.config["JWT_REFRESH_TOKEN_EXPIRES"] = False diff --git a/tox.ini b/tox.ini index 32c97fc8..50c26c36 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = py37,py38,py39,py310,pypy3.9,mypy,coverage,style,docs +envlist = py{37,38,39,310}-{flask21,flask},pypy3.9,mypy,coverage,style,docs [testenv] commands = @@ -13,6 +13,8 @@ deps = pytest cryptography python-dateutil + flask21: Flask>=2.1,<2.2 + flask: Flask>=2.2 [testenv:mypy] commands =