Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add concept of StatelessContext, take 4. #3579

Merged
merged 7 commits into from
Jul 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/3579.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Lazily load state on master process when using workers to reduce DB consumption
6 changes: 4 additions & 2 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def __init__(self, hs):

@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
Expand Down Expand Up @@ -544,7 +545,8 @@ def is_server_admin(self, user):

@defer.inlineCallbacks
def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_ids = yield self.compute_auth_events(builder, prev_state_ids)

auth_events_entries = yield self.store.add_event_hashes(
auth_ids
Expand Down
175 changes: 134 additions & 41 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,12 @@

from twisted.internet import defer

from synapse.util.logcontext import make_deferred_yieldable, run_in_background


class EventContext(object):
"""
Attributes:
current_state_ids (dict[(str, str), str]):
The current state map including the current event.
(type, state_key) -> event_id

prev_state_ids (dict[(str, str), str]):
The current state map excluding the current event.
(type, state_key) -> event_id

state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
Expand All @@ -47,36 +41,74 @@ class EventContext(object):

prev_state_events (?): XXX: is this ever set to anything other than
the empty list?

_current_state_ids (dict[(str, str), str]|None):
The current state map including the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id

_prev_state_ids (dict[(str, str), str]|None):
The current state map excluding the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id

_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet

_event_type (str): The type of the event the context is associated with.
Only set when state has not been fetched yet.

_event_state_key (str|None): The state_key of the event the context is
associated with. Only set when state has not been fetched yet.

_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
Only set when state has not been fetched yet.
"""

__slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group",
"rejected",
"prev_group",
"delta_ids",
"prev_state_events",
"app_service",
"_current_state_ids",
"_prev_state_ids",
"_prev_state_id",
"_event_type",
"_event_state_key",
"_fetching_state_deferred",
]

def __init__(self, state_group, current_state_ids, prev_state_ids,
prev_group=None, delta_ids=None):
@staticmethod
def with_state(state_group, current_state_ids, prev_state_ids,
prev_group=None, delta_ids=None):
context = EventContext()

# The current state including the current event
self.current_state_ids = current_state_ids
context._current_state_ids = current_state_ids
# The current state excluding the current event
self.prev_state_ids = prev_state_ids
self.state_group = state_group
context._prev_state_ids = prev_state_ids
context.state_group = state_group

context._prev_state_id = None
context._event_type = None
context._event_state_key = None
context._fetching_state_deferred = defer.succeed(None)

# A previously persisted state group and a delta between that
# and this state.
self.prev_group = prev_group
self.delta_ids = delta_ids
context.prev_group = prev_group
context.delta_ids = delta_ids

self.prev_state_events = []
context.prev_state_events = []

self.rejected = False
self.app_service = None
context.rejected = False
context.app_service = None

return context

def serialize(self, event):
"""Converts self to a type that can be serialized as JSON, and then
Expand Down Expand Up @@ -123,30 +155,17 @@ def deserialize(store, input):
Returns:
EventContext
"""
context = EventContext()

# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
prev_state_id = input["prev_state_id"]
event_type = input["event_type"]
event_state_key = input["event_state_key"]

state_group = input["state_group"]
context._prev_state_id = input["prev_state_id"]
context._event_type = input["event_type"]
context._event_state_key = input["event_state_key"]

current_state_ids = yield store.get_state_ids_for_group(
state_group,
)
if prev_state_id and event_state_key:
prev_state_ids = dict(current_state_ids)
prev_state_ids[(event_type, event_state_key)] = prev_state_id
else:
prev_state_ids = current_state_ids

context = EventContext(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
prev_group=input["prev_group"],
delta_ids=_decode_state_dict(input["delta_ids"]),
)
context.state_group = input["state_group"]
context.prev_group = input["prev_group"]
context.delta_ids = _decode_state_dict(input["delta_ids"])

context.rejected = input["rejected"]
context.prev_state_events = input["prev_state_events"]
Expand All @@ -157,6 +176,80 @@ def deserialize(store, input):

defer.returnValue(context)

@defer.inlineCallbacks
def get_current_state_ids(self, store):
"""Gets the current state IDs

Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
"""

