Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Clean up caching/locking of OIDC metadata load
Browse files Browse the repository at this point in the history
Ensure that we lock correctly to prevent multiple concurrent metadata load
requests, and generally clean up the way we construct the metadata cache.
  • Loading branch information
richvdh committed Feb 10, 2021
1 parent 9eedff8 commit 7da0f08
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 48 deletions.
59 changes: 37 additions & 22 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __init__(

self._token_generator = token_generator

self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str

self._scopes = provider.scopes
Expand All @@ -254,14 +255,11 @@ def __init__(
provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
issuer=provider.issuer,
authorization_endpoint=provider.authorization_endpoint,
token_endpoint=provider.token_endpoint,
userinfo_endpoint=provider.userinfo_endpoint,
jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = provider.discover

# cache of metadata for the identity provider (endpoint uris, mostly). This is
# loaded on-demand from the discovery endpoint (if discovery is enabled), with
# possible overrides from the config. Access via `load_metadata`.
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)

# cache of JWKs used by the identity provider to sign tokens. Loaded on demand
# from the IdP's jwks_uri, if required.
Expand Down Expand Up @@ -292,7 +290,7 @@ def __init__(

self._sso_handler.register_identity_provider(self)

def _validate_metadata(self):
def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
"""Verifies the provider metadata.
This checks the validity of the currently loaded provider. Not
Expand All @@ -311,7 +309,6 @@ def _validate_metadata(self):
if self._skip_verification is True:
return

m = self._provider_metadata
m.validate_issuer()
m.validate_authorization_endpoint()
m.validate_token_endpoint()
Expand Down Expand Up @@ -363,30 +360,48 @@ def _uses_userinfo(self) -> bool:
or self._user_profile_method == "userinfo_endpoint"
)

async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
"""Return the provider metadata.
The values metadatas are discovered if ``oidc_config.discovery`` is
``True`` and then cached.
If this is the first call, the metadata is built from the config and from the
metadata discovery endpoint (if enabled), and then validated. If the metadata
is successfully validated, it is then cached for future use.
Args:
force: If true, any cached metadata is discarded to force a reload.
Raises:
ValueError: if something in the provider is not valid
Returns:
The provider's metadata.
"""
# If we are using the OpenID Discovery documents, it needs to be loaded once
# FIXME: should there be a lock here?
if self._provider_needs_discovery:
url = get_well_known_url(self._provider_metadata["issuer"], external=True)
if force:
# reset the cached call to ensure we get a new result
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)

return await self._provider_metadata.get()

async def _load_metadata(self) -> OpenIDProviderMetadata:
# init the metadata from our config
metadata = OpenIDProviderMetadata(
issuer=self._config.issuer,
authorization_endpoint=self._config.authorization_endpoint,
token_endpoint=self._config.token_endpoint,
userinfo_endpoint=self._config.userinfo_endpoint,
jwks_uri=self._config.jwks_uri,
) # type: OpenIDProviderMetadata

# load any data from the discovery endpoint, if enabled
if self._config.discover:
url = get_well_known_url(self._config.issuer, external=True)
metadata_response = await self._http_client.get_json(url)
# TODO: maybe update the other way around to let user override some values?
self._provider_metadata.update(metadata_response)
self._provider_needs_discovery = False
metadata.update(metadata_response)

self._validate_metadata()
self._validate_metadata(metadata)

return self._provider_metadata
return metadata

async def load_jwks(self, force: bool = False) -> JWKS:
"""Load the JSON Web Key Set used to sign ID tokens.
Expand Down
71 changes: 45 additions & 26 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from synapse.server import HomeServer
from synapse.types import UserID

from tests.test_utils import FakeResponse, simple_async_mock
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config

try:
Expand Down Expand Up @@ -131,7 +131,6 @@ def default_config(self):
return config

def make_homeserver(self, reactor, clock):

self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
Expand All @@ -151,7 +150,15 @@ def make_homeserver(self, reactor, clock):
return hs

def metadata_edit(self, values):
return patch.dict(self.provider._provider_metadata, values)
"""Modify the result that will be returned by the well-known query"""

async def patched_get_json(uri):
res = await get_json(uri)
if uri == WELL_KNOWN:
res.update(values)
return res

return patch.object(self.http_client, "get_json", patched_get_json)

def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
Expand Down Expand Up @@ -212,7 +219,14 @@ def test_load_jwks(self):
self.http_client.get_json.assert_called_once_with(JWKS_URI)

# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
original = self.provider.load_metadata

async def patched_load_metadata():
m = (await original()).copy()
m.update({"jwks_uri": None})
return m

with patch.object(self.provider, "load_metadata", patched_load_metadata):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)

# Return empty key set if JWKS are not used
Expand All @@ -222,63 +236,68 @@ def test_load_jwks(self):
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})

@override_config({"oidc_config": COMMON_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
h = self.provider

def force_load_metadata():
async def force_load():
return await h.load_metadata(force=True)

return get_awaitable_result(force_load())

# Default test config does not throw
h._validate_metadata()
force_load_metadata()

with self.metadata_edit({"issuer": None}):
self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)

with self.metadata_edit({"issuer": "http://insecure/"}):
self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)

with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)

with self.metadata_edit({"authorization_endpoint": None}):
self.assertRaisesRegex(
ValueError, "authorization_endpoint", h._validate_metadata
ValueError, "authorization_endpoint", force_load_metadata
)

with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
self.assertRaisesRegex(
ValueError, "authorization_endpoint", h._validate_metadata
ValueError, "authorization_endpoint", force_load_metadata
)

with self.metadata_edit({"token_endpoint": None}):
self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)

with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)

with self.metadata_edit({"jwks_uri": None}):
self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)

with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)

with self.metadata_edit({"response_types_supported": ["id_token"]}):
self.assertRaisesRegex(
ValueError, "response_types_supported", h._validate_metadata
ValueError, "response_types_supported", force_load_metadata
)

with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
):
# should not throw, as client_secret_basic is the default auth method
h._validate_metadata()
force_load_metadata()

with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_post"]}
):
self.assertRaisesRegex(
ValueError,
"token_endpoint_auth_methods_supported",
h._validate_metadata,
force_load_metadata,
)

# Tests for configs that require the userinfo endpoint
Expand All @@ -287,24 +306,24 @@ def test_validate_config(self):
h._user_profile_method = "userinfo_endpoint"
self.assertTrue(h._uses_userinfo)

# Revert the profile method and do not request the "openid" scope.
# Revert the profile method and do not request the "openid" scope: this should
# mean that we check for a userinfo endpoint
h._user_profile_method = "auto"
h._scopes = []
self.assertTrue(h._uses_userinfo)
self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
with self.metadata_edit({"userinfo_endpoint": None}):
self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata)

with self.metadata_edit(
{"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
):
# Shouldn't raise with a valid userinfo, even without
h._validate_metadata()
with self.metadata_edit({"jwks_uri": None}):
# Shouldn't raise with a valid userinfo, even without jwks
force_load_metadata()

@override_config({"oidc_config": {"skip_verification": True}})
def test_skip_verification(self):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
self.provider._validate_metadata()
get_awaitable_result(self.provider.load_metadata())

def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
Expand Down

0 comments on commit 7da0f08

Please sign in to comment.