From 27aafa1e11cf837807786ed8138d3447c1637b14 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Jul 2020 15:12:04 -0400 Subject: [PATCH] Update the auth providers to be async. --- docs/password_auth_providers.md | 16 ++++++------- synapse/handlers/ui_auth/checkers.py | 35 ++++++++++++++-------------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index 5d9ae670413d..efc7a3119a26 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -68,15 +68,15 @@ methods. > will be called for each login attempt where the login type matches one > of the keys returned by `get_supported_login_types`. > -> It is passed the (possibly UNqualified) `user` provided by the client, +> It is passed the (possibly unqualified) `user` provided by the client, > the login type, and a dictionary of login secrets passed by the > client. > -> The method should return a Twisted `Deferred` object, which resolves +> The method should return an `Awaitable` object, which resolves > to the canonical `@localpart:domain` user id if authentication is > successful, and `None` if not. > -> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in +> Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in > which case the second field is a callback which will be called with > the result from the `/login` call (including `access_token`, > `device_id`, etc.) @@ -88,11 +88,11 @@ methods. > passed the medium (ex. "email"), an address (ex. > "") and the user's password. > -> The method should return a Twisted `Deferred` object, which resolves +> The method should return an `Awaitable` object, which resolves > to a `str` containing the user's (canonical) User ID if > authentication was successful, and `None` if not. > -> As with `check_auth`, the `Deferred` may alternatively resolve to a +> As with `check_auth`, the `Awaitable` may alternatively resolve to a > `(user_id, callback)` tuple. `someprovider.check_password`(*user_id*, *password*) @@ -102,11 +102,11 @@ methods. > providers that just want to provide a mechanism for validating > `m.login.password` logins. > -> Iif implemented, it will be called to check logins with an +> If implemented, it will be called to check logins with an > `m.login.password` login type. It is passed a qualified > `@localpart:domain` user id, and the password provided by the user. > -> The method should return a Twisted `Deferred` object, which resolves +> The method should return an `Awaitable` object, which resolves > to `True` if authentication is successful, and `False` if not. `someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*) @@ -116,5 +116,5 @@ methods. > any: access tokens are occasionally created without an associated > device ID), and the (now deactivated) access token. > -> It may return a Twisted `Deferred` object; the logout request will +> It may return an `Awaitable` object; the logout request will > wait for the deferred to complete but the result is ignored. diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index a140e9391ea9..a011e9fe2980 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -14,10 +14,10 @@ # limitations under the License. import logging +from typing import Any from canonicaljson import json -from twisted.internet import defer from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType @@ -33,25 +33,25 @@ class UserInteractiveAuthChecker: def __init__(self, hs): pass - def is_enabled(self): + def is_enabled(self) -> bool: """Check if the configuration of the homeserver allows this checker to work Returns: - bool: True if this login type is enabled. + True if this login type is enabled. """ - def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: """Given the authentication dict from the client, attempt to check this step Args: - authdict (dict): authentication dictionary from the client - clientip (str): The IP address of the client. + authdict: authentication dictionary from the client + clientip: The IP address of the client. Raises: SynapseError if authentication failed Returns: - Deferred: the result of authentication (to pass back to the client?) + The result of authentication (to pass back to the client?) """ raise NotImplementedError() @@ -62,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - def check_auth(self, authdict, clientip): - return defer.succeed(True) + async def check_auth(self, authdict, clientip): + return True class TermsAuthChecker(UserInteractiveAuthChecker): @@ -72,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - def check_auth(self, authdict, clientip): - return defer.succeed(True) + async def check_auth(self, authdict, clientip): + return True class RecaptchaAuthChecker(UserInteractiveAuthChecker): @@ -89,8 +89,7 @@ def __init__(self, hs): def is_enabled(self): return self._enabled - @defer.inlineCallbacks - def check_auth(self, authdict, clientip): + async def check_auth(self, authdict, clientip): try: user_response = authdict["response"] except KeyError: @@ -107,7 +106,7 @@ def check_auth(self, authdict, clientip): # TODO: get this from the homeserver rather than creating a new one for # each request try: - resp_body = yield self._http_client.post_urlencoded_get_json( + resp_body = await self._http_client.post_urlencoded_get_json( self._url, args={ "secret": self._secret, @@ -219,8 +218,8 @@ def is_enabled(self): ThreepidBehaviour.LOCAL, ) - def check_auth(self, authdict, clientip): - return defer.ensureDeferred(self._check_threepid("email", authdict)) + async def check_auth(self, authdict, clientip): + return await self._check_threepid("email", authdict) class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): @@ -233,8 +232,8 @@ def __init__(self, hs): def is_enabled(self): return bool(self.hs.config.account_threepid_delegate_msisdn) - def check_auth(self, authdict, clientip): - return defer.ensureDeferred(self._check_threepid("msisdn", authdict)) + async def check_auth(self, authdict, clientip): + return await self._check_threepid("msisdn", authdict) INTERACTIVE_AUTH_CHECKERS = [