diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index d3cfa27d06bb..001a76f38605 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -41,6 +41,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder +from synapse.util.caches.cached_call import RetryOnExceptionCachedCall if TYPE_CHECKING: from synapse.server import HomeServer @@ -261,6 +262,11 @@ def __init__( jwks_uri=provider.jwks_uri, ) # type: OpenIDProviderMetadata self._provider_needs_discovery = provider.discover + + # cache of JWKs used by the identity provider to sign tokens. Loaded on demand + # from the IdP's jwks_uri, if required. + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + self._user_mapping_provider = provider.user_mapping_provider_class( provider.user_mapping_provider_config ) @@ -340,8 +346,7 @@ def _validate_metadata(self): ) else: # If we're not using userinfo, we need a valid jwks to validate the ID token - if m.get("jwks") is None: - m.validate_jwks_uri() + m.validate_jwks_uri() @property def _uses_userinfo(self) -> bool: @@ -411,27 +416,27 @@ async def load_jwks(self, force: bool = False) -> JWKS: ] } """ + if force: + # reset the cached call to ensure we get a new result + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + return await self._jwks.get() + + async def _load_jwks(self) -> JWKS: if self._uses_userinfo: # We're not using jwt signing, return an empty jwk set return {"keys": []} - # First check if the JWKS are loaded in the provider metadata. - # It can happen either if the provider gives its JWKS in the discovery - # document directly or if it was already loaded once. metadata = await self.load_metadata() - jwk_set = metadata.get("jwks") - if jwk_set is not None and not force: - return jwk_set - # Loading the JWKS using the `jwks_uri` metadata + # Load the JWKS using the `jwks_uri` metadata. uri = metadata.get("jwks_uri") if not uri: + # this should be unreachable: load_metadata validates that + # there is a jwks_uri in the metadata if _uses_userinfo is unset raise RuntimeError('Missing "jwks_uri" in metadata') jwk_set = await self._http_client.get_json(uri) - # Caching the JWKS in the provider's metadata - self._provider_metadata["jwks"] = jwk_set return jwk_set async def _exchange_code(self, code: str) -> Token: