Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use StateResolutionHandler to resolve state in persist_events #2864

Merged
merged 5 commits into from
Feb 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions synapse/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def resolve_state_groups_for_events(self, room_id, event_ids):
))

result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, self._state_map_factory,
room_id, state_groups_ids, None, self._state_map_factory,
)
defer.returnValue(result)

Expand Down Expand Up @@ -371,7 +371,9 @@ def start_caching(self):

@defer.inlineCallbacks
@log_function
def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
def resolve_state_groups(
self, room_id, state_groups_ids, event_map, state_map_factory,
):
"""Resolves conflicts between a set of state groups

Always generates a new state group (unless we hit the cache), so should
Expand All @@ -383,6 +385,14 @@ def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)

event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.

If None, all events will be fetched via state_map_factory.

Returns:
Deferred[_StateCacheEntry]: resolved state
"""
Expand Down Expand Up @@ -423,6 +433,7 @@ def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
event_map=event_map,
state_map_factory=state_map_factory,
)
else:
Expand Down Expand Up @@ -555,11 +566,20 @@ def _seperate(state_sets):


@defer.inlineCallbacks
def resolve_events_with_factory(state_sets, state_map_factory):
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.

event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.

If None, all events will be fetched via state_map_factory.

state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
Expand All @@ -580,12 +600,16 @@ def resolve_events_with_factory(state_sets, state_map_factory):
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
if event_map is not None:
needed_events -= set(event_map.iterkeys())

logger.info("Asking for %d conflicted events", len(needed_events))

# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict.
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)

# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
Expand All @@ -597,6 +621,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):

new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(event_map.iterkeys())

logger.info("Asking for %d auth events", len(new_needed_events))

Expand Down
96 changes: 40 additions & 56 deletions synapse/storage/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.state import resolve_events_with_factory
from synapse.util.caches.descriptors import cached
from synapse.types import get_domain_from_id

Expand Down Expand Up @@ -237,6 +236,8 @@ def __init__(self, db_conn, hs):

self._event_persist_queue = _EventPeristenceQueue()

self._state_resolution_handler = hs.get_state_resolution_handler()

def persist_events(self, events_and_contexts, backfilled=False):
"""
Write events to the database
Expand Down Expand Up @@ -402,6 +403,7 @@ def _persist_events(self, events_and_contexts, backfilled=False,
"Calculating state delta for room %s", room_id,
)
current_state = yield self._get_new_state_after_events(
room_id,
ev_ctx_rm, new_latest_event_ids,
)
if current_state is not None:
Expand Down Expand Up @@ -487,11 +489,14 @@ def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
defer.returnValue(new_latest_event_ids)

@defer.inlineCallbacks
def _get_new_state_after_events(self, events_context, new_latest_event_ids):
def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
"""Calculate the current state dict after adding some new events to
a room

Args:
room_id (str):
room to which the events are being added. Used for logging etc

events_context (list[(EventBase, EventContext)]):
events and contexts which are being added to the room

Expand All @@ -503,8 +508,12 @@ def _get_new_state_after_events(self, events_context, new_latest_event_ids):
None if there are no changes to the room state, or
a dict of (type, state_key) -> event_id].
"""
state_sets = []
state_groups = set()

if not new_latest_event_ids:
defer.returnValue({})

# map from state_group to ((type, key) -> event_id) state map
state_groups = {}
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
Expand All @@ -515,82 +524,57 @@ def _get_new_state_after_events(self, events_context, new_latest_event_ids):
if ctx.current_state_ids is None:
raise Exception("Unknown current state")

if ctx.state_group is None:
# I don't think this can happen, but let's double-check
raise Exception(
"Context for new extremity event %s has no state "
"group" % (event_id, ),
)

# If we've already seen the state group don't bother adding
# it to the state sets again
if ctx.state_group not in state_groups:
state_sets.append(ctx.current_state_ids)
state_groups[ctx.state_group] = ctx.current_state_ids
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
if ctx.state_group:
# Add this as a seen state group (if it has a state
# group)
state_groups.add(ctx.state_group)
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
was_updated = True
missing_event_ids.append(event_id)

if not was_updated:
return

if missing_event_ids:
# Now pull out the state for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)

groups = set(event_to_groups.itervalues()) - state_groups
groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())

if groups:
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.itervalues())
state_groups.update(group_to_state)

if not new_latest_event_ids:
defer.returnValue({})
elif was_updated:
if len(state_sets) == 1:
# If there is only one state set, then we know what the current
# state is.
defer.returnValue(state_sets[0])
else:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
# the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events
# up in the db.

logger.info(
"Resolving state with %i state sets", len(state_sets),
)
if len(state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
defer.returnValue(state_groups.values()[0])

events_map = {ev.event_id: ev for ev, _ in events_context}

@defer.inlineCallbacks
def get_events(ev_ids):
# We get the events by first looking at the list of events we
# are trying to persist, and then fetching the rest from the DB.
db = []
to_return = {}
for ev_id in ev_ids:
ev = events_map.get(ev_id, None)
if ev:
to_return[ev_id] = ev
else:
db.append(ev_id)

if db:
evs = yield self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
to_return.update(evs)
defer.returnValue(to_return)
def get_events(ev_ids):
return self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
events_map = {ev.event_id: ev for ev, _ in events_context}
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups, events_map, get_events
)

current_state = yield resolve_events_with_factory(
state_sets,
state_map_factory=get_events,
)
defer.returnValue(current_state)
else:
return
defer.returnValue(res.state)

@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, current_state):
Expand Down