Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add concept of StatelessContext #3550

Closed
wants to merge 9 commits into from
1 change: 1 addition & 0 deletions changelog.d/3550.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Lazily load state on master process when using workers to reduce DB consumption
189 changes: 152 additions & 37 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import abc

from frozendict import frozendict

from twisted.internet import defer

from synapse.util.logcontext import make_deferred_yieldable


class EventContext(object):
class StatelessContext(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like a funny name, given it may have a state.

Indeed I wonder if having it separate from EventContext is worthwhile? maybe combine them and have DeserializedContext just override the accessors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we can't just override accessors since they would need to return a deferreds in the DeserializedContext

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but that's also true for EventContext. I don't understand why that means we can't override the accessors.

"""
Attributes:
current_state_ids (dict[(str, str), str]):
The current state map including the current event.
(type, state_key) -> event_id

prev_state_ids (dict[(str, str), str]):
The current state map excluding the current event.
(type, state_key) -> event_id

state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else
False

push_actions (list[(str, list[object])]): list of (user_id, actions)
tuples

prev_group (int): Previously persisted state group. ``None`` for an
outlier.
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
Expand All @@ -47,9 +40,9 @@ class EventContext(object):
the empty list?
"""

__metaclass__ = abc.ABCMeta

__slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group",
"rejected",
"prev_group",
Expand All @@ -59,10 +52,6 @@ class EventContext(object):
]

def __init__(self):
# The current state including the current event
self.current_state_ids = None
# The current state excluding the current event
self.prev_state_ids = None
self.state_group = None

self.rejected = False
Expand All @@ -76,9 +65,61 @@ def __init__(self):

self.app_service = None

# The current state including the current event
self.current_state_ids = None
# The current state excluding the current event
self.prev_state_ids = None

@abc.abstractmethod
def get_current_state_ids(self, store):
"""Gets the current state IDs

Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
"""
raise NotImplementedError()

@abc.abstractmethod
def get_prev_state_ids(self, store):
"""Gets the prev state IDs

Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
"""
raise NotImplementedError()


class EventContext(StatelessContext):
"""This is the same as StatelessContext, except that
current_state_ids and prev_state_ids are already calculated.

Attributes:
current_state_ids (dict[(str, str), str]|None):
The current state map including the current event.
(type, state_key) -> event_id
Is None if event is an outlier

prev_state_ids (dict[(str, str), str]|None):
The current state map excluding the current event.
(type, state_key) -> event_id`
Is None if event is an outlier
"""
__slots__ = [
"current_state_ids",
"prev_state_ids",
]

def __init__(self):
super(EventContext, self).__init__()

self.current_state_ids = None
self.prev_state_ids = None

def serialize(self, event):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
deserialized by `DeserializedContext.deserialize`

Args:
event (FrozenEvent): The event that this context relates to
Expand Down Expand Up @@ -108,46 +149,120 @@ def serialize(self, event):
"app_service_id": self.app_service.id if self.app_service else None
}

def get_current_state_ids(self, store):
"""Implements StatelessContext"""
return defer.succeed(self.current_state_ids)

def get_prev_state_ids(self, store):
"""Implements StatelessContext"""
return defer.succeed(self.prev_state_ids)


class DeserializedContext(StatelessContext):
"""A context that comes from a serialized version of a StatelessContext.

It does not necessarily have current_state_ids and prev_state_ids precomputed
(unlike EventContext), but does cache the results of
`get_current_state_ids` and `get_prev_state_ids`.

Attributes:
_have_fetched_state (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
_prev_state_id (str|None): If set then the event associated with the
context overrode the _prev_state_id
_event_type (str): The type of the event the context is associated with
_event_state_key (str|None): The state_key of the event the context is
associated with
_current_state_ids (dict[(str, str), str]|None):
The current state map including the current event.
(type, state_key) -> event_id
_prev_state_ids (dict[(str, str), str]|None):
The current state map excluding the current event.
(type, state_key) -> event_id`
"""

__slots__ = [
"_current_state_ids",
"_prev_state_ids",
"_have_fetched_state",
"_prev_state_id",
"_event_type",
"_event_state_key",
]

