From f941830e35db49395af224bd7c696435394a1490 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 21 May 2021 14:24:57 +0100 Subject: [PATCH 01/16] Rewrite KeyRing --- synapse/crypto/keyring.py | 611 +++++++++------------ synapse/rest/key/v2/remote_key_resource.py | 9 +- tests/crypto/test_keyring.py | 75 +-- 3 files changed, 304 insertions(+), 391 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6fc07129788f..b703a9ad4365 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -16,8 +16,7 @@ import abc import logging import urllib -from collections import defaultdict -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple import attr from signedjson.key import ( @@ -44,17 +43,12 @@ from synapse.config.key import TrustedKeyServer from synapse.events import EventBase from synapse.events.utils import prune_event_dict -from synapse.logging.context import ( - PreserveLoggingContext, - make_deferred_yieldable, - preserve_fn, - run_in_background, -) +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.async_helpers import yieldable_gather_results -from synapse.util.metrics import Measure +from synapse.util.batching_queue import BatchingQueue from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -83,14 +77,6 @@ class VerifyJsonRequest: request_name: The name of the request. key_ids: The set of key_ids to that could be used to verify the JSON object - - key_ready (Deferred[str, str, nacl.signing.VerifyKey]): - A deferred (server_name, key_id, verify_key) tuple that resolves when - a verify key has been fetched. The deferreds' callbacks are run with no - logcontext. - - If we are unable to find a key which satisfies the request, the deferred - errbacks with an M_UNAUTHORIZED SynapseError. """ server_name = attr.ib(type=str) @@ -98,7 +84,6 @@ class VerifyJsonRequest: minimum_valid_until_ts = attr.ib(type=int) request_name = attr.ib(type=str) key_ids = attr.ib(type=List[str]) - key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) @staticmethod def from_json_object( @@ -139,12 +124,35 @@ def from_event( key_ids=key_ids, ) + def to_fetch_key_request(self) -> "_FetchKeyRequest": + """Create a key fetch request for all keys needed to satisfy the + verification request. + """ + return _FetchKeyRequest( + server_name=self.server_name, + minimum_valid_until_ts=self.minimum_valid_until_ts, + key_ids=self.key_ids, + ) + class KeyLookupError(ValueError): pass +@attr.s(slots=True) +class _FetchKeyRequest: + """A request for keys for a given server.""" + + server_name = attr.ib(type=str) + minimum_valid_until_ts = attr.ib(type=int) + key_ids = attr.ib(type=List[str]) + + class Keyring: + """Handles verifying signed JSON objects and fetching the keys needed to do + so. + """ + def __init__( self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None ): @@ -158,12 +166,11 @@ def __init__( ) self._key_fetchers = key_fetchers - # map from server name to Deferred. Has an entry for each server with - # an ongoing key download; the Deferred completes once the download - # completes. - # - # These are regular, logcontext-agnostic Deferreds. - self.key_downloads = {} # type: Dict[str, defer.Deferred] + self._server_queue = BatchingQueue( + "keyring_server", + clock=hs.get_clock(), + process_batch_callback=self._inner_fetch_key_requests, + ) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]] def verify_json_for_server( self, @@ -195,8 +202,7 @@ def verify_json_for_server( validity_time, request_name, ) - requests = (request,) - return make_deferred_yieldable(self._verify_objects(requests)[0]) + return defer.ensureDeferred(self.process_request(request)) def verify_json_objects_for_server( self, server_and_json: Iterable[Tuple[str, dict, int, str]] @@ -221,12 +227,20 @@ def verify_json_objects_for_server( server_name. The deferreds run their callbacks in the sentinel logcontext. """ - return self._verify_objects( - VerifyJsonRequest.from_json_object( - server_name, json_object, validity_time, request_name + return [ + defer.ensureDeferred( + run_in_background( + self.process_request, + VerifyJsonRequest.from_json_object( + server_name, + json_object, + validity_time, + request_name, + ), + ) ) for server_name, json_object, validity_time, request_name in server_and_json - ) + ] def verify_events_for_server( self, server_and_events: Iterable[Tuple[str, EventBase, int]] @@ -252,321 +266,225 @@ def verify_events_for_server( server_name. The deferreds run their callbacks in the sentinel logcontext. """ - return self._verify_objects( - VerifyJsonRequest.from_event(server_name, event, validity_time) + return [ + run_in_background( + self.process_request, + VerifyJsonRequest.from_event( + server_name, + event, + validity_time, + ), + ) for server_name, event, validity_time in server_and_events - ) - - def _verify_objects( - self, verify_requests: Iterable[VerifyJsonRequest] - ) -> List[defer.Deferred]: - """Does the work of verify_json_[objects_]for_server - - - Args: - verify_requests: Iterable of verification requests. + ] - Returns: - List: for each input item, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. + async def process_request(self, verify_request: VerifyJsonRequest) -> None: + """Processes the `VerifyJsonRequest`. Raises if the object is not signed + by the server, the signatures don't match or we failed to fetch the + necessary keys. """ - # a list of VerifyJsonRequests which are awaiting a key lookup - key_lookups = [] - handle = preserve_fn(_handle_key_deferred) - - def process(verify_request: VerifyJsonRequest) -> defer.Deferred: - """Process an entry in the request list - - Adds a key request to key_lookups, and returns a deferred which - will complete or fail (in the sentinel context) when verification completes. - """ - if not verify_request.key_ids: - return defer.fail( - SynapseError( - 400, - "Not signed by %s" % (verify_request.server_name,), - Codes.UNAUTHORIZED, - ) - ) - logger.debug( - "Verifying %s for %s with key_ids %s, min_validity %i", - verify_request.request_name, - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, + if not verify_request.key_ids: + raise SynapseError( + 400, + f"Not signed by {verify_request.server_name}", + Codes.UNAUTHORIZED, ) - # add the key request to the queue, but don't start it off yet. - key_lookups.append(verify_request) + # Add they keys we need to verify to the queue for retrieval. We queue + # up requests for the same server so we don't end up with many in flight + # requests for the same keys. + found_keys_by_server = await self._server_queue.add_to_queue( + verify_request.to_fetch_key_request(), key=verify_request.server_name + ) - # now run _handle_key_deferred, which will wait for the key request - # to complete and then do the verification. - # - # We want _handle_key_request to log to the right context, so we - # wrap it with preserve_fn (aka run_in_background) - return handle(verify_request) + # Since we batch up requests the returned set of keys may contain keys + # from other servers, so we pull out only the ones we care about.s + found_keys = found_keys_by_server.get(verify_request.server_name, {}) + + # For each signature to check we ensure we have fetched the necessary + # keys and the signature matches. + for key_id in verify_request.key_ids: + key_result = found_keys.get(key_id) + if not key_result: + raise SynapseError( + 401, + f"Missing keys for {verify_request.server_name}: {key_id}", + Codes.UNAUTHORIZED, + ) - results = [process(r) for r in verify_requests] + if key_result.valid_until_ts < verify_request.minimum_valid_until_ts: + raise SynapseError( + 401, + f"Failed to find key with recent enough `valid_until_ts` for {verify_request.server_name}: {key_id}", + Codes.UNAUTHORIZED, + ) - if key_lookups: - run_in_background(self._start_key_lookups, key_lookups) + verify_key = key_result.verify_key + try: + json_object = verify_request.get_json_object() + verify_signed_json( + json_object, + verify_request.server_name, + verify_key, + ) + except SignatureVerifyException as e: + logger.debug( + "Error verifying signature for %s:%s:%s with key %s: %s", + verify_request.server_name, + verify_key.alg, + verify_key.version, + encode_verify_key_base64(verify_key), + str(e), + ) + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s: %s" + % ( + verify_request.server_name, + verify_key.alg, + verify_key.version, + str(e), + ), + Codes.UNAUTHORIZED, + ) - return results + async def _inner_fetch_key_requests( + self, requests: List[_FetchKeyRequest] + ) -> Dict[str, Dict[str, FetchKeyResult]]: + """Processing function for the queue of `_FetchKeyRequest`.""" + + logger.debug("Starting fetch for %s", requests) + + # First we need to deduplicate requests for the same key. We do this by + # taking the *maximum* requested `minimum_valid_until_ts` for each pair + # of server name/key ID. + server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]] + for request in requests: + by_server = server_to_key_to_ts.setdefault(request.server_name, {}) + for key_id in request.key_ids: + existing_ts = by_server.get(key_id) + + if existing_ts: + by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts) + else: + by_server[key_id] = request.minimum_valid_until_ts + + deduped_requests = [ + _FetchKeyRequest(server_name, minimum_valid_ts, [key_id]) + for server_name, by_server in server_to_key_to_ts.items() + for key_id, minimum_valid_ts in by_server.items() + ] + + logger.debug("Deduplicated key requests to %s", deduped_requests) + + # For each key we call `_inner_verify_request` which will handle + # fetching each key. Note these shouldn't throw if we fail to contact + # other servers etc. + results_per_request = await yieldable_gather_results( + self._inner_fetch_key_request, + deduped_requests, + ) - async def _start_key_lookups( - self, verify_requests: List[VerifyJsonRequest] - ) -> None: - """Sets off the key fetches for each verify request + # We now convert the returned list of results into a map from server + # name to key ID to FetchKeyResult, to return. + to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]] + for (request, results) in zip(deduped_requests, results_per_request): + to_return_by_server = to_return.setdefault(request.server_name, {}) + for key_id, key_result in results.items(): + existing = to_return_by_server.get(key_id) + if not existing or existing.valid_until_ts < key_result.valid_until_ts: + to_return_by_server[key_id] = key_result - Once each fetch completes, verify_request.key_ready will be resolved. + return to_return - Args: - verify_requests: + async def _inner_fetch_key_request( + self, verify_request: _FetchKeyRequest + ) -> Dict[str, FetchKeyResult]: + """Attempt to fetch the given key by calling each key fetcher one by + one. """ + logger.debug("Starting fetch for %s", verify_request) - try: - # map from server name to a set of outstanding request ids - server_to_request_ids = {} # type: Dict[str, Set[int]] + found_keys: Dict[str, FetchKeyResult] = {} + missing_key_ids = set(verify_request.key_ids) - for verify_request in verify_requests: - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids.setdefault(server_name, set()).add(request_id) - - # Wait for any previous lookups to complete before proceeding. - await self.wait_for_previous_lookups(server_to_request_ids.keys()) - - # take out a lock on each of the servers by sticking a Deferred in - # key_downloads - for server_name in server_to_request_ids.keys(): - self.key_downloads[server_name] = defer.Deferred() - logger.debug("Got key lookup lock on %s", server_name) - - # When we've finished fetching all the keys for a given server_name, - # drop the lock by resolving the deferred in key_downloads. - def drop_server_lock(server_name): - d = self.key_downloads.pop(server_name) - d.callback(None) - - def lookup_done(res, verify_request): - server_name = verify_request.server_name - server_requests = server_to_request_ids[server_name] - server_requests.remove(id(verify_request)) - - # if there are no more requests for this server, we can drop the lock. - if not server_requests: - logger.debug("Releasing key lookup lock on %s", server_name) - drop_server_lock(server_name) - - return res - - for verify_request in verify_requests: - verify_request.key_ready.addBoth(lookup_done, verify_request) - - # Actually start fetching keys. - self._get_server_verify_keys(verify_requests) - except Exception: - logger.exception("Error starting key lookups") - - async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None: - """Waits for any previous key lookups for the given servers to finish. - - Args: - server_names: list of servers which we want to look up - - Returns: - Resolves once all key lookups for the given servers have - completed. Follows the synapse rules of logcontext preservation. - """ - loop_count = 1 - while True: - wait_on = [ - (server_name, self.key_downloads[server_name]) - for server_name in server_names - if server_name in self.key_downloads - ] - if not wait_on: + for fetcher in self._key_fetchers: + if not missing_key_ids: break - logger.info( - "Waiting for existing lookups for %s to complete [loop %i]", - [w[0] for w in wait_on], - loop_count, - ) - with PreserveLoggingContext(): - await defer.DeferredList((w[1] for w in wait_on)) - loop_count += 1 - - def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None: - """Tries to find at least one key for each verify request - - For each verify_request, verify_request.key_ready is called back with - params (server_name, key_id, VerifyKey) if a key is found, or errbacked - with a SynapseError if none of the keys are found. - - Args: - verify_requests: list of verify requests - """ - - remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} - - async def do_iterations(): - try: - with Measure(self.clock, "get_server_verify_keys"): - for f in self._key_fetchers: - if not remaining_requests: - return - await self._attempt_key_fetches_with_fetcher( - f, remaining_requests - ) - - # look for any requests which weren't satisfied - while remaining_requests: - verify_request = remaining_requests.pop() - rq_str = ( - "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ) - ) - - # If we run the errback immediately, it may cancel our - # loggingcontext while we are still in it, so instead we - # schedule it for the next time round the reactor. - # - # (this also ensures that we don't get a stack overflow if we - # has a massive queue of lookups waiting for this server). - self.clock.call_later( - 0, - verify_request.key_ready.errback, - SynapseError( - 401, - "Failed to find any key to satisfy %s" % (rq_str,), - Codes.UNAUTHORIZED, - ), - ) - except Exception as err: - # we don't really expect to get here, because any errors should already - # have been caught and logged. But if we do, let's log the error and make - # sure that all of the deferreds are resolved. - logger.error("Unexpected error in _get_server_verify_keys: %s", err) - with PreserveLoggingContext(): - for verify_request in remaining_requests: - if not verify_request.key_ready.called: - verify_request.key_ready.errback(err) - - run_in_background(do_iterations) - - async def _attempt_key_fetches_with_fetcher( - self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest] - ): - """Use a key fetcher to attempt to satisfy some key requests - - Args: - fetcher: fetcher to use to fetch the keys - remaining_requests: outstanding key requests. - Any successfully-completed requests will be removed from the list. - """ - # The keys to fetch. - # server_name -> key_id -> min_valid_ts - missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]] - - for verify_request in remaining_requests: - # any completed requests should already have been removed - assert not verify_request.key_ready.called - keys_for_server = missing_keys[verify_request.server_name] - - for key_id in verify_request.key_ids: - # If we have several requests for the same key, then we only need to - # request that key once, but we should do so with the greatest - # min_valid_until_ts of the requests, so that we can satisfy all of - # the requests. - keys_for_server[key_id] = max( - keys_for_server.get(key_id, -1), - verify_request.minimum_valid_until_ts, - ) + logger.debug("Getting keys from %s for %s", fetcher, verify_request) + keys = await fetcher.get_keys( + verify_request.server_name, + list(missing_key_ids), + verify_request.minimum_valid_until_ts, + ) - results = await fetcher.get_keys(missing_keys) + for key_id, key in keys.items(): + if not key: + continue - completed = [] - for verify_request in remaining_requests: - server_name = verify_request.server_name + # If we already have a result for the given key ID we keep the + # one with the highest `valid_until_ts`. + existing_key = found_keys.get(key_id) + if existing_key: + if key.valid_until_ts <= existing_key.valid_until_ts: + continue - # see if any of the keys we got this time are sufficient to - # complete this VerifyJsonRequest. - result_keys = results.get(server_name, {}) - for key_id in verify_request.key_ids: - fetch_key_result = result_keys.get(key_id) - if not fetch_key_result: - # we didn't get a result for this key - continue + # We always store the returned key even if it doesn't the + # `minimum_valid_until_ts` requirement, as some verification + # requests may still be able to be satisfied by it. + # + # We still keep looking for the key from other fetchers in that + # case though. + found_keys[key_id] = key - if ( - fetch_key_result.valid_until_ts - < verify_request.minimum_valid_until_ts - ): - # key was not valid at this point + if key.valid_until_ts < verify_request.minimum_valid_until_ts: continue - # we have a valid key for this request. If we run the callback - # immediately, it may cancel our loggingcontext while we are still in - # it, so instead we schedule it for the next time round the reactor. - # - # (this also ensures that we don't get a stack overflow if we had - # a massive queue of lookups waiting for this server). - logger.debug( - "Found key %s:%s for %s", - server_name, - key_id, - verify_request.request_name, - ) - self.clock.call_later( - 0, - verify_request.key_ready.callback, - (server_name, key_id, fetch_key_result.verify_key), - ) - completed.append(verify_request) - break + missing_key_ids.discard(key_id) - remaining_requests.difference_update(completed) + return found_keys class KeyFetcher(metaclass=abc.ABCMeta): - @abc.abstractmethod + def __init__(self, hs: "HomeServer"): + self._queue = BatchingQueue( + self.__class__.__name__, hs.get_clock(), self._fetch_keys + ) + async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] - ) -> Dict[str, Dict[str, FetchKeyResult]]: - """ - Args: - keys_to_fetch: - the keys to be fetched. server_name -> key_id -> min_valid_ts + self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + results = await self._queue.add_to_queue( + _FetchKeyRequest( + server_name=server_name, + key_ids=key_ids, + minimum_valid_until_ts=minimum_valid_until_ts, + ) + ) + return results.get(server_name, {}) - Returns: - Map from server_name -> key_id -> FetchKeyResult - """ - raise NotImplementedError + @abc.abstractmethod + async def _fetch_keys( + self, keys_to_fetch: List[_FetchKeyRequest] + ) -> Dict[str, Dict[str, FetchKeyResult]]: + pass class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + super().__init__(hs) - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] - ) -> Dict[str, Dict[str, FetchKeyResult]]: - """see KeyFetcher.get_keys""" + self.store = hs.get_datastore() + async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]): key_ids_to_fetch = ( - (server_name, key_id) - for server_name, keys_for_server in keys_to_fetch.items() - for key_id in keys_for_server.keys() + (queue_value.server_name, key_id) + for queue_value in keys_to_fetch + for key_id in queue_value.key_ids ) res = await self.store.get_server_verify_keys(key_ids_to_fetch) @@ -578,6 +496,8 @@ async def get_keys( class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.store = hs.get_datastore() self.config = hs.config @@ -685,10 +605,10 @@ def __init__(self, hs: "HomeServer"): self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] + async def _fetch_keys( + self, keys_to_fetch: List[_FetchKeyRequest] ) -> Dict[str, Dict[str, FetchKeyResult]]: - """see KeyFetcher.get_keys""" + """see KeyFetcher._fetch_keys""" async def get_key(key_server: TrustedKeyServer) -> Dict: try: @@ -724,12 +644,12 @@ async def get_key(key_server: TrustedKeyServer) -> Dict: return union_of_keys async def get_server_verify_key_v2_indirect( - self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer + self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: keys_to_fetch: - the keys to be fetched. server_name -> key_id -> min_valid_ts + the keys to be fetched. key_server: notary server to query for the keys @@ -743,7 +663,7 @@ async def get_server_verify_key_v2_indirect( perspective_name = key_server.server_name logger.info( "Requesting keys %s from notary server %s", - keys_to_fetch.items(), + keys_to_fetch, perspective_name, ) @@ -753,11 +673,13 @@ async def get_server_verify_key_v2_indirect( path="/_matrix/key/v2/query", data={ "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() + queue_value.server_name: { + key_id: { + "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, + } + for key_id in queue_value.key_ids } - for server_name, server_keys in keys_to_fetch.items() + for queue_value in keys_to_fetch } }, ) @@ -858,7 +780,20 @@ def __init__(self, hs: "HomeServer"): self.client = hs.get_federation_http_client() async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] + self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + results = await self._queue.add_to_queue( + _FetchKeyRequest( + server_name=server_name, + key_ids=key_ids, + minimum_valid_until_ts=minimum_valid_until_ts, + ), + key=server_name, + ) + return results.get(server_name, {}) + + async def _fetch_keys( + self, keys_to_fetch: List[_FetchKeyRequest] ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: @@ -871,8 +806,10 @@ async def get_keys( results = {} - async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: - server_name, key_ids = key_to_fetch_item + async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None: + server_name = key_to_fetch_item.server_name + key_ids = key_to_fetch_item.key_ids + try: keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) results[server_name] = keys @@ -883,7 +820,7 @@ async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - await yieldable_gather_results(get_key, keys_to_fetch.items()) + await yieldable_gather_results(get_key, keys_to_fetch) return results async def get_server_verify_key_v2_direct( @@ -955,37 +892,3 @@ async def get_server_verify_key_v2_direct( keys.update(response_keys) return keys - - -async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: - """Waits for the key to become available, and then performs a verification - - Args: - verify_request: - - Raises: - SynapseError if there was a problem performing the verification - """ - server_name = verify_request.server_name - with PreserveLoggingContext(): - _, key_id, verify_key = await verify_request.key_ready - - json_object = verify_request.get_json_object() - - try: - verify_signed_json(json_object, server_name, verify_key) - except SignatureVerifyException as e: - logger.debug( - "Error verifying signature for %s:%s:%s with key %s: %s", - server_name, - verify_key.alg, - verify_key.version, - encode_verify_key_base64(verify_key), - str(e), - ) - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s: %s" - % (server_name, verify_key.alg, verify_key.version, str(e)), - Codes.UNAUTHORIZED, - ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index f648678b0903..24af3c28eab3 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -22,6 +22,7 @@ from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.util import json_decoder +from synapse.util.async_helpers import yieldable_gather_results logger = logging.getLogger(__name__) @@ -213,7 +214,13 @@ async def query_keys(self, request, query, query_remote_on_cache_miss=False): # If there is a cache miss, request the missing keys, then recurse (and # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: - await self.fetcher.get_keys(cache_misses) + await yieldable_gather_results( + self.fetcher.get_keys, + ( + (server_name, list(keys), 0) + for server_name, keys in cache_misses.items() + ), + ) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2775dfd8807d..f23f32480c4a 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from typing import Dict, List from unittest.mock import Mock import attr @@ -92,16 +93,16 @@ def test_verify_json_objects_for_server_awaits_previous_requests(self): # deferred completes. first_lookup_deferred = Deferred() - async def first_lookup_fetch(keys_to_fetch): - self.assertEquals(current_context().request.id, "context_11") - self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) + async def first_lookup_fetch( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + # self.assertEquals(current_context().request.id, "context_11") + self.assertEqual(server_name, "server10") + self.assertEqual(key_ids, [get_key_id(key1)]) + self.assertEqual(minimum_valid_until_ts, 0) await make_deferred_yieldable(first_lookup_deferred) - return { - "server10": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) - } - } + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)} mock_fetcher.get_keys.side_effect = first_lookup_fetch @@ -126,18 +127,18 @@ async def first_lookup(): d0 = ensureDeferred(first_lookup()) + self.pump() + mock_fetcher.get_keys.assert_called_once() # a second request for a server with outstanding requests # should block rather than start a second call - async def second_lookup_fetch(keys_to_fetch): - self.assertEquals(current_context().request.id, "context_12") - return { - "server10": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) - } - } + async def second_lookup_fetch( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + # self.assertEquals(current_context().request.id, "context_12") + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)} mock_fetcher.get_keys.reset_mock() mock_fetcher.get_keys.side_effect = second_lookup_fetch @@ -226,7 +227,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self): # We expect the keyring tried to refetch the key once. mock_fetcher.get_keys.assert_called_once_with( - {"server9": {get_key_id(key1): 500}} + "server9", [get_key_id(key1)], 500 ) # should succeed on a signed object with a 0 minimum_valid_until_ms @@ -239,15 +240,15 @@ def test_verify_json_dedupes_key_requests(self): """Two requests for the same key should be deduped.""" key1 = signedjson.key.generate_signing_key(1) - async def get_keys(keys_to_fetch): + async def get_keys( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: # there should only be one request object (with the max validity) - self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + self.assertEqual(server_name, "server1") + self.assertEqual(key_ids, [get_key_id(key1)]) + self.assertEqual(minimum_valid_until_ts, 1500) - return { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } - } + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)} mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) @@ -274,19 +275,21 @@ def test_verify_json_falls_back_to_other_fetchers(self): """If the first fetcher cannot provide a recent enough key, we fall back""" key1 = signedjson.key.generate_signing_key(1) - async def get_keys1(keys_to_fetch): - self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return { - "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)} - } - - async def get_keys2(keys_to_fetch): - self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } - } + async def get_keys1( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + self.assertEqual(server_name, "server1") + self.assertEqual(key_ids, [get_key_id(key1)]) + self.assertEqual(minimum_valid_until_ts, 1500) + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)} + + async def get_keys2( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + self.assertEqual(server_name, "server1") + self.assertEqual(key_ids, [get_key_id(key1)]) + self.assertEqual(minimum_valid_until_ts, 1500) + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)} mock_fetcher1 = Mock() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) From 83d5280c80875aedb1255f38d9fdba4e9879765b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 4 May 2021 17:47:19 +0100 Subject: [PATCH 02/16] Fix remote resource --- synapse/rest/key/v2/remote_key_resource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 24af3c28eab3..e19e9ef5c749 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -215,7 +215,7 @@ async def query_keys(self, request, query, query_remote_on_cache_miss=False): # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: await yieldable_gather_results( - self.fetcher.get_keys, + lambda t: self.fetcher.get_keys(*t), ( (server_name, list(keys), 0) for server_name, keys in cache_misses.items() From 66a532fa0fc3262e26a6ec9334382c2fdc5bc35a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 May 2021 11:19:21 +0100 Subject: [PATCH 03/16] Remove request name --- synapse/crypto/keyring.py | 21 ++---------- synapse/federation/transport/server.py | 4 ++- synapse/groups/attestations.py | 4 ++- tests/crypto/test_keyring.py | 47 ++++++++++++++++++++------ 4 files changed, 45 insertions(+), 31 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index b703a9ad4365..72c401df840d 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -74,15 +74,12 @@ class VerifyJsonRequest: minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) - request_name: The name of the request. - key_ids: The set of key_ids to that could be used to verify the JSON object """ server_name = attr.ib(type=str) get_json_object = attr.ib(type=Callable[[], JsonDict]) minimum_valid_until_ts = attr.ib(type=int) - request_name = attr.ib(type=str) key_ids = attr.ib(type=List[str]) @staticmethod @@ -90,7 +87,6 @@ def from_json_object( server_name: str, json_object: JsonDict, minimum_valid_until_ms: int, - request_name: str, ): """Create a VerifyJsonRequest to verify all signatures on a signed JSON object for the given server. @@ -100,7 +96,6 @@ def from_json_object( server_name, lambda: json_object, minimum_valid_until_ms, - request_name=request_name, key_ids=key_ids, ) @@ -120,7 +115,6 @@ def from_event( # memory than the Event object itself. lambda: prune_event_dict(event.room_version, event.get_pdu_json()), minimum_valid_until_ms, - request_name=event.event_id, key_ids=key_ids, ) @@ -177,7 +171,6 @@ def verify_json_for_server( server_name: str, json_object: JsonDict, validity_time: int, - request_name: str, ) -> defer.Deferred: """Verify that a JSON object has been signed by a given server @@ -189,9 +182,6 @@ def verify_json_for_server( validity_time: timestamp at which we require the signing key to be valid. (0 implies we don't care) - request_name: an identifier for this json object (eg, an event id) - for logging. - Returns: Deferred[None]: completes if the the object was correctly signed, otherwise errbacks with an error @@ -200,27 +190,23 @@ def verify_json_for_server( server_name, json_object, validity_time, - request_name, ) return defer.ensureDeferred(self.process_request(request)) def verify_json_objects_for_server( - self, server_and_json: Iterable[Tuple[str, dict, int, str]] + self, server_and_json: Iterable[Tuple[str, dict, int]] ) -> List[defer.Deferred]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. Args: server_and_json: - Iterable of (server_name, json_object, validity_time, request_name) + Iterable of (server_name, json_object, validity_time) tuples. validity_time is a timestamp at which the signing key must be valid. - request_name is an identifier for this json object (eg, an event id) - for logging. - Returns: List: for each input triplet, a deferred indicating success or failure to verify each json object's signature for the given @@ -235,11 +221,10 @@ def verify_json_objects_for_server( server_name, json_object, validity_time, - request_name, ), ) ) - for server_name, json_object, validity_time, request_name in server_and_json + for server_name, json_object, validity_time in server_and_json ] def verify_events_for_server( diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index c17a085a4f93..ae4718daf228 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -151,7 +151,9 @@ async def authenticate_request(self, request, content): ) await self.keyring.verify_json_for_server( - origin, json_request, now, "Incoming request" + origin, + json_request, + now, ) logger.debug("Request from %s", origin) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index d2fc8be5f571..ff8372c4e9ad 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -108,7 +108,9 @@ async def verify_attestation( assert server_name is not None await self.keyring.verify_json_for_server( - server_name, attestation, now, "Group attestation" + server_name, + attestation, + now, ) def create_attestation(self, group_id: str, user_id: str) -> JsonDict: diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f23f32480c4a..b409ce68aa86 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -109,7 +109,7 @@ async def first_lookup_fetch( async def first_lookup(): with LoggingContext("context_11", request=FakeRequest("context_11")): res_deferreds = kr.verify_json_objects_for_server( - [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] + [("server10", json1, 0), ("server11", {}, 0)] ) # the unsigned json should be rejected pretty quickly @@ -147,7 +147,13 @@ async def second_lookup_fetch( async def second_lookup(): with LoggingContext("context_12", request=FakeRequest("context_12")): res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1, 0, "test")] + [ + ( + "server10", + json1, + 0, + ) + ] ) res_deferreds_2[0].addBoth(self.check_context, None) second_lookup_state[0] = 1 @@ -184,11 +190,11 @@ def test_verify_json_for_server(self): signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") + d = _verify_json_for_server(kr, "server9", {}, 0) self.get_failure(d, SynapseError) # should succeed on a signed object - d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") + d = _verify_json_for_server(kr, "server9", json1, 500) # self.assertFalse(d.called) self.get_success(d) @@ -215,14 +221,12 @@ def test_verify_json_for_server_with_null_valid_until_ms(self): signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") + d = _verify_json_for_server(kr, "server9", {}, 0) self.get_failure(d, SynapseError) # should fail on a signed object with a non-zero minimum_valid_until_ms, # as it tries to refetch the keys and fails. - d = _verify_json_for_server( - kr, "server9", json1, 500, "test signed non-zero min" - ) + d = _verify_json_for_server(kr, "server9", json1, 500) self.get_failure(d, SynapseError) # We expect the keyring tried to refetch the key once. @@ -232,7 +236,10 @@ def test_verify_json_for_server_with_null_valid_until_ms(self): # should succeed on a signed object with a 0 minimum_valid_until_ms d = _verify_json_for_server( - kr, "server9", json1, 0, "test signed with zero min" + kr, + "server9", + json1, + 0, ) self.get_success(d) @@ -260,7 +267,14 @@ async def get_keys( # the first request should succeed; the second should fail because the key # has expired results = kr.verify_json_objects_for_server( - [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")] + [ + ( + "server1", + json1, + 500, + ), + ("server1", json1, 1500), + ] ) self.assertEqual(len(results), 2) self.get_success(results[0]) @@ -301,7 +315,18 @@ async def get_keys2( signedjson.sign.sign_json(json1, "server1", key1) results = kr.verify_json_objects_for_server( - [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")] + [ + ( + "server1", + json1, + 1200, + ), + ( + "server1", + json1, + 1500, + ), + ] ) self.assertEqual(len(results), 2) self.get_success(results[0]) From e1936118306f2dceffa8221b4e7e5bdb9a5a4989 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 21 May 2021 14:32:14 +0100 Subject: [PATCH 04/16] Newsfile --- changelog.d/10035.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/10035.feature diff --git a/changelog.d/10035.feature b/changelog.d/10035.feature new file mode 100644 index 000000000000..68052b5a7eaa --- /dev/null +++ b/changelog.d/10035.feature @@ -0,0 +1 @@ +Rewrite logic around verifying JSON object and fetching server keys to be more performant and use less memory. From dc19b1aeb0d11648bc3176e1579ed0e5ab243be9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 24 May 2021 10:31:11 +0100 Subject: [PATCH 05/16] Fix tests --- tests/crypto/test_keyring.py | 26 ++++++++----------- tests/rest/key/v2/test_remote_key_resource.py | 18 ++++++------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index b409ce68aa86..a159dcd2f219 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -377,9 +377,8 @@ async def get_json(destination, path, **kwargs): self.http_client.get_json.side_effect = get_json - keys_to_fetch = {SERVER_NAME: {"key1": 0}} - keys = self.get_success(fetcher.get_keys(keys_to_fetch)) - k = keys[SERVER_NAME][testverifykey_id] + keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) + k = keys[testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key.alg, "ed25519") @@ -406,7 +405,7 @@ async def get_json(destination, path, **kwargs): # change the server name: the result should be ignored response["server_name"] = "OTHER_SERVER" - keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) @@ -493,10 +492,9 @@ def test_get_keys_from_perspectives(self): self.expect_outgoing_key_query(SERVER_NAME, "key1", response) - keys_to_fetch = {SERVER_NAME: {"key1": 0}} - keys = self.get_success(fetcher.get_keys(keys_to_fetch)) - self.assertIn(SERVER_NAME, keys) - k = keys[SERVER_NAME][testverifykey_id] + keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) + self.assertIn(testverifykey_id, keys) + k = keys[testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key.alg, "ed25519") @@ -543,10 +541,9 @@ def test_get_perspectives_own_key(self): self.expect_outgoing_key_query(SERVER_NAME, "key1", response) - keys_to_fetch = {SERVER_NAME: {"key1": 0}} - keys = self.get_success(fetcher.get_keys(keys_to_fetch)) - self.assertIn(SERVER_NAME, keys) - k = keys[SERVER_NAME][testverifykey_id] + keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) + self.assertIn(testverifykey_id, keys) + k = keys[testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key.alg, "ed25519") @@ -587,14 +584,13 @@ def build_response(): def get_key_from_perspectives(response): fetcher = PerspectivesKeyFetcher(self.hs) - keys_to_fetch = {SERVER_NAME: {"key1": 0}} self.expect_outgoing_key_query(SERVER_NAME, "key1", response) - return self.get_success(fetcher.get_keys(keys_to_fetch)) + return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) # start with a valid response so we can check we are testing the right thing response = build_response() keys = get_key_from_perspectives(response) - k = keys[SERVER_NAME][testverifykey_id] + k = keys[testverifykey_id] self.assertEqual(k.verify_key, testverifykey) # remove the perspectives server's signature diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 3b275bc23b47..a75c0ea3f04a 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -208,10 +208,10 @@ def test_get_key(self): keyid = "ed25519:%s" % (testkey.version,) fetcher = PerspectivesKeyFetcher(self.hs2) - d = fetcher.get_keys({"targetserver": {keyid: 1000}}) + d = fetcher.get_keys("targetserver", [keyid], 1000) res = self.get_success(d) - self.assertIn("targetserver", res) - keyres = res["targetserver"][keyid] + self.assertIn(keyid, res) + keyres = res[keyid] assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), @@ -230,10 +230,10 @@ def test_get_notary_key(self): keyid = "ed25519:%s" % (testkey.version,) fetcher = PerspectivesKeyFetcher(self.hs2) - d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}}) + d = fetcher.get_keys(self.hs.hostname, [keyid], 1000) res = self.get_success(d) - self.assertIn(self.hs.hostname, res) - keyres = res[self.hs.hostname][keyid] + self.assertIn(keyid, res) + keyres = res[keyid] assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), @@ -247,10 +247,10 @@ def test_get_notary_keyserver_key(self): keyid = "ed25519:%s" % (self.hs_signing_key.version,) fetcher = PerspectivesKeyFetcher(self.hs2) - d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}}) + d = fetcher.get_keys(self.hs.hostname, [keyid], 1000) res = self.get_success(d) - self.assertIn(self.hs.hostname, res) - keyres = res[self.hs.hostname][keyid] + self.assertIn(keyid, res) + keyres = res[keyid] assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), From b45541591c621aaa0999c8372084ca2fa4268cf2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 May 2021 15:12:02 +0100 Subject: [PATCH 06/16] Make verify_json_for_server async --- synapse/crypto/keyring.py | 6 +++--- tests/crypto/test_keyring.py | 19 +++---------------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 72c401df840d..e3c4c2d78e49 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -166,12 +166,12 @@ def __init__( process_batch_callback=self._inner_fetch_key_requests, ) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]] - def verify_json_for_server( + async def verify_json_for_server( self, server_name: str, json_object: JsonDict, validity_time: int, - ) -> defer.Deferred: + ): """Verify that a JSON object has been signed by a given server Args: @@ -191,7 +191,7 @@ def verify_json_for_server( json_object, validity_time, ) - return defer.ensureDeferred(self.process_request(request)) + return await self.process_request(request) def verify_json_objects_for_server( self, server_and_json: Iterable[Tuple[str, dict, int]] diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index a159dcd2f219..142695d5189e 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -22,7 +22,6 @@ from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key -from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError @@ -611,21 +610,9 @@ def get_key_id(key): return "%s:%s" % (key.alg, key.version) -@defer.inlineCallbacks -def run_in_context(f, *args, **kwargs): - with LoggingContext("testctx"): - rv = yield f(*args, **kwargs) - return rv - - -def _verify_json_for_server(kr, *args): +async def _verify_json_for_server(kr, *args): """thin wrapper around verify_json_for_server which makes sure it is wrapped with the patched defer.inlineCallbacks. """ - - @defer.inlineCallbacks - def v(): - rv1 = yield kr.verify_json_for_server(*args) - return rv1 - - return run_in_context(v) + with LoggingContext("testctx"): + return await kr.verify_json_for_server(*args) From eec11d66ecd9e40a47d9ac2969b6b251d384fa38 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 May 2021 15:27:40 +0100 Subject: [PATCH 07/16] Remove superfluous ensureDeferred --- synapse/crypto/keyring.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index e3c4c2d78e49..9f26a4af11bc 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -214,15 +214,13 @@ def verify_json_objects_for_server( logcontext. """ return [ - defer.ensureDeferred( - run_in_background( - self.process_request, - VerifyJsonRequest.from_json_object( - server_name, - json_object, - validity_time, - ), - ) + run_in_background( + self.process_request, + VerifyJsonRequest.from_json_object( + server_name, + json_object, + validity_time, + ), ) for server_name, json_object, validity_time in server_and_json ] From 5db369d66187f278da9eda938e5884bd9687e3c6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 May 2021 15:27:52 +0100 Subject: [PATCH 08/16] Move get_json_object() out from try/except --- synapse/crypto/keyring.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 9f26a4af11bc..12ea32b5de59 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -304,8 +304,8 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None: ) verify_key = key_result.verify_key + json_object = verify_request.get_json_object() try: - json_object = verify_request.get_json_object() verify_signed_json( json_object, verify_request.server_name, From 116501022e7c72f2fbc5071e2338f32f6e2ae350 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 May 2021 15:30:23 +0100 Subject: [PATCH 09/16] Apply suggestions from code review Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/crypto/keyring.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 12ea32b5de59..9a442f3cbc49 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -274,7 +274,7 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None: Codes.UNAUTHORIZED, ) - # Add they keys we need to verify to the queue for retrieval. We queue + # Add the keys we need to verify to the queue for retrieval. We queue # up requests for the same server so we don't end up with many in flight # requests for the same keys. found_keys_by_server = await self._server_queue.add_to_queue( @@ -292,7 +292,7 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None: if not key_result: raise SynapseError( 401, - f"Missing keys for {verify_request.server_name}: {key_id}", + f"Failed to retrieve key {key_id} for {verify_request.server_name}", Codes.UNAUTHORIZED, ) @@ -346,12 +346,8 @@ async def _inner_fetch_key_requests( for request in requests: by_server = server_to_key_to_ts.setdefault(request.server_name, {}) for key_id in request.key_ids: - existing_ts = by_server.get(key_id) - - if existing_ts: - by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts) - else: - by_server[key_id] = request.minimum_valid_until_ts + existing_ts = by_server.get(key_id, 0) + by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts) deduped_requests = [ _FetchKeyRequest(server_name, minimum_valid_ts, [key_id]) From fbdabd4f2f484598df4db17522f58b33e30e0cd9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 May 2021 16:02:44 +0100 Subject: [PATCH 10/16] fmt --- synapse/crypto/keyring.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 9a442f3cbc49..7dab5bd243c7 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -346,8 +346,8 @@ async def _inner_fetch_key_requests( for request in requests: by_server = server_to_key_to_ts.setdefault(request.server_name, {}) for key_id in request.key_ids: - existing_ts = by_server.get(key_id, 0) - by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts) + existing_ts = by_server.get(key_id, 0) + by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts) deduped_requests = [ _FetchKeyRequest(server_name, minimum_valid_ts, [key_id]) From 3dc8841f657bd193cf74f92cd10fe6ab9188667a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 May 2021 14:45:24 +0100 Subject: [PATCH 11/16] Add docstring --- synapse/crypto/keyring.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 7dab5bd243c7..b19eb5775a17 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -135,7 +135,18 @@ class KeyLookupError(ValueError): @attr.s(slots=True) class _FetchKeyRequest: - """A request for keys for a given server.""" + """A request for keys for a given server. + + We will continue to try and fetch until we have all the keys listed under + `key_ids` (with an appropriate `valid_until_ts` property) or we run out of + places to fetch keys from. + + Attributes: + server_name: The name of the server that owns the keys. + minimum_valid_until_ts: The earliest timestamp at which we need the + keys to be valid at. + key_ids: The IDs of the keys to attempt to fetch + """ server_name = attr.ib(type=str) minimum_valid_until_ts = attr.ib(type=int) From dbde3554851ad1f206d37b53d59468d6ae342319 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 May 2021 14:46:06 +0100 Subject: [PATCH 12/16] Only require one signature to be verified per server --- synapse/crypto/keyring.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index b19eb5775a17..56f1e18837d7 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -288,31 +288,25 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None: # Add the keys we need to verify to the queue for retrieval. We queue # up requests for the same server so we don't end up with many in flight # requests for the same keys. + key_request = verify_request.to_fetch_key_request() found_keys_by_server = await self._server_queue.add_to_queue( - verify_request.to_fetch_key_request(), key=verify_request.server_name + key_request, key=verify_request.server_name ) # Since we batch up requests the returned set of keys may contain keys # from other servers, so we pull out only the ones we care about.s found_keys = found_keys_by_server.get(verify_request.server_name, {}) - # For each signature to check we ensure we have fetched the necessary - # keys and the signature matches. + # Verify each signature we got valid keys for, raising if we can't + # verify any of them. + verified = False for key_id in verify_request.key_ids: key_result = found_keys.get(key_id) if not key_result: - raise SynapseError( - 401, - f"Failed to retrieve key {key_id} for {verify_request.server_name}", - Codes.UNAUTHORIZED, - ) + continue if key_result.valid_until_ts < verify_request.minimum_valid_until_ts: - raise SynapseError( - 401, - f"Failed to find key with recent enough `valid_until_ts` for {verify_request.server_name}: {key_id}", - Codes.UNAUTHORIZED, - ) + continue verify_key = key_result.verify_key json_object = verify_request.get_json_object() @@ -322,6 +316,7 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None: verify_request.server_name, verify_key, ) + verified = True except SignatureVerifyException as e: logger.debug( "Error verifying signature for %s:%s:%s with key %s: %s", @@ -343,6 +338,13 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None: Codes.UNAUTHORIZED, ) + if not verified: + raise SynapseError( + 401, + f"Failed to find any key to satisfy: {key_request}", + Codes.UNAUTHORIZED, + ) + async def _inner_fetch_key_requests( self, requests: List[_FetchKeyRequest] ) -> Dict[str, Dict[str, FetchKeyResult]]: From a596fc0a36aa0897dc741af558a84ca5e8a41523 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Jun 2021 09:13:56 +0100 Subject: [PATCH 13/16] Fix tests now that we use BatchingQueue --- tests/util/test_batching_queue.py | 39 ++++++++++++++----------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index edf29e5b96ae..44835d715cdc 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from prometheus_client.samples import Sample + from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable @@ -45,37 +47,32 @@ async def _process_queue(self, values): self._pending_calls.append((values, d)) return await make_deferred_yieldable(d) + def _get_sample_with_name(self, metric, name: str) -> Sample: + """For a prometheus metric get the sample that has a matching "name" + label. + """ + for sample in metric.collect()[0].samples: + if sample.labels.get("name") == name: + return sample + + self.fail("Found no matching sample") + def _assert_metrics(self, queued, keys, in_flight): """Assert that the metrics are correct""" - self.assertEqual(len(number_queued.collect()), 1) - self.assertEqual(len(number_queued.collect()[0].samples), 1) - self.assertEqual( - number_queued.collect()[0].samples[0].labels, - {"name": self.queue._name}, - ) + sample = self._get_sample_with_name(number_queued, self.queue._name) self.assertEqual( - number_queued.collect()[0].samples[0].value, + sample.value, queued, "number_queued", ) - self.assertEqual(len(number_of_keys.collect()), 1) - self.assertEqual(len(number_of_keys.collect()[0].samples), 1) - self.assertEqual( - number_queued.collect()[0].samples[0].labels, {"name": self.queue._name} - ) - self.assertEqual( - number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys" - ) + sample = self._get_sample_with_name(number_of_keys, self.queue._name) + self.assertEqual(sample.value, keys, "number_of_keys") - self.assertEqual(len(number_in_flight.collect()), 1) - self.assertEqual(len(number_in_flight.collect()[0].samples), 1) - self.assertEqual( - number_queued.collect()[0].samples[0].labels, {"name": self.queue._name} - ) + sample = self._get_sample_with_name(number_in_flight, self.queue._name) self.assertEqual( - number_in_flight.collect()[0].samples[0].value, + sample.value, in_flight, "number_in_flight", ) From 4091155996f4a134a03a9a06e36659d6cb20976e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Jun 2021 16:51:15 +0100 Subject: [PATCH 14/16] Fix tests when using old prometheus lib --- tests/util/test_batching_queue.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index 44835d715cdc..07be57d72c6d 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from prometheus_client.samples import Sample - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable @@ -47,13 +45,13 @@ async def _process_queue(self, values): self._pending_calls.append((values, d)) return await make_deferred_yieldable(d) - def _get_sample_with_name(self, metric, name: str) -> Sample: - """For a prometheus metric get the sample that has a matching "name" - label. + def _get_sample_with_name(self, metric, name) -> int: + """For a prometheus metric get the value of the sample that has a + matching "name" label. """ for sample in metric.collect()[0].samples: if sample.labels.get("name") == name: - return sample + return sample.value self.fail("Found no matching sample") @@ -62,17 +60,17 @@ def _assert_metrics(self, queued, keys, in_flight): sample = self._get_sample_with_name(number_queued, self.queue._name) self.assertEqual( - sample.value, + sample, queued, "number_queued", ) sample = self._get_sample_with_name(number_of_keys, self.queue._name) - self.assertEqual(sample.value, keys, "number_of_keys") + self.assertEqual(sample, keys, "number_of_keys") sample = self._get_sample_with_name(number_in_flight, self.queue._name) self.assertEqual( - sample.value, + sample, in_flight, "number_in_flight", ) From 9ab4cbf9cce9e91b7835c189520af7a07ec5c31d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Jun 2021 16:07:34 +0100 Subject: [PATCH 15/16] Apply suggestions from code review Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/crypto/keyring.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 56f1e18837d7..08befd508b46 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -143,8 +143,7 @@ class _FetchKeyRequest: Attributes: server_name: The name of the server that owns the keys. - minimum_valid_until_ts: The earliest timestamp at which we need the - keys to be valid at. + minimum_valid_until_ts: The timestamp which the keys must be valid until. key_ids: The IDs of the keys to attempt to fetch """ @@ -182,7 +181,7 @@ async def verify_json_for_server( server_name: str, json_object: JsonDict, validity_time: int, - ): + ) -> None: """Verify that a JSON object has been signed by a given server Args: From 6fcd8fd9757cf01f6d0b0e70922d66dcdcba14b9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Jun 2021 16:10:47 +0100 Subject: [PATCH 16/16] Review suggestions --- synapse/crypto/keyring.py | 6 ++---- tests/crypto/test_keyring.py | 19 +++++-------------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 08befd508b46..c840ffca7141 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -184,6 +184,8 @@ async def verify_json_for_server( ) -> None: """Verify that a JSON object has been signed by a given server + Completes if the the object was correctly signed, otherwise raises. + Args: server_name: name of the server which must have signed this object @@ -191,10 +193,6 @@ async def verify_json_for_server( validity_time: timestamp at which we require the signing key to be valid. (0 implies we don't care) - - Returns: - Deferred[None]: completes if the the object was correctly signed, otherwise - errbacks with an error """ request = VerifyJsonRequest.from_json_object( server_name, diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 142695d5189e..745c295d3ba1 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -189,11 +189,11 @@ def test_verify_json_for_server(self): signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}, 0) + d = kr.verify_json_for_server("server9", {}, 0) self.get_failure(d, SynapseError) # should succeed on a signed object - d = _verify_json_for_server(kr, "server9", json1, 500) + d = kr.verify_json_for_server("server9", json1, 500) # self.assertFalse(d.called) self.get_success(d) @@ -220,12 +220,12 @@ def test_verify_json_for_server_with_null_valid_until_ms(self): signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}, 0) + d = kr.verify_json_for_server("server9", {}, 0) self.get_failure(d, SynapseError) # should fail on a signed object with a non-zero minimum_valid_until_ms, # as it tries to refetch the keys and fails. - d = _verify_json_for_server(kr, "server9", json1, 500) + d = kr.verify_json_for_server("server9", json1, 500) self.get_failure(d, SynapseError) # We expect the keyring tried to refetch the key once. @@ -234,8 +234,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self): ) # should succeed on a signed object with a 0 minimum_valid_until_ms - d = _verify_json_for_server( - kr, + d = kr.verify_json_for_server( "server9", json1, 0, @@ -608,11 +607,3 @@ def get_key_from_perspectives(response): def get_key_id(key): """Get the matrix ID tag for a given SigningKey or VerifyKey""" return "%s:%s" % (key.alg, key.version) - - -async def _verify_json_for_server(kr, *args): - """thin wrapper around verify_json_for_server which makes sure it is wrapped - with the patched defer.inlineCallbacks. - """ - with LoggingContext("testctx"): - return await kr.verify_json_for_server(*args)