Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Factor out an AsyncEventContext #6298

Merged
merged 6 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6298.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor EventContext for clarity.
107 changes: 42 additions & 65 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ class EventContext:
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
(type, state_key) -> event_id. ``None`` for an outlier.

prev_state_events (?): XXX: is this ever set to anything other than
the empty list?

app_service: FIXME

_current_state_ids (dict[(str, str), str]|None):
Expand All @@ -51,36 +48,16 @@ class EventContext:
The current state map excluding the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id

_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet

_event_type (str): The type of the event the context is associated with.
Only set when state has not been fetched yet.

_event_state_key (str|None): The state_key of the event the context is
associated with. Only set when state has not been fetched yet.

_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
Only set when state has not been fetched yet.
"""

state_group = attr.ib(default=None)
rejected = attr.ib(default=False)
prev_group = attr.ib(default=None)
delta_ids = attr.ib(default=None)
prev_state_events = attr.ib(default=attr.Factory(list))
app_service = attr.ib(default=None)

_current_state_ids = attr.ib(default=None)
_prev_state_ids = attr.ib(default=None)
_prev_state_id = attr.ib(default=None)

_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
_current_state_ids = attr.ib(default=None)

@staticmethod
def with_state(
Expand All @@ -90,7 +67,6 @@ def with_state(
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
state_group=state_group,
fetching_state_deferred=defer.succeed(None),
prev_group=prev_group,
delta_ids=delta_ids,
)
Expand Down Expand Up @@ -125,7 +101,6 @@ def serialize(self, event, store):
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None,
}

Expand All @@ -141,7 +116,7 @@ def deserialize(store, input):
Returns:
EventContext
"""
context = EventContext(
context = _AsyncEventContextImpl(
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
prev_state_id=input["prev_state_id"],
Expand All @@ -151,7 +126,6 @@ def deserialize(store, input):
prev_group=input["prev_group"],
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"],
prev_state_events=input["prev_state_events"],
)

app_service_id = input["app_service_id"]
Expand All @@ -170,14 +144,7 @@ def get_current_state_ids(self, store):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""

if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)

yield make_deferred_yieldable(self._fetching_state_deferred)

yield self._ensure_fetched(store)
return self._current_state_ids

@defer.inlineCallbacks
Expand All @@ -190,14 +157,7 @@ def get_prev_state_ids(self, store):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""

if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)

yield make_deferred_yieldable(self._fetching_state_deferred)

yield self._ensure_fetched(store)
return self._prev_state_ids

def get_cached_current_state_ids(self):
Expand All @@ -211,6 +171,44 @@ def get_cached_current_state_ids(self):

return self._current_state_ids

def _ensure_fetched(self, store):
return defer.succeed(None)


@attr.s(slots=True)
class _AsyncEventContextImpl(EventContext):
"""
An implementation of EventContext which fetches _current_state_ids and
_prev_state_ids from the database on demand.

Attributes:

_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet

_event_type (str): The type of the event the context is associated with.

_event_state_key (str): The state_key of the event the context is
associated with.

_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
"""

_prev_state_id = attr.ib(default=None)
_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)

def _ensure_fetched(self, store):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)

return make_deferred_yieldable(self._fetching_state_deferred)

@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
Expand All @@ -228,27 +226,6 @@ def _fill_out_state(self, store):
else:
self._prev_state_ids = self._current_state_ids

@defer.inlineCallbacks
def update_state(
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
):
"""Replace the state in the context
"""

# We need to make sure we wait for any ongoing fetching of state
# to complete so that the updated state doesn't get clobbered
if self._fetching_state_deferred:
yield make_deferred_yieldable(self._fetching_state_deferred)

self.state_group = state_group
self._prev_state_ids = prev_state_ids
self.prev_group = prev_group
self._current_state_ids = current_state_ids
self.delta_ids = delta_ids

# We need to ensure that that we've marked as having fetched the state
self._fetching_state_deferred = defer.succeed(None)


def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
Expand Down
38 changes: 19 additions & 19 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import (
make_deferred_yieldable,
Expand Down Expand Up @@ -1846,14 +1847,7 @@ def _prep_event(self, origin, event, state, auth_events, backfilled):
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c

try:
yield self.do_auth(origin, event, context, auth_events=auth_events)
except AuthError as e:
logger.warning(
"[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
)

context.rejected = RejectedReason.AUTH_ERROR
context = yield self.do_auth(origin, event, context, auth_events=auth_events)

if not context.rejected:
yield self._check_for_soft_fail(event, state, backfilled)
Expand Down Expand Up @@ -2022,12 +2016,12 @@ def do_auth(self, origin, event, context, auth_events):

Also NB that this function adds entries to it.
Returns:
defer.Deferred[None]
defer.Deferred[EventContext]: updated context object
"""
room_version = yield self.store.get_room_version(event.room_id)

try:
yield self._update_auth_events_and_context_for_auth(
context = yield self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
)
except Exception:
Expand All @@ -2045,7 +2039,9 @@ def do_auth(self, origin, event, context, auth_events):
event_auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
raise e
context.rejected = RejectedReason.AUTH_ERROR

return context

@defer.inlineCallbacks
def _update_auth_events_and_context_for_auth(
Expand All @@ -2069,7 +2065,7 @@ def _update_auth_events_and_context_for_auth(
auth_events (dict[(str, str)->synapse.events.EventBase]):

Returns:
defer.Deferred[None]
defer.Deferred[EventContext]: updated context
"""
event_auth_events = set(event.auth_event_ids())

Expand Down Expand Up @@ -2108,7 +2104,7 @@ def _update_auth_events_and_context_for_auth(
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e)
return
return context

seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
Expand Down Expand Up @@ -2149,7 +2145,7 @@ def _update_auth_events_and_context_for_auth(

if event.internal_metadata.is_outlier():
logger.info("Skipping auth_event fetch for outlier")
return
return context

# FIXME: Assumes we have and stored all the state for all the
# prev_events
Expand All @@ -2158,7 +2154,7 @@ def _update_auth_events_and_context_for_auth(
)

if not different_auth:
return
return context

logger.info(
"auth_events refers to events which are not in our calculated auth "
Expand Down Expand Up @@ -2205,10 +2201,12 @@ def _update_auth_events_and_context_for_auth(

auth_events.update(new_state)

yield self._update_context_for_auth_events(
context = yield self._update_context_for_auth_events(
event, context, auth_events, event_key
)

return context

@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution,
Expand All @@ -2217,14 +2215,16 @@ def _update_context_for_auth_events(self, event, context, auth_events, event_key
Args:
event (Event): The event we're handling the context for

context (synapse.events.snapshot.EventContext): event context
to be updated
context (synapse.events.snapshot.EventContext): initial event context

auth_events (dict[(str, str)->str]): Events to update in the event
context.

event_key ((str, str)): (type, state_key) for the current event.
this will not be included in the current_state in the context.

Returns:
Deferred[EventContext]: new event context
"""
state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
Expand All @@ -2249,7 +2249,7 @@ def _update_context_for_auth_events(self, event, context, auth_events, event_key
current_state_ids=current_state_ids,
)

yield context.update_state(
return EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def setUp(self):
)

self.handler = self.homeserver.get_handlers().federation_handler
self.handler.do_auth = lambda *a, **b: succeed(True)
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
context
)
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
Expand Down