Skip to content

Commit

Permalink
feat(socialaccount): provider+view in render_authentication_error()
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Dec 15, 2023
1 parent 95eb678 commit 051f3be
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 54 deletions.
17 changes: 14 additions & 3 deletions allauth/socialaccount/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def pre_social_login(self, request, sociallogin):
"""
pass

def authentication_error(
def on_authentication_error(
self,
request,
provider_id,
provider,
error=None,
exception=None,
extra_context=None,
Expand All @@ -69,7 +69,18 @@ def authentication_error(
You can use this hook to intervene, e.g. redirect to an
educational flow by raising an ImmediateHttpResponse.
"""
pass
if hasattr(self, "authentication_error"):
warnings.warn(
"adapter.authentication_error() is deprecated, use adapter.on_authentication_error()"
)

self.authentication_error(
request,
provider.id,
error=error,
exception=exception,
extra_context=extra_context,
)

def new_user(self, request, sociallogin):
"""
Expand Down
8 changes: 4 additions & 4 deletions allauth/socialaccount/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,17 @@ def _login_social_account(request, sociallogin):

def render_authentication_error(
request,
provider_id,
provider,
error=AuthError.UNKNOWN,
exception=None,
extra_context=None,
):
try:
if extra_context is None:
extra_context = {}
get_adapter().authentication_error(
get_adapter().on_authentication_error(
request,
provider_id,
provider,
error=error,
exception=exception,
extra_context=extra_context,
Expand All @@ -137,7 +137,7 @@ def render_authentication_error(
return HttpResponseRedirect(reverse("socialaccount_login_cancelled"))
context = {
"auth_error": {
"provider": provider_id,
"provider": provider,
"code": error,
"exception": exception,
}
Expand Down
21 changes: 8 additions & 13 deletions allauth/socialaccount/providers/draugiem/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,22 @@ def login(request):

@csrf_exempt
def callback(request):
adapter = get_adapter()
provider = adapter.get_provider(request, DraugiemProvider.id)

if "dr_auth_status" not in request.GET:
return render_authentication_error(
request, DraugiemProvider.id, error=AuthError.UNKNOWN
)
return render_authentication_error(request, provider, error=AuthError.UNKNOWN)

if request.GET["dr_auth_status"] != "ok":
return render_authentication_error(
request, DraugiemProvider.id, error=AuthError.DENIED
)
return render_authentication_error(request, provider, error=AuthError.DENIED)

if "dr_auth_code" not in request.GET:
return render_authentication_error(
request, DraugiemProvider.id, error=AuthError.UNKNOWN
)
return render_authentication_error(request, provider, error=AuthError.UNKNOWN)

ret = None
auth_exception = None
try:
app = get_adapter().get_app(request, DraugiemProvider.id)
app = provider.app
login = draugiem_complete_login(request, app, request.GET["dr_auth_code"])
login.state = SocialLogin.unstash_state(request)

Expand All @@ -67,9 +64,7 @@ def callback(request):
auth_exception = e

if not ret:
ret = render_authentication_error(
request, DraugiemProvider.id, exception=auth_exception
)
ret = render_authentication_error(request, provider, exception=auth_exception)

return ret

Expand Down
9 changes: 4 additions & 5 deletions allauth/socialaccount/providers/facebook/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ def complete_login(self, request, app, access_token, **kwargs):

class LoginByTokenView(View):
def dispatch(self, request):
self.adapter = get_adapter()
self.provider = self.adapter.get_provider(request, FacebookProvider.id)
try:
return super().dispatch(request)
except (
requests.RequestException,
forms.ValidationError,
PermissionDenied,
) as exc:
return render_authentication_error(
request, FacebookProvider.id, exception=exc
)
return render_authentication_error(request, self.provider, exception=exc)

def get(self, request):
# If we leave out get().get() it will return a response with a 405, but
Expand All @@ -103,8 +103,7 @@ def post(self, request):
if not form.is_valid():
raise forms.ValidationError()

adapter = get_adapter()
provider = adapter.get_provider(request, FacebookProvider.id)
provider = self.provider
login_options = provider.get_fb_login_options(request)
app = provider.app
access_token = form.cleaned_data["access_token"]
Expand Down
19 changes: 9 additions & 10 deletions allauth/socialaccount/providers/oauth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def login(self, request, *args, **kwargs):
return client.get_redirect(auth_url, auth_params)
except OAuthError as e:
logger.error("OAuth authentication error", exc_info=True)
return render_authentication_error(
request, self.adapter.provider_id, exception=e
)
return render_authentication_error(request, provider, exception=e)


class OAuthCallbackView(OAuthView):
Expand All @@ -93,21 +91,24 @@ def dispatch(self, request):
View to handle final steps of OAuth based authentication where the user
gets redirected back to from the service provider
"""
provider = self.adapter.get_provider()
login_done_url = reverse(self.adapter.provider_id + "_callback")
client = self._get_client(request, login_done_url)
if not client.is_valid():
if "denied" in request.GET:
error = AuthError.CANCELLED
else:
error = AuthError.UNKNOWN
extra_context = dict(oauth_client=client)
return render_authentication_error(
request,
self.adapter.provider_id,
provider,
error=error,
extra_context=extra_context,
extra_context={
"oauth_client": client,
"callback_view": self,
},
)
app = self.adapter.get_provider().app
app = provider.app
try:
access_token = client.get_access_token()
token = SocialToken(
Expand All @@ -123,6 +124,4 @@ def dispatch(self, request):
login.state = SocialLogin.unstash_state(request)
return complete_social_login(request, login)
except OAuthError as e:
return render_authentication_error(
request, self.adapter.provider_id, exception=e
)
return render_authentication_error(request, provider, exception=e)
16 changes: 10 additions & 6 deletions allauth/socialaccount/providers/oauth2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,12 @@ def login(self, request, *args, **kwargs):
try:
return HttpResponseRedirect(client.get_redirect_url(auth_url, auth_params))
except OAuth2Error as e:
return render_authentication_error(request, provider.id, exception=e)
return render_authentication_error(request, provider, exception=e)


class OAuth2CallbackView(OAuth2View):
def dispatch(self, request, *args, **kwargs):
provider = self.adapter.get_provider()
if "error" in request.GET or "code" not in request.GET:
# Distinguish cancel from error
auth_error = request.GET.get("error", None)
Expand All @@ -138,9 +139,14 @@ def dispatch(self, request, *args, **kwargs):
else:
error = AuthError.UNKNOWN
return render_authentication_error(
request, self.adapter.provider_id, error=error
request,
provider,
error=error,
extra_context={
"callback_view": self,
},
)
app = self.adapter.get_provider().app
app = provider.app
client = self.get_client(self.request, app)

try:
Expand All @@ -166,6 +172,4 @@ def dispatch(self, request, *args, **kwargs):
RequestException,
ProviderException,
) as e:
return render_authentication_error(
request, self.adapter.provider_id, exception=e
)
return render_authentication_error(request, provider, exception=e)
17 changes: 10 additions & 7 deletions allauth/socialaccount/providers/openid/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def _openid_consumer(request, provider, endpoint):
class OpenIDLoginView(View):
template_name = "openid/login.html"
form_class = LoginForm
provider = OpenIDProvider
provider_class = OpenIDProvider

