diff --git a/changelog.d/10624.misc b/changelog.d/10624.misc new file mode 100644 index 000000000000..9a765435dbe4 --- /dev/null +++ b/changelog.d/10624.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 529d025c39b5..de86918b7bc0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -285,12 +285,12 @@ async def on_receive_pdu( # - Fetching any missing prev events to fill in gaps in the graph # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): - prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen + if sent_to_us_directly: + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen - if missing_prevs: - if sent_to_us_directly: + if missing_prevs: # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("min_depth: %d", min_depth) @@ -351,106 +351,8 @@ async def on_receive_pdu( affected=pdu.event_id, ) - else: - # We don't have all of the prev_events for this event. - # - # In this case, we need to fall back to asking another server in the - # federation for the state at this event. That's ok provided we then - # resolve the state against other bits of the DAG before using it (which - # will ensure that you can't just take over a room by sending an event, - # withholding its prev_events, and declaring yourself to be an admin in - # the subsequent state request). - # - # Since we're pulling this event as a missing prev_event, then clearly - # this event is not going to become the only forward-extremity and we are - # guaranteed to resolve its state against our existing forward - # extremities, so that should be fine. - # - # XXX this really feels like it could/should be merged with the above, - # but there is an interaction with min_depth that I'm not really - # following. - logger.info( - "Event %s is missing prev_events %s: calculating state for a " - "backwards extremity", - event_id, - shortstr(missing_prevs), - ) - - # Calculate the state after each of the previous events, and - # resolve them to find the correct state at the current event. - event_map = {event_id: pdu} - try: - # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids( - room_id, seen - ) - - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours - - # Ask the remote server for the states we don't - # know about - for p in missing_prevs: - logger.info( - "Requesting state after missing prev_event %s", p - ) - - with nested_logging_context(p): - # note that if any of the missing prevs share missing state or - # auth events, the requests to fetch those events are deduped - # by the get_pdu_cache in federation_client. - remote_state = ( - await self._get_state_after_missing_prev_event( - origin, room_id, p - ) - ) - - remote_state_map = { - (x.type, x.state_key): x.event_id - for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x - - room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), - ) - - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] - except Exception: - logger.warning( - "Error attempting to resolve state at missing " - "prev_events", - exc_info=True, - ) - raise FederationError( - "ERROR", - 403, - "We can't get valid state history.", - affected=event_id, - ) + else: + state = await self._resolve_state_at_missing_prevs(origin, pdu) # A second round of checks for all events. Check that the event passes auth # based on `auth_events`, this allows us to assert that the event would @@ -1493,6 +1395,123 @@ async def get_event(event_id: str): event_infos, ) + async def _resolve_state_at_missing_prevs( + self, dest: str, event: EventBase + ) -> Optional[Iterable[EventBase]]: + """Calculate the state at an event with missing prev_events. + + This is used when we have pulled a batch of events from a remote server, and + still don't have all the prev_events. + + If we already have all the prev_events for `event`, this method does nothing. + + Otherwise, the missing prevs become new backwards extremities, and we fall back + to asking the remote server for the state after each missing `prev_event`, + and resolving across them. + + That's ok provided we then resolve the state against other bits of the DAG + before using it - in other words, that the received event `event` is not going + to become the only forwards_extremity in the room (which will ensure that you + can't just take over a room by sending an event, withholding its prev_events, + and declaring yourself to be an admin in the subsequent state request). + + In other words: we should only call this method if `event` has been *pulled* + as part of a batch of missing prev events, or similar. + + Params: + dest: the remote server to ask for state at the missing prevs. Typically, + this will be the server we got `event` from. + event: an event to check for missing prevs. + + Returns: + if we already had all the prev events, `None`. Otherwise, returns a list of + the events in the state at `event`. + """ + room_id = event.room_id + event_id = event.event_id + + prevs = set(event.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + return None + + logger.info( + "Event %s is missing prev_events %s: calculating state for a " + "backwards extremity", + event_id, + shortstr(missing_prevs), + ) + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. + event_map = {event_id: event} + try: + # Get the state of the events we know about + ours = await self.state_store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + state_maps: List[StateMap[str]] = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours + + # Ask the remote server for the states we don't + # know about + for p in missing_prevs: + logger.info("Requesting state after missing prev_event %s", p) + + with nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state = await self._get_state_after_missing_prev_event( + dest, room_id, p + ) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = await self.store.get_room_version_id(room_id) + state_map = await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = await self.store.get_events( + list(state_map.values()), + get_prev_content=False, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + event_map.update(evs) + + state = [event_map[e] for e in state_map.values()] + except Exception: + logger.warning( + "Error attempting to resolve state at missing prev_events", + exc_info=True, + ) + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=event_id, + ) + return state + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event