Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add cache for get_membership_from_event_ids (#12272)
Browse files Browse the repository at this point in the history
This should speed up push rule calculations for rooms with large numbers of local users when the main push rule cache fails.

Co-authored-by: reivilibre <[email protected]>
  • Loading branch information
erikjohnston and reivilibre authored Mar 25, 2022
1 parent 38adf14 commit 7ca8ee6
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 22 deletions.
1 change: 1 addition & 0 deletions changelog.d/12272.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a new cache `_get_membership_from_event_id` to speed up push rule calculations in large rooms.
30 changes: 16 additions & 14 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
Expand Down Expand Up @@ -292,7 +293,7 @@ def _condition_checker(
return True


MemberMap = Dict[str, Tuple[str, str]]
MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
Expand All @@ -306,7 +307,7 @@ class RulesForRoomData:
*only* include data, and not references to e.g. the data stores.
"""

# event_id -> (user_id, state)
# event_id -> EventIdMembership
member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
rules_by_user: RulesByUser = attr.Factory(dict)
Expand Down Expand Up @@ -447,11 +448,10 @@ async def get_rules(

res = self.data.member_map.get(event_id, None)
if res:
user_id, state = res
if state == Membership.JOIN:
rules = self.data.rules_by_user.get(user_id, None)
if res.membership == Membership.JOIN:
rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
ret_rules_by_user[user_id] = rules
ret_rules_by_user[res.user_id] = rules
continue

# If a user has left a room we remove their push rule. If they
Expand Down Expand Up @@ -502,24 +502,26 @@ async def _update_rules_with_member_event_ids(
"""
sequence = self.data.sequence

rows = await self.store.get_membership_from_event_ids(member_event_ids.values())

members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
members = await self.store.get_membership_from_event_ids(
member_event_ids.values()
)

# If the event is a join event then it will be in current state evnts
# If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
members[event_id] = EventIdMembership(
user_id=event.state_key, membership=event.membership
)

if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())

joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
entry.user_id
for entry in members.values()
if entry and entry.membership == Membership.JOIN
}

logger.debug("Joined: %r", joined_user_ids)
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def _invalidate_caches_for_event(

self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))

# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
self._get_membership_from_event_id.invalidate((event_id,))

if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)

Expand Down
7 changes: 7 additions & 0 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,13 @@ def non_null_str_or_none(val: Any) -> Optional[str]:
(event.state_key,),
)

# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
txn.call_after(
self.store._get_membership_from_event_id.invalidate,
(event.event_id,),
)

# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#
Expand Down
37 changes: 33 additions & 4 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"


@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""

user_id: str
membership: str


class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
Expand Down Expand Up @@ -772,7 +780,7 @@ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
desc="_get_joined_profiles_from_event_ids",
)

return {
Expand Down Expand Up @@ -1000,12 +1008,26 @@ async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:

return set(room_ids)

@cached(max_entries=5000)
async def _get_membership_from_event_id(
self, member_event_id: str
) -> Optional[EventIdMembership]:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
)
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs."""
) -> Dict[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.
Returns:
Mapping from event ID to `EventIdMembership` if the event is a
membership event, otherwise the value is None.
"""

return await self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
Expand All @@ -1015,6 +1037,13 @@ async def get_membership_from_event_ids(
desc="get_membership_from_event_ids",
)

return {
row["event_id"]: EventIdMembership(
membership=row["membership"], user_id=row["user_id"]
)
for row in rows
}

async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:
Expand Down
15 changes: 11 additions & 4 deletions synapse/storage/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,8 +1023,13 @@ async def _is_server_still_joined(

# Check if any of the changes that we don't have events for are joins.
if events_to_check:
rows = await self.main_store.get_membership_from_event_ids(events_to_check)
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
members = await self.main_store.get_membership_from_event_ids(
events_to_check
)
is_still_joined = any(
member and member.membership == Membership.JOIN
for member in members.values()
)
if is_still_joined:
return True

Expand Down Expand Up @@ -1060,9 +1065,11 @@ async def _is_server_still_joined(
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
row["user_id"] for row in rows if row["membership"] == Membership.JOIN
member.user_id
for member in members.values()
if member and member.membership == Membership.JOIN
)

return False
Expand Down

0 comments on commit 7ca8ee6

Please sign in to comment.