diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 833a7a4ff679..1384aa0d8196 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -280,25 +280,24 @@ def get_cross_signing_keys_from_cache(self, query, from_user_id): defer.Deferred[dict[str, dict[str, dict]]]: map from (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key """ + master_keys = {} + self_signing_keys = {} user_signing_keys = {} - users = list(query) + user_ids = list(query) - master_keys = yield self.store.get_e2e_cross_signing_keys_bulk( - users, "master", from_user_id - ) - self_signing_keys = yield self.store.get_e2e_cross_signing_keys_bulk( - users, "self_signing", from_user_id - ) + keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) - if from_user_id in users: + for user_id, user_info in keys.items(): + if "master" in user_info: + master_keys[user_id] = user_info["master"] + if "self_signing" in user_info: + self_signing_keys[user_id] = user_info["self_signing"] + + if from_user_id in keys and "user_signing" in keys[from_user_id]: # users can see other users' master and self-signing keys, but can # only see their own user-signing keys - key = yield self.store.get_e2e_cross_signing_key( - from_user_id, "user_signing", from_user_id - ) - if key: - user_signing_keys[from_user_id] = key + user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] return { "master_keys": master_keys, diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index ad97b5118189..f4d56ffcf480 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable from six import iteritems @@ -25,7 +24,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList class EndToEndKeyWorkerStore(SQLBaseStore): @@ -337,108 +336,180 @@ def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): from_user_id, ) - def _get_e2e_cross_signing_keys_bulk_txn( - self, - txn: Connection, - users: Iterable[str], - key_type: str, - from_user_id: str = None, - ) -> dict: - """Returns the cross-signing keys for a set of users. + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + pass + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk(self, user_ids: list) -> dict: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. Args: txn (twisted.enterprise.adbapi.Connection): db connection - users (iterable[str]): the users whose keys are being requested - key_type (str): the type of keys that are being requested: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on - the keys will be included in the result + user_ids (list[str]): the users whose keys are being requested Returns: - dict[str, dict]: mapping from user ID to key data. If a user's - cross-signing key was not found, their user ID will not be in the - dict. + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + """ + return self.db.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: list, + ) -> dict: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. - # convert to a list if needed, so that we can slice it - if not isinstance(users, list): - users = list(users) + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + """ result = {} batch_size = 100 - chunks = [users[i : i + batch_size] for i in range(0, len(users), batch_size)] + chunks = [ + user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) + ] for user_chunk in chunks: sql = """ - SELECT k.user_id, k.keydata, k.stream_id + SELECT k.user_id, k.keytype, k.keydata, k.stream_id FROM e2e_cross_signing_keys k - INNER JOIN (SELECT user_id, MAX(stream_id) AS stream_id + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id FROM e2e_cross_signing_keys - WHERE keytype = ? - GROUP BY user_id) s - USING (user_id, stream_id) - WHERE k.user_id IN (%s) AND k.keytype = ? + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE k.user_id IN (%s) """ % ( - ",".join("?" for u in user_chunk) + ",".join("?" for u in user_chunk), ) - query_params = [key_type] + query_params = [] query_params.extend(user_chunk) - query_params.append(key_type) txn.execute(sql, query_params) rows = self.db.cursor_to_dict(txn) - devices = {} for row in rows: user_id = row["user_id"] + key_type = row["keytype"] key = json.loads(row["keydata"]) - result[user_id] = key + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: dict, from_user_id: str, + ) -> dict: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + for key_type, key in user_info.items(): + device_id = None for k in key["keys"].values(): - devices[user_id] = k - - if devices and from_user_id: - # if we're asked to get signatures, and we have any devices to get - # signatures for, fetch the signatures - sql = """ - SELECT target_user_id, key_id, signature - FROM e2e_cross_signing_signatures - WHERE user_id = ? - AND (%s) - """ % ( - " OR ".join( - "(target_user_id = ? AND target_device_id = ?)" for d in devices - ) + device_id = k + devices[(user_id, device_id)] = key_type + + device_list = list(devices) + + # split into batches + batch_size = 100 + chunks = [ + device_list[i : i + batch_size] + for i in range(0, len(device_list), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for d in devices ) - query_params = [from_user_id] - for item in devices.items(): - # item is a (user_id, device_id) tuple - query_params.extend(item) - - txn.execute(sql, query_params) - rows = self.db.cursor_to_dict(txn) - - # and add the signatures to the appropriate keys - for row in rows: - key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_user_key = result[target_user_id] - signatures = target_user_key.setdefault("signatures", {}) - user_sigs = signatures.setdefault(from_user_id, {}) - user_sigs[key_id] = row["signature"] + ) + query_params = [from_user_id] + for item in devices: + # item is a (user_id, device_id) tuple + query_params.extend(item) - return result + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks def get_e2e_cross_signing_keys_bulk( - self, users: Iterable[str], key_type: str, from_user_id: str = None + self, user_ids: list, from_user_id: str = None ) -> defer.Deferred: """Returns the cross-signing keys for a set of users. Args: - users (iterable[str]): the users whose keys are being requested - key_type (str): the type of keys that are being requested: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key + user_ids (list[str]): the users whose keys are being requested from_user_id (str): if specified, signatures made by this user on the self-signing keys will be included in the result @@ -447,13 +518,18 @@ def get_e2e_cross_signing_keys_bulk( cross-signing key was not found, their user ID will not be in the dict. """ - return self.db.runInteraction( - "get_e2e_cross_signing_key", - self._get_e2e_cross_signing_keys_bulk_txn, - users, - key_type, - from_user_id, - ) + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result def get_all_user_signature_changes_for_remotes(self, from_key, to_key): """Return a list of changes from the user signature stream to notify remotes. @@ -643,6 +719,10 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): }, ) + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 84f5ae22c374..2e8f6543e5ae 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -271,7 +271,7 @@ def __init__( else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: