From e7bf940253c16da6509852cb023b0b7ce215e3ef Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 30 Oct 2019 16:07:15 +0000 Subject: [PATCH] Create new EventContexts rather than updating existing ones --- synapse/events/snapshot.py | 21 -------------------- synapse/handlers/federation.py | 36 ++++++++++++++++++---------------- tests/test_federation.py | 4 +++- 3 files changed, 22 insertions(+), 39 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6421c65f6657..6e379059562a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -222,27 +222,6 @@ def _fill_out_state(self, store): else: self._prev_state_ids = self._current_state_ids - @defer.inlineCallbacks - def update_state( - self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids - ): - """Replace the state in the context - """ - - # We need to make sure we wait for any ongoing fetching of state - # to complete so that the updated state doesn't get clobbered - if self._fetching_state_deferred: - yield make_deferred_yieldable(self._fetching_state_deferred) - - self.state_group = state_group - self._prev_state_ids = prev_state_ids - self.prev_group = prev_group - self._current_state_ids = current_state_ids - self.delta_ids = delta_ids - - # We need to ensure that that we've marked as having fetched the state - self._fetching_state_deferred = defer.succeed(None) - def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 08276fdebf46..2caddd552efa 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( make_deferred_yieldable, @@ -1842,12 +1843,7 @@ def _prep_event(self, origin, event, state, auth_events, backfilled): if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c - try: - yield self.do_auth(origin, event, context, auth_events=auth_events) - except AuthError as e: - logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) - - context.rejected = RejectedReason.AUTH_ERROR + context = yield self.do_auth(origin, event, context, auth_events=auth_events) if not context.rejected: yield self._check_for_soft_fail(event, state, backfilled) @@ -2016,12 +2012,12 @@ def do_auth(self, origin, event, context, auth_events): Also NB that this function adds entries to it. Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context object """ room_version = yield self.store.get_room_version(event.room_id) try: - yield self._update_auth_events_and_context_for_auth( + context = yield self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2039,7 +2035,9 @@ def do_auth(self, origin, event, context, auth_events): event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: logger.warn("Failed auth resolution for %r because %s", event, e) - raise e + context.rejected = RejectedReason.AUTH_ERROR + + return context @defer.inlineCallbacks def _update_auth_events_and_context_for_auth( @@ -2063,7 +2061,7 @@ def _update_auth_events_and_context_for_auth( auth_events (dict[(str, str)->synapse.events.EventBase]): Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context """ event_auth_events = set(event.auth_event_ids()) @@ -2102,7 +2100,7 @@ def _update_auth_events_and_context_for_auth( # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e) - return + return context seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] @@ -2143,7 +2141,7 @@ def _update_auth_events_and_context_for_auth( if event.internal_metadata.is_outlier(): logger.info("Skipping auth_event fetch for outlier") - return + return context # FIXME: Assumes we have and stored all the state for all the # prev_events @@ -2152,7 +2150,7 @@ def _update_auth_events_and_context_for_auth( ) if not different_auth: - return + return context logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2199,10 +2197,12 @@ def _update_auth_events_and_context_for_auth( auth_events.update(new_state) - yield self._update_context_for_auth_events( + context = yield self._update_context_for_auth_events( event, context, auth_events, event_key ) + return context + @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, @@ -2211,14 +2211,16 @@ def _update_context_for_auth_events(self, event, context, auth_events, event_key Args: event (Event): The event we're handling the context for - context (synapse.events.snapshot.EventContext): event context - to be updated + context (synapse.events.snapshot.EventContext): initial event context auth_events (dict[(str, str)->str]): Events to update in the event context. event_key ((str, str)): (type, state_key) for the current event. this will not be included in the current_state in the context. + + Returns: + Deferred[EventContext]: new event context """ state_updates = { k: a.event_id for k, a in iteritems(auth_events) if k != event_key @@ -2243,7 +2245,7 @@ def _update_context_for_auth_events(self, event, context, auth_events, event_key current_state_ids=current_state_ids, ) - yield context.update_state( + return EventContext.with_state( state_group=state_group, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, diff --git a/tests/test_federation.py b/tests/test_federation.py index d1acb16f3007..7d82b584664d 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -59,7 +59,9 @@ def setUp(self): ) self.handler = self.homeserver.get_handlers().federation_handler - self.handler.do_auth = lambda *a, **b: succeed(True) + self.handler.do_auth = lambda origin, event, context, auth_events: succeed( + context + ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus