Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert auth handler to async/await #7261

Merged
merged 5 commits into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7261.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert auth handler to async/await.
173 changes: 81 additions & 92 deletions synapse/handlers/auth.py

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,10 @@ def delete_device(self, user_id, device_id):
else:
raise

yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
yield defer.ensureDeferred(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
)

yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
Expand Down Expand Up @@ -391,8 +393,10 @@ def delete_devices(self, user_id, device_ids):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
yield defer.ensureDeferred(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
Expand Down
28 changes: 21 additions & 7 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def register_user(
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield self._auth_handler.hash(password)
password_hash = yield defer.ensureDeferred(
self._auth_handler.hash(password)
)

if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token)
Expand Down Expand Up @@ -540,8 +542,10 @@ def register_device(self, user_id, device_id, initial_display_name, is_guest=Fal
user_id, ["guest = true"]
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms
access_token = yield defer.ensureDeferred(
self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
)

return (device_id, access_token)
Expand Down Expand Up @@ -617,8 +621,13 @@ def _register_email_threepid(self, user_id, threepid, token):
logger.info("Can't add incomplete 3pid")
return

yield self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
yield defer.ensureDeferred(
self._auth_handler.add_threepid(
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
)

# And we add an email pusher for them by default, but only
Expand Down Expand Up @@ -670,6 +679,11 @@ def _register_msisdn_threepid(self, user_id, threepid):
return None
raise

yield self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
yield defer.ensureDeferred(
self._auth_handler.add_threepid(
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
)
13 changes: 5 additions & 8 deletions synapse/handlers/set_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import logging
from typing import Optional

from twisted.internet import defer

from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester

Expand All @@ -34,8 +32,7 @@ def __init__(self, hs):
self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()

@defer.inlineCallbacks
def set_password(
async def set_password(
self,
user_id: str,
new_password: str,
Expand All @@ -46,10 +43,10 @@ def set_password(
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)

self._password_policy_handler.validate_password(new_password)
password_hash = yield self._auth_handler.hash(new_password)
password_hash = await self._auth_handler.hash(new_password)

try:
yield self.store.user_set_password_hash(user_id, password_hash)
await self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
Expand All @@ -61,12 +58,12 @@ def set_password(
except_access_token_id = requester.access_token_id if requester else None

# First delete all of their other devices.
yield self._device_handler.delete_all_devices_for_user(
await self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id
)

# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
await self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id
)
6 changes: 4 additions & 2 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def check_user_exists(self, user_id):
Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return self._auth_handler.check_user_exists(user_id)
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))

@defer.inlineCallbacks
def register(self, localpart, displayname=None, emails=[]):
Expand Down Expand Up @@ -196,7 +196,9 @@ def invalidate_access_token(self, access_token):
yield self._hs.get_device_handler().delete_device(user_id, device_id)
else:
# no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token)
yield defer.ensureDeferred(
self._auth_handler.delete_access_token(access_token)
)

def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection
Expand Down
64 changes: 40 additions & 24 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_get_user_by_req_user_valid_token(self):
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)

def test_get_user_by_req_user_bad_token(self):
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_get_user_by_req_appservice_valid_token(self):
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)

@defer.inlineCallbacks
Expand All @@ -125,7 +125,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self):
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)

def test_get_user_by_req_appservice_valid_token_bad_ip(self):
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
Expand Down Expand Up @@ -225,7 +225,9 @@ def test_get_user_from_macaroon(self):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)

Expand All @@ -250,7 +252,9 @@ def test_get_guest_user_from_macaroon(self):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()

user_info = yield self.auth.get_user_by_access_token(serialized)
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
Expand All @@ -260,10 +264,13 @@ def test_get_guest_user_from_macaroon(self):
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
self.store.add_access_token_to_user = Mock()
self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
self.store.get_device = Mock(return_value=defer.succeed(None))

token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
token = yield defer.ensureDeferred(
self.hs.handlers.auth_handler.get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None
Expand All @@ -286,7 +293,9 @@ def get_user(tok):
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield defer.ensureDeferred(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest)

Expand All @@ -301,7 +310,9 @@ def get_user(tok):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()

with self.assertRaises(InvalidClientCredentialsError) as cm:
yield self.auth.get_user_by_req(request, allow_guest=True)
yield defer.ensureDeferred(
self.auth.get_user_by_req(request, allow_guest=True)
)

self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
Expand All @@ -316,7 +327,7 @@ def test_blocking_mau(self):
small_number_of_users = 1

# Ensure no error thrown
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())

self.hs.config.limit_usage_by_mau = True

Expand All @@ -325,7 +336,7 @@ def test_blocking_mau(self):
)

with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
Expand All @@ -334,7 +345,7 @@ def test_blocking_mau(self):
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())

@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
Expand All @@ -343,15 +354,19 @@ def test_blocking_mau__depending_on_user_type(self):

self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed
yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
yield defer.ensureDeferred(
self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Bots not allowed
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
yield defer.ensureDeferred(
self.auth.check_auth_blocking(user_type=UserTypes.BOT)
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Real users not allowed
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())

@defer.inlineCallbacks
def test_reserved_threepid(self):
Expand All @@ -362,21 +377,22 @@ def test_reserved_threepid(self):
unknown_threepid = {"medium": "email", "address": "[email protected]"}
self.hs.config.mau_limits_reserved_threepids = [threepid]

yield self.store.register_user(user_id="user1", password_hash=None)
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())

with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking(threepid=unknown_threepid)
yield defer.ensureDeferred(
self.auth.check_auth_blocking(threepid=unknown_threepid)
)

yield self.auth.check_auth_blocking(threepid=threepid)
yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))

@defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
Expand All @@ -393,7 +409,7 @@ def test_hs_disabled_no_server_notices_user(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
Expand All @@ -404,4 +420,4 @@ def test_server_notices_mxid_special_cased(self):
user = "@user:server"
self.hs.config.server_notices_mxid = user
self.hs.config.hs_disabled_message = "Reason for being disabled"
yield self.auth.check_auth_blocking(user)
yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
Loading