def dispatch(self, request, *args, **kwargs):
self.provider = self.provider_class(request)
return super().dispatch(request, *args, **kwargs)

def get(self, request):
form = self.get_form()
Expand All @@ -44,7 +48,7 @@ def get(self, request):
return self.perform_openid_auth(form)
except (UnicodeDecodeError, DiscoveryFailure) as e:
# UnicodeDecodeError: necaris/python3-openid#1
return render_authentication_error(request, self.provider.id, exception=e)
return render_authentication_error(request, self.provider, exception=e)

def post(self, request):
form = self.get_form()
Expand Down Expand Up @@ -85,7 +89,7 @@ def perform_openid_auth(self, form):
return form

request = self.request
provider = self.provider(request)
provider = self.provider
endpoint = form.cleaned_data["openid"]
client = self.get_client(provider, endpoint)
realm = self.get_realm(provider)
Expand All @@ -99,7 +103,6 @@ def perform_openid_auth(self, form):
ax = FetchRequest()
for name in AXAttributes:
ax.add(AttrInfo(name, required=True))
provider = OpenIDProvider(request)
server_settings = provider.get_server_settings(request.GET.get("openid"))
extra_attributes = server_settings.get("extra_attributes", [])
for _, name, required in extra_attributes:
Expand All @@ -121,10 +124,10 @@ def perform_openid_auth(self, form):


class OpenIDCallbackView(View):
provider = OpenIDProvider
provider_class = OpenIDProvider

def get(self, request):
provider = self.provider(request)
provider = self.provider = self.provider_class(request)
endpoint = request.GET.get("openid.op_endpoint", "")
client = self.get_client(provider, endpoint)
response = self.get_openid_response(client)
Expand All @@ -146,7 +149,7 @@ def complete_login(self, login):
return complete_social_login(self.request, login)

def render_error(self, error):
return render_authentication_error(self.request, self.provider.id, error=error)
return render_authentication_error(self.request, self.provider, error=error)

def get_client(self, provider, endpoint):
return _openid_consumer(self.request, provider, endpoint)
Expand Down
6 changes: 3 additions & 3 deletions allauth/socialaccount/providers/saml/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def dispatch(self, request, organization_slug):
)
return render_authentication_error(
request,
provider.id,
provider,
extra_context={
"saml_errors": errors,
"saml_last_error_reason": auth.get_last_error_reason(),
},
)
if not auth.is_authenticated():
return render_authentication_error(
request, provider.id, error=AuthError.CANCELLED
request, provider, error=AuthError.CANCELLED
)

relay_state = decode_relay_state(request.POST.get("RelayState"))
Expand Down Expand Up @@ -106,7 +106,7 @@ def dispatch(self, request, organization_slug):
serialized_login = acs_session.store.get("login")
if not serialized_login:
logger.error("Unable to finish login, SAML ACS session missing")
return render_authentication_error(request, provider.id)
return render_authentication_error(request, provider)
acs_session.delete()
login = SocialLogin.deserialize(serialized_login)
return complete_social_login(request, login)
Expand Down
4 changes: 2 additions & 2 deletions allauth/socialaccount/providers/steam/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class SteamOpenIDLoginView(OpenIDLoginView):
provider = SteamOpenIDProvider
provider_class = SteamOpenIDProvider

def get_form(self):
items = dict(list(self.request.GET.items()) + list(self.request.POST.items()))
Expand All @@ -38,7 +38,7 @@ def get_callback_url(self):


class SteamOpenIDCallbackView(OpenIDCallbackView):
provider = SteamOpenIDProvider
provider_class = SteamOpenIDProvider


steam_login = SteamOpenIDLoginView.as_view()
Expand Down
2 changes: 1 addition & 1 deletion allauth/socialaccount/providers/telegram/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def post(self, request):
auth_date_validity = provider.get_auth_date_validity()
if hash != expected_hash or time.time() - auth_date > auth_date_validity:
return render_authentication_error(
request, provider_id=provider.id, extra_context={"response": data}
request, provider=provider, extra_context={"response": data}
)

login = provider.sociallogin_from_response(request, data)
Expand Down

0 comments on commit 051f3be

Please sign in to comment.