diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 29de84aeaa80..b1c2f1997227 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections import defaultdict from typing import ( TYPE_CHECKING, Collection, @@ -681,38 +682,27 @@ async def get_rooms_for_users( Returns: Map from user_id to set of rooms that is currently in. """ - return await self.db_pool.runInteraction( - "get_rooms_for_users", - self._get_rooms_for_users_txn, - user_ids, - ) - - def _get_rooms_for_users_txn( - self, txn: LoggingTransaction, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[str]]: - clause, args = make_in_list_sql_clause( - self.database_engine, - "c.state_key", - user_ids, + rows = await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="state_key", + iterable=user_ids, + retcols=( + "user_id", + "room_id", + ), + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + desc="get_rooms_for_users", ) - sql = f""" - SELECT c.state_key, room_id - FROM current_state_events AS c - WHERE - c.type = 'm.room.member' - AND c.membership = ? - AND {clause} - """ - - txn.execute(sql, [Membership.JOIN] + args) - - result: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} - for user_id, room_id in txn: - result[user_id].add(room_id) + user_rooms: Dict[str, Set[str]] = defaultdict(set) + for row in rows: + user_rooms[row["user_id"]].add(row["room_id"]) - return {user_id: frozenset(v) for user_id, v in result.items()} + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} @cached(max_entries=10000) async def does_pair_of_users_share_a_room(