From 92103cb2c8b8bff6b522a7bfa8a3a776b4821b11 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 14 Jun 2022 10:51:15 +0200 Subject: [PATCH] Decouple `synapse.api.auth_blocking.AuthBlocking` from `synapse.api.auth.Auth`. (#13021) --- changelog.d/13021.misc | 1 + synapse/api/auth.py | 14 ------- synapse/handlers/auth.py | 5 ++- synapse/handlers/message.py | 4 +- synapse/handlers/register.py | 3 +- synapse/handlers/room.py | 3 +- synapse/handlers/sync.py | 4 +- synapse/server.py | 5 +++ .../resource_limits_server_notices.py | 4 +- tests/api/test_auth.py | 42 ++++++++++++------- tests/handlers/test_auth.py | 2 +- tests/handlers/test_register.py | 2 +- tests/handlers/test_sync.py | 2 +- .../test_resource_limits_server_notices.py | 22 ++++++---- 14 files changed, 63 insertions(+), 50 deletions(-) create mode 100644 changelog.d/13021.misc diff --git a/changelog.d/13021.misc b/changelog.d/13021.misc new file mode 100644 index 000000000000..84c41cdf59f0 --- /dev/null +++ b/changelog.d/13021.misc @@ -0,0 +1 @@ +Decouple `synapse.api.auth_blocking.AuthBlocking` from `synapse.api.auth.Auth`. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5a410f805a56..c037ccb98441 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -20,7 +20,6 @@ from twisted.web.server import Request from synapse import event_auth -from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.errors import ( AuthError, @@ -67,8 +66,6 @@ def __init__(self, hs: "HomeServer"): 10000, "token_cache" ) - self._auth_blocking = AuthBlocking(self.hs) - self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips self._macaroon_secret_key = hs.config.key.macaroon_secret_key @@ -711,14 +708,3 @@ async def check_user_in_room_or_world_readable( "User %s not in room %s, and room previews are disabled" % (user_id, room_id), ) - - async def check_auth_blocking( - self, - user_id: Optional[str] = None, - threepid: Optional[dict] = None, - user_type: Optional[str] = None, - requester: Optional[Requester] = None, - ) -> None: - await self._auth_blocking.check_auth_blocking( - user_id=user_id, threepid=threepid, user_type=user_type, requester=requester - ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6e15028b0a38..60d13040a2f3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -199,6 +199,7 @@ class AuthHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.auth = hs.get_auth() + self.auth_blocking = hs.get_auth_blocking() self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: @@ -985,7 +986,7 @@ async def create_access_token_for_user_id( not is_appservice_ghost or self.hs.config.appservice.track_appservice_user_ips ): - await self.auth.check_auth_blocking(user_id) + await self.auth_blocking.check_auth_blocking(user_id) access_token = self.generate_access_token(target_user_id_obj) await self.store.add_access_token_to_user( @@ -1439,7 +1440,7 @@ async def validate_short_term_login_token( except Exception: raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) - await self.auth.check_auth_blocking(res.user_id) + await self.auth_blocking.check_auth_blocking(res.user_id) return res async def delete_access_token(self, access_token: str) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ad87c4178272..189f52fe5a7a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -444,7 +444,7 @@ async def _expire_event(self, event_id: str) -> None: class EventCreationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.auth = hs.get_auth() + self.auth_blocking = hs.get_auth_blocking() self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -605,7 +605,7 @@ async def create_event( Returns: Tuple of created event, Context """ - await self.auth.check_auth_blocking(requester=requester) + await self.auth_blocking.check_auth_blocking(requester=requester) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version_id = event_dict["content"]["room_version"] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 338204287f66..c77d18172283 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -91,6 +91,7 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() + self.auth_blocking = hs.get_auth_blocking() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() @@ -276,7 +277,7 @@ async def register_user( # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: - await self.auth.check_auth_blocking(threepid=threepid) + await self.auth_blocking.check_auth_blocking(threepid=threepid) if localpart is not None: await self.check_username(localpart, guest_access_token=guest_access_token) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 42aae4a215a8..75c0be8c361a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -110,6 +110,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() + self.auth_blocking = hs.get_auth_blocking() self.clock = hs.get_clock() self.hs = hs self.spam_checker = hs.get_spam_checker() @@ -706,7 +707,7 @@ async def create_room( """ user_id = requester.user.to_string() - await self.auth.check_auth_blocking(requester=requester) + await self.auth_blocking.check_auth_blocking(requester=requester) if ( self._server_notices_mxid is not None diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b4ead79f9733..af19c513be1d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -237,7 +237,7 @@ def __init__(self, hs: "HomeServer"): self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.state = hs.get_state_handler() - self.auth = hs.get_auth() + self.auth_blocking = hs.get_auth_blocking() self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -280,7 +280,7 @@ async def wait_for_sync_for_user( # not been exceeded (if not part of the group by this point, almost certain # auth_blocking will occur) user_id = sync_config.user.to_string() - await self.auth.check_auth_blocking(requester=requester) + await self.auth_blocking.check_auth_blocking(requester=requester) res = await self.response_cache.wrap( sync_config.request_key, diff --git a/synapse/server.py b/synapse/server.py index a66ec228dbab..a6a415aeabfc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -29,6 +29,7 @@ from twisted.web.resource import Resource from synapse.api.auth import Auth +from synapse.api.auth_blocking import AuthBlocking from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter from synapse.appservice.api import ApplicationServiceApi @@ -379,6 +380,10 @@ def get_notifier(self) -> Notifier: def get_auth(self) -> Auth: return Auth(self) + @cache_in_self + def get_auth_blocking(self) -> AuthBlocking: + return AuthBlocking(self) + @cache_in_self def get_http_client_context_factory(self) -> IPolicyForHTTPS: if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use: diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 686302077861..3134cd2d3d6c 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -37,7 +37,7 @@ def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() - self._auth = hs.get_auth() + self._auth_blocking = hs.get_auth_blocking() self._config = hs.config self._resouce_limited = False self._account_data_handler = hs.get_account_data_handler() @@ -91,7 +91,7 @@ async def maybe_send_server_notice_to_user(self, user_id: str) -> None: # Normally should always pass in user_id to check_auth_blocking # if you have it, but in this case are checking what would happen # to other users if they were to arrive. - await self._auth.check_auth_blocking() + await self._auth_blocking.check_auth_blocking() except ResourceLimitError as e: limit_msg = e.msg limit_type = e.limit_type diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index bc75ddd3e954..54af9089e92c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.auth import Auth +from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import UserTypes from synapse.api.errors import ( AuthError, @@ -49,7 +50,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' - self.auth_blocking = self.auth._auth_blocking + self.auth_blocking = AuthBlocking(hs) self.test_user = "@foo:bar" self.test_token = b"_test_token_" @@ -362,20 +363,22 @@ def test_blocking_mau(self): small_number_of_users = 1 # Ensure no error thrown - self.get_success(self.auth.check_auth_blocking()) + self.get_success(self.auth_blocking.check_auth_blocking()) self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = simple_async_mock(lots_of_users) - e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + e = self.get_failure( + self.auth_blocking.check_auth_blocking(), ResourceLimitError + ) self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) - self.get_success(self.auth.check_auth_blocking()) + self.get_success(self.auth_blocking.check_auth_blocking()) def test_blocking_mau__depending_on_user_type(self): self.auth_blocking._max_mau_value = 50 @@ -383,15 +386,18 @@ def test_blocking_mau__depending_on_user_type(self): self.store.get_monthly_active_count = simple_async_mock(100) # Support users allowed - self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)) + self.get_success( + self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) + ) self.store.get_monthly_active_count = simple_async_mock(100) # Bots not allowed self.get_failure( - self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError + self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), + ResourceLimitError, ) self.store.get_monthly_active_count = simple_async_mock(100) # Real users not allowed - self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): self.auth_blocking._max_mau_value = 50 @@ -419,7 +425,7 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): app_service=appservice, authenticated_entity="@appservice:server", ) - self.get_success(self.auth.check_auth_blocking(requester=requester)) + self.get_success(self.auth_blocking.check_auth_blocking(requester=requester)) def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): self.auth_blocking._max_mau_value = 50 @@ -448,7 +454,8 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): authenticated_entity="@appservice:server", ) self.get_failure( - self.auth.check_auth_blocking(requester=requester), ResourceLimitError + self.auth_blocking.check_auth_blocking(requester=requester), + ResourceLimitError, ) def test_reserved_threepid(self): @@ -459,18 +466,21 @@ def test_reserved_threepid(self): unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.auth_blocking._mau_limits_reserved_threepids = [threepid] - self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) self.get_failure( - self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError + self.auth_blocking.check_auth_blocking(threepid=unknown_threepid), + ResourceLimitError, ) - self.get_success(self.auth.check_auth_blocking(threepid=threepid)) + self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid)) def test_hs_disabled(self): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" - e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + e = self.get_failure( + self.auth_blocking.check_auth_blocking(), ResourceLimitError + ) self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) @@ -485,7 +495,9 @@ def test_hs_disabled_no_server_notices_user(self): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" - e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + e = self.get_failure( + self.auth_blocking.check_auth_blocking(), ResourceLimitError + ) self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) @@ -495,4 +507,4 @@ def test_server_notices_mxid_special_cased(self): user = "@user:server" self.auth_blocking._server_notices_mxid = user self.auth_blocking._hs_disabled_message = "Reason for being disabled" - self.get_success(self.auth.check_auth_blocking(user)) + self.get_success(self.auth_blocking.check_auth_blocking(user)) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 67a7829769b6..7106799d44e3 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -38,7 +38,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # MAU tests # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' - self.auth_blocking = hs.get_auth()._auth_blocking + self.auth_blocking = hs.get_auth_blocking() self.auth_blocking._max_mau_value = 50 self.small_number_of_users = 1 diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index b6ba19c73903..23f35d5bf53d 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -699,7 +699,7 @@ async def get_or_create_user( """ if localpart is None: raise SynapseError(400, "Request must include user id") - await self.hs.get_auth().check_auth_blocking() + await self.hs.get_auth_blocking().check_auth_blocking() need_register = True try: diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index db3302a4c78d..ecc7cc6461ee 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -45,7 +45,7 @@ def prepare(self, reactor, clock, hs: HomeServer): # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' - self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking = self.hs.get_auth_blocking() def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 07e29788e5be..e07ae78fc4c0 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -96,7 +96,9 @@ def test_maybe_send_server_notice_to_user_flag_off(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) + self._rlsn._auth_blocking.check_auth_blocking = Mock( + return_value=make_awaitable(None) + ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) @@ -112,7 +114,7 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): """ Test when user has blocked notice, but notice ought to be there (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError(403, "foo"), ) @@ -132,7 +134,7 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice(self): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError(403, "foo"), ) @@ -145,7 +147,9 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) + self._rlsn._auth_blocking.check_auth_blocking = Mock( + return_value=make_awaitable(None) + ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -156,7 +160,9 @@ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) + self._rlsn._auth_blocking.check_auth_blocking = Mock( + return_value=make_awaitable(None) + ) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=make_awaitable(None) ) @@ -170,7 +176,7 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER @@ -185,7 +191,7 @@ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED @@ -202,7 +208,7 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER