Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implement a CachedCall to handle boilerplate of caching results #9353

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util.caches.cached_call import CachedCall

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -261,6 +262,9 @@ def __init__(
jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = provider.discover

self._jwks = CachedCall(self._load_jwks)

self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
)
Expand Down Expand Up @@ -414,27 +418,27 @@ async def load_jwks(self, force: bool = False) -> JWKS:
]
}
"""
if force:
# reset the cached call to ensure we get a new result
self._jwks = CachedCall(self._load_jwks)
return await self._jwks.get()

async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}

# First check if the JWKS are loaded in the provider metadata.
# It can happen either if the provider gives its JWKS in the discovery
# document directly or if it was already loaded once.
metadata = await self.load_metadata()
jwk_set = metadata.get("jwks")
if jwk_set is not None and not force:
return jwk_set

# Loading the JWKS using the `jwks_uri` metadata
# Load the JWKS using the `jwks_uri` metadata.
uri = metadata.get("jwks_uri")
if not uri:
# this should be unreachable: load_metadata validates that
# there is a jwks_uri in the metadata if _uses_userinfo is unset
raise RuntimeError('Missing "jwks_uri" in metadata')

jwk_set = await self._http_client.get_json(uri)

# Caching the JWKS in the provider's metadata
self._provider_metadata["jwks"] = jwk_set
return jwk_set

async def _exchange_code(self, code: str) -> Token:
Expand Down
97 changes: 97 additions & 0 deletions synapse/util/caches/cached_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union

from twisted.internet.defer import Deferred
from twisted.python.failure import Failure

from synapse.logging.context import make_deferred_yieldable, run_in_background

TV = TypeVar("TV")


class CachedCall(Generic[TV]):
"""A wrapper for asynchronous calls whose results should be shared

This is useful for wrapping asynchronous functions, where there might be multiple
callers, but we only want to call the underlying function once (and have the result
returned to all callers).

Similar results can be achieved via a lock of some form, but that typically requires
more boilerplate (and ends up being less efficient).

Correctly handles Synapse logcontexts (logs and resource usage for the underlying
function are logged against the logcontext which is active when get() is first
called).

Example usage:

_cached_val = CachedCall(_load_prop)

async def handle_request() -> X:
# We can call this multiple times, but it will result in a single call to
# _load_prop().
return await _cached_val.get()

async def _load_prop() -> X:
await difficult_operation()

"""

__slots__ = ["_callable", "_deferred", "_result"]

def __init__(self, f: Callable[[], Awaitable[TV]]):
"""
Args:
f: The underlying function. Only one call to this function will be alive
at once (per instance of CachedCall)
"""
self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
self._deferred = None # type: Optional[Deferred]
self._result = None # type: Union[None, Failure, TV]

async def get(self) -> TV:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe slighter nicer ergonomics if we used __call__ instead?

This would let you do something like:

async def handle_request() -> X:
    # We can call this multiple times, but it will result in a single call to
    # _load_prop().
    return await _cached_val()

This feels a bit nicer since you've wrapped a function and then it returns a function-like thing and might let you use it as a decorator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hrm, maybe. I always feel like relying on __call__ is a bit magical, and tend to prefer that the interactions are made explicit, which is why I did it this way. I could be persuaded though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought is that it makes it more like functools.cached_property, but perhaps that isn't the goal. 😄

"""Kick off the call if necessary, and return the result"""

# Fire off the callable now if this is our first time
if not self._deferred:
self._deferred = run_in_background(self._callable)

# we will never need the callable again, so make sure it can be GCed
self._callable = None

# once the deferred completes, store the result. We cannot simply leave the
# result in the deferred, since if it's a Failure, GCing the deferred
# would then log a critical error about unhandled Failures.
def got_result(r):
self._result = r

self._deferred.addBoth(got_result)

# TODO: consider cancellation semantics. Currently, if the call to get()
# is cancelled, the underlying call will continue (and any future calls
# will get the result/exception), which I think is *probably* ok, modulo
# the fact the underlying call may be logged to a cancelled logcontext,
# and any eventual exception may not be caught.

# we can now await the deferred, and once it completes, return the result.
await make_deferred_yieldable(self._deferred)

# I *think* this is the easiest way to correctly raise a Failure without having
# to gut-wrench into the implementation of Deferred.
d = Deferred()
d.callback(self._result)
return await d
Comment on lines +93 to +97
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this will work properly for errors, do we need to call d.errback if it is an instance of Failure? (I don't think Deferred does anything special if you callback a Failure.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work, because errback(x) and callback(x) are both just thin wrappers around _startRunCallbacks(x). You could argue that this isn't using Deferred's API as it's intended, which might be fair.

2 changes: 1 addition & 1 deletion tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_load_jwks(self):

# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
self.get_failure(self.provider.load_jwks(force=True), ValueError)

# Return empty key set if JWKS are not used
self.provider._scopes = [] # not asking the openid scope
Expand Down