From b7ae67094d7d1c175f3710aa17c09f2fd25ea169 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Dec 2019 13:52:33 +0000 Subject: [PATCH] Fix EventContext to use the correct store --- synapse/api/auth.py | 3 ++- synapse/events/snapshot.py | 20 ++++++++++-------- synapse/events/third_party_rules.py | 3 ++- synapse/handlers/_base.py | 5 ++++- synapse/handlers/federation.py | 14 ++++++------- synapse/handlers/message.py | 10 ++++----- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 5 +++-- synapse/push/bulk_push_rule_evaluator.py | 6 ++++-- synapse/storage/data_stores/main/push_rule.py | 2 +- .../storage/data_stores/main/roommember.py | 2 +- tests/test_state.py | 21 ++++++++++--------- 12 files changed, 52 insertions(+), 41 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9fd52a8c7745..501978977a97 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -71,6 +71,7 @@ def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() + self.storage = hs.get_storage() self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) register_cache("cache", "token_cache", self.token_cache) @@ -79,7 +80,7 @@ def __init__(self, hs): @defer.inlineCallbacks def check_from_context(self, room_version, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 64e898f40c3b..8914a9439817 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -216,7 +216,7 @@ def state_group(self) -> Optional[int]: return self._state_group @defer.inlineCallbacks - def get_current_state_ids(self, store): + def get_current_state_ids(self, storage): """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -234,11 +234,11 @@ def get_current_state_ids(self, store): if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched(store) + yield self._ensure_fetched(storage) return self._current_state_ids @defer.inlineCallbacks - def get_prev_state_ids(self, store): + def get_prev_state_ids(self, storage): """ Gets the room state map, excluding this event. @@ -250,7 +250,7 @@ def get_prev_state_ids(self, store): Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched(store) + yield self._ensure_fetched(storage) return self._prev_state_ids def get_cached_current_state_ids(self): @@ -270,7 +270,7 @@ def get_cached_current_state_ids(self): return self._current_state_ids - def _ensure_fetched(self, store): + def _ensure_fetched(self, storage): return defer.succeed(None) @@ -300,23 +300,25 @@ class _AsyncEventContextImpl(EventContext): _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self, store): + def _ensure_fetched(self, storage): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store + self._fill_out_state, storage ) return make_deferred_yieldable(self._fetching_state_deferred) @defer.inlineCallbacks - def _fill_out_state(self, store): + def _fill_out_state(self, storage): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) + self._current_state_ids = yield storage.state.get_state_ids_for_group( + self.state_group + ) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 714a9b1579ab..42b1b992ce47 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -28,6 +28,7 @@ def __init__(self, hs): self.third_party_rules = None self.store = hs.get_datastore() + self.storage = hs.get_storage() module = None config = None @@ -53,7 +54,7 @@ def check_event_allowed(self, event, context): if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) # Retrieve the state events from the database. state_events = {} diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d15c6282fb3a..968dc84a1980 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -40,6 +40,7 @@ def __init__(self, hs): hs (synapse.server.HomeServer): """ self.store = hs.get_datastore() + self.storage = hs.get_storage() self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.state_handler = hs.get_state_handler() @@ -134,7 +135,9 @@ def maybe_kick_guest_users(self, event, context=None): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids( + self.storage + ) current_state = yield self.store.get_events( list(current_state_ids.values()) ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 60bb00fc6ab4..8504a2c8b3d3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -718,7 +718,7 @@ async def _process_received_pdu( # changing their profile info. newly_joined = True - prev_state_ids = await context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids(self.storage) prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: @@ -1418,7 +1418,7 @@ def on_send_join_request(self, origin, pdu): user = UserID.from_string(event.state_key) yield self.user_joined_room(user, event.room_id) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) @@ -1927,7 +1927,7 @@ def _prep_event( context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -2336,12 +2336,12 @@ def _update_context_for_auth_events(self, event, context, auth_events): k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) current_state_ids = dict(current_state_ids) current_state_ids.update(state_updates) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) prev_state_ids = dict(prev_state_ids) prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) @@ -2625,7 +2625,7 @@ def add_display_name_to_third_party_invite( event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( @@ -2673,7 +2673,7 @@ def _check_signature(self, event, context): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bf9add7fe2e8..b169602b5618 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -515,7 +515,7 @@ def create_event( # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( yield self.store.get_event(prev_event_id, allow_none=True) @@ -665,7 +665,7 @@ def deduplicate_state_event(self, event, context): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return @@ -914,7 +914,7 @@ def persist_and_notify_client_event( def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) state_to_include_ids = [ e_id @@ -967,7 +967,7 @@ def is_inviter_member_event(e): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -989,7 +989,7 @@ def is_inviter_member_event(e): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d3a1a7b4a6ba..ef5157fae55b 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -184,7 +184,7 @@ def _upgrade_room(self, requester, old_room_id, new_version): requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids(self.store) + old_room_state = yield tombstone_context.get_current_state_ids(self.storage) # update any aliases yield self._move_aliases_to_new_room( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7b7270fc61c4..35fe396cf980 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -50,6 +50,7 @@ def __init__(self, hs): """ self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config @@ -193,7 +194,7 @@ def _local_membership_update( requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -601,7 +602,7 @@ def send_membership_event(self, requester, event, context, ratelimit=True): if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) if event.membership == Membership.JOIN: if requester.is_guest: guest_can_join = yield self._can_guest_join(prev_state_ids) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 788178076032..b90398effc47 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -62,6 +62,7 @@ class BulkPushRuleEvaluator(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() self.auth = hs.get_auth() self.room_push_rule_cache_metrics = register_cache( @@ -116,7 +117,7 @@ def _get_rules_for_room(self, room_id): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and @@ -239,6 +240,7 @@ def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metri self.room_id = room_id self.is_mine_id = hs.is_mine_id self.store = hs.get_datastore() + self.storage = hs.get_storage() self.room_push_rule_cache_metrics = room_push_rule_cache_metrics self.linearizer = Linearizer(name="rules_for_room") @@ -304,7 +306,7 @@ def get_rules(self, event, context): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 5ba13aa973a0..ce6f6f712c97 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -244,7 +244,7 @@ def bulk_get_push_rules_for_room(self, event, context): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids(self.hs.get_storage()) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 92e3b9c512f7..a2583101afc1 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -477,7 +477,7 @@ def get_joined_users_from_context(self, event, context): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids(self.hs.get_storage()) result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/tests/test_state.py b/tests/test_state.py index 176535947adc..159d8083c8f4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -178,6 +178,7 @@ def setUp(self): hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_storage.return_value = storage + self.storage = storage self.state = StateHandler(hs) self.event_id = 0 @@ -419,10 +420,10 @@ def test_annotate_with_old_message(self): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -442,10 +443,10 @@ def test_annotate_with_old_state(self): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -479,7 +480,7 @@ def test_trivial_annotate_message(self): context = yield self.state.compute_event_context(event) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) self.assertEqual( set([e.event_id for e in old_state]), set(current_state_ids.values()) @@ -511,7 +512,7 @@ def test_trivial_annotate_state(self): context = yield self.state.compute_event_context(event) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.storage) self.assertEqual( set([e.event_id for e in old_state]), set(prev_state_ids.values()) @@ -552,7 +553,7 @@ def test_resolve_message_conflict(self): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) self.assertEqual(len(current_state_ids), 6) @@ -594,7 +595,7 @@ def test_resolve_state_conflict(self): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) self.assertEqual(len(current_state_ids), 6) @@ -649,7 +650,7 @@ def test_standard_depth_conflict(self): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -677,7 +678,7 @@ def test_standard_depth_conflict(self): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids(self.storage) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])