diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6e379059562a..7494920296bb 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -48,20 +48,6 @@ class EventContext: The current state map excluding the current event. None if outlier or we haven't fetched the state from DB yet. (type, state_key) -> event_id - - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet - - _event_type (str): The type of the event the context is associated with. - Only set when state has not been fetched yet. - - _event_state_key (str|None): The state_key of the event the context is - associated with. Only set when state has not been fetched yet. - - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. - Only set when state has not been fetched yet. """ state_group = attr.ib(default=None) @@ -70,13 +56,8 @@ class EventContext: delta_ids = attr.ib(default=None) app_service = attr.ib(default=None) - _current_state_ids = attr.ib(default=None) _prev_state_ids = attr.ib(default=None) - _prev_state_id = attr.ib(default=None) - - _event_type = attr.ib(default=None) - _event_state_key = attr.ib(default=None) - _fetching_state_deferred = attr.ib(default=None) + _current_state_ids = attr.ib(default=None) @staticmethod def with_state( @@ -86,7 +67,6 @@ def with_state( current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, state_group=state_group, - fetching_state_deferred=defer.succeed(None), prev_group=prev_group, delta_ids=delta_ids, ) @@ -136,7 +116,7 @@ def deserialize(store, input): Returns: EventContext """ - context = EventContext( + context = AsyncEventContext( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. prev_state_id=input["prev_state_id"], @@ -164,14 +144,7 @@ def get_current_state_ids(self, store): Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._current_state_ids @defer.inlineCallbacks @@ -184,14 +157,7 @@ def get_prev_state_ids(self, store): Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._prev_state_ids def get_cached_current_state_ids(self): @@ -205,6 +171,44 @@ def get_cached_current_state_ids(self): return self._current_state_ids + def _ensure_fetched(self, store): + return defer.succeed(None) + + +@attr.s(slots=True) +class AsyncEventContext(EventContext): + """ + A version of EventContext which fetches _current_state_ids and _prev_state_ids + from the database on demand. + + Attributes: + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _event_type (str): The type of the event the context is associated with. + + _event_state_key (str): The state_key of the event the context is + associated with. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + """ + + _prev_state_id = attr.ib(default=None) + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) + + def _ensure_fetched(self, store): + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store + ) + + return make_deferred_yieldable(self._fetching_state_deferred) + @defer.inlineCallbacks def _fill_out_state(self, store): """Called to populate the _current_state_ids and _prev_state_ids