if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store,
)

yield make_deferred_yieldable(self._fetching_state_deferred)

defer.returnValue(self._current_state_ids)

@defer.inlineCallbacks
def get_prev_state_ids(self, store):
"""Gets the prev state IDs

Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
"""

if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store,
)

yield make_deferred_yieldable(self._fetching_state_deferred)

defer.returnValue(self._prev_state_ids)

@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return

self._current_state_ids = yield store.get_state_ids_for_group(
self.state_group,
)
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)

key = (self._event_type, self._event_state_key)
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids = self._current_state_ids

@defer.inlineCallbacks
def update_state(self, state_group, prev_state_ids, current_state_ids,
delta_ids):
"""Replace the state in the context
"""

# We need to make sure we wait for any ongoing fetching of state
# to complete so that the updated state doesn't get clobbered
if self._fetching_state_deferred:
yield make_deferred_yieldable(self._fetching_state_deferred)

self.state_group = state_group
self._prev_state_ids = prev_state_ids
self._current_state_ids = current_state_ids
self.delta_ids = delta_ids

# We need to ensure that that we've marked as having fetched the state
self._fetching_state_deferred = defer.succeed(None)


def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def maybe_kick_guest_users(self, event, context=None):
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
current_state_ids = yield context.get_current_state_ids(self.store)
current_state = yield self.store.get_events(
list(context.current_state_ids.values())
list(current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(
Expand Down
55 changes: 39 additions & 16 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,10 @@ def _process_received_pdu(self, origin, pdu, state, auth_chain):
# joined the room. Don't bother if the user is just
# changing their profile info.
newly_joined = True
prev_state_id = context.prev_state_ids.get(

prev_state_ids = yield context.get_prev_state_ids(self.store)

prev_state_id = prev_state_ids.get(
(event.type, event.state_key)
)
if prev_state_id:
Expand Down Expand Up @@ -1106,10 +1109,12 @@ def on_send_join_request(self, origin, pdu):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)

state_ids = list(context.prev_state_ids.values())
prev_state_ids = yield context.get_prev_state_ids(self.store)

state_ids = list(prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids)

state = yield self.store.get_events(list(context.prev_state_ids.values()))
state = yield self.store.get_events(list(prev_state_ids.values()))

defer.returnValue({
"state": list(state.values()),
Expand Down Expand Up @@ -1635,8 +1640,9 @@ def _prep_event(self, origin, event, state=None, auth_events=None):
)

if not auth_events:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
Expand Down Expand Up @@ -1876,9 +1882,10 @@ def do_auth(self, origin, event, context, auth_events):
break

if do_resolution:
prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids
event, prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
Expand Down Expand Up @@ -1968,21 +1975,35 @@ def _update_context_for_auth_events(self, event, context, auth_events,
k: a.event_id for k, a in iteritems(auth_events)
if k != event_key
}
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update(state_updates)
current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = dict(current_state_ids)

current_state_ids.update(state_updates)

if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids)
context.delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
delta_ids = dict(context.delta_ids)
delta_ids.update(state_updates)

prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = dict(prev_state_ids)

prev_state_ids.update({
k: a.event_id for k, a in iteritems(auth_events)
})
context.state_group = yield self.store.store_state_group(

state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
delta_ids=delta_ids,
current_state_ids=current_state_ids,
)

yield context.update_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
delta_ids=delta_ids,
)

@defer.inlineCallbacks
Expand Down Expand Up @@ -2222,7 +2243,8 @@ def add_display_name_to_third_party_invite(self, event_dict, event, context):
event.content["third_party_invite"]["signed"]["token"]
)
original_invite = None
original_invite_id = context.prev_state_ids.get(key)
prev_state_ids = yield context.get_prev_state_ids(self.store)
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
Expand Down Expand Up @@ -2264,7 +2286,8 @@ def _check_signature(self, event, context):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]

invite_event_id = context.prev_state_ids.get(
prev_state_ids = yield context.get_prev_state_ids(self.store)
invite_event_id = prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,)
)

Expand Down
Loading