Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix EventContext to use the correct store
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Dec 18, 2019
1 parent 10e03ca commit b7ae670
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 41 deletions.
3 changes: 2 additions & 1 deletion synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.storage = hs.get_storage()

self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache)
Expand All @@ -79,7 +80,7 @@ def __init__(self, hs):

@defer.inlineCallbacks
def check_from_context(self, room_version, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
Expand Down
20 changes: 11 additions & 9 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def state_group(self) -> Optional[int]:
return self._state_group

@defer.inlineCallbacks
def get_current_state_ids(self, store):
def get_current_state_ids(self, storage):
"""
Gets the room state map, including this event - ie, the state in ``state_group``
Expand All @@ -234,11 +234,11 @@ def get_current_state_ids(self, store):
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")

yield self._ensure_fetched(store)
yield self._ensure_fetched(storage)
return self._current_state_ids

@defer.inlineCallbacks
def get_prev_state_ids(self, store):
def get_prev_state_ids(self, storage):
"""
Gets the room state map, excluding this event.
Expand All @@ -250,7 +250,7 @@ def get_prev_state_ids(self, store):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
yield self._ensure_fetched(store)
yield self._ensure_fetched(storage)
return self._prev_state_ids

def get_cached_current_state_ids(self):
Expand All @@ -270,7 +270,7 @@ def get_cached_current_state_ids(self):

return self._current_state_ids

def _ensure_fetched(self, store):
def _ensure_fetched(self, storage):
return defer.succeed(None)


Expand Down Expand Up @@ -300,23 +300,25 @@ class _AsyncEventContextImpl(EventContext):
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)

def _ensure_fetched(self, store):
def _ensure_fetched(self, storage):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
self._fill_out_state, storage
)

return make_deferred_yieldable(self._fetching_state_deferred)

@defer.inlineCallbacks
def _fill_out_state(self, store):
def _fill_out_state(self, storage):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return

self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
self._current_state_ids = yield storage.state.get_state_ids_for_group(
self.state_group
)
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)

Expand Down
3 changes: 2 additions & 1 deletion synapse/events/third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, hs):
self.third_party_rules = None

self.store = hs.get_datastore()
self.storage = hs.get_storage()

module = None
config = None
Expand All @@ -53,7 +54,7 @@ def check_event_allowed(self, event, context):
if self.third_party_rules is None:
return True

prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)

# Retrieve the state events from the database.
state_events = {}
Expand Down
5 changes: 4 additions & 1 deletion synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, hs):
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
Expand Down Expand Up @@ -134,7 +135,9 @@ def maybe_kick_guest_users(self, event, context=None):
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = yield context.get_current_state_ids(
self.storage
)
current_state = yield self.store.get_events(
list(current_state_ids.values())
)
Expand Down
14 changes: 7 additions & 7 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ async def _process_received_pdu(
# changing their profile info.
newly_joined = True

prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_ids = await context.get_prev_state_ids(self.storage)

prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
Expand Down Expand Up @@ -1418,7 +1418,7 @@ def on_send_join_request(self, origin, pdu):
user = UserID.from_string(event.state_key)
yield self.user_joined_room(user, event.room_id)

prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)

state_ids = list(prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids)
Expand Down Expand Up @@ -1927,7 +1927,7 @@ def _prep_event(
context = yield self.state_handler.compute_event_context(event, old_state=state)

if not auth_events:
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
Expand Down Expand Up @@ -2336,12 +2336,12 @@ def _update_context_for_auth_events(self, event, context, auth_events):
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
}

current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = yield context.get_current_state_ids(self.storage)
current_state_ids = dict(current_state_ids)

current_state_ids.update(state_updates)

prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
prev_state_ids = dict(prev_state_ids)

prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
Expand Down Expand Up @@ -2625,7 +2625,7 @@ def add_display_name_to_third_party_invite(
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
Expand Down Expand Up @@ -2673,7 +2673,7 @@ def _check_signature(self, event, context):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]

prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))

invite_event = None
Expand Down
10 changes: 5 additions & 5 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def create_event(
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = (
yield self.store.get_event(prev_event_id, allow_none=True)
Expand Down Expand Up @@ -665,7 +665,7 @@ def deduplicate_state_event(self, event, context):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return
Expand Down Expand Up @@ -914,7 +914,7 @@ def persist_and_notify_client_event(
def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender

current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = yield context.get_current_state_ids(self.storage)

state_to_include_ids = [
e_id
Expand Down Expand Up @@ -967,7 +967,7 @@ def is_inviter_member_event(e):
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")

prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
Expand All @@ -989,7 +989,7 @@ def is_inviter_member_event(e):
event.internal_metadata.recheck_redaction = False

if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _upgrade_room(self, requester, old_room_id, new_version):
requester, tombstone_event, tombstone_context
)

old_room_state = yield tombstone_context.get_current_state_ids(self.store)
old_room_state = yield tombstone_context.get_current_state_ids(self.storage)

# update any aliases
yield self._move_aliases_to_new_room(
Expand Down
5 changes: 3 additions & 2 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, hs):
"""
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
Expand Down Expand Up @@ -193,7 +194,7 @@ def _local_membership_update(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)

prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)

prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)

Expand Down Expand Up @@ -601,7 +602,7 @@ def send_membership_event(self, requester, event, context, ratelimit=True):
if prev_event is not None:
return

prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(prev_state_ids)
Expand Down
6 changes: 4 additions & 2 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class BulkPushRuleEvaluator(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.auth = hs.get_auth()

self.room_push_rule_cache_metrics = register_cache(
Expand Down Expand Up @@ -116,7 +117,7 @@ def _get_rules_for_room(self, room_id):

@defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context):
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.storage)
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
Expand Down Expand Up @@ -239,6 +240,7 @@ def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metri
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.room_push_rule_cache_metrics = room_push_rule_cache_metrics

self.linearizer = Linearizer(name="rules_for_room")
Expand Down Expand Up @@ -304,7 +306,7 @@ def get_rules(self, event, context):

push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = yield context.get_current_state_ids(self.storage)
push_rules_delta_state_cache_metric.inc_misses()

push_rules_state_size_counter.inc(len(current_state_ids))
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/data_stores/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def bulk_get_push_rules_for_room(self, event, context):
# To do this we set the state_group to a new object as object() != object()
state_group = object()

current_state_ids = yield context.get_current_state_ids(self)
current_state_ids = yield context.get_current_state_ids(self.hs.get_storage())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/data_stores/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def get_joined_users_from_context(self, event, context):
# To do this we set the state_group to a new object as object() != object()
state_group = object()

current_state_ids = yield context.get_current_state_ids(self)
current_state_ids = yield context.get_current_state_ids(self.hs.get_storage())
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
Expand Down
Loading

0 comments on commit b7ae670

Please sign in to comment.