diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 12d18137e07..286efe275f8 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -50,7 +50,7 @@ class Membership: KNOCK: Final = "knock" LEAVE: Final = "leave" BAN: Final = "ban" - LIST: Final = {INVITE, JOIN, KNOCK, LEAVE, BAN} + LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN) class PresenceState: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index bd3c87f5f4e..8849503f52b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -135,7 +135,7 @@ async def _snapshot_all_rooms( memberships.append(Membership.LEAVE) room_list = await self.store.get_rooms_for_local_user_where_membership_is( - user_id=user_id, membership_list=memberships + user_id=user_id, membership_list=tuple(memberships) ) user = UserID.from_string(user_id) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index a7d52fa6483..80661325270 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -278,7 +278,7 @@ async def _search( # TODO: Search through left rooms too rooms = await self.store.get_rooms_for_local_user_where_membership_is( requester.user.to_string(), - membership_list=[Membership.JOIN], + membership_list=(Membership.JOIN,), # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) room_ids = {r.room_id for r in rooms} diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 001a290e87e..e9cdc628d54 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -114,7 +114,7 @@ async def maybe_get_notice_room_for_user(self, user_id: str) -> Optional[str]: return None rooms = await self._store.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.INVITE, Membership.JOIN] + user_id, (Membership.INVITE, Membership.JOIN) ) for room in rooms: # it's worth noting that there is an asymmetry here in that we @@ -262,7 +262,7 @@ async def maybe_invite_user_to_room(self, user_id: str, room_id: str) -> None: # Check whether the user has already joined or been invited to this room. If # that's the case, there is no need to re-invite them. joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.INVITE, Membership.JOIN] + user_id, (Membership.INVITE, Membership.JOIN) ) for room in joined_rooms: if room.room_id == room_id: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 881888fa93f..ac9d634d0d2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -120,6 +120,9 @@ def _invalidate_state_caches( "get_user_in_room_with_profile", (room_id, user_id) ) self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,)) + self._attempt_to_invalidate_cache( + "get_rooms_for_local_user_where_membership_is", (user_id,) + ) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) @@ -146,6 +149,9 @@ def _invalidate_state_caches_all(self, room_id: str) -> None: self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None) self._attempt_to_invalidate_cache("get_user_in_room_with_profile", None) self._attempt_to_invalidate_cache("get_rooms_for_user", None) + self._attempt_to_invalidate_cache( + "get_rooms_for_local_user_where_membership_is", None + ) self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) def _attempt_to_invalidate_cache( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index f62d9f705dc..7a5cd67a64d 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -385,7 +385,7 @@ async def get_invited_rooms_for_local_user( """ return await self.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.INVITE] + user_id, (Membership.INVITE,) ) async def get_knocked_at_rooms_for_local_user( @@ -401,7 +401,7 @@ async def get_knocked_at_rooms_for_local_user( """ return await self.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.KNOCK] + user_id, (Membership.KNOCK,) ) async def get_invite_for_local_user_in_room( @@ -422,12 +422,13 @@ async def get_invite_for_local_user_in_room( return invite return None + @cached(max_entries=1000, uncached_args=["excluded_rooms"], tree=True) async def get_rooms_for_local_user_where_membership_is( self, user_id: str, membership_list: Collection[str], excluded_rooms: StrCollection = (), - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -1320,6 +1321,9 @@ def f(txn: LoggingTransaction) -> None: self._invalidate_cache_and_stream( txn, self.get_forgotten_rooms_for_user, (user_id,) ) + self._invalidate_cache_and_stream( + txn, self.get_rooms_for_local_user_where_membership_is, (user_id,) + ) await self.db_pool.runInteraction("forget_membership", f) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index a85ea994dec..77caab24898 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -535,7 +535,7 @@ def test_pending_invites(self) -> None: # Check that the membership of @invitee:test in the room is now "leave". memberships = self.get_success( store.get_rooms_for_local_user_where_membership_is( - invitee_id, [Membership.LEAVE] + invitee_id, (Membership.LEAVE,) ) ) self.assertEqual(len(memberships), 1, memberships) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 418b5561088..e2f19e25e30 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -68,7 +68,7 @@ def test_one_member(self) -> None: rooms_for_user = self.get_success( self.store.get_rooms_for_local_user_where_membership_is( - self.u_alice, [Membership.JOIN] + self.u_alice, (Membership.JOIN,) ) )