@staticmethod
@defer.inlineCallbacks
def deserialize(store, input):
"""Converts a dict that was produced by `serialize` back into a
EventContext.
StatelessContext.

Args:
store (DataStore): Used to convert AS ID to AS object
input (dict): A dict produced by `serialize`

Returns:
EventContext
DeserializedContext
"""
context = EventContext()
context = DeserializedContext()
context.state_group = input["state_group"]
context.rejected = input["rejected"]
context.prev_group = input["prev_group"]
context.delta_ids = _decode_state_dict(input["delta_ids"])
context.prev_state_events = input["prev_state_events"]

# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
prev_state_id = input["prev_state_id"]
event_type = input["event_type"]
event_state_key = input["event_state_key"]
context._prev_state_id = input["prev_state_id"]
context._event_type = input["event_type"]
context._event_state_key = input["event_state_key"]

context.current_state_ids = yield store.get_state_ids_for_group(
context.state_group,
)
if prev_state_id and event_state_key:
context.prev_state_ids = dict(context.current_state_ids)
context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
else:
context.prev_state_ids = context.current_state_ids
context._have_fetched_state = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is going to be a Deferred, then it needs a different name.

context._current_state_ids = None
context._prev_state_ids = None

app_service_id = input["app_service_id"]
if app_service_id:
context.app_service = store.get_app_service_by_id(app_service_id)

defer.returnValue(context)
return context

@defer.inlineCallbacks
def get_current_state_ids(self, store):
"""Implements StatelessContext"""

if not self._have_fetched_state:
self._have_fetched_state = self._fill_out_state(store)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need a run_in_background here (and in get_prev_state_ids)


yield make_deferred_yieldable(self._have_fetched_state)

defer.returnValue(self.current_state_ids)

@defer.inlineCallbacks
def get_prev_state_ids(self, store):
"""Implements StatelessContext"""

if not self._have_fetched_state:
self._have_fetched_state = self._fill_out_state(store)

yield make_deferred_yieldable(self._have_fetched_state)

defer.returnValue(self.current_state_ids)

@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return

self.current_state_ids = yield store.get_state_ids_for_group(
self.state_group,
)
if self._prev_state_id and self._event_state_key is not None:
self.prev_state_ids = dict(self.current_state_ids)

key = (self._event_type, self._event_state_key)
self.prev_state_ids[key] = self._prev_state_id
else:
self.prev_state_ids = self.current_state_ids


def _encode_state_dict(state_dict):
Expand Down
8 changes: 7 additions & 1 deletion synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,20 @@ def ratelimit(self, requester, update=True):

@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, context=None):
"""
Args:
event (FrozenEvent)
context (StatelessContext)
"""
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
current_state_ids = yield context.get_current_state_ids(self.store)
current_state = yield self.store.get_events(
list(context.current_state_ids.values())
list(current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(
Expand Down
26 changes: 19 additions & 7 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,13 @@ def persist_and_notify_client_event(
calculated the push actions for the event, and checked auth.

This should only be run on master.

Args:
requester (Requester)
event (FrozenEvent)
context (StatelessContext)
ratelimit(bool)
extra_users (list[UserID])
"""
assert not self.config.worker_app

Expand Down Expand Up @@ -884,9 +891,11 @@ def is_inviter_member_event(e):
e.sender == event.sender
)

current_state_ids = yield context.get_current_state_ids(self.store)

state_to_include_ids = [
e_id
for k, e_id in iteritems(context.current_state_ids)
for k, e_id in current_state_ids.iteritems()
if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
Expand Down Expand Up @@ -922,8 +931,9 @@ def is_inviter_member_event(e):
)

if event.type == EventTypes.Redaction:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
Expand All @@ -943,11 +953,13 @@ def is_inviter_member_event(e):
"You don't have permission to redact events"
)

if event.type == EventTypes.Create and context.prev_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",
)

(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
Expand Down
6 changes: 4 additions & 2 deletions synapse/replication/http/send_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SynapseError,
)
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.events.snapshot import DeserializedContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import Requester, UserID
from synapse.util.caches.response_cache import ResponseCache
Expand Down Expand Up @@ -136,7 +136,9 @@ def _handle_request(self, request):
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)

requester = Requester.deserialize(self.store, content["requester"])
context = yield EventContext.deserialize(self.store, content["context"])
context = yield DeserializedContext.deserialize(
self.store, content["context"],
)

ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
Expand Down
Loading