From 9db3a90782350d203c3659647883ea43b1c0c0cb Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 15 Aug 2023 14:27:39 -0700 Subject: [PATCH 1/4] add an expiring cache to `_introspect_token` --- synapse/api/auth/msc3861_delegated.py | 112 +++++++++++++++++--------- 1 file changed, 73 insertions(+), 39 deletions(-) diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index 9524102a3037..e85931c6ffda 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -39,6 +39,7 @@ from synapse.types import Requester, UserID, create_requester from synapse.util import json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall +from synapse.util.caches.expiringcache import ExpiringCache if TYPE_CHECKING: from synapse.server import HomeServer @@ -106,6 +107,14 @@ def __init__(self, hs: "HomeServer"): self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) + self._clock = hs.get_clock() + self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache( + cache_name="introspection_token_cache", + clock=self._clock, + max_len=10000, + expiry_ms=5 * 60 * 1000, + ) + if isinstance(auth_method, PrivateKeyJWTWithKid): # Use the JWK as the client secret when using the private_key_jwt method assert self._config.jwk, "No JWK provided" @@ -144,50 +153,75 @@ async def _introspect_token(self, token: str) -> IntrospectionToken: Returns: The introspection response """ - metadata = await self._issuer_metadata.get() - introspection_endpoint = metadata.get("introspection_endpoint") - raw_headers: Dict[str, str] = { - "Content-Type": "application/x-www-form-urlencoded", - "User-Agent": str(self._http_client.user_agent, "utf-8"), - "Accept": "application/json", - } - - args = {"token": token, "token_type_hint": "access_token"} - body = urlencode(args, True) - - # Fill the body/headers with credentials - uri, raw_headers, body = self._client_auth.prepare( - method="POST", uri=introspection_endpoint, headers=raw_headers, body=body - ) - headers = Headers({k: [v] for (k, v) in raw_headers.items()}) - - # Do the actual request - # We're not using the SimpleHttpClient util methods as we don't want to - # check the HTTP status code, and we do the body encoding ourselves. - response = await self._http_client.request( - method="POST", - uri=uri, - data=body.encode("utf-8"), - headers=headers, - ) + # check the cache before doing a request + introspection_token = self._token_cache.get(token, None) + + expired = False + if introspection_token: + # check the expiration field of the token (if it exists) + exp = introspection_token.get("exp", None) + if exp: + time_now = self._clock.time_msec() + expired = time_now > exp + + if not introspection_token or expired: + metadata = await self._issuer_metadata.get() + introspection_endpoint = metadata.get("introspection_endpoint") + raw_headers: Dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": str(self._http_client.user_agent, "utf-8"), + "Accept": "application/json", + } + + args = {"token": token, "token_type_hint": "access_token"} + body = urlencode(args, True) + + # Fill the body/headers with credentials + uri, raw_headers, body = self._client_auth.prepare( + method="POST", + uri=introspection_endpoint, + headers=raw_headers, + body=body, + ) + headers = Headers({k: [v] for (k, v) in raw_headers.items()}) + + # Do the actual request + # We're not using the SimpleHttpClient util methods as we don't want to + # check the HTTP status code, and we do the body encoding ourselves. + response = await self._http_client.request( + method="POST", + uri=uri, + data=body.encode("utf-8"), + headers=headers, + ) - resp_body = await make_deferred_yieldable(readBody(response)) + resp_body = await make_deferred_yieldable(readBody(response)) - if response.code < 200 or response.code >= 300: - raise HttpResponseException( - response.code, - response.phrase.decode("ascii", errors="replace"), - resp_body, - ) + if response.code < 200 or response.code >= 300: + raise HttpResponseException( + response.code, + response.phrase.decode("ascii", errors="replace"), + resp_body, + ) - resp = json_decoder.decode(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) - if not isinstance(resp, dict): - raise ValueError( - "The introspection endpoint returned an invalid JSON response." - ) + if not isinstance(resp, dict): + raise ValueError( + "The introspection endpoint returned an invalid JSON response." + ) + + expiration = resp.get("exp", None) + if expiration: + if self._clock.time_msec() > expiration: + raise InvalidClientTokenError("Token is expired.") + + introspection_token = IntrospectionToken(**resp) + + # add token to cache + self._token_cache[token] = introspection_token - return IntrospectionToken(**resp) + return introspection_token async def is_server_admin(self, requester: Requester) -> bool: return "urn:synapse:admin:*" in requester.scope From e4dfba4425b0802c5827f1100da10dbfb1dc1c9f Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 15 Aug 2023 14:27:44 -0700 Subject: [PATCH 2/4] add some tests --- tests/handlers/test_oauth_delegation.py | 62 +++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 6309d7b36e8a..c86af57d2291 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -491,6 +491,68 @@ def test_unavailable_introspection_endpoint(self) -> None: error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) + def test_introspection_token_cache(self) -> None: + access_token = "open_sesame" + self.http_client.request = simple_async_mock( + return_value=FakeResponse.json( + code=200, + payload={"active": "true", "scope": "guest", "jti": access_token}, + ) + ) + + # first call should cache response - check cache + # Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code + # for regular auth code via the config + self.get_success( + self.auth._introspect_token(access_token) # type: ignore[attr-defined] + ) + introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined] + self.assertEqual(introspection_token["jti"], access_token) + # there's been one http request + self.http_client.request.assert_called_once() + + # second call should pull from cache, there should still be only one http request + token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined] + self.http_client.request.assert_called_once() + self.assertEqual(token["jti"], access_token) + + # advance past five minutes and check that cache expired - there should be more than one http call now + self.reactor.advance(360) + token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined] + self.assertEqual(self.http_client.request.call_count, 2) + self.assertEqual(token_2["jti"], access_token) + + # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a + # token with a soon-to-expire `exp` field to the cache + self.http_client.request = simple_async_mock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": "true", + "scope": "guest", + "jti": "stale", + "exp": self.clock.time_msec() + 100, + }, + ) + ) + self.get_success( + self.auth._introspect_token("stale") # type: ignore[attr-defined] + ) + introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined] + self.assertEqual(introspection_token["jti"], "stale") + self.assertEqual(self.http_client.request.call_count, 1) + + # advance the reactor past the token expiry but less than the cache expiry + self.reactor.advance(30) + self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined] + + # check that the next call causes another http request (which will fail because the token is technically expired + # but the important thing is we discard the token from the cache and try the network) + self.get_failure( + self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined] + ) + self.assertEqual(self.http_client.request.call_count, 2) + def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: # We only generate a master key to simplify the test. master_signing_key = generate_signing_key(device_id) From d3841eb33708bae5ffe7df88bc22df86314aa7c8 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 15 Aug 2023 15:04:21 -0700 Subject: [PATCH 3/4] newsfragment --- changelog.d/16117.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/16117.misc diff --git a/changelog.d/16117.misc b/changelog.d/16117.misc new file mode 100644 index 000000000000..f33fa6dc1751 --- /dev/null +++ b/changelog.d/16117.misc @@ -0,0 +1 @@ +Cache token introspection response from OIDC provider. From 400c90d0a702d553175494f47ec9fe91b12cbea9 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 16 Aug 2023 09:23:11 -0700 Subject: [PATCH 4/4] requested changes --- synapse/api/auth/msc3861_delegated.py | 106 ++++++++++++------------ tests/handlers/test_oauth_delegation.py | 6 +- 2 files changed, 57 insertions(+), 55 deletions(-) diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index e85931c6ffda..3a516093f54c 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -156,70 +156,72 @@ async def _introspect_token(self, token: str) -> IntrospectionToken: # check the cache before doing a request introspection_token = self._token_cache.get(token, None) - expired = False if introspection_token: # check the expiration field of the token (if it exists) exp = introspection_token.get("exp", None) if exp: - time_now = self._clock.time_msec() + time_now = self._clock.time() expired = time_now > exp + if not expired: + return introspection_token + else: + return introspection_token + + metadata = await self._issuer_metadata.get() + introspection_endpoint = metadata.get("introspection_endpoint") + raw_headers: Dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": str(self._http_client.user_agent, "utf-8"), + "Accept": "application/json", + } + + args = {"token": token, "token_type_hint": "access_token"} + body = urlencode(args, True) + + # Fill the body/headers with credentials + uri, raw_headers, body = self._client_auth.prepare( + method="POST", + uri=introspection_endpoint, + headers=raw_headers, + body=body, + ) + headers = Headers({k: [v] for (k, v) in raw_headers.items()}) + + # Do the actual request + # We're not using the SimpleHttpClient util methods as we don't want to + # check the HTTP status code, and we do the body encoding ourselves. + response = await self._http_client.request( + method="POST", + uri=uri, + data=body.encode("utf-8"), + headers=headers, + ) - if not introspection_token or expired: - metadata = await self._issuer_metadata.get() - introspection_endpoint = metadata.get("introspection_endpoint") - raw_headers: Dict[str, str] = { - "Content-Type": "application/x-www-form-urlencoded", - "User-Agent": str(self._http_client.user_agent, "utf-8"), - "Accept": "application/json", - } - - args = {"token": token, "token_type_hint": "access_token"} - body = urlencode(args, True) - - # Fill the body/headers with credentials - uri, raw_headers, body = self._client_auth.prepare( - method="POST", - uri=introspection_endpoint, - headers=raw_headers, - body=body, - ) - headers = Headers({k: [v] for (k, v) in raw_headers.items()}) - - # Do the actual request - # We're not using the SimpleHttpClient util methods as we don't want to - # check the HTTP status code, and we do the body encoding ourselves. - response = await self._http_client.request( - method="POST", - uri=uri, - data=body.encode("utf-8"), - headers=headers, - ) - - resp_body = await make_deferred_yieldable(readBody(response)) + resp_body = await make_deferred_yieldable(readBody(response)) - if response.code < 200 or response.code >= 300: - raise HttpResponseException( - response.code, - response.phrase.decode("ascii", errors="replace"), - resp_body, - ) + if response.code < 200 or response.code >= 300: + raise HttpResponseException( + response.code, + response.phrase.decode("ascii", errors="replace"), + resp_body, + ) - resp = json_decoder.decode(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) - if not isinstance(resp, dict): - raise ValueError( - "The introspection endpoint returned an invalid JSON response." - ) + if not isinstance(resp, dict): + raise ValueError( + "The introspection endpoint returned an invalid JSON response." + ) - expiration = resp.get("exp", None) - if expiration: - if self._clock.time_msec() > expiration: - raise InvalidClientTokenError("Token is expired.") + expiration = resp.get("exp", None) + if expiration: + if self._clock.time() > expiration: + raise InvalidClientTokenError("Token is expired.") - introspection_token = IntrospectionToken(**resp) + introspection_token = IntrospectionToken(**resp) - # add token to cache - self._token_cache[token] = introspection_token + # add token to cache + self._token_cache[token] = introspection_token return introspection_token diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index c86af57d2291..82c26e303f25 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -500,7 +500,7 @@ def test_introspection_token_cache(self) -> None: ) ) - # first call should cache response - check cache + # first call should cache response # Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code # for regular auth code via the config self.get_success( @@ -531,7 +531,7 @@ def test_introspection_token_cache(self) -> None: "active": "true", "scope": "guest", "jti": "stale", - "exp": self.clock.time_msec() + 100, + "exp": self.clock.time() + 100, }, ) ) @@ -543,7 +543,7 @@ def test_introspection_token_cache(self) -> None: self.assertEqual(self.http_client.request.call_count, 1) # advance the reactor past the token expiry but less than the cache expiry - self.reactor.advance(30) + self.reactor.advance(120) self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined] # check that the next call causes another http request (which will fail because the token is technically expired