diff --git a/supertokens_python/post_init_callbacks.py b/supertokens_python/post_init_callbacks.py index 982e1b74..ddbf0afa 100644 --- a/supertokens_python/post_init_callbacks.py +++ b/supertokens_python/post_init_callbacks.py @@ -29,3 +29,7 @@ def run_post_init_callbacks() -> None: for cb in PostSTInitCallbacks.post_init_callbacks: cb() PostSTInitCallbacks.post_init_callbacks = [] + + @staticmethod + def reset(): + PostSTInitCallbacks.post_init_callbacks = [] diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 09cc9d24..bd08f167 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -5,6 +5,7 @@ from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig +from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.recipe import ( accountlinking, dashboard, @@ -202,6 +203,7 @@ async def default_func( # pylint: disable=unused-argument def st_reset(): + PostSTInitCallbacks.reset() override_logging.reset_override_logs() reset_override_params() ProcessState.get_instance().reset() diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index b29cf1b4..03a84328 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -1,3 +1,4 @@ +import json from typing import Any, Callable, Coroutine, Dict, List, Set, Union import time @@ -151,5 +152,15 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return data.to_json() if isinstance(data, IsVerifiedSCV): return "IsVerifiedSCV" + if is_jsonable(data): + return data - return data + return "Some custom object" + + +def is_jsonable(x: Any) -> bool: + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index df95a83a..52fc98d9 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -1,6 +1,7 @@ from typing import Callable, List, Union from typing import Dict, Any, Optional from supertokens_python.asyncio import list_users_by_account_info +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.recipe.accountlinking import ( RecipeLevelUser, ShouldAutomaticallyLink, @@ -35,6 +36,7 @@ from supertokens_python.recipe.session.claims import PrimitiveClaim from supertokens_python.recipe.thirdparty.interfaces import ( SignInUpNotAllowed, + SignInUpOkResult, SignInUpPostNoEmailGivenByProviderResponse, SignInUpPostOkResult, ) @@ -309,6 +311,54 @@ def func1( return func1 + if eval_str.startswith("thirdparty.init.override.functions"): + if "setIsVerifiedInSignInUp" in eval_str: + from supertokens_python.recipe.thirdparty.interfaces import ( + RecipeInterface as ThirdPartyRecipeInterface, + ) + + def custom_override( + original_implementation: ThirdPartyRecipeInterface, + ) -> ThirdPartyRecipeInterface: + og_sign_in_up = original_implementation.sign_in_up + + async def sign_in_up( + third_party_id: str, + third_party_user_id: str, + email: str, + is_verified: bool, + oauth_tokens: Dict[str, Any], + raw_user_info_from_provider: RawUserInfoFromProvider, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + SignInUpOkResult, + SignInUpNotAllowed, + LinkingToSessionUserFailedError, + ]: + user_context[ + "isVerified" + ] = is_verified # this information comes from the third party provider + return await og_sign_in_up( + third_party_id, + third_party_user_id, + email, + is_verified, + oauth_tokens, + raw_user_info_from_provider, + session, + should_try_linking_with_session_user, + tenant_id, + user_context, + ) + + original_implementation.sign_in_up = sign_in_up + return original_implementation + + return custom_override + elif eval_str.startswith("passwordless.init.smsDelivery.service.sendSms"): def func2(