Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Refactor get_user_devices_from_cache to avoid mutating cached values. (
Browse files Browse the repository at this point in the history
…#15040)

The previous version of the code could mutate a cached value,
but only if the input requested all devices of a user *and* a specific
device.

To avoid this nonsensical situation we no longer fetch a specific
device ID if all of a user's devices are returned.
  • Loading branch information
clokep authored Feb 10, 2023
1 parent fd296b7 commit a481fb9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
1 change: 1 addition & 0 deletions changelog.d/15040.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid mutating a cached value in `get_user_devices_from_cache`.
11 changes: 7 additions & 4 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,22 @@ async def query_devices(
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
query_list: List[Tuple[str, Optional[str]]] = []
user_ids = set()
user_and_device_ids: List[Tuple[str, str]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
user_and_device_ids.extend(
(user_id, device_id) for device_id in device_ids
)
else:
query_list.append((user_id, None))
user_ids.add(user_id)

(
user_ids_not_in_cache,
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
) = await self.store.get_user_devices_from_cache(
user_ids, user_and_device_ids
)

# Check that the homeserver still shares a room with all cached users.
# Note that this check may be slightly racy when a remote user leaves a
Expand Down
31 changes: 17 additions & 14 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,42 +746,45 @@ def _add_user_signature_change_txn(
@trace
@cancellable
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, Optional[str]]]
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list: List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
user_ids: users which should have all device IDs returned
user_and_device_ids: List of (user_id, device_ids)
Returns:
A tuple of (user_ids_not_in_cache, results_map), where
user_ids_not_in_cache is a set of user_ids and results_map is a
mapping of user_id -> device_id -> device_info.
"""
user_ids = {user_id for user_id, _ in query_list}
user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
user_map = await self.get_device_list_last_stream_id_for_remotes(
list(unique_user_ids)
)

# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
user_ids
unique_user_ids
)
user_ids_in_cache = {
user_id for user_id, stream_id in user_map.items() if stream_id
} - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache

# First fetch all the users which all devices are to be returned.
results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue

if device_id:
for user_id in user_ids:
if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already
# fetched all devices for.
for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)

set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
Expand Down

0 comments on commit a481fb9

Please sign in to comment.