diff --git a/changelog.d/13270.bugfix b/changelog.d/13270.bugfix new file mode 100644 index 000000000000..d023b25eea27 --- /dev/null +++ b/changelog.d/13270.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.40 where a user invited to a restricted room would be briefly unable to join. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 98c203ada03b..4caf6cbdee9d 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -120,7 +120,7 @@ async def build( The signed and hashed event. """ if auth_event_ids is None: - state_ids = await self._state.get_current_state_ids( + state_ids = await self._state.compute_state_after_events( self.room_id, prev_event_ids ) auth_event_ids = self._event_auth_handler.compute_auth_events( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 90e0b2160021..a5b9ac904e91 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -755,14 +755,14 @@ async def update_membership_locked( latest_event_ids = await self.store.get_prev_events_for_room(room_id) - current_state_ids = await self.state_handler.get_current_state_ids( - room_id, latest_event_ids=latest_event_ids + state_before_join = await self.state_handler.compute_state_after_events( + room_id, latest_event_ids ) # TODO: Refactor into dictionary of explicitly allowed transitions # between old and new state, with specific error messages for some # transitions and generic otherwise - old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) + old_state_id = state_before_join.get((EventTypes.Member, target.to_string())) if old_state_id: old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None @@ -813,11 +813,11 @@ async def update_membership_locked( if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = await self._is_host_in_room(current_state_ids) + is_host_in_room = await self._is_host_in_room(state_before_join) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = await self._can_guest_join(current_state_ids) + guest_can_join = await self._can_guest_join(state_before_join) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -855,7 +855,12 @@ async def update_membership_locked( # Check if a remote join should be performed. remote_join, remote_room_hosts = await self._should_perform_remote_join( - target.to_string(), room_id, remote_room_hosts, content, is_host_in_room + target.to_string(), + room_id, + remote_room_hosts, + content, + is_host_in_room, + state_before_join, ) if remote_join: if ratelimit: @@ -995,6 +1000,7 @@ async def _should_perform_remote_join( remote_room_hosts: List[str], content: JsonDict, is_host_in_room: bool, + state_before_join: StateMap[str], ) -> Tuple[bool, List[str]]: """ Check whether the server should do a remote join (as opposed to a local @@ -1014,6 +1020,8 @@ async def _should_perform_remote_join( content: The content to use as the event body of the join. This may be modified. is_host_in_room: True if the host is in the room. + state_before_join: The state before the join event (i.e. the resolution of + the states after its parent events). Returns: A tuple of: @@ -1030,20 +1038,17 @@ async def _should_perform_remote_join( # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) - current_state_ids = await self._storage_controllers.state.get_current_state_ids( - room_id - ) # If restricted join rules are not being used, a local join can always # be used. if not await self.event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version + state_before_join, room_version ): return False, [] # If the user is invited to the room or already joined, the join # event can always be issued locally. - prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None) prev_member_event = None if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) @@ -1058,10 +1063,10 @@ async def _should_perform_remote_join( # # If not, generate a new list of remote hosts based on which # can issue invites. - event_map = await self.store.get_events(current_state_ids.values()) + event_map = await self.store.get_events(state_before_join.values()) current_state = { state_key: event_map[event_id] - for state_key, event_id in current_state_ids.items() + for state_key, event_id in state_before_join.items() } allowed_servers = get_servers_from_users( get_users_which_can_issue_invite(current_state) @@ -1075,7 +1080,7 @@ async def _should_perform_remote_join( # Ensure the member should be allowed access via membership in a room. await self.event_auth_handler.check_restricted_join_rules( - current_state_ids, room_version, user_id, prev_member_event + state_before_join, room_version, user_id, prev_member_event ) # If this is going to be a local join, additional information must @@ -1085,7 +1090,7 @@ async def _should_perform_remote_join( EventContentFields.AUTHORISING_USER ] = await self.event_auth_handler.get_user_which_could_invite( room_id, - current_state_ids, + state_before_join, ) return False, [] diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 9f0a36652c25..9155c46c51b7 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -151,22 +151,27 @@ def __init__(self, hs: "HomeServer"): ReplicationUpdateCurrentStateRestServlet.make_client(hs) ) - async def get_current_state_ids( + async def compute_state_after_events( self, room_id: str, - latest_event_ids: Collection[str], + event_ids: Collection[str], ) -> StateMap[str]: - """Get the current state, or the state at a set of events, for a room + """Fetch the state after each of the given event IDs. Resolve them and return. + + This is typically used where `event_ids` is a collection of forward extremities + in a room, intended to become the `prev_events` of a new event E. If so, the + return value of this function represents the state before E. Args: - room_id: - latest_event_ids: The forward extremities to resolve. + room_id: the room_id containing the given events. + event_ids: the events whose state should be fetched and resolved. Returns: - the state dict, mapping from (event_type, state_key) -> event_id + the state dict (a mapping from (event_type, state_key) -> event_id) which + holds the resolution of the states after the given event IDs. """ - logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + logger.debug("calling resolve_state_groups from compute_state_after_events") + ret = await self.resolve_state_groups_for_events(room_id, event_ids) return await ret.get_state(self._state_storage_controller, StateFilter.all()) async def get_current_users_in_room(