diff --git a/changelog.d/16441.misc b/changelog.d/16441.misc new file mode 100644 index 000000000000..32264a62b20a --- /dev/null +++ b/changelog.d/16441.misc @@ -0,0 +1 @@ +Improve rate limiting logic. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 90343c230604..1b50495af182 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -382,8 +382,10 @@ async def _local_membership_update( and persist a new event for the new membership change. Args: - requester: - target: + requester: User requesting the membership change, i.e. the sender of the + desired membership event. + target: Use whose membership should change, i.e. the state_key of the + desired membership event. room_id: membership: @@ -415,7 +417,6 @@ async def _local_membership_update( Returns: Tuple of event ID and stream ordering position """ - user_id = target.to_string() if content is None: @@ -475,21 +476,6 @@ async def _local_membership_update( (EventTypes.Member, user_id), None ) - if event.membership == Membership.JOIN: - newly_joined = True - if prev_member_event_id: - prev_member_event = await self.store.get_event( - prev_member_event_id - ) - newly_joined = prev_member_event.membership != Membership.JOIN - - # Only rate-limit if the user actually joined the room, otherwise we'll end - # up blocking profile updates. - if newly_joined and ratelimit: - await self._join_rate_limiter_local.ratelimit(requester) - await self._join_rate_per_room_limiter.ratelimit( - requester, key=room_id, update=False - ) with opentracing.start_active_span("handle_new_client_event"): result_event = ( await self.event_creation_handler.handle_new_client_event( @@ -618,6 +604,25 @@ async def update_membership( Raises: ShadowBanError if a shadow-banned requester attempts to send an invite. """ + if ratelimit: + if action == Membership.JOIN: + # Only rate-limit if the user isn't already joined to the room, otherwise + # we'll end up blocking profile updates. + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + requester.user.to_string(), + room_id, + ) + if current_membership != Membership.JOIN: + await self._join_rate_limiter_local.ratelimit(requester) + await self._join_rate_per_room_limiter.ratelimit( + requester, key=room_id, update=False + ) + elif action == Membership.INVITE: + await self.ratelimit_invite(requester, room_id, target.to_string()) + if action == Membership.INVITE and requester.shadow_banned: # We randomly sleep a bit just to annoy the requester. await self.clock.sleep(random.randint(1, 10)) @@ -794,8 +799,6 @@ async def update_membership_locked( if effective_membership_state == Membership.INVITE: target_id = target.to_string() - if ratelimit: - await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 7627823d3fd3..aaa4f3bba049 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -1444,6 +1444,30 @@ def test_join_local_ratelimit(self) -> None: room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS ) + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_attempts_local_ratelimit(self) -> None: + """Tests that unsuccessful joins that end up being denied are rate-limited.""" + # Create 4 rooms + room_ids = [ + self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4) + ] + # Pre-emptively ban the user who will attempt to join. + joiner_user_id = self.register_user("joiner", "secret") + for room_id in room_ids: + self.helper.ban(room_id, self.user_id, joiner_user_id) + + # Now make a new user try to join some of them. + # The user can make 3 requests, each of which should be denied. + for room_id in room_ids[0:3]: + self.helper.join(room_id, joiner_user_id, expect_code=HTTPStatus.FORBIDDEN) + + # The fourth attempt should be rate limited. + self.helper.join( + room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS + ) + @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} )