From 9eedff87beed2b370bc30bebaf93f080a0c4979f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 9 Feb 2021 18:05:17 +0000 Subject: [PATCH] Stop caching the JWKS in provider_metadata AFAICT from the RFCs (https://tools.ietf.org/html/rfc8414), `jwks` is not a field that can be returned from the metadata endpoint, so this is entirely an internal cache. The object is opaque enough without us inventing extra things to store in it, so let's move it out. --- synapse/handlers/oidc_handler.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index d3cfa27d06bb..001a76f38605 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -41,6 +41,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder +from synapse.util.caches.cached_call import RetryOnExceptionCachedCall if TYPE_CHECKING: from synapse.server import HomeServer @@ -261,6 +262,11 @@ def __init__( jwks_uri=provider.jwks_uri, ) # type: OpenIDProviderMetadata self._provider_needs_discovery = provider.discover + + # cache of JWKs used by the identity provider to sign tokens. Loaded on demand + # from the IdP's jwks_uri, if required. + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + self._user_mapping_provider = provider.user_mapping_provider_class( provider.user_mapping_provider_config ) @@ -340,8 +346,7 @@ def _validate_metadata(self): ) else: # If we're not using userinfo, we need a valid jwks to validate the ID token - if m.get("jwks") is None: - m.validate_jwks_uri() + m.validate_jwks_uri() @property def _uses_userinfo(self) -> bool: @@ -411,27 +416,27 @@ async def load_jwks(self, force: bool = False) -> JWKS: ] } """ + if force: + # reset the cached call to ensure we get a new result + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + return await self._jwks.get() + + async def _load_jwks(self) -> JWKS: if self._uses_userinfo: # We're not using jwt signing, return an empty jwk set return {"keys": []} - # First check if the JWKS are loaded in the provider metadata. - # It can happen either if the provider gives its JWKS in the discovery - # document directly or if it was already loaded once. metadata = await self.load_metadata() - jwk_set = metadata.get("jwks") - if jwk_set is not None and not force: - return jwk_set - # Loading the JWKS using the `jwks_uri` metadata + # Load the JWKS using the `jwks_uri` metadata. uri = metadata.get("jwks_uri") if not uri: + # this should be unreachable: load_metadata validates that + # there is a jwks_uri in the metadata if _uses_userinfo is unset raise RuntimeError('Missing "jwks_uri" in metadata') jwk_set = await self._http_client.get_json(uri) - # Caching the JWKS in the provider's metadata - self._provider_metadata["jwks"] = jwk_set return jwk_set async def _exchange_code(self, code: str) -> Token: