Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cache token introspection response from OIDC provider #16117

Merged
merged 4 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16117.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache token introspection response from OIDC provider.
40 changes: 38 additions & 2 deletions synapse/api/auth/msc3861_delegated.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.expiringcache import ExpiringCache

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -106,6 +107,14 @@ def __init__(self, hs: "HomeServer"):

self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)

self._clock = hs.get_clock()
self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
cache_name="introspection_token_cache",
clock=self._clock,
max_len=10000,
expiry_ms=5 * 60 * 1000,
)

if isinstance(auth_method, PrivateKeyJWTWithKid):
# Use the JWK as the client secret when using the private_key_jwt method
assert self._config.jwk, "No JWK provided"
Expand Down Expand Up @@ -144,6 +153,20 @@ async def _introspect_token(self, token: str) -> IntrospectionToken:
Returns:
The introspection response
"""
# check the cache before doing a request
introspection_token = self._token_cache.get(token, None)

if introspection_token:
# check the expiration field of the token (if it exists)
exp = introspection_token.get("exp", None)
if exp:
time_now = self._clock.time()
expired = time_now > exp
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
if not expired:
return introspection_token
else:
return introspection_token

metadata = await self._issuer_metadata.get()
introspection_endpoint = metadata.get("introspection_endpoint")
raw_headers: Dict[str, str] = {
Expand All @@ -157,7 +180,10 @@ async def _introspect_token(self, token: str) -> IntrospectionToken:

# Fill the body/headers with credentials
uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
method="POST",
uri=introspection_endpoint,
headers=raw_headers,
body=body,
)
headers = Headers({k: [v] for (k, v) in raw_headers.items()})

Expand Down Expand Up @@ -187,7 +213,17 @@ async def _introspect_token(self, token: str) -> IntrospectionToken:
"The introspection endpoint returned an invalid JSON response."
)

return IntrospectionToken(**resp)
expiration = resp.get("exp", None)
if expiration:
if self._clock.time() > expiration:
raise InvalidClientTokenError("Token is expired.")

introspection_token = IntrospectionToken(**resp)

# add token to cache
self._token_cache[token] = introspection_token

return introspection_token

async def is_server_admin(self, requester: Requester) -> bool:
return "urn:synapse:admin:*" in requester.scope
Expand Down
62 changes: 62 additions & 0 deletions tests/handlers/test_oauth_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,68 @@ def test_unavailable_introspection_endpoint(self) -> None:
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)

def test_introspection_token_cache(self) -> None:
access_token = "open_sesame"
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={"active": "true", "scope": "guest", "jti": access_token},
)
)

# first call should cache response
# Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code
# for regular auth code via the config
self.get_success(
self.auth._introspect_token(access_token) # type: ignore[attr-defined]
)
introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined]
self.assertEqual(introspection_token["jti"], access_token)
# there's been one http request
self.http_client.request.assert_called_once()

# second call should pull from cache, there should still be only one http request
token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
self.http_client.request.assert_called_once()
self.assertEqual(token["jti"], access_token)

# advance past five minutes and check that cache expired - there should be more than one http call now
self.reactor.advance(360)
token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
self.assertEqual(self.http_client.request.call_count, 2)
self.assertEqual(token_2["jti"], access_token)

# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
# token with a soon-to-expire `exp` field to the cache
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={
"active": "true",
"scope": "guest",
"jti": "stale",
"exp": self.clock.time() + 100,
},
)
)
self.get_success(
self.auth._introspect_token("stale") # type: ignore[attr-defined]
)
introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined]
self.assertEqual(introspection_token["jti"], "stale")
self.assertEqual(self.http_client.request.call_count, 1)

# advance the reactor past the token expiry but less than the cache expiry
self.reactor.advance(120)
self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined]

# check that the next call causes another http request (which will fail because the token is technically expired
# but the important thing is we discard the token from the cache and try the network)
self.get_failure(
self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined]
)
self.assertEqual(self.http_client.request.call_count, 2)

def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)
Expand Down