From d40442b732628784e8e8fc6671742226513c390a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 19 Apr 2021 14:43:10 +0100 Subject: [PATCH 1/8] Only store data in cache, not smart objects --- synapse/storage/databases/main/roommember.py | 139 ++++++++++--------- 1 file changed, 72 insertions(+), 67 deletions(-) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index fd525dce65c3..4446e2889c9a 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) + +import attr from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -33,7 +45,12 @@ ProfileInfo, RoomsForUser, ) -from synapse.types import Collection, PersistedEventPosition, get_domain_from_id +from synapse.types import ( + Collection, + PersistedEventPosition, + StateMap, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -53,6 +70,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self._joined_host_linearizer = Linearizer("_JoinedHostsCache") + # Is the current_state_events.membership up to date? Or is the # background update still running? self._current_state_events_membership_up_to_date = False @@ -730,19 +749,62 @@ async def get_joined_hosts(self, room_id: str, state_entry): @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, room_id, state_group, current_state_ids, state_entry - ): + self, + room_id: str, + state_group: int, + current_state_ids: StateMap[str], + state_entry: "_StateCacheEntry", + ) -> FrozenSet[str]: # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. assert state_group is not None cache = await self._get_joined_hosts_cache(room_id) - return await cache.get_destinations(state_entry) + + if state_entry.state_group == cache.state_group: + return frozenset(cache.hosts_to_joined_users) + + with (await self._joined_host_linearizer.queue(room_id)): + if state_entry.state_group == cache.state_group: + pass + elif state_entry.prev_group == cache.state_group: + for (typ, state_key), event_id in state_entry.delta_ids.items(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = cache.hosts_to_joined_users.setdefault(host, set()) + + event = await self.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + cache.hosts_to_joined_users.pop(host, None) + else: + joined_users = await self.get_joined_users_from_state( + room_id, state_entry + ) + + cache.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + cache.state_group = state_entry.state_group + else: + cache.state_group = object() + + return frozenset(cache.hosts_to_joined_users) @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": - return _JoinedHostsCache(self, room_id) + return _JoinedHostsCache() @cached(num_args=2) async def did_forget(self, user_id: str, room_id: str) -> bool: @@ -1052,71 +1114,14 @@ def f(txn): await self.db_pool.runInteraction("forget_membership", f) +@attr.s(slots=True) class _JoinedHostsCache: """Cache for joined hosts in a room that is optimised to handle updates via state deltas. """ - def __init__(self, store, room_id): - self.store = store - self.room_id = room_id - - self.hosts_to_joined_users = {} - - self.state_group = object() - - self.linearizer = Linearizer("_JoinedHostsCache") - - self._len = 0 - - async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]: - """Get set of destinations for a state entry - - Args: - state_entry - - Returns: - The destinations as a set. - """ - if state_entry.state_group == self.state_group: - return frozenset(self.hosts_to_joined_users) - - with (await self.linearizer.queue(())): - if state_entry.state_group == self.state_group: - pass - elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in state_entry.delta_ids.items(): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = self.hosts_to_joined_users.setdefault(host, set()) - - event = await self.store.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - self.hosts_to_joined_users.pop(host, None) - else: - joined_users = await self.store.get_joined_users_from_state( - self.room_id, state_entry - ) - - self.hosts_to_joined_users = {} - for user_id in joined_users: - host = intern_string(get_domain_from_id(user_id)) - self.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - self.state_group = state_entry.state_group - else: - self.state_group = object() - self._len = sum(len(v) for v in self.hosts_to_joined_users.values()) - return frozenset(self.hosts_to_joined_users) + hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict) + state_group = attr.ib(type=Union[object, int], factory=object) def __len__(self): - return self._len + return sum(len(v) for v in self.hosts_to_joined_users.values()) From 7ade200d7a89853d7a9a0e7565b575c59b748b9b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 19 Apr 2021 14:53:06 +0100 Subject: [PATCH 2/8] Add some comments to _get_joined_hosts --- synapse/storage/databases/main/roommember.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4446e2889c9a..eac79c108fd9 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -760,15 +760,26 @@ async def _get_joined_hosts( # with a state_group of None are likely to be different. assert state_group is not None + # We use a secondary cache of previous work to allow us to build up the + # joined hosts for the given state group based on previous state groups. + # + # We cache one object per room containing the results of the last state + # group we got joined hosts for, with the idea being that generally + # `get_joined_hosts` with the "current" state group for the room. cache = await self._get_joined_hosts_cache(room_id) + # If the state group in the cache matches then its a no-op. if state_entry.state_group == cache.state_group: return frozenset(cache.hosts_to_joined_users) + # Since we'll mutate the cache we need to lock. with (await self._joined_host_linearizer.queue(room_id)): if state_entry.state_group == cache.state_group: + # Same state group, so nothing to do pass elif state_entry.prev_group == cache.state_group: + # The cache work is for the previous state group, so we work out + # the delta. for (typ, state_key), event_id in state_entry.delta_ids.items(): if typ != EventTypes.Member: continue @@ -786,6 +797,8 @@ async def _get_joined_hosts( if not known_joins: cache.hosts_to_joined_users.pop(host, None) else: + # The cache doesn't match the state group or prev state group, + # so we calculate the result from first principles. joined_users = await self.get_joined_users_from_state( room_id, state_entry ) @@ -1120,6 +1133,7 @@ class _JoinedHostsCache: via state deltas. """ + # Dict of host to the set of their users in the room at the state group. hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict) state_group = attr.ib(type=Union[object, int], factory=object) From 7bb13fed3415d676aefcc500499c1089373981f2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 19 Apr 2021 15:34:54 +0100 Subject: [PATCH 3/8] Only cache the data for push rules cache --- synapse/push/bulk_push_rule_evaluator.py | 136 +++++++++++++---------- 1 file changed, 77 insertions(+), 59 deletions(-) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 50b470c310b1..97a5acd51378 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -106,6 +106,8 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() + self._rules_linearizer = Linearizer(name="rules_for_room") + self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", @@ -123,7 +125,16 @@ async def _get_rules_for_event( dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + + rules_for_room_data = self._get_rules_for_room(room_id) + rules_for_room = RulesForRoom( + hs=self.hs, + room_id=room_id, + rules_for_room_cache=self._get_rules_for_room.cache, + room_push_rule_cache_metrics=self.room_push_rule_cache_metrics, + linearizer=self._rules_linearizer, + cached_data=rules_for_room_data, + ) rules_by_user = await rules_for_room.get_rules(event, context) @@ -142,17 +153,12 @@ async def _get_rules_for_event( return rules_by_user @lru_cache() - def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": + def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData": """Get the current RulesForRoom object for the given room id""" # It's important that RulesForRoom gets added to self._get_rules_for_room.cache # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) - return RulesForRoom( - self.hs, - room_id, - self._get_rules_for_room.cache, - self.room_push_rule_cache_metrics, - ) + return RulesForRoomData() async def _get_power_levels_and_sender_level( self, event: EventBase, context: EventContext @@ -282,6 +288,34 @@ def _condition_checker( return True +@attr.s(slots=True) +class RulesForRoomData: + # event_id -> (user_id, state) + member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict) + # user_id -> rules + rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict) + + # The last state group we updated the caches for. If the state_group of + # a new event comes along, we know that we can just return the cached + # result. + # On invalidation of the rules themselves (if the user changes them), + # we invalidate everything and set state_group to `object()` + state_group = attr.ib(type=Union[object, int], factory=object) + + # A sequence number to keep track of when we're allowed to update the + # cache. We bump the sequence number when we invalidate the cache. If + # the sequence number changes while we're calculating stuff we should + # not update the cache with it. + sequence = attr.ib(type=int, default=0) + + # A cache of user_ids that we *know* aren't interesting, e.g. user_ids + # owned by AS's, or remote users, etc. (I.e. users we will never need to + # calculate push for) + # These never need to be invalidated as we will never set up push for + # them. + uninteresting_user_set = attr.ib(type=Set[str], factory=set) + + class RulesForRoom: """Caches push rules for users in a room. @@ -295,6 +329,8 @@ def __init__( room_id: str, rules_for_room_cache: LruCache, room_push_rule_cache_metrics: CacheMetric, + linearizer: Linearizer, + cached_data: RulesForRoomData, ): """ Args: @@ -303,38 +339,16 @@ def __init__( rules_for_room_cache: The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics: The metrics object + data """ self.room_id = room_id self.is_mine_id = hs.is_mine_id self.store = hs.get_datastore() self.room_push_rule_cache_metrics = room_push_rule_cache_metrics - self.linearizer = Linearizer(name="rules_for_room") - - # event_id -> (user_id, state) - self.member_map = {} # type: Dict[str, Tuple[str, str]] - # user_id -> rules - self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]] - - # The last state group we updated the caches for. If the state_group of - # a new event comes along, we know that we can just return the cached - # result. - # On invalidation of the rules themselves (if the user changes them), - # we invalidate everything and set state_group to `object()` - self.state_group = object() - - # A sequence number to keep track of when we're allowed to update the - # cache. We bump the sequence number when we invalidate the cache. If - # the sequence number changes while we're calculating stuff we should - # not update the cache with it. - self.sequence = 0 - - # A cache of user_ids that we *know* aren't interesting, e.g. user_ids - # owned by AS's, or remote users, etc. (I.e. users we will never need to - # calculate push for) - # These never need to be invalidated as we will never set up push for - # them. - self.uninteresting_user_set = set() # type: Set[str] + self.linearizer = linearizer + + self.data = cached_data # We need to be clever on the invalidating caches callbacks, as # otherwise the invalidation callback holds a reference to the object, @@ -352,25 +366,25 @@ async def get_rules( """ state_group = context.state_group - if state_group and self.state_group == state_group: + if state_group and self.data.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() - return self.rules_by_user + return self.data.rules_by_user - with (await self.linearizer.queue(())): - if state_group and self.state_group == state_group: + with (await self.linearizer.queue(event.room_id)): + if state_group and self.data.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() - return self.rules_by_user + return self.data.rules_by_user self.room_push_rule_cache_metrics.inc_misses() ret_rules_by_user = {} missing_member_event_ids = {} - if state_group and self.state_group == context.prev_group: + if state_group and self.data.state_group == context.prev_group: # If we have a simple delta then we can reuse most of the previous # results. - ret_rules_by_user = self.rules_by_user + ret_rules_by_user = self.data.rules_by_user current_state_ids = context.delta_ids push_rules_delta_state_cache_metric.inc_hits() @@ -393,24 +407,24 @@ async def get_rules( if typ != EventTypes.Member: continue - if user_id in self.uninteresting_user_set: + if user_id in self.data.uninteresting_user_set: continue if not self.is_mine_id(user_id): - self.uninteresting_user_set.add(user_id) + self.data.uninteresting_user_set.add(user_id) continue if self.store.get_if_app_services_interested_in_user(user_id): - self.uninteresting_user_set.add(user_id) + self.data.uninteresting_user_set.add(user_id) continue event_id = current_state_ids[key] - res = self.member_map.get(event_id, None) + res = self.data.member_map.get(event_id, None) if res: user_id, state = res if state == Membership.JOIN: - rules = self.rules_by_user.get(user_id, None) + rules = self.data.rules_by_user.get(user_id, None) if rules: ret_rules_by_user[user_id] = rules continue @@ -430,7 +444,7 @@ async def get_rules( else: # The push rules didn't change but lets update the cache anyway self.update_cache( - self.sequence, + self.data.sequence, members={}, # There were no membership changes rules_by_user=ret_rules_by_user, state_group=state_group, @@ -461,7 +475,7 @@ async def _update_rules_with_member_event_ids( for. Used when updating the cache. event: The event we are currently computing push rules for. """ - sequence = self.sequence + sequence = self.data.sequence rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) @@ -507,17 +521,17 @@ def invalidate_all(self) -> None: # GC'd if it gets dropped from the rules_to_user cache. Instead use # `self.invalidate_all_cb` logger.debug("Invalidating RulesForRoom for %r", self.room_id) - self.sequence += 1 - self.state_group = object() - self.member_map = {} - self.rules_by_user = {} + self.data.sequence += 1 + self.data.state_group = object() + self.data.member_map = {} + self.data.rules_by_user = {} push_rules_invalidation_counter.inc() def update_cache(self, sequence, members, rules_by_user, state_group) -> None: - if sequence == self.sequence: - self.member_map.update(members) - self.rules_by_user = rules_by_user - self.state_group = state_group + if sequence == self.data.sequence: + self.data.member_map.update(members) + self.data.rules_by_user = rules_by_user + self.data.state_group = state_group @attr.attrs(slots=True, frozen=True) @@ -535,6 +549,10 @@ class _Invalidation: room_id = attr.ib(type=str) def __call__(self) -> None: - rules = self.cache.get(self.room_id, None, update_metrics=False) - if rules: - rules.invalidate_all() + rules_data = self.cache.get(self.room_id, None, update_metrics=False) + if rules_data: + rules_data.sequence += 1 + rules_data.state_group = object() + rules_data.member_map = {} + rules_data.rules_by_user = {} + push_rules_invalidation_counter.inc() From 4acb2a7f750dec8628b64ba1f2ef0551f59f6120 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 19 Apr 2021 15:39:46 +0100 Subject: [PATCH 4/8] Newsfile --- changelog.d/9845.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/9845.misc diff --git a/changelog.d/9845.misc b/changelog.d/9845.misc new file mode 100644 index 000000000000..875dd6d13156 --- /dev/null +++ b/changelog.d/9845.misc @@ -0,0 +1 @@ +Only store the raw data in the in-memory caches, rather than objects that include references to e.g. the data stores. From 0c159093cf020e9e2757e2ac3a34f8a1f6bf042a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 21 Apr 2021 13:52:27 +0100 Subject: [PATCH 5/8] Apply suggestions from code review Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/push/bulk_push_rule_evaluator.py | 9 +++++---- synapse/storage/databases/main/roommember.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 97a5acd51378..bd618af9c518 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -154,8 +154,8 @@ async def _get_rules_for_event( @lru_cache() def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData": - """Get the current RulesForRoom object for the given room id""" - # It's important that RulesForRoom gets added to self._get_rules_for_room.cache + """Get the current RulesForRoomData object for the given room id""" + # It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) return RulesForRoomData() @@ -339,7 +339,8 @@ def __init__( rules_for_room_cache: The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics: The metrics object - data + linearizer: + cached_data: """ self.room_id = room_id self.is_mine_id = hs.is_mine_id @@ -371,7 +372,7 @@ async def get_rules( self.room_push_rule_cache_metrics.inc_hits() return self.data.rules_by_user - with (await self.linearizer.queue(event.room_id)): + with (await self.linearizer.queue(self.room_id)): if state_group and self.data.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index eac79c108fd9..0a025194f14e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -768,7 +768,7 @@ async def _get_joined_hosts( # `get_joined_hosts` with the "current" state group for the room. cache = await self._get_joined_hosts_cache(room_id) - # If the state group in the cache matches then its a no-op. + # If the state group in the cache matches, we already have the data we need. if state_entry.state_group == cache.state_group: return frozenset(cache.hosts_to_joined_users) @@ -778,7 +778,7 @@ async def _get_joined_hosts( # Same state group, so nothing to do pass elif state_entry.prev_group == cache.state_group: - # The cache work is for the previous state group, so we work out + # The cached work is for the previous state group, so we work out # the delta. for (typ, state_key), event_id in state_entry.delta_ids.items(): if typ != EventTypes.Member: From 9fb443d6c8aa1c3429f937bdca36d822e295b385 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Apr 2021 10:35:44 +0100 Subject: [PATCH 6/8] Review comments for joined hosts cache --- synapse/storage/databases/main/roommember.py | 21 ++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 0a025194f14e..c0eb4a23df5a 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -70,6 +70,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + # Used by `_get_joined_hosts` to ensure only one thing mutates the cache + # at a time. Keyed by room_id. self._joined_host_linearizer = Linearizer("_JoinedHostsCache") # Is the current_state_events.membership up to date? Or is the @@ -759,13 +761,16 @@ async def _get_joined_hosts( # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. assert state_group is not None + assert state_entry.state_group is None or state_entry.state_group == state_group # We use a secondary cache of previous work to allow us to build up the # joined hosts for the given state group based on previous state groups. # # We cache one object per room containing the results of the last state - # group we got joined hosts for, with the idea being that generally - # `get_joined_hosts` with the "current" state group for the room. + # group we got joined hosts for. The idea being that generally + # `get_joined_hosts` is called with the "current" state group for the + # room, and so consecutive calls will be for consecutive state groups + # which point to the previous state group. cache = await self._get_joined_hosts_cache(room_id) # If the state group in the cache matches, we already have the data we need. @@ -775,7 +780,9 @@ async def _get_joined_hosts( # Since we'll mutate the cache we need to lock. with (await self._joined_host_linearizer.queue(room_id)): if state_entry.state_group == cache.state_group: - # Same state group, so nothing to do + # Same state group, so nothing to do. We've already checked for + # this above, but the cache may have changed while waiting on + # the lock. pass elif state_entry.prev_group == cache.state_group: # The cached work is for the previous state group, so we work out @@ -1129,12 +1136,14 @@ def f(txn): @attr.s(slots=True) class _JoinedHostsCache: - """Cache for joined hosts in a room that is optimised to handle updates - via state deltas. - """ + """The cached data used by the `_get_joined_hosts_cache`.""" # Dict of host to the set of their users in the room at the state group. hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict) + + # The state group `hosts_to_joined_users` is derived from. Will be an object + # if the class is newly created or if the state is not based on a state + # group. state_group = attr.ib(type=Union[object, int], factory=object) def __len__(self): From 2e6fd580e8356ca410a9216746f63b85768284e6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Apr 2021 11:12:30 +0100 Subject: [PATCH 7/8] Review comments for bulk push rules cache --- synapse/push/bulk_push_rule_evaluator.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bd618af9c518..617bf2a7d3f5 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -106,6 +106,8 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() + # Used by `RulesForRoom` to ensure only one thing mutates the cache at a + # time. Keyed off room_id. self._rules_linearizer = Linearizer(name="rules_for_room") self.room_push_rule_cache_metrics = register_cache( @@ -347,6 +349,8 @@ def __init__( self.store = hs.get_datastore() self.room_push_rule_cache_metrics = room_push_rule_cache_metrics + # Used to ensure only one thing mutates the cache at a time. Keyed off + # room_id. self.linearizer = linearizer self.data = cached_data @@ -516,18 +520,6 @@ async def _update_rules_with_member_event_ids( self.update_cache(sequence, members, ret_rules_by_user, state_group) - def invalidate_all(self) -> None: - # Note: Don't hand this function directly to an invalidation callback - # as it keeps a reference to self and will stop this instance from being - # GC'd if it gets dropped from the rules_to_user cache. Instead use - # `self.invalidate_all_cb` - logger.debug("Invalidating RulesForRoom for %r", self.room_id) - self.data.sequence += 1 - self.data.state_group = object() - self.data.member_map = {} - self.data.rules_by_user = {} - push_rules_invalidation_counter.inc() - def update_cache(self, sequence, members, rules_by_user, state_group) -> None: if sequence == self.data.sequence: self.data.member_map.update(members) From b2686a50af5f9beb88b7f6678dbe6e2009d7211d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Apr 2021 10:25:11 +0100 Subject: [PATCH 8/8] Words --- synapse/push/bulk_push_rule_evaluator.py | 16 ++++++++++++++-- synapse/storage/databases/main/roommember.py | 15 +++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 617bf2a7d3f5..350646f45888 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -292,6 +292,12 @@ def _condition_checker( @attr.s(slots=True) class RulesForRoomData: + """The data stored in the cache by `RulesForRoom`. + + We don't store `RulesForRoom` directly in the cache as we want our caches to + *only* include data, and not references to e.g. the data stores. + """ + # event_id -> (user_id, state) member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict) # user_id -> rules @@ -323,6 +329,10 @@ class RulesForRoom: This efficiently handles users joining/leaving the room by not invalidating the entire cache for the room. + + A new instance is constructed for each call to + `BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from + previous calls passed in. """ def __init__( @@ -341,8 +351,10 @@ def __init__( rules_for_room_cache: The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics: The metrics object - linearizer: - cached_data: + linearizer: The linearizer used to ensure only one thing mutates + the cache at a time. Keyed off room_id + cached_data: Cached data from previous calls to `self.get_rules`, + can be mutated. """ self.room_id = room_id self.is_mine_id = hs.is_mine_id diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index c0eb4a23df5a..32579d07e960 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -757,9 +757,11 @@ async def _get_joined_hosts( current_state_ids: StateMap[str], state_entry: "_StateCacheEntry", ) -> FrozenSet[str]: - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. + # We don't use `state_group`, its there so that we can cache based on + # it. However, its important that its never None, since two + # current_state's with a state_group of None are likely to be different. + # + # The `state_group` must match the `state_entry.state_group` (if not None). assert state_group is not None assert state_entry.state_group is None or state_entry.state_group == state_group @@ -767,7 +769,7 @@ async def _get_joined_hosts( # joined hosts for the given state group based on previous state groups. # # We cache one object per room containing the results of the last state - # group we got joined hosts for. The idea being that generally + # group we got joined hosts for. The idea is that generally # `get_joined_hosts` is called with the "current" state group for the # room, and so consecutive calls will be for consecutive state groups # which point to the previous state group. @@ -1142,8 +1144,9 @@ class _JoinedHostsCache: hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict) # The state group `hosts_to_joined_users` is derived from. Will be an object - # if the class is newly created or if the state is not based on a state - # group. + # if the instance is newly created or if the state is not based on a state + # group. (An object is used as a sentinel value to ensure that it never is + # equal to anything else). state_group = attr.ib(type=Union[object, int], factory=object) def __len__(self):