diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index ba2410998a96..26025504393e 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,7 @@ import logging from typing import ( Any, + Awaitable, Callable, Dict, Generic, @@ -354,7 +355,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def __init__( self, - orig: Callable[..., Any], + orig: Callable[..., Awaitable[Dict]], cached_method_name: str, list_name: str, num_args: Optional[int] = None, @@ -385,13 +386,13 @@ def __init__( def __get__( self, obj: Optional[Any], objtype: Optional[Type] = None - ) -> Callable[..., Any]: + ) -> Callable[..., defer.Deferred[Dict[Hashable, Any]]]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args: Any, **kwargs: Any) -> Any: + def wrapped(*args: Any, **kwargs: Any) -> defer.Deferred[Dict]: # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None)