From 051f3beb3cae6af5d6fb6d641a77d9553ab79265 Mon Sep 17 00:00:00 2001 From: Raymond Penners Date: Fri, 15 Dec 2023 14:44:16 +0100 Subject: [PATCH] feat(socialaccount): provider+view in render_authentication_error() --- allauth/socialaccount/adapter.py | 17 ++++++++++++--- allauth/socialaccount/helpers.py | 8 +++---- .../socialaccount/providers/draugiem/views.py | 21 +++++++------------ .../socialaccount/providers/facebook/views.py | 9 ++++---- .../socialaccount/providers/oauth/views.py | 19 ++++++++--------- .../socialaccount/providers/oauth2/views.py | 16 ++++++++------ .../socialaccount/providers/openid/views.py | 17 ++++++++------- allauth/socialaccount/providers/saml/views.py | 6 +++--- .../socialaccount/providers/steam/views.py | 4 ++-- .../socialaccount/providers/telegram/views.py | 2 +- 10 files changed, 65 insertions(+), 54 deletions(-) diff --git a/allauth/socialaccount/adapter.py b/allauth/socialaccount/adapter.py index fa2d4ee647..f3d6e849b7 100644 --- a/allauth/socialaccount/adapter.py +++ b/allauth/socialaccount/adapter.py @@ -54,10 +54,10 @@ def pre_social_login(self, request, sociallogin): """ pass - def authentication_error( + def on_authentication_error( self, request, - provider_id, + provider, error=None, exception=None, extra_context=None, @@ -69,7 +69,18 @@ def authentication_error( You can use this hook to intervene, e.g. redirect to an educational flow by raising an ImmediateHttpResponse. """ - pass + if hasattr(self, "authentication_error"): + warnings.warn( + "adapter.authentication_error() is deprecated, use adapter.on_authentication_error()" + ) + + self.authentication_error( + request, + provider.id, + error=error, + exception=exception, + extra_context=extra_context, + ) def new_user(self, request, sociallogin): """ diff --git a/allauth/socialaccount/helpers.py b/allauth/socialaccount/helpers.py index eec6233270..a2749b2005 100644 --- a/allauth/socialaccount/helpers.py +++ b/allauth/socialaccount/helpers.py @@ -116,7 +116,7 @@ def _login_social_account(request, sociallogin): def render_authentication_error( request, - provider_id, + provider, error=AuthError.UNKNOWN, exception=None, extra_context=None, @@ -124,9 +124,9 @@ def render_authentication_error( try: if extra_context is None: extra_context = {} - get_adapter().authentication_error( + get_adapter().on_authentication_error( request, - provider_id, + provider, error=error, exception=exception, extra_context=extra_context, @@ -137,7 +137,7 @@ def render_authentication_error( return HttpResponseRedirect(reverse("socialaccount_login_cancelled")) context = { "auth_error": { - "provider": provider_id, + "provider": provider, "code": error, "exception": exception, } diff --git a/allauth/socialaccount/providers/draugiem/views.py b/allauth/socialaccount/providers/draugiem/views.py index fd4b9e5b8e..e2f85cf52a 100644 --- a/allauth/socialaccount/providers/draugiem/views.py +++ b/allauth/socialaccount/providers/draugiem/views.py @@ -40,25 +40,22 @@ def login(request): @csrf_exempt def callback(request): + adapter = get_adapter() + provider = adapter.get_provider(request, DraugiemProvider.id) + if "dr_auth_status" not in request.GET: - return render_authentication_error( - request, DraugiemProvider.id, error=AuthError.UNKNOWN - ) + return render_authentication_error(request, provider, error=AuthError.UNKNOWN) if request.GET["dr_auth_status"] != "ok": - return render_authentication_error( - request, DraugiemProvider.id, error=AuthError.DENIED - ) + return render_authentication_error(request, provider, error=AuthError.DENIED) if "dr_auth_code" not in request.GET: - return render_authentication_error( - request, DraugiemProvider.id, error=AuthError.UNKNOWN - ) + return render_authentication_error(request, provider, error=AuthError.UNKNOWN) ret = None auth_exception = None try: - app = get_adapter().get_app(request, DraugiemProvider.id) + app = provider.app login = draugiem_complete_login(request, app, request.GET["dr_auth_code"]) login.state = SocialLogin.unstash_state(request) @@ -67,9 +64,7 @@ def callback(request): auth_exception = e if not ret: - ret = render_authentication_error( - request, DraugiemProvider.id, exception=auth_exception - ) + ret = render_authentication_error(request, provider, exception=auth_exception) return ret diff --git a/allauth/socialaccount/providers/facebook/views.py b/allauth/socialaccount/providers/facebook/views.py index 45f8deaaea..f2375336fb 100644 --- a/allauth/socialaccount/providers/facebook/views.py +++ b/allauth/socialaccount/providers/facebook/views.py @@ -82,6 +82,8 @@ def complete_login(self, request, app, access_token, **kwargs): class LoginByTokenView(View): def dispatch(self, request): + self.adapter = get_adapter() + self.provider = self.adapter.get_provider(request, FacebookProvider.id) try: return super().dispatch(request) except ( @@ -89,9 +91,7 @@ def dispatch(self, request): forms.ValidationError, PermissionDenied, ) as exc: - return render_authentication_error( - request, FacebookProvider.id, exception=exc - ) + return render_authentication_error(request, self.provider, exception=exc) def get(self, request): # If we leave out get().get() it will return a response with a 405, but @@ -103,8 +103,7 @@ def post(self, request): if not form.is_valid(): raise forms.ValidationError() - adapter = get_adapter() - provider = adapter.get_provider(request, FacebookProvider.id) + provider = self.provider login_options = provider.get_fb_login_options(request) app = provider.app access_token = form.cleaned_data["access_token"] diff --git a/allauth/socialaccount/providers/oauth/views.py b/allauth/socialaccount/providers/oauth/views.py index eee0583b51..86fa88b248 100644 --- a/allauth/socialaccount/providers/oauth/views.py +++ b/allauth/socialaccount/providers/oauth/views.py @@ -82,9 +82,7 @@ def login(self, request, *args, **kwargs): return client.get_redirect(auth_url, auth_params) except OAuthError as e: logger.error("OAuth authentication error", exc_info=True) - return render_authentication_error( - request, self.adapter.provider_id, exception=e - ) + return render_authentication_error(request, provider, exception=e) class OAuthCallbackView(OAuthView): @@ -93,6 +91,7 @@ def dispatch(self, request): View to handle final steps of OAuth based authentication where the user gets redirected back to from the service provider """ + provider = self.adapter.get_provider() login_done_url = reverse(self.adapter.provider_id + "_callback") client = self._get_client(request, login_done_url) if not client.is_valid(): @@ -100,14 +99,16 @@ def dispatch(self, request): error = AuthError.CANCELLED else: error = AuthError.UNKNOWN - extra_context = dict(oauth_client=client) return render_authentication_error( request, - self.adapter.provider_id, + provider, error=error, - extra_context=extra_context, + extra_context={ + "oauth_client": client, + "callback_view": self, + }, ) - app = self.adapter.get_provider().app + app = provider.app try: access_token = client.get_access_token() token = SocialToken( @@ -123,6 +124,4 @@ def dispatch(self, request): login.state = SocialLogin.unstash_state(request) return complete_social_login(request, login) except OAuthError as e: - return render_authentication_error( - request, self.adapter.provider_id, exception=e - ) + return render_authentication_error(request, provider, exception=e) diff --git a/allauth/socialaccount/providers/oauth2/views.py b/allauth/socialaccount/providers/oauth2/views.py index 1cbfd3598d..ac78c0ae7c 100644 --- a/allauth/socialaccount/providers/oauth2/views.py +++ b/allauth/socialaccount/providers/oauth2/views.py @@ -125,11 +125,12 @@ def login(self, request, *args, **kwargs): try: return HttpResponseRedirect(client.get_redirect_url(auth_url, auth_params)) except OAuth2Error as e: - return render_authentication_error(request, provider.id, exception=e) + return render_authentication_error(request, provider, exception=e) class OAuth2CallbackView(OAuth2View): def dispatch(self, request, *args, **kwargs): + provider = self.adapter.get_provider() if "error" in request.GET or "code" not in request.GET: # Distinguish cancel from error auth_error = request.GET.get("error", None) @@ -138,9 +139,14 @@ def dispatch(self, request, *args, **kwargs): else: error = AuthError.UNKNOWN return render_authentication_error( - request, self.adapter.provider_id, error=error + request, + provider, + error=error, + extra_context={ + "callback_view": self, + }, ) - app = self.adapter.get_provider().app + app = provider.app client = self.get_client(self.request, app) try: @@ -166,6 +172,4 @@ def dispatch(self, request, *args, **kwargs): RequestException, ProviderException, ) as e: - return render_authentication_error( - request, self.adapter.provider_id, exception=e - ) + return render_authentication_error(request, provider, exception=e) diff --git a/allauth/socialaccount/providers/openid/views.py b/allauth/socialaccount/providers/openid/views.py index 8b4fa2e659..f9db6f1a17 100644 --- a/allauth/socialaccount/providers/openid/views.py +++ b/allauth/socialaccount/providers/openid/views.py @@ -33,7 +33,11 @@ def _openid_consumer(request, provider, endpoint): class OpenIDLoginView(View): template_name = "openid/login.html" form_class = LoginForm - provider = OpenIDProvider + provider_class = OpenIDProvider + + def dispatch(self, request, *args, **kwargs): + self.provider = self.provider_class(request) + return super().dispatch(request, *args, **kwargs) def get(self, request): form = self.get_form() @@ -44,7 +48,7 @@ def get(self, request): return self.perform_openid_auth(form) except (UnicodeDecodeError, DiscoveryFailure) as e: # UnicodeDecodeError: necaris/python3-openid#1 - return render_authentication_error(request, self.provider.id, exception=e) + return render_authentication_error(request, self.provider, exception=e) def post(self, request): form = self.get_form() @@ -85,7 +89,7 @@ def perform_openid_auth(self, form): return form request = self.request - provider = self.provider(request) + provider = self.provider endpoint = form.cleaned_data["openid"] client = self.get_client(provider, endpoint) realm = self.get_realm(provider) @@ -99,7 +103,6 @@ def perform_openid_auth(self, form): ax = FetchRequest() for name in AXAttributes: ax.add(AttrInfo(name, required=True)) - provider = OpenIDProvider(request) server_settings = provider.get_server_settings(request.GET.get("openid")) extra_attributes = server_settings.get("extra_attributes", []) for _, name, required in extra_attributes: @@ -121,10 +124,10 @@ def perform_openid_auth(self, form): class OpenIDCallbackView(View): - provider = OpenIDProvider + provider_class = OpenIDProvider def get(self, request): - provider = self.provider(request) + provider = self.provider = self.provider_class(request) endpoint = request.GET.get("openid.op_endpoint", "") client = self.get_client(provider, endpoint) response = self.get_openid_response(client) @@ -146,7 +149,7 @@ def complete_login(self, login): return complete_social_login(self.request, login) def render_error(self, error): - return render_authentication_error(self.request, self.provider.id, error=error) + return render_authentication_error(self.request, self.provider, error=error) def get_client(self, provider, endpoint): return _openid_consumer(self.request, provider, endpoint) diff --git a/allauth/socialaccount/providers/saml/views.py b/allauth/socialaccount/providers/saml/views.py index 80596d6d1b..7adb359498 100644 --- a/allauth/socialaccount/providers/saml/views.py +++ b/allauth/socialaccount/providers/saml/views.py @@ -68,7 +68,7 @@ def dispatch(self, request, organization_slug): ) return render_authentication_error( request, - provider.id, + provider, extra_context={ "saml_errors": errors, "saml_last_error_reason": auth.get_last_error_reason(), @@ -76,7 +76,7 @@ def dispatch(self, request, organization_slug): ) if not auth.is_authenticated(): return render_authentication_error( - request, provider.id, error=AuthError.CANCELLED + request, provider, error=AuthError.CANCELLED ) relay_state = decode_relay_state(request.POST.get("RelayState")) @@ -106,7 +106,7 @@ def dispatch(self, request, organization_slug): serialized_login = acs_session.store.get("login") if not serialized_login: logger.error("Unable to finish login, SAML ACS session missing") - return render_authentication_error(request, provider.id) + return render_authentication_error(request, provider) acs_session.delete() login = SocialLogin.deserialize(serialized_login) return complete_social_login(request, login) diff --git a/allauth/socialaccount/providers/steam/views.py b/allauth/socialaccount/providers/steam/views.py index 03b2a5ab18..2c3ed26d4e 100644 --- a/allauth/socialaccount/providers/steam/views.py +++ b/allauth/socialaccount/providers/steam/views.py @@ -26,7 +26,7 @@ class SteamOpenIDLoginView(OpenIDLoginView): - provider = SteamOpenIDProvider + provider_class = SteamOpenIDProvider def get_form(self): items = dict(list(self.request.GET.items()) + list(self.request.POST.items())) @@ -38,7 +38,7 @@ def get_callback_url(self): class SteamOpenIDCallbackView(OpenIDCallbackView): - provider = SteamOpenIDProvider + provider_class = SteamOpenIDProvider steam_login = SteamOpenIDLoginView.as_view() diff --git a/allauth/socialaccount/providers/telegram/views.py b/allauth/socialaccount/providers/telegram/views.py index 90fa6fae4e..2f1ab5ce35 100644 --- a/allauth/socialaccount/providers/telegram/views.py +++ b/allauth/socialaccount/providers/telegram/views.py @@ -65,7 +65,7 @@ def post(self, request): auth_date_validity = provider.get_auth_date_validity() if hash != expected_hash or time.time() - auth_date > auth_date_validity: return render_authentication_error( - request, provider_id=provider.id, extra_context={"response": data} + request, provider=provider, extra_context={"response": data} ) login = provider.sociallogin_from_response(request, data)