Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Factor out an AsyncEventContext #6298

Merged
merged 6 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6298.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor EventContext for clarity.
107 changes: 42 additions & 65 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ class EventContext:
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
(type, state_key) -> event_id. ``None`` for an outlier.

prev_state_events (?): XXX: is this ever set to anything other than
the empty list?

app_service: FIXME

_current_state_ids (dict[(str, str), str]|None):
Expand All @@ -51,36 +48,16 @@ class EventContext:
The current state map excluding the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id

_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet

_event_type (str): The type of the event the context is associated with.
Only set when state has not been fetched yet.

_event_state_key (str|None): The state_key of the event the context is
associated with. Only set when state has not been fetched yet.

_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
Only set when state has not been fetched yet.
"""

state_group = attr.ib(default=None)
rejected = attr.ib(default=False)
prev_group = attr.ib(default=None)
delta_ids = attr.ib(default=None)
prev_state_events = attr.ib(default=attr.Factory(list))
app_service = attr.ib(default=None)

_current_state_ids = attr.ib(default=None)
_prev_state_ids = attr.ib(default=None)
_prev_state_id = attr.ib(default=None)

_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
_current_state_ids = attr.ib(default=None)

@staticmethod
def with_state(
Expand All @@ -90,7 +67,6 @@ def with_state(
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
state_group=state_group,
fetching_state_deferred=defer.succeed(None),
prev_group=prev_group,
delta_ids=delta_ids,
)
Expand Down Expand Up @@ -125,7 +101,6 @@ def serialize(self, event, store):
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None,
}

Expand All @@ -141,7 +116,7 @@ def deserialize(store, input):
Returns:
EventContext
"""
context = EventContext(
context = AsyncEventContext(
richvdh marked this conversation as resolved.
Show resolved Hide resolved
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
prev_state_id=input["prev_state_id"],
Expand All @@ -151,7 +126,6 @@ def deserialize(store, input):
prev_group=input["prev_group"],
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"],
prev_state_events=input["prev_state_events"],
)

app_service_id = input["app_service_id"]
Expand All @@ -170,14 +144,7 @@ def get_current_state_ids(self, store):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""

if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)

yield make_deferred_yieldable(self._fetching_state_deferred)

yield self._ensure_fetched(store)
return self._current_state_ids

@defer.inlineCallbacks
Expand All @@ -190,14 +157,7 @@ def get_prev_state_ids(self, store):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""

if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)

yield make_deferred_yieldable(self._fetching_state_deferred)

yield self._ensure_fetched(store)
return self._prev_state_ids

def get_cached_current_state_ids(self):
Expand All @@ -211,6 +171,44 @@ def get_cached_current_state_ids(self):

return self._current_state_ids

def _ensure_fetched(self, store):
return defer.succeed(None)


@attr.s(slots=True)
class AsyncEventContext(EventContext):
"""
A version of EventContext which fetches _current_state_ids and _prev_state_ids
from the database on demand.

Attributes:

_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet

_event_type (str): The type of the event the context is associated with.

_event_state_key (str): The state_key of the event the context is
associated with.

_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
"""

_prev_state_id = attr.ib(default=None)
_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)

def _ensure_fetched(self, store):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)

return make_deferred_yieldable(self._fetching_state_deferred)

@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
Expand All @@ -228,27 +226,6 @@ def _fill_out_state(self, store):
else:
self._prev_state_ids = self._current_state_ids

@defer.inlineCallbacks
def update_state(
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
):
"""Replace the state in the context
"""

# We need to make sure we wait for any ongoing fetching of state
# to complete so that the updated state doesn't get clobbered
if self._fetching_state_deferred:
yield make_deferred_yieldable(self._fetching_state_deferred)

self.state_group = state_group
self._prev_state_ids = prev_state_ids
self.prev_group = prev_group
self._current_state_ids = current_state_ids
self.delta_ids = delta_ids

# We need to ensure that that we've marked as having fetched the state
self._fetching_state_deferred = defer.succeed(None)


def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
Expand Down
38 changes: 19 additions & 19 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import (
make_deferred_yieldable,
Expand Down Expand Up @@ -1846,14 +1847,7 @@ def _prep_event(self, origin, event, state, auth_events, backfilled):
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c

try:
yield self.do_auth(origin, event, context, auth_events=auth_events)
except AuthError as e:
logger.warning(
"[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
)

context.rejected = RejectedReason.AUTH_ERROR
context = yield self.do_auth(origin, event, context, auth_events=auth_events)

if not context.rejected:
yield self._check_for_soft_fail(event, state, backfilled)
Expand Down Expand Up @@ -2022,12 +2016,12 @@ def do_auth(self, origin, event, context, auth_events):

Also NB that this function adds entries to it.
Returns:
defer.Deferred[None]
defer.Deferred[EventContext]: updated context object
"""
room_version = yield self.store.get_room_version(event.room_id)

try:
yield self._update_auth_events_and_context_for_auth(
context = yield self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
)
except Exception:
Expand All @@ -2045,7 +2039,9 @@ def do_auth(self, origin, event, context, auth_events):
event_auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
raise e
context.rejected = RejectedReason.AUTH_ERROR

return context

@defer.inlineCallbacks
def _update_auth_events_and_context_for_auth(
Expand All @@ -2069,7 +2065,7 @@ def _update_auth_events_and_context_for_auth(
auth_events (dict[(str, str)->synapse.events.EventBase]):

Returns:
defer.Deferred[None]
defer.Deferred[EventContext]: updated context
"""
event_auth_events = set(event.auth_event_ids())

Expand Down Expand Up @@ -2108,7 +2104,7 @@ def _update_auth_events_and_context_for_auth(
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e)
return
return context

seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
Expand Down Expand Up @@ -2149,7 +2145,7 @@ def _update_auth_events_and_context_for_auth(

if event.internal_metadata.is_outlier():
logger.info("Skipping auth_event fetch for outlier")
return
return context

# FIXME: Assumes we have and stored all the state for all the
# prev_events
Expand All @@ -2158,7 +2154,7 @@ def _update_auth_events_and_context_for_auth(
)

if not different_auth:
return
return context

logger.info(
"auth_events refers to events which are not in our calculated auth "
Expand Down Expand Up @@ -2205,10 +2201,12 @@ def _update_auth_events_and_context_for_auth(

auth_events.update(new_state)

yield self._update_context_for_auth_events(
context = yield self._update_context_for_auth_events(
event, context, auth_events, event_key
)

return context

@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution,
Expand All @@ -2217,14 +2215,16 @@ def _update_context_for_auth_events(self, event, context, auth_events, event_key
Args:
event (Event): The event we're handling the context for

context (synapse.events.snapshot.EventContext): event context
to be updated
context (synapse.events.snapshot.EventContext): initial event context

auth_events (dict[(str, str)->str]): Events to update in the event
context.

event_key ((str, str)): (type, state_key) for the current event.
this will not be included in the current_state in the context.

Returns:
Deferred[EventContext]: new event context
"""
state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
Expand All @@ -2249,7 +2249,7 @@ def _update_context_for_auth_events(self, event, context, auth_events, event_key
current_state_ids=current_state_ids,
)

yield context.update_state(
return EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def setUp(self):
)

self.handler = self.homeserver.get_handlers().federation_handler
self.handler.do_auth = lambda *a, **b: succeed(True)
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
context
)
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
Expand Down