diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index a15198e05d0c..003eaba893ad 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -187,6 +187,7 @@ def notify_new_events(self, current_id): prev_id for prev_id, _ in event.prev_events ], ) + destinations = set(destinations) if send_on_behalf_of is not None: # If we are sending the event on behalf of another server diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3b7818af5c12..82dedbbc9945 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -89,7 +89,7 @@ def _handle_timeouts(self): until = self._member_typing_until.get(member, None) if not until or until <= now: logger.info("Timing out typing for: %s", member.user_id) - preserve_fn(self._stopped_typing)(member) + self._stopped_typing(member) continue # Check if we need to resend a keep alive over federation for this @@ -147,7 +147,7 @@ def started_typing(self, target_user, auth_user, room_id, timeout): # No point sending another notification defer.returnValue(None) - yield self._push_update( + self._push_update( member=member, typing=True, ) @@ -171,7 +171,7 @@ def stopped_typing(self, target_user, auth_user, room_id): member = RoomMember(room_id=room_id, user_id=target_user_id) - yield self._stopped_typing(member) + self._stopped_typing(member) @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -180,7 +180,6 @@ def user_left_room(self, user, room_id): member = RoomMember(room_id=room_id, user_id=user_id) yield self._stopped_typing(member) - @defer.inlineCallbacks def _stopped_typing(self, member): if member.user_id not in self._room_typing.get(member.room_id, set()): # No point @@ -189,16 +188,15 @@ def _stopped_typing(self, member): self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) - yield self._push_update( + self._push_update( member=member, typing=False, ) - @defer.inlineCallbacks def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - yield self._push_remote(member, typing) + preserve_fn(self._push_remote)(member, typing) self._push_update_local( member=member, diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index fcaf58b93b68..6cd3a843df7d 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -108,6 +108,8 @@ def __init__(self, db_conn, hs): get_current_state_ids = ( StateStore.__dict__["get_current_state_ids"] ) + get_state_group_delta = DataStore.get_state_group_delta.__func__ + _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] has_room_changed_since = DataStore.has_room_changed_since.__func__ get_unread_push_actions_for_user_in_range_for_http = ( diff --git a/synapse/state.py b/synapse/state.py index dffa79e4c933..a98145598fc8 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -170,9 +170,7 @@ def get_current_user_in_room(self, room_id, latest_event_ids=None): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_user_in_room") entry = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_state( - room_id, entry.state_id, entry.state - ) + joined_users = yield self.store.get_joined_users_from_state(room_id, entry) defer.returnValue(joined_users) @defer.inlineCallbacks @@ -181,9 +179,7 @@ def get_current_hosts_in_room(self, room_id, latest_event_ids=None): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_hosts_in_room") entry = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_hosts = yield self.store.get_joined_hosts( - room_id, entry.state_id, entry.state - ) + joined_hosts = yield self.store.get_joined_hosts(room_id, entry) defer.returnValue(joined_hosts) @defer.inlineCallbacks @@ -206,12 +202,12 @@ def compute_event_context(self, event, old_state=None): Returns: synapse.events.snapshot.EventContext: """ - context = EventContext() if event.internal_metadata.is_outlier(): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. + context = EventContext() if old_state: context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state @@ -230,6 +226,7 @@ def compute_event_context(self, event, old_state=None): defer.returnValue(context) if old_state: + context = EventContext() context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } @@ -250,19 +247,13 @@ def compute_event_context(self, event, old_state=None): defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") - if event.is_state(): - entry = yield self.resolve_state_groups( - event.room_id, [e for e, _ in event.prev_events], - event_type=event.type, - state_key=event.state_key, - ) - else: - entry = yield self.resolve_state_groups( - event.room_id, [e for e, _ in event.prev_events], - ) + entry = yield self.resolve_state_groups( + event.room_id, [e for e, _ in event.prev_events], + ) curr_state = entry.state + context = EventContext() context.prev_state_ids = curr_state if event.is_state(): context.state_group = self.store.get_next_state_group() @@ -275,11 +266,14 @@ def compute_event_context(self, event, old_state=None): context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids[key] = event.event_id - context.prev_group = entry.prev_group - context.delta_ids = entry.delta_ids - if context.delta_ids is not None: - context.delta_ids = dict(context.delta_ids) - context.delta_ids[key] = event.event_id + if entry.state_group: + context.prev_group = entry.state_group + context.delta_ids = { + key: event.event_id + } + elif entry.prev_group: + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids else: if entry.state_group is None: entry.state_group = self.store.get_next_state_group() @@ -295,7 +289,7 @@ def compute_event_context(self, event, old_state=None): @defer.inlineCallbacks @log_function - def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): + def resolve_state_groups(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -320,11 +314,13 @@ def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key="" if len(group_names) == 1: name, state_list = state_groups_ids.items().pop() + prev_group, delta_ids = yield self.store.get_state_group_delta(name) + defer.returnValue(_StateCacheEntry( state=state_list, state_group=name, - prev_group=name, - delta_ids={}, + prev_group=prev_group, + delta_ids=delta_ids, )) with (yield self.resolve_linearizer.queue(group_names)): @@ -377,11 +373,11 @@ def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key="" prev_group = None delta_ids = None - for old_group, old_ids in state_groups_ids.items(): - if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): + for old_group, old_ids in state_groups_ids.iteritems(): + if not set(new_state) - set(old_ids): n_delta_ids = { k: v - for k, v in new_state.items() + for k, v in new_state.iteritems() if old_ids.get(k) != v } if not delta_ids or len(n_delta_ids) < len(delta_ids): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 0829ae5bee43..8656455f6e2f 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -18,6 +18,7 @@ from collections import namedtuple from ._base import SQLBaseStore +from synapse.util.async import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii @@ -392,7 +393,8 @@ def get_joined_users_from_context(self, event, context): context=context, ) - def get_joined_users_from_state(self, room_id, state_group, state_ids): + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -401,7 +403,7 @@ def get_joined_users_from_state(self, room_id, state_group, state_ids): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, state_ids, + room_id, state_group, state_entry.state, context=state_entry, ) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, @@ -534,7 +536,8 @@ def _is_host_joined(self, room_id, host, state_group, current_state_ids): defer.returnValue(False) - def get_joined_hosts(self, room_id, state_group, state_ids): + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -543,33 +546,20 @@ def get_joined_hosts(self, room_id, state_group, state_ids): state_group = object() return self._get_joined_hosts( - room_id, state_group, state_ids + room_id, state_group, state_entry.state, state_entry=state_entry, ) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - def _get_joined_hosts(self, room_id, state_group, current_state_ids): + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - joined_hosts = set() - for etype, state_key in current_state_ids: - if etype == EventTypes.Member: - try: - host = get_domain_from_id(state_key) - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - if host in joined_hosts: - continue - - event_id = current_state_ids[(etype, state_key)] - event = yield self.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - joined_hosts.add(intern_string(host)) + cache = self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) defer.returnValue(joined_hosts) @@ -647,3 +637,75 @@ def add_membership_profile_txn(txn): yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) defer.returnValue(result) + + @cached(max_entries=10000, iterable=True) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in state_entry.delta_ids.iteritems(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry, + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues()) + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + def __len__(self): + return self._len diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a7c3d401d405..c3eecbe8240d 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -98,6 +98,45 @@ def _get_current_state_ids_txn(txn): _get_current_state_ids_txn, ) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + }, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return None, None + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + }, + retcols=("type", "state_key", "event_id",) + ) + + return prev_group, { + (row["type"], row["state_key"]): row["event_id"] + for row in delta_ids + } + return self.runInteraction( + "get_state_group_delta", + _get_state_group_delta_txn, + ) + @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): if not event_ids: diff --git a/tests/test_state.py b/tests/test_state.py index 6454f994e325..feb84f3d48b8 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -143,6 +143,7 @@ def setUp(self): "add_event_hashes", "get_events", "get_next_state_group", + "get_state_group_delta", ] ) hs = Mock(spec_set=[ @@ -154,6 +155,7 @@ def setUp(self): hs.get_auth.return_value = Auth(hs) self.store.get_next_state_group.side_effect = Mock + self.store.get_state_group_delta.return_value = (None, None) self.state = StateHandler(hs) self.event_id = 0