From ff8f94d36cf6692c19e8ca8309e661e37d95a47b Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 13 Aug 2021 19:47:02 +0100 Subject: [PATCH 01/38] fairly mechanical changes --- synapse/util/__init__.py | 2 +- synapse/util/batching_queue.py | 2 +- synapse/util/file_consumer.py | 16 ++++++++++------ synapse/util/ratelimitutils.py | 21 ++++++++++++++------- synapse/util/templates.py | 5 +++-- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b69f562ca586..8ba901fabb3c 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -79,7 +79,7 @@ class Clock: @defer.inlineCallbacks def sleep(self, seconds): - d = defer.Deferred() + d: defer.Deferred = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) res = yield d diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 274cea7eb709..6d0e2a43f0a7 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -122,7 +122,7 @@ async def add_to_queue(self, value: V, key: Hashable = ()) -> R: # First we create a defer and add it and the value to the list of # pending items. - d = defer.Deferred() + d: defer.Deferred = defer.Deferred() self._next_values.setdefault(key, []).append((value, d)) # If we're not currently processing the key fire off a background diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index e946189f9a72..ecda6b0eda39 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,6 +13,7 @@ # limitations under the License. import queue +from typing import Optional from twisted.internet import threads @@ -51,7 +52,7 @@ def __init__(self, file_obj, reactor): # Queue of slices of bytes to be written. When producer calls # unregister a final None is sent. - self._bytes_queue = queue.Queue() + self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() # Deferred that is resolved when finished writing self._finished_deferred = None @@ -59,7 +60,7 @@ def __init__(self, file_obj, reactor): # If the _writer thread throws an exception it gets stored here. self._write_exception = None - def registerProducer(self, producer, streaming): + def registerProducer(self, producer, streaming) -> None: """Part of IConsumer interface Args: @@ -81,17 +82,19 @@ def registerProducer(self, producer, streaming): if not streaming: self._producer.resumeProducing() - def unregisterProducer(self): + def unregisterProducer(self) -> None: """Part of IProducer interface""" self._producer = None + assert self._finished_deferred is not None if not self._finished_deferred.called: self._bytes_queue.put_nowait(None) - def write(self, bytes): + def write(self, bytes) -> None: """Part of IProducer interface""" if self._write_exception: raise self._write_exception + assert self._finished_deferred is not None if self._finished_deferred.called: raise Exception("consumer has closed") @@ -101,9 +104,10 @@ def write(self, bytes): # then we pause the producer. if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: self._paused_producer = True + assert self._producer is not None self._producer.pauseProducing() - def _writer(self): + def _writer(self) -> None: """This is run in a background thread to write to the file.""" try: while self._producer or not self._bytes_queue.empty(): @@ -134,7 +138,7 @@ def wait(self): """Returns a deferred that resolves when finished writing to file""" return make_deferred_yieldable(self._finished_deferred) - def _resume_paused_producer(self): + def _resume_paused_producer(self) -> None: """Gets called if we should resume producing after being paused""" if self._paused_producer and self._producer: self._paused_producer = False diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index a654c6968492..4df8625b522f 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -15,33 +15,38 @@ import collections import contextlib import logging +from typing import DefaultDict from twisted.internet import defer from synapse.api.errors import LimitExceededError +from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) +from synapse.util import Clock logger = logging.getLogger(__name__) class FederationRateLimiter: - def __init__(self, clock, config): + def __init__(self, clock: Clock, config: FederationRateLimitConfig): """ Args: - clock (Clock) - config (FederationRateLimitConfig) + clock + config """ def new_limiter(): return _PerHostRatelimiter(clock=clock, config=config) - self.ratelimiters = collections.defaultdict(new_limiter) + self.ratelimiters: DefaultDict[ + str, "_PerHostRatelimiter" + ] = collections.defaultdict(new_limiter) - def ratelimit(self, host): + def ratelimit(self, host: str): """Used to ratelimit an incoming request from a given host Example usage: @@ -79,7 +84,9 @@ def __init__(self, clock, config): # map from request_id object to Deferred for requests which are ready # for processing but have been queued - self.ready_request_queue = collections.OrderedDict() + self.ready_request_queue: collections.OrderedDict[ + object, defer.Deferred + ] = collections.OrderedDict() # request id objects for requests which are in progress self.current_processing = set() @@ -122,7 +129,7 @@ def _on_enter(self, request_id): def queue_request(): if len(self.current_processing) >= self.concurrent_requests: - queue_defer = defer.Deferred() + queue_defer: defer.Deferred = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( "Ratelimiter: queueing request (queue now %i items)", diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 38543dd1ea19..70eaf71ca7b9 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -27,7 +27,7 @@ def build_jinja_env( template_search_directories: Iterable[str], config: "HomeServerConfig", - autoescape: Union[bool, Callable[[str], bool], None] = None, + autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None, ) -> jinja2.Environment: """Set up a Jinja2 environment to load templates from the given search path @@ -56,7 +56,8 @@ def build_jinja_env( if autoescape is None: autoescape = jinja2.select_autoescape() - loader = jinja2.FileSystemLoader(template_search_directories) + # the type signature of this is wrong + loader = jinja2.FileSystemLoader(template_search_directories) # type: ignore[arg-type] env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters From 6deee28978e4c728973b2738a13b709b951091a4 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 13 Aug 2021 19:47:11 +0100 Subject: [PATCH 02/38] stranger changes (REVIEW) --- synapse/util/manhole.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index da24ba0470b6..c35a567348f1 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -79,7 +79,10 @@ def manhole(username, password, globals): checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() - rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( + # mypy ignored here because: + # - can't deduce types of lambdas + # - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol] + rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( # type: ignore[misc,assignment] SynapseManhole, dict(globals, __name__="__console__") ) @@ -110,6 +113,7 @@ def showsyntaxerror(self, filename=None): any syntax errors to be sent to the terminal, rather than sentry. """ type, value, tb = sys.exc_info() + assert value is not None sys.last_type = type sys.last_value = value sys.last_traceback = tb @@ -135,9 +139,8 @@ def showtraceback(self): """ sys.last_type, sys.last_value, last_tb = ei = sys.exc_info() sys.last_traceback = last_tb - try: - # We remove the first stack item because it is our own code. - lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) - self.write("".join(lines)) - finally: - last_tb = ei = None + assert last_tb is not None + + # We remove the first stack item because it is our own code. + lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) + self.write("".join(lines)) From 3c2b4dd9ad06db7954fe37413db00c1a8c16af16 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 13 Aug 2021 19:47:29 +0100 Subject: [PATCH 03/38] Newsfile & mypy.ini --- changelog.d/10601.misc | 1 + mypy.ini | 12 +----------- 2 files changed, 2 insertions(+), 11 deletions(-) create mode 100644 changelog.d/10601.misc diff --git a/changelog.d/10601.misc b/changelog.d/10601.misc new file mode 100644 index 000000000000..8b573ab49a95 --- /dev/null +++ b/changelog.d/10601.misc @@ -0,0 +1 @@ +Add type annotations to complete the synapse.util package. diff --git a/mypy.ini b/mypy.ini index 5d6cd557bca2..92387ccb17b4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -70,17 +70,7 @@ files = synapse/storage/util, synapse/streams, synapse/types.py, - synapse/util/async_helpers.py, - synapse/util/caches, - synapse/util/daemonize.py, - synapse/util/hash.py, - synapse/util/iterutils.py, - synapse/util/linked_list.py, - synapse/util/metrics.py, - synapse/util/macaroons.py, - synapse/util/module_loader.py, - synapse/util/msisdn.py, - synapse/util/stringutils.py, + synapse/util, synapse/visibility.py, tests/replication, tests/test_event_auth.py, From 1d0b435a648a35f95628a131703922cc513d76a2 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 16 Aug 2021 14:45:57 +0100 Subject: [PATCH 04/38] Put the switch back in to the 'more magic' position --- synapse/util/manhole.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index c35a567348f1..c36c131f8409 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -141,6 +141,13 @@ def showtraceback(self): sys.last_traceback = last_tb assert last_tb is not None - # We remove the first stack item because it is our own code. - lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) - self.write("".join(lines)) + try: + # We remove the first stack item because it is our own code. + lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) + self.write("".join(lines)) + finally: + # On the line below, last_tb and ei appear to be dead. + # It's unclear whether there is a reason behind this line. + # It conceivably could be because an exception raised in this block + # will keep the local frame (containing these local variables) around. + last_tb = ei = None # type: ignore From 22df193bc349782c19d2ed17f1d9080aa27ee09d Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 16 Aug 2021 15:04:00 +0100 Subject: [PATCH 05/38] Fix up some more types --- synapse/util/ratelimitutils.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 4df8625b522f..457ffa19c0e5 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -15,7 +15,8 @@ import collections import contextlib import logging -from typing import DefaultDict +import typing +from typing import DefaultDict, Iterator from twisted.internet import defer @@ -28,17 +29,14 @@ ) from synapse.util import Clock +if typing.TYPE_CHECKING: + from contextlib import _GeneratorContextManager + logger = logging.getLogger(__name__) class FederationRateLimiter: def __init__(self, clock: Clock, config: FederationRateLimitConfig): - """ - Args: - clock - config - """ - def new_limiter(): return _PerHostRatelimiter(clock=clock, config=config) @@ -46,7 +44,7 @@ def new_limiter(): str, "_PerHostRatelimiter" ] = collections.defaultdict(new_limiter) - def ratelimit(self, host: str): + def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred]": """Used to ratelimit an incoming request from a given host Example usage: @@ -96,7 +94,7 @@ def __init__(self, clock, config): self.request_times = [] @contextlib.contextmanager - def ratelimit(self): + def ratelimit(self) -> Iterator[defer.Deferred]: # `contextlib.contextmanager` takes a generator and turns it into a # context manager. The generator should only yield once with a value # to be returned by manager. @@ -109,7 +107,7 @@ def ratelimit(self): finally: self._on_exit(request_id) - def _on_enter(self, request_id): + def _on_enter(self, request_id) -> defer.Deferred: time_now = self.clock.time_msec() # remove any entries from request_times which aren't within the window @@ -184,7 +182,7 @@ def on_both(r): ret_defer.addBoth(on_both) return make_deferred_yieldable(ret_defer) - def _on_exit(self, request_id): + def _on_exit(self, request_id) -> None: logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: From e36db3fce6446a1ab7745fe3a3e8d928f5f2019f Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 18 Aug 2021 14:51:07 +0100 Subject: [PATCH 06/38] Update annotations in util --- synapse/util/__init__.py | 20 ++++++++++--------- synapse/util/distributor.py | 21 ++++++++++---------- synapse/util/frozenutils.py | 5 +++-- synapse/util/httpresourcetree.py | 27 +++++++++++++------------- synapse/util/patch_inline_callbacks.py | 4 ++-- synapse/util/ratelimitutils.py | 2 +- synapse/util/retryutils.py | 6 ++++-- synapse/util/rlimit.py | 2 +- synapse/util/templates.py | 2 +- synapse/util/threepids.py | 10 ++++++---- synapse/util/versionstring.py | 2 +- synapse/util/wheel_timer.py | 16 +++++++++------ 12 files changed, 65 insertions(+), 52 deletions(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 8ba901fabb3c..6bc7d4a68b13 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,12 +15,14 @@ import json import logging import re -from typing import Pattern +from typing import Any, Dict, Pattern import attr from frozendict import frozendict from twisted.internet import defer, task +from twisted.internet.interfaces import IDelayedCall +from twisted.internet.task import LoopingCall from synapse.logging import context @@ -30,12 +32,12 @@ _WILDCARD_RUN = re.compile(r"([\?\*]+)") -def _reject_invalid_json(val): +def _reject_invalid_json(val) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) -def _handle_frozendict(obj): +def _handle_frozendict(obj: Any) -> Dict[Any, Any]: """Helper for json_encoder. Makes frozendicts serializable by returning the underlying dict """ @@ -78,22 +80,22 @@ class Clock: _reactor = attr.ib() @defer.inlineCallbacks - def sleep(self, seconds): + def sleep(self, seconds: float): d: defer.Deferred = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) res = yield d return res - def time(self): + def time(self) -> float: """Returns the current system time in seconds since epoch.""" return self._reactor.seconds() - def time_msec(self): + def time_msec(self) -> int: """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) - def looping_call(self, f, msec, *args, **kwargs): + def looping_call(self, f, msec, *args, **kwargs) -> LoopingCall: """Call a function repeatedly. Waits `msec` initially before calling `f` for the first time. @@ -113,7 +115,7 @@ def looping_call(self, f, msec, *args, **kwargs): d.addErrback(log_failure, "Looping call died", consumeErrors=False) return call - def call_later(self, delay, callback, *args, **kwargs): + def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall: """Call something later Note that the function will be called with no logcontext, so if it is anything @@ -133,7 +135,7 @@ def wrapped_callback(*args, **kwargs): with context.PreserveLoggingContext(): return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) - def cancel_call_later(self, timer, ignore_errs=False): + def cancel_call_later(self, timer: IDelayedCall, ignore_errs=False) -> None: try: timer.cancel() except Exception: diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 1f803aef6d1b..32e5f1047b57 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Callable, Dict, List from twisted.internet import defer @@ -38,10 +39,10 @@ class Distributor: """ def __init__(self): - self.signals = {} - self.pre_registration = {} + self.signals: Dict[str, Signal] = {} + self.pre_registration: Dict[str, List[Callable]] = {} - def declare(self, name): + def declare(self, name: str) -> None: if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) @@ -52,7 +53,7 @@ def declare(self, name): for observer in self.pre_registration[name]: signal.observe(observer) - def observe(self, name, observer): + def observe(self, name: str, observer: Callable) -> None: if name in self.signals: self.signals[name].observe(observer) else: @@ -62,7 +63,7 @@ def observe(self, name, observer): self.pre_registration[name] = [] self.pre_registration[name].append(observer) - def fire(self, name, *args, **kwargs): + def fire(self, name: str, *args, **kwargs) -> None: """Dispatches the given signal to the registered observers. Runs the observers as a background process. Does not return a deferred. @@ -83,18 +84,18 @@ class Signal: method into all of the observers. """ - def __init__(self, name): - self.name = name - self.observers = [] + def __init__(self, name: str): + self.name: str = name + self.observers: List[Callable] = [] - def observe(self, observer): + def observe(self, observer: Callable) -> None: """Adds a new callable to the observer list which will be invoked by the 'fire' method. Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args, **kwargs): + def fire(self, *args, **kwargs) -> defer.Deferred: """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 2ac7c2913cdc..9c405eb4d763 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from frozendict import frozendict -def freeze(o): +def freeze(o: Any) -> Any: if isinstance(o, dict): return frozendict({k: freeze(v) for k, v in o.items()}) @@ -33,7 +34,7 @@ def freeze(o): return o -def unfreeze(o): +def unfreeze(o: Any) -> Any: if isinstance(o, (dict, frozendict)): return {k: unfreeze(v) for k, v in o.items()} diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 3c0e8469f3ed..b163643ca333 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -13,42 +13,43 @@ # limitations under the License. import logging +from typing import Dict -from twisted.web.resource import NoResource +from twisted.web.resource import NoResource, Resource logger = logging.getLogger(__name__) -def create_resource_tree(desired_tree, root_resource): +def create_resource_tree( + desired_tree: Dict[str, Resource], root_resource: Resource +) -> Resource: """Create the resource tree for this homeserver. This in unduly complicated because Twisted does not support putting child resources more than 1 level deep at a time. Args: - web_client (bool): True to enable the web client. - root_resource (twisted.web.resource.Resource): The root - resource to add the tree to. + desired_tree: Dict from desired paths to desired resources. + root_resource: The root resource to add the tree to. Returns: - twisted.web.resource.Resource: the ``root_resource`` with a tree of - child resources added to it. + The ``root_resource`` with a tree of child resources added to it. """ # ideally we'd just use getChild and putChild but getChild doesn't work # unless you give it a Request object IN ADDITION to the name :/ So # instead, we'll store a copy of this mapping so we can actually add # extra resources to existing nodes. See self._resource_id for the key. - resource_mappings = {} - for full_path, res in desired_tree.items(): + resource_mappings: Dict[str, Resource] = {} + for full_path_str, res in desired_tree.items(): # twisted requires all resources to be bytes - full_path = full_path.encode("utf-8") + full_path = full_path_str.encode("utf-8") logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource = NoResource() + child_resource: Resource = NoResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource @@ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource): return root_resource -def _resource_id(resource, path_seg): +def _resource_id(resource: Resource, path_seg: bytes) -> str: """Construct an arbitrary resource ID so you can retrieve the mapping later. @@ -96,4 +97,4 @@ def _resource_id(resource, path_seg): Returns: str: A unique string which can be a key to the child Resource. """ - return "%s-%s" % (resource, path_seg) + return "%s-%r" % (resource, path_seg) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 99f01e325cf6..9dd010af3b0e 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -24,7 +24,7 @@ _already_patched = False -def do_patch(): +def do_patch() -> None: """ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit """ @@ -107,7 +107,7 @@ def check_ctx(r): _already_patched = True -def _check_yield_points(f: Callable, changes: List[str]): +def _check_yield_points(f: Callable, changes: List[str]) -> Callable: """Wraps a generator that is about to be passed to defer.inlineCallbacks checking that after every yield the log contexts are correct. diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 457ffa19c0e5..ce2e9230c044 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -37,7 +37,7 @@ class FederationRateLimiter: def __init__(self, clock: Clock, config: FederationRateLimitConfig): - def new_limiter(): + def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter(clock=clock, config=config) self.ratelimiters: DefaultDict[ diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 129b47cd4994..bc65062b62d0 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -51,7 +51,9 @@ def __init__(self, retry_last_ts, retry_interval, destination): self.destination = destination -async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): +async def get_retry_limiter( + destination, clock, store, ignore_backoff=False, **kwargs +) -> "RetryDestinationLimiter": """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. If we are not ready to retry the destination, this will raise a @@ -216,7 +218,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self.failure_ts is None: self.failure_ts = retry_last_ts - async def store_retry_timings(): + async def store_retry_timings() -> None: try: await self.store.set_destination_retry_timings( self.destination, diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index bf812ab5166a..06651e956d1b 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -18,7 +18,7 @@ logger = logging.getLogger("synapse.app.homeserver") -def change_resource_limit(soft_file_no): +def change_resource_limit(soft_file_no: int) -> None: try: soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 70eaf71ca7b9..66f8fbb7758a 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -111,5 +111,5 @@ def mxc_to_http_filter( return mxc_to_http_filter -def _format_ts_filter(value: int, format: str): +def _format_ts_filter(value: int, format: str) -> str: return time.strftime(format, time.localtime(value / 1000)) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index a1cf1960b08f..2841f0f47ed8 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -15,6 +15,8 @@ import logging import re +from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -28,13 +30,13 @@ MAX_EMAIL_ADDRESS_LENGTH = 500 -def check_3pid_allowed(hs, medium, address): +def check_3pid_allowed(hs: HomeServer, medium: str, address: str) -> bool: """Checks whether a given format of 3PID is allowed to be used on this HS Args: - hs (synapse.server.HomeServer): server - medium (str): 3pid medium - e.g. email, msisdn - address (str): address within that medium (e.g. "wotan@matrix.org") + hs: server + medium: 3pid medium - e.g. email, msisdn + address: address within that medium (e.g. "wotan@matrix.org") msisdns need to first have been canonicalised Returns: bool: whether the 3PID medium/address is allowed to be added to this HS diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index cb08af7385eb..1c20b24bbe68 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -def get_version_string(module): +def get_version_string(module) -> str: """Given a module calculate a git-aware version string for it. If called on a module not in a git checkout will return `__verison__`. diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 61814aff241f..12ebf2a3aeed 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Generic, List, TypeVar class _Entry: @@ -21,7 +22,10 @@ def __init__(self, end_key): self.queue = [] -class WheelTimer: +T = TypeVar("T") + + +class WheelTimer(Generic[T]): """Stores arbitrary objects that will be returned after their timers have expired. """ @@ -36,13 +40,13 @@ def __init__(self, bucket_size=5000): self.entries = [] self.current_tick = 0 - def insert(self, now, obj, then): + def insert(self, now: int, obj: T, then: int) -> None: """Inserts object into timer. Args: - now (int): Current time in msec - obj (object): Object to be inserted - then (int): When to return the object strictly after. + now: Current time in msec + obj: Object to be inserted + then: When to return the object strictly after. """ then_key = int(then / self.bucket_size) + 1 @@ -70,7 +74,7 @@ def insert(self, now, obj, then): self.entries[-1].queue.append(obj) - def fetch(self, now): + def fetch(self, now: int) -> List[T]: """Fetch any objects that have timed out Args: From db57064c90aeda751393ac949d32bbb0b44d0153 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 18 Aug 2021 14:51:33 +0100 Subject: [PATCH 07/38] Fix fallout (related annotations and assertions around codebase) --- synapse/api/ratelimiting.py | 8 ++++---- synapse/federation/sender/__init__.py | 6 ++++-- synapse/handlers/account_validity.py | 1 + synapse/handlers/appservice.py | 3 +++ synapse/handlers/presence.py | 6 +++--- synapse/handlers/typing.py | 2 +- synapse/storage/databases/main/registration.py | 1 + 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 3e3d09bbd244..cbdd74025b35 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -46,7 +46,7 @@ def __init__( # * How many times an action has occurred since a point in time # * The point in time # * The rate_hz of this particular entry. This can vary per request - self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() + self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() async def can_do_action( self, @@ -56,7 +56,7 @@ async def can_do_action( burst_count: Optional[int] = None, update: bool = True, n_actions: int = 1, - _time_now_s: Optional[int] = None, + _time_now_s: Optional[float] = None, ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? @@ -160,7 +160,7 @@ async def can_do_action( return allowed, time_allowed - def _prune_message_counts(self, time_now_s: int): + def _prune_message_counts(self, time_now_s: float): """Remove message count entries that have not exceeded their defined rate_hz limit @@ -188,7 +188,7 @@ async def ratelimit( burst_count: Optional[int] = None, update: bool = True, n_actions: int = 1, - _time_now_s: Optional[int] = None, + _time_now_s: Optional[float] = None, ): """Checks if an action can be performed. If not, raises a LimitExceededError diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index d980e0d9866a..52704c6f41b1 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -19,6 +19,7 @@ import attr from prometheus_client import Counter +from twisted.internet.interfaces import IDelayedCall from typing_extensions import Literal from twisted.internet import defer @@ -284,7 +285,7 @@ def __init__(self, hs: "HomeServer"): ) # wake up destinations that have outstanding PDUs to be caught up - self._catchup_after_startup_timer = self.clock.call_later( + self._catchup_after_startup_timer: Optional[IDelayedCall] = self.clock.call_later( CATCH_UP_STARTUP_DELAY_SEC, run_as_background_process, "wake_destinations_needing_catchup", @@ -406,7 +407,7 @@ async def handle_event(event: EventBase) -> None: now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) - + assert ts is not None synapse.metrics.event_processing_lag_by_event.labels( "federation_sender" ).observe((now - ts) / 1000) @@ -435,6 +436,7 @@ async def handle_room_events(events: Iterable[EventBase]) -> None: if events: now = self.clock.time_msec() ts = await self.store.get_received_ts(events[-1].event_id) + assert ts is not None synapse.metrics.event_processing_lag.labels( "federation_sender" diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 078accd634f1..dd6975505317 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -398,6 +398,7 @@ async def renew_account_for_user( """ now = self.clock.time_msec() if expiration_ts is None: + assert self._account_validity_period is not None expiration_ts = now + self._account_validity_period await self.store.set_account_validity_for_user( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 4ab4046650b8..a7b5a4e9c94f 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -131,6 +131,8 @@ async def start_scheduler(): now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) + assert ts is not None + synapse.metrics.event_processing_lag_by_event.labels( "appservice_sender" ).observe((now - ts) / 1000) @@ -166,6 +168,7 @@ async def handle_room_events(events): if events: now = self.clock.time_msec() ts = await self.store.get_received_ts(events[-1].event_id) + assert ts is not None synapse.metrics.event_processing_lag.labels( "appservice_sender" diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 7ca14e1d8473..8c0f7fc98934 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -37,7 +37,7 @@ Optional, Set, Tuple, - Union, + Union, Any, ) from prometheus_client import Counter @@ -610,7 +610,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.server_name = hs.hostname - self.wheel_timer = WheelTimer() + self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() self._presence_enabled = hs.config.use_presence @@ -919,7 +919,7 @@ async def bump_presence_active_time(self, user: UserID) -> None: prev_state = await self.current_state_for_user(user_id) - new_fields = {"last_active_ts": self.clock.time_msec()} + new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a97c448595e9..b84bd5e49ab6 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -73,7 +73,7 @@ def __init__(self, hs: "HomeServer"): self._room_typing: Dict[str, Set[str]] = {} self._member_last_federation_poke: Dict[RoomMember, int] = {} - self.wheel_timer = WheelTimer(bucket_size=5000) + self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 self.clock.looping_call(self._handle_timeouts, 5000) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 14670c28817d..f011c9e4a9fe 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1058,6 +1058,7 @@ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() + assert self._account_validity_period is not None expiration_ts = now_ms + self._account_validity_period if use_delta: From 348f9ff622ddf48bf5d18b4ff761ec2b47195b98 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 18 Aug 2021 15:09:19 +0100 Subject: [PATCH 08/38] antilint --- synapse/federation/sender/__init__.py | 6 ++++-- synapse/handlers/presence.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 52704c6f41b1..4c0a3eb5af64 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -19,10 +19,10 @@ import attr from prometheus_client import Counter -from twisted.internet.interfaces import IDelayedCall from typing_extensions import Literal from twisted.internet import defer +from twisted.internet.interfaces import IDelayedCall import synapse.metrics from synapse.api.presence import UserPresenceState @@ -285,7 +285,9 @@ def __init__(self, hs: "HomeServer"): ) # wake up destinations that have outstanding PDUs to be caught up - self._catchup_after_startup_timer: Optional[IDelayedCall] = self.clock.call_later( + self._catchup_after_startup_timer: Optional[ + IDelayedCall + ] = self.clock.call_later( CATCH_UP_STARTUP_DELAY_SEC, run_as_background_process, "wake_destinations_needing_catchup", diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8c0f7fc98934..3b1be6096903 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -28,6 +28,7 @@ from contextlib import contextmanager from typing import ( TYPE_CHECKING, + Any, Callable, Collection, Dict, @@ -37,7 +38,7 @@ Optional, Set, Tuple, - Union, Any, + Union, ) from prometheus_client import Counter From d081c8306f2db65d9042662f2c00d5d3202eebed Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 18 Aug 2021 15:09:27 +0100 Subject: [PATCH 09/38] add type parameters for Deferreds --- synapse/util/__init__.py | 2 +- synapse/util/batching_queue.py | 2 +- synapse/util/distributor.py | 4 ++-- synapse/util/ratelimitutils.py | 12 ++++++------ 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 6bc7d4a68b13..a915e9d22b2e 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -81,7 +81,7 @@ class Clock: @defer.inlineCallbacks def sleep(self, seconds: float): - d: defer.Deferred = defer.Deferred() + d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) res = yield d diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 6d0e2a43f0a7..2a903004a91b 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -122,7 +122,7 @@ async def add_to_queue(self, value: V, key: Hashable = ()) -> R: # First we create a defer and add it and the value to the list of # pending items. - d: defer.Deferred = defer.Deferred() + d: defer.Deferred[R] = defer.Deferred() self._next_values.setdefault(key, []).append((value, d)) # If we're not currently processing the key fire off a background diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 32e5f1047b57..b741aa4558e1 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List from twisted.internet import defer @@ -95,7 +95,7 @@ def observe(self, observer: Callable) -> None: Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args, **kwargs) -> defer.Deferred: + def fire(self, *args, **kwargs) -> defer.Deferred[List[Any]]: """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index ce2e9230c044..30fc07d6f253 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -44,7 +44,7 @@ def new_limiter() -> "_PerHostRatelimiter": str, "_PerHostRatelimiter" ] = collections.defaultdict(new_limiter) - def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred]": + def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]": """Used to ratelimit an incoming request from a given host Example usage: @@ -83,7 +83,7 @@ def __init__(self, clock, config): # map from request_id object to Deferred for requests which are ready # for processing but have been queued self.ready_request_queue: collections.OrderedDict[ - object, defer.Deferred + object, defer.Deferred[None] ] = collections.OrderedDict() # request id objects for requests which are in progress @@ -94,7 +94,7 @@ def __init__(self, clock, config): self.request_times = [] @contextlib.contextmanager - def ratelimit(self) -> Iterator[defer.Deferred]: + def ratelimit(self) -> Iterator[defer.Deferred[None]]: # `contextlib.contextmanager` takes a generator and turns it into a # context manager. The generator should only yield once with a value # to be returned by manager. @@ -107,7 +107,7 @@ def ratelimit(self) -> Iterator[defer.Deferred]: finally: self._on_exit(request_id) - def _on_enter(self, request_id) -> defer.Deferred: + def _on_enter(self, request_id) -> defer.Deferred[None]: time_now = self.clock.time_msec() # remove any entries from request_times which aren't within the window @@ -125,9 +125,9 @@ def _on_enter(self, request_id) -> defer.Deferred: self.request_times.append(time_now) - def queue_request(): + def queue_request() -> "defer.Deferred[None]": if len(self.current_processing) >= self.concurrent_requests: - queue_defer: defer.Deferred = defer.Deferred() + queue_defer: defer.Deferred[None] = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( "Ratelimiter: queueing request (queue now %i items)", From 76c3b6b3bd9f58e01d1b2f95c39db1b4f13c7509 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 23 Aug 2021 15:47:43 +0100 Subject: [PATCH 10/38] Fix circular import of HomeServer --- synapse/util/threepids.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 2841f0f47ed8..baa9190a9af2 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -14,8 +14,10 @@ import logging import re +import typing -from synapse.server import HomeServer +if typing.TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -30,7 +32,7 @@ MAX_EMAIL_ADDRESS_LENGTH = 500 -def check_3pid_allowed(hs: HomeServer, medium: str, address: str) -> bool: +def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: """Checks whether a given format of 3PID is allowed to be used on this HS Args: From 30ffee4028a776013dbc207a7024ce09e8b2ca7a Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 23 Aug 2021 17:34:30 +0100 Subject: [PATCH 11/38] Quote deferreds in method signatures --- synapse/util/distributor.py | 2 +- synapse/util/ratelimitutils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index b741aa4558e1..a380c03e8db3 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -95,7 +95,7 @@ def observe(self, observer: Callable) -> None: Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args, **kwargs) -> defer.Deferred[List[Any]]: + def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]": """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 30fc07d6f253..aba9b5e6ccd9 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -94,7 +94,7 @@ def __init__(self, clock, config): self.request_times = [] @contextlib.contextmanager - def ratelimit(self) -> Iterator[defer.Deferred[None]]: + def ratelimit(self) -> "Iterator[defer.Deferred[None]]": # `contextlib.contextmanager` takes a generator and turns it into a # context manager. The generator should only yield once with a value # to be returned by manager. @@ -107,7 +107,7 @@ def ratelimit(self) -> Iterator[defer.Deferred[None]]: finally: self._on_exit(request_id) - def _on_enter(self, request_id) -> defer.Deferred[None]: + def _on_enter(self, request_id) -> "defer.Deferred[None]": time_now = self.clock.time_msec() # remove any entries from request_times which aren't within the window From 10bd84fa4ff156c0b043ed6ac360617f3025138b Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Sep 2021 17:05:36 +0100 Subject: [PATCH 12/38] Annotate more types --- synapse/util/__init__.py | 23 +++++++++------ synapse/util/ratelimitutils.py | 10 +++---- synapse/util/retryutils.py | 54 ++++++++++++++++++---------------- 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index a915e9d22b2e..9ff771fe01df 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,16 +15,21 @@ import json import logging import re -from typing import Any, Dict, Pattern +import typing +from typing import Any, Dict, Pattern, Callable import attr from frozendict import frozendict + from twisted.internet import defer, task from twisted.internet.interfaces import IDelayedCall from twisted.internet.task import LoopingCall +from twisted.python.failure import Failure from synapse.logging import context +if typing.TYPE_CHECKING: + from twisted.application.reactors import Reactor logger = logging.getLogger(__name__) @@ -32,7 +37,7 @@ _WILDCARD_RUN = re.compile(r"([\?\*]+)") -def _reject_invalid_json(val) -> None: +def _reject_invalid_json(val: Any) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) @@ -62,7 +67,7 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]: json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) -def unwrapFirstError(failure): +def unwrapFirstError(failure: Failure) -> Failure: # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) return failure.value.subFailure @@ -77,10 +82,10 @@ class Clock: reactor: The Twisted reactor to use. """ - _reactor = attr.ib() + _reactor: Reactor = attr.ib() @defer.inlineCallbacks - def sleep(self, seconds: float): + def sleep(self, seconds: float) -> float: d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) @@ -95,7 +100,7 @@ def time_msec(self) -> int: """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) - def looping_call(self, f, msec, *args, **kwargs) -> LoopingCall: + def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall: """Call a function repeatedly. Waits `msec` initially before calling `f` for the first time. @@ -104,8 +109,8 @@ def looping_call(self, f, msec, *args, **kwargs) -> LoopingCall: other than trivial, you probably want to wrap it in run_as_background_process. Args: - f(function): The function to call repeatedly. - msec(float): How long to wait between calls in milliseconds. + f: The function to call repeatedly. + msec: How long to wait between calls in milliseconds. *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ @@ -135,7 +140,7 @@ def wrapped_callback(*args, **kwargs): with context.PreserveLoggingContext(): return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) - def cancel_call_later(self, timer: IDelayedCall, ignore_errs=False) -> None: + def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool=False) -> None: try: timer.cancel() except Exception: diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index aba9b5e6ccd9..d7236853f8a8 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -63,11 +63,11 @@ def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None] class _PerHostRatelimiter: - def __init__(self, clock, config): + def __init__(self, clock: Clock, config: FederationRateLimitConfig): """ Args: - clock (Clock) - config (FederationRateLimitConfig) + clock + config """ self.clock = clock @@ -107,7 +107,7 @@ def ratelimit(self) -> "Iterator[defer.Deferred[None]]": finally: self._on_exit(request_id) - def _on_enter(self, request_id) -> "defer.Deferred[None]": + def _on_enter(self, request_id: object) -> "defer.Deferred[None]": time_now = self.clock.time_msec() # remove any entries from request_times which aren't within the window @@ -182,7 +182,7 @@ def on_both(r): ret_defer.addBoth(on_both) return make_deferred_yieldable(ret_defer) - def _on_exit(self, request_id) -> None: + def _on_exit(self, request_id: object) -> None: logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index bc65062b62d0..bc10aa0d8462 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -13,9 +13,13 @@ # limitations under the License. import logging import random +from types import TracebackType +from typing import Any, Optional, Type, TypeVar import synapse.logging.context from synapse.api.errors import CodeMessageException +from synapse.storage import DataStore +from synapse.util import Clock logger = logging.getLogger(__name__) @@ -30,17 +34,17 @@ class NotRetryingDestination(Exception): - def __init__(self, retry_last_ts, retry_interval, destination): + def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): """Raised by the limiter (and federation client) to indicate that we are are deliberately not attempting to contact a given server. Args: - retry_last_ts (int): the unix ts in milliseconds of our last attempt + retry_last_ts: the unix ts in milliseconds of our last attempt to contact the server. 0 indicates that the last attempt was successful or that we've never actually attempted to connect. - retry_interval (int): the time in milliseconds to wait until the next + retry_interval: the time in milliseconds to wait until the next attempt. - destination (str): the domain in question + destination: the domain in question """ msg = "Not retrying server %s." % (destination,) @@ -52,7 +56,7 @@ def __init__(self, retry_last_ts, retry_interval, destination): async def get_retry_limiter( - destination, clock, store, ignore_backoff=False, **kwargs + destination: str, clock: Clock, store: DataStore, ignore_backoff: bool = False, **kwargs: Any ) -> "RetryDestinationLimiter": """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. @@ -62,10 +66,10 @@ async def get_retry_limiter( CodeMessageException with code < 500) Args: - destination (str): name of homeserver - clock (synapse.util.clock): timing source - store (synapse.storage.transactions.TransactionStore): datastore - ignore_backoff (bool): true to ignore the historical backoff data and + destination: name of homeserver + clock: timing source + store: datastore + ignore_backoff: true to ignore the historical backoff data and try the request anyway. We will still reset the retry_interval on success. Example usage: @@ -116,13 +120,13 @@ async def get_retry_limiter( class RetryDestinationLimiter: def __init__( self, - destination, - clock, - store, - failure_ts, - retry_interval, - backoff_on_404=False, - backoff_on_failure=True, + destination: str, + clock: Clock, + store: DataStore, + failure_ts: Optional[int], + retry_interval: int, + backoff_on_404: bool = False, + backoff_on_failure: bool = True, ): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -130,17 +134,17 @@ def __init__( If no exception is raised, marks the destination as "up". Args: - destination (str) - clock (Clock) - store (DataStore) - failure_ts (int|None): when this destination started failing (in ms since + destination + clock + store + failure_ts: when this destination started failing (in ms since the epoch), or zero if the last request was successful - retry_interval (int): The next retry interval taken from the + retry_interval: The next retry interval taken from the database in milliseconds, or zero if the last request was successful. - backoff_on_404 (bool): Back off if we get a 404 + backoff_on_404: Back off if we get a 404 - backoff_on_failure (bool): set to False if we should not increase the + backoff_on_failure: set to False if we should not increase the retry interval on a failure. """ self.clock = clock @@ -152,10 +156,10 @@ def __init__( self.backoff_on_404 = backoff_on_404 self.backoff_on_failure = backoff_on_failure - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb: TracebackType) -> None: valid_err_code = False if exc_type is None: valid_err_code = True From 0c26b7f3e0d546a040232c89b339c7674bbb2541 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Sep 2021 18:01:15 +0100 Subject: [PATCH 13/38] Use attrs class and fix ignored fields [WANTS REVIEW] --- synapse/config/ratelimiting.py | 19 ++++++++----------- synapse/rest/client/register.py | 4 ++-- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 7a8d5851c40b..62edd8695dc6 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -14,6 +14,8 @@ from typing import Dict, Optional +import attr + from ._base import Config @@ -29,18 +31,13 @@ def __init__( self.burst_count = int(config.get("burst_count", defaults["burst_count"])) +@attr.s(auto_attribs=True) class FederationRateLimitConfig: - _items_and_default = { - "window_size": 1000, - "sleep_limit": 10, - "sleep_delay": 500, - "reject_limit": 50, - "concurrent": 3, - } - - def __init__(self, **kwargs): - for i in self._items_and_default.keys(): - setattr(self, i, kwargs.get(i) or self._items_and_default[i]) + window_size: int = 1000 + sleep_limit: int = 10 + sleep_delay: int = 500 + reject_limit: int = 50 + concurrent: int = 3 class RatelimitConfig(Config): diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 58b8e8f2614f..4f53172bef28 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -354,11 +354,11 @@ def __init__(self, hs): # Artificially delay requests if rate > sleep_limit/window_size sleep_limit=1, # Amount of artificial delay to apply - sleep_msec=1000, + sleep_delay=1000, # Error with 429 if more than reject_limit requests are queued reject_limit=1, # Allow 1 request at a time - concurrent_requests=1, + concurrent=1, ), ) From 715bfdc53289402fab2cb28f05434fef292652ed Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Sep 2021 18:01:30 +0100 Subject: [PATCH 14/38] Ignore import issues [WANTS REVIEW] --- synapse/util/async_helpers.py | 2 +- synapse/util/caches/lrucache.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a3b65aee27b5..fa20ee65536e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -268,7 +268,7 @@ def __init__( if not clock: from twisted.internet import reactor - clock = Clock(reactor) + clock = Clock(reactor) # type: ignore[arg-type] self._clock = clock self.max_count = max_count diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 5c65d187b6da..00ac9b1d3e43 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -341,7 +341,7 @@ def __init__( # Default `clock` to something sensible. Note that we rename it to # `real_clock` so that mypy doesn't think its still `Optional`. if clock is None: - real_clock = Clock(reactor) + real_clock = Clock(reactor) # type: ignore[arg-type] else: real_clock = clock From 1e4632f08cf6f17da1e123166e15b54dc4997a2b Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Sep 2021 18:01:51 +0100 Subject: [PATCH 15/38] Annotate more types 3 to go! --- synapse/util/__init__.py | 15 ++++++++------- synapse/util/ratelimitutils.py | 8 ++++---- synapse/util/retryutils.py | 8 ++++++-- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 9ff771fe01df..d33f55338cc8 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -16,20 +16,21 @@ import logging import re import typing -from typing import Any, Dict, Pattern, Callable +from typing import Any, Callable, Dict, Pattern import attr from frozendict import frozendict - from twisted.internet import defer, task -from twisted.internet.interfaces import IDelayedCall +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IDelayedCall, IReactorTime from twisted.internet.task import LoopingCall from twisted.python.failure import Failure from synapse.logging import context + if typing.TYPE_CHECKING: - from twisted.application.reactors import Reactor + pass logger = logging.getLogger(__name__) @@ -82,10 +83,10 @@ class Clock: reactor: The Twisted reactor to use. """ - _reactor: Reactor = attr.ib() + _reactor: IReactorTime = attr.ib() @defer.inlineCallbacks - def sleep(self, seconds: float) -> float: + def sleep(self, seconds: float) -> typing.Iterable[Deferred[float]]: d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) @@ -140,7 +141,7 @@ def wrapped_callback(*args, **kwargs): with context.PreserveLoggingContext(): return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) - def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool=False) -> None: + def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None: try: timer.cancel() except Exception: diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index d7236853f8a8..5b6701bc667f 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -16,7 +16,7 @@ import contextlib import logging import typing -from typing import DefaultDict, Iterator +from typing import DefaultDict, Iterator, List, Set from twisted.internet import defer @@ -78,7 +78,7 @@ def __init__(self, clock: Clock, config: FederationRateLimitConfig): self.concurrent_requests = config.concurrent # request_id objects for requests which have been slept - self.sleeping_requests = set() + self.sleeping_requests: Set[object] = set() # map from request_id object to Deferred for requests which are ready # for processing but have been queued @@ -87,11 +87,11 @@ def __init__(self, clock: Clock, config: FederationRateLimitConfig): ] = collections.OrderedDict() # request id objects for requests which are in progress - self.current_processing = set() + self.current_processing: Set[object] = set() # times at which we have recently (within the last window_size ms) # received requests. - self.request_times = [] + self.request_times: List[int] = [] @contextlib.contextmanager def ratelimit(self) -> "Iterator[defer.Deferred[None]]": diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index bc10aa0d8462..3b76e0b75a16 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -14,7 +14,7 @@ import logging import random from types import TracebackType -from typing import Any, Optional, Type, TypeVar +from typing import Any, Optional import synapse.logging.context from synapse.api.errors import CodeMessageException @@ -56,7 +56,11 @@ def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): async def get_retry_limiter( - destination: str, clock: Clock, store: DataStore, ignore_backoff: bool = False, **kwargs: Any + destination: str, + clock: Clock, + store: DataStore, + ignore_backoff: bool = False, + **kwargs: Any, ) -> "RetryDestinationLimiter": """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. From 05cc10cabea9ea35ac8b9795485a9367f87cca5f Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Thu, 2 Sep 2021 10:22:04 +0100 Subject: [PATCH 16/38] Annotate more types --- synapse/util/async_helpers.py | 8 +++---- synapse/util/caches/dictionary_cache.py | 30 ++++++++++++++----------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index fa20ee65536e..aecdb715cc9e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -551,18 +551,18 @@ def failure_cb(val): @attr.s(slots=True, frozen=True) -class DoneAwaitable: +class DoneAwaitable(Generic[R]): """Simple awaitable that returns the provided value.""" - value = attr.ib() + value = attr.ib(type="R") def __await__(self): return self - def __iter__(self): + def __iter__(self) -> "DoneAwaitable[R]": return self - def __next__(self): + def __next__(self) -> None: raise StopIteration(self.value) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 3f852edd7fcf..225591ca7ac6 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -14,7 +14,7 @@ import enum import logging import threading -from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar +from typing import Dict, Generic, Iterable, Optional, Set, TypeVar import attr @@ -27,10 +27,12 @@ KT = TypeVar("KT") # The type of the dictionary keys. DKT = TypeVar("DKT") +# The type of the dictionary values. +DV = TypeVar("DV") @attr.s(slots=True) -class DictionaryEntry: +class DictionaryEntry(Generic[DKT, DV]): """Returned when getting an entry from the cache Attributes: @@ -43,10 +45,10 @@ class DictionaryEntry: """ full = attr.ib(type=bool) - known_absent = attr.ib() - value = attr.ib() + known_absent = attr.ib(type=Set[DKT]) + value = attr.ib(type=Dict[DKT, DV]) - def __len__(self): + def __len__(self) -> int: return len(self.value) @@ -56,13 +58,13 @@ class _Sentinel(enum.Enum): sentinel = object() -class DictionaryCache(Generic[KT, DKT]): +class DictionaryCache(Generic[KT, DKT, DV]): """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ def __init__(self, name: str, max_entries: int = 1000): - self.cache: LruCache[KT, DictionaryEntry] = LruCache( + self.cache: LruCache[KT, DictionaryEntry[DKT, DV]] = LruCache( max_size=max_entries, cache_name=name, size_callback=len ) @@ -82,12 +84,12 @@ def check_thread(self) -> None: def get( self, key: KT, dict_keys: Optional[Iterable[DKT]] = None - ) -> DictionaryEntry: + ) -> DictionaryEntry[DKT, DV]: """Fetch an entry out of the cache Args: key - dict_key: If given a set of keys then return only those keys + dict_keys: If given a set of keys then return only those keys that exist in the cache. Returns: @@ -125,7 +127,7 @@ def update( self, sequence: int, key: KT, - value: Dict[DKT, Any], + value: Dict[DKT, DV], fetched_keys: Optional[Set[DKT]] = None, ) -> None: """Updates the entry in the cache @@ -151,15 +153,17 @@ def update( self._update_or_insert(key, value, fetched_keys) def _update_or_insert( - self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] + self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT] ) -> None: # We pop and reinsert as we need to tell the cache the size may have # changed - entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) + entry: DictionaryEntry[DKT, DV] = self.cache.pop( + key, DictionaryEntry(False, set(), {}) + ) entry.value.update(value) entry.known_absent.update(known_absent) self.cache[key] = entry - def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: + def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None: self.cache[key] = DictionaryEntry(True, known_absent, value) From 1c6704cabd3b8c2b5028281f37f022f1adff137c Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Thu, 2 Sep 2021 10:22:19 +0100 Subject: [PATCH 17/38] Annotate types and ignore Twisted issues [WANTS REVIEW] --- synapse/util/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index d33f55338cc8..b59c1e3d80cb 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -16,7 +16,7 @@ import logging import re import typing -from typing import Any, Callable, Dict, Pattern +from typing import Any, Callable, Dict, Generator, Pattern import attr from frozendict import frozendict @@ -71,7 +71,7 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]: def unwrapFirstError(failure: Failure) -> Failure: # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) - return failure.value.subFailure + return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations @attr.s(slots=True) @@ -85,8 +85,8 @@ class Clock: _reactor: IReactorTime = attr.ib() - @defer.inlineCallbacks - def sleep(self, seconds: float) -> typing.Iterable[Deferred[float]]: + @defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations + def sleep(self, seconds: float) -> Generator[Deferred[float], Any, Any]: d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) From c38437316f8497b80a35757e99b29d8bd60a0e49 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Thu, 2 Sep 2021 16:40:08 +0100 Subject: [PATCH 18/38] Add IReactorThreads as parent of ISynapseReactor --- synapse/types.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/synapse/types.py b/synapse/types.py index 80fa903c4bae..d4759b2dfd45 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -38,6 +38,7 @@ IReactorCore, IReactorPluggableNameResolver, IReactorTCP, + IReactorThreads, IReactorTime, ) @@ -63,7 +64,12 @@ # Note that this seems to require inheriting *directly* from Interface in order # for mypy-zope to realize it is an interface. class ISynapseReactor( - IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface + IReactorTCP, + IReactorPluggableNameResolver, + IReactorTime, + IReactorCore, + IReactorThreads, + Interface, ): """The interfaces necessary for Synapse to function.""" From 884a8b6b71ea98facc7e945be4f9f1d0d81aebdd Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Thu, 2 Sep 2021 16:40:48 +0100 Subject: [PATCH 19/38] Annotate more types --- stubs/txredisapi.pyi | 2 +- synapse/util/async_helpers.py | 4 +-- synapse/util/caches/__init__.py | 14 ++++----- synapse/util/caches/deferred_cache.py | 14 ++++----- synapse/util/caches/lrucache.py | 2 +- synapse/util/caches/stream_change_cache.py | 2 +- synapse/util/caches/treecache.py | 16 +++++----- synapse/util/daemonize.py | 2 +- synapse/util/distributor.py | 2 +- synapse/util/file_consumer.py | 35 +++++++++++++--------- synapse/util/linked_list.py | 8 ++--- synapse/util/macaroons.py | 2 +- synapse/util/manhole.py | 4 +-- synapse/util/ratelimitutils.py | 8 ++--- synapse/util/wheel_timer.py | 13 ++++---- 15 files changed, 67 insertions(+), 61 deletions(-) diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index c1a06ae022f6..4ff3c6de5feb 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory): def buildProtocol(self, addr) -> RedisProtocol: ... class SubscriberFactory(RedisFactory): - def __init__(self): ... + def __init__(self) -> None: ... diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index aecdb715cc9e..080fc004036c 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -411,7 +411,7 @@ class ReadWriteLock: # writers and readers have been resolved. The new writer replaces the latest # writer. - def __init__(self): + def __init__(self) -> None: # Latest readers queued self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} @@ -503,7 +503,7 @@ def timeout_deferred( timed_out = [False] - def time_it_out(): + def time_it_out() -> None: timed_out[0] = True try: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 9012034b7aa8..d3d2a1b1f79a 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -64,32 +64,32 @@ class CacheMetric: evicted_size = attr.ib(default=0) memory_usage = attr.ib(default=None) - def inc_hits(self): + def inc_hits(self) -> None: self.hits += 1 - def inc_misses(self): + def inc_misses(self) -> None: self.misses += 1 - def inc_evictions(self, size=1): + def inc_evictions(self, size=1) -> None: self.evicted_size += size - def inc_memory_usage(self, memory: int): + def inc_memory_usage(self, memory: int) -> None: if self.memory_usage is None: self.memory_usage = 0 self.memory_usage += memory - def dec_memory_usage(self, memory: int): + def dec_memory_usage(self, memory: int) -> None: self.memory_usage -= memory - def clear_memory_usage(self): + def clear_memory_usage(self) -> None: if self.memory_usage is not None: self.memory_usage = 0 def describe(self): return [] - def collect(self): + def collect(self) -> None: try: if self._cache_type == "response_cache": response_cache_size.labels(self._cache_name).set(len(self._cache)) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index b6456392cd92..f05590da0d54 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -93,7 +93,7 @@ def __init__( TreeCache, "MutableMapping[KT, CacheEntry]" ] = cache_type() - def metrics_cb(): + def metrics_cb() -> None: cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) # cache is used for completed results and maps to the result itself, rather than @@ -113,7 +113,7 @@ def metrics_cb(): def max_entries(self): return self.cache.max_size - def check_thread(self): + def check_thread(self) -> None: expected_thread = self.thread if expected_thread is None: self.thread = threading.current_thread() @@ -235,7 +235,7 @@ def set( self._pending_deferred_cache[key] = entry - def compare_and_pop(): + def compare_and_pop() -> bool: """Check if our entry is still the one in _pending_deferred_cache, and if so, pop it. @@ -256,7 +256,7 @@ def compare_and_pop(): return False - def cb(result): + def cb(result) -> None: if compare_and_pop(): self.cache.set(key, result, entry.callbacks) else: @@ -268,7 +268,7 @@ def cb(result): # not have been. Either way, let's double-check now. entry.invalidate() - def eb(_fail): + def eb(_fail) -> None: compare_and_pop() entry.invalidate() @@ -314,7 +314,7 @@ def invalidate(self, key): for entry in iterate_tree_cache_entry(entry): entry.invalidate() - def invalidate_all(self): + def invalidate_all(self) -> None: self.check_thread() self.cache.clear() for entry in self._pending_deferred_cache.values(): @@ -332,7 +332,7 @@ def __init__( self.callbacks = set(callbacks) self.invalidated = False - def invalidate(self): + def invalidate(self) -> None: if not self.invalidated: self.invalidated = True for callback in self.callbacks: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 00ac9b1d3e43..c8ebb61cda09 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -384,7 +384,7 @@ def __init__( lock = threading.Lock() - def evict(): + def evict() -> None: while cache_len() > self.max_size: # Get the last node in the list (i.e. the oldest node). todelete = list_root.prev_node diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 3a41a8baa603..27b1da235ef3 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -195,7 +195,7 @@ def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: for entity in r: del self._entity_to_key[entity] - def _evict(self): + def _evict(self) -> None: while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 4138931e7bc1..563845f86769 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -35,17 +35,17 @@ class TreeCache: root = {key_1: {key_2: _value}} """ - def __init__(self): - self.size = 0 + def __init__(self) -> None: + self.size: int = 0 self.root = TreeCacheNode() - def __setitem__(self, key, value): - return self.set(key, value) + def __setitem__(self, key, value) -> None: + self.set(key, value) - def __contains__(self, key): + def __contains__(self, key) -> bool: return self.get(key, SENTINEL) is not SENTINEL - def set(self, key, value): + def set(self, key, value) -> None: if isinstance(value, TreeCacheNode): # this would mean we couldn't tell where our tree ended and the value # started. @@ -73,7 +73,7 @@ def get(self, key, default=None): return default return node.get(key[-1], default) - def clear(self): + def clear(self) -> None: self.size = 0 self.root = TreeCacheNode() @@ -128,7 +128,7 @@ def pop(self, key, default=None): def values(self): return iterate_tree_cache_entry(self.root) - def __len__(self): + def __len__(self) -> int: return self.size diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index d8532411c2c5..f1a351cfd4a6 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -126,7 +126,7 @@ def sigterm(signum, frame): signal.signal(signal.SIGTERM, sigterm) # Cleanup pid file at exit. - def exit(): + def exit() -> None: logger.warning("Stopping daemon.") os.remove(pid_file) sys.exit(0) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index a380c03e8db3..31097d64398e 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -38,7 +38,7 @@ class Distributor: model will do for today. """ - def __init__(self): + def __init__(self) -> None: self.signals: Dict[str, Signal] = {} self.pre_registration: Dict[str, List[Callable]] = {} diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index ecda6b0eda39..54fe4fc2a0d4 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,11 +13,14 @@ # limitations under the License. import queue -from typing import Optional +from typing import BinaryIO, Optional, Union from twisted.internet import threads +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IPullProducer, IPushProducer from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import ISynapseReactor class BackgroundFileConsumer: @@ -25,9 +28,9 @@ class BackgroundFileConsumer: and pull producers Args: - file_obj (file): The file like object to write to. Closed when + file_obj: The file like object to write to. Closed when finished. - reactor (twisted.internet.reactor): the Twisted reactor to use + reactor: the Twisted reactor to use """ # For PushProducers pause if we have this many unwritten slices @@ -35,13 +38,13 @@ class BackgroundFileConsumer: # And resume once the size of the queue is less than this _RESUME_ON_QUEUE_SIZE = 2 - def __init__(self, file_obj, reactor): - self._file_obj = file_obj + def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None: + self._file_obj: BinaryIO = file_obj - self._reactor = reactor + self._reactor: ISynapseReactor = reactor # Producer we're registered with - self._producer = None + self._producer: Optional[Union[IPushProducer, IPullProducer]] = None # True if PushProducer, false if PullProducer self.streaming = False @@ -55,17 +58,19 @@ def __init__(self, file_obj, reactor): self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() # Deferred that is resolved when finished writing - self._finished_deferred = None + self._finished_deferred: Optional[Deferred[int]] = None # TODO # If the _writer thread throws an exception it gets stored here. - self._write_exception = None + self._write_exception: Optional[Exception] = None - def registerProducer(self, producer, streaming) -> None: + def registerProducer( + self, producer: Union[IPushProducer, IPullProducer], streaming: bool + ) -> None: """Part of IConsumer interface Args: - producer (IProducer) - streaming (bool): True if push based producer, False if pull + producer + streaming: True if push based producer, False if pull based. """ if self._producer: @@ -80,6 +85,7 @@ def registerProducer(self, producer, streaming) -> None: self._writer, ) if not streaming: + assert isinstance(self._producer, IPullProducer) self._producer.resumeProducing() def unregisterProducer(self) -> None: @@ -89,7 +95,7 @@ def unregisterProducer(self) -> None: if not self._finished_deferred.called: self._bytes_queue.put_nowait(None) - def write(self, bytes) -> None: + def write(self, write_bytes: bytes) -> None: """Part of IProducer interface""" if self._write_exception: raise self._write_exception @@ -98,11 +104,12 @@ def write(self, bytes) -> None: if self._finished_deferred.called: raise Exception("consumer has closed") - self._bytes_queue.put_nowait(bytes) + self._bytes_queue.put_nowait(write_bytes) # If this is a PushProducer and the queue is getting behind # then we pause the producer. if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: + assert isinstance(self._producer, IPushProducer) self._paused_producer = True assert self._producer is not None self._producer.pauseProducing() diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py index a456b136f06c..9f4be757baa5 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py @@ -74,7 +74,7 @@ def insert_after( new_node._refs_insert_after(node) return new_node - def remove_from_list(self): + def remove_from_list(self) -> None: """Remove this node from the list.""" with self._LOCK: self._refs_remove_node_from_list() @@ -84,7 +84,7 @@ def remove_from_list(self): # immediately rather than at the next GC. self.cache_entry = None - def move_after(self, node: "ListNode"): + def move_after(self, node: "ListNode") -> None: """Move this node from its current location in the list to after the given node. """ @@ -103,7 +103,7 @@ def move_after(self, node: "ListNode"): # Insert self back into the list, after target node self._refs_insert_after(node) - def _refs_remove_node_from_list(self): + def _refs_remove_node_from_list(self) -> None: """Internal method to *just* remove the node from the list, without e.g. clearing out the cache entry. """ @@ -122,7 +122,7 @@ def _refs_remove_node_from_list(self): self.prev_node = None self.next_node = None - def _refs_insert_after(self, node: "ListNode"): + def _refs_insert_after(self, node: "ListNode") -> None: """Internal method to insert the node after the given node.""" # This method should only be called when we're not already in the list. diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index d1f76e3dc54f..84e4f6ff55f7 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N should be considered expired. Normally the current time. """ - def verify_expiry_caveat(caveat: str): + def verify_expiry_caveat(caveat: str) -> bool: time_msec = get_time_ms() prefix = "time < " if not caveat.startswith(prefix): diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index c4826373c5c7..6382f0df431d 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -98,7 +98,7 @@ def manhole(username, password, globals): class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" - def connectionMade(self): + def connectionMade(self) -> None: super().connectionMade() # replace the manhole interpreter with our own impl @@ -133,7 +133,7 @@ def showsyntaxerror(self, filename=None): lines = traceback.format_exception_only(type, value) self.write("".join(lines)) - def showtraceback(self): + def showtraceback(self) -> None: """Display the exception that just occurred. Overrides the base implementation, ignoring sys.excepthook. We always want diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 5b6701bc667f..66e7f2e2a78b 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -150,7 +150,7 @@ def queue_request() -> "defer.Deferred[None]": self.sleeping_requests.add(request_id) - def on_wait_finished(_): + def on_wait_finished(_) -> "defer.Deferred[None]": logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() @@ -160,19 +160,19 @@ def on_wait_finished(_): else: ret_defer = queue_request() - def on_start(r): + def on_start(r: object) -> object: logger.debug("Ratelimit [%s]: Processing req", id(request_id)) self.current_processing.add(request_id) return r - def on_err(r): + def on_err(r: object) -> object: # XXX: why is this necessary? this is called before we start # processing the request so why would the request be in # current_processing? self.current_processing.discard(request_id) return r - def on_both(r): + def on_both(r: object) -> object: # Ensure that we've properly cleaned up. self.sleeping_requests.discard(request_id) self.ready_request_queue.pop(request_id, None) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 12ebf2a3aeed..9540022701e7 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -13,16 +13,15 @@ # limitations under the License. from typing import Generic, List, TypeVar +T = TypeVar("T") + -class _Entry: +class _Entry(Generic[T]): __slots__ = ["end_key", "queue"] - def __init__(self, end_key): + def __init__(self, end_key) -> None: self.end_key = end_key - self.queue = [] - - -T = TypeVar("T") + self.queue: List[T] = [] class WheelTimer(Generic[T]): @@ -37,7 +36,7 @@ def __init__(self, bucket_size=5000): accuracy of the timer. """ self.bucket_size = bucket_size - self.entries = [] + self.entries: List[_Entry[T]] = [] self.current_tick = 0 def insert(self, now: int, obj: T, then: int) -> None: From a22f4c0c24ccc12a58fbaebfc02258842ae3afcf Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Thu, 2 Sep 2021 16:48:42 +0100 Subject: [PATCH 20/38] Add type annotation fixes to fix CI --- synapse/rest/client/register.py | 4 ++-- synapse/util/file_consumer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 4f53172bef28..c8dbd67c19a9 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -15,7 +15,7 @@ import hmac import logging import random -from typing import List, Union +from typing import Dict, List, Union import synapse import synapse.api.auth @@ -777,7 +777,7 @@ async def _do_guest_registration(self, params, address=None): user_id, device_id, initial_display_name, is_guest=True ) - result = { + result: Dict[str, Union[int, str]] = { "user_id": user_id, "device_id": device_id, "access_token": access_token, diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 54fe4fc2a0d4..e1aa421743e5 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -58,7 +58,7 @@ def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None: self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() # Deferred that is resolved when finished writing - self._finished_deferred: Optional[Deferred[int]] = None # TODO + self._finished_deferred: Optional[Deferred[None]] = None # If the _writer thread throws an exception it gets stored here. self._write_exception: Optional[Exception] = None @@ -141,7 +141,7 @@ def _writer(self) -> None: finally: self._file_obj.close() - def wait(self): + def wait(self) -> "Deferred[None]": """Returns a deferred that resolves when finished writing to file""" return make_deferred_yieldable(self._finished_deferred) From 9444ca171ab7a3c280e8049af49a5dd068584a42 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Thu, 2 Sep 2021 16:59:59 +0100 Subject: [PATCH 21/38] Resolve type issue that arose from merge --- synapse/rest/client/register.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index a5fb49bf9f3e..f4c8b43622b6 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -763,7 +763,10 @@ async def _create_registration_details( Returns: dictionary for response from /register """ - result = {"user_id": user_id, "home_server": self.hs.hostname} + result: Dict[str, Union[str, int]] = { + "user_id": user_id, + "home_server": self.hs.hostname, + } if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") From a0aef0bca182c95992119466470c7367a5983eb3 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Thu, 2 Sep 2021 17:13:57 +0100 Subject: [PATCH 22/38] Back out of generics due to python-attrs/attrs#313 --- synapse/util/async_helpers.py | 8 +++++--- synapse/util/caches/dictionary_cache.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 080fc004036c..152ba40d8aea 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -550,16 +550,18 @@ def failure_cb(val): return new_d +# This class can't be generic because it uses slots with attrs. +# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True, frozen=True) -class DoneAwaitable(Generic[R]): +class DoneAwaitable: # should be: Generic[R] """Simple awaitable that returns the provided value.""" - value = attr.ib(type="R") + value = attr.ib(type=Any) # should be: R def __await__(self): return self - def __iter__(self) -> "DoneAwaitable[R]": + def __iter__(self) -> "DoneAwaitable": return self def __next__(self) -> None: diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 225591ca7ac6..ade088aae2ef 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -14,7 +14,7 @@ import enum import logging import threading -from typing import Dict, Generic, Iterable, Optional, Set, TypeVar +from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar import attr @@ -31,8 +31,10 @@ DV = TypeVar("DV") +# This class can't be generic because it uses slots with attrs. +# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True) -class DictionaryEntry(Generic[DKT, DV]): +class DictionaryEntry: # should be: Generic[DKT, DV]. """Returned when getting an entry from the cache Attributes: @@ -45,8 +47,8 @@ class DictionaryEntry(Generic[DKT, DV]): """ full = attr.ib(type=bool) - known_absent = attr.ib(type=Set[DKT]) - value = attr.ib(type=Dict[DKT, DV]) + known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT] + value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV] def __len__(self) -> int: return len(self.value) @@ -64,7 +66,7 @@ class DictionaryCache(Generic[KT, DKT, DV]): """ def __init__(self, name: str, max_entries: int = 1000): - self.cache: LruCache[KT, DictionaryEntry[DKT, DV]] = LruCache( + self.cache: LruCache[KT, DictionaryEntry] = LruCache( max_size=max_entries, cache_name=name, size_callback=len ) @@ -84,7 +86,7 @@ def check_thread(self) -> None: def get( self, key: KT, dict_keys: Optional[Iterable[DKT]] = None - ) -> DictionaryEntry[DKT, DV]: + ) -> DictionaryEntry: """Fetch an entry out of the cache Args: @@ -158,9 +160,7 @@ def _update_or_insert( # We pop and reinsert as we need to tell the cache the size may have # changed - entry: DictionaryEntry[DKT, DV] = self.cache.pop( - key, DictionaryEntry(False, set(), {}) - ) + entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry.value.update(value) entry.known_absent.update(known_absent) self.cache[key] = entry From 289df40d81676be3d242719e82d3d411a92d8d06 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Sep 2021 10:34:17 +0100 Subject: [PATCH 23/38] Quote return types with Deferreds --- synapse/util/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b59c1e3d80cb..bd234549bd85 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -86,7 +86,7 @@ class Clock: _reactor: IReactorTime = attr.ib() @defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations - def sleep(self, seconds: float) -> Generator[Deferred[float], Any, Any]: + def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) From 8e719ed89b0f278b0283c5d5ec87b65b15a5b1fb Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 6 Sep 2021 15:45:40 +0100 Subject: [PATCH 24/38] Fix use of None as default --- synapse/config/ratelimiting.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index b6bba4942670..36636ab07e40 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -66,11 +66,15 @@ def read_config(self, config, **kwargs): else: self.rc_federation = FederationRateLimitConfig( **{ - "window_size": config.get("federation_rc_window_size"), - "sleep_limit": config.get("federation_rc_sleep_limit"), - "sleep_delay": config.get("federation_rc_sleep_delay"), - "reject_limit": config.get("federation_rc_reject_limit"), - "concurrent": config.get("federation_rc_concurrent"), + k: v + for k, v in { + "window_size": config.get("federation_rc_window_size"), + "sleep_limit": config.get("federation_rc_sleep_limit"), + "sleep_delay": config.get("federation_rc_sleep_delay"), + "reject_limit": config.get("federation_rc_reject_limit"), + "concurrent": config.get("federation_rc_concurrent"), + }.items() + if v is not None } ) From 34e327dcc8d5ac6e99bb1426a3331a0c8cecc9e1 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 6 Sep 2021 16:21:28 +0100 Subject: [PATCH 25/38] Use a cast to work around Mocks not working with isinstance --- synapse/util/file_consumer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index e1aa421743e5..de2adacd70dc 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,7 +13,7 @@ # limitations under the License. import queue -from typing import BinaryIO, Optional, Union +from typing import BinaryIO, Optional, Union, cast from twisted.internet import threads from twisted.internet.defer import Deferred @@ -85,7 +85,6 @@ def registerProducer( self._writer, ) if not streaming: - assert isinstance(self._producer, IPullProducer) self._producer.resumeProducing() def unregisterProducer(self) -> None: @@ -109,10 +108,10 @@ def write(self, write_bytes: bytes) -> None: # If this is a PushProducer and the queue is getting behind # then we pause the producer. if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: - assert isinstance(self._producer, IPushProducer) self._paused_producer = True assert self._producer is not None - self._producer.pauseProducing() + # cast safe because `streaming` means this is an IPushProducer + cast(IPushProducer, self._producer).pauseProducing() def _writer(self) -> None: """This is run in a background thread to write to the file.""" From cd9a68de5a842e922cf34313e0a51aac17a9b804 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 6 Sep 2021 17:19:18 +0100 Subject: [PATCH 26/38] Fix up parameters which were previously silently ignored --- tests/unittest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittest.py b/tests/unittest.py index f2c90cc47b53..7a6f5954d06c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -734,9 +734,9 @@ def authenticate_request(self, request, content): FederationRateLimitConfig( window_size=1, sleep_limit=1, - sleep_msec=1, + sleep_delay=1, reject_limit=1000, - concurrent_requests=1000, + concurrent=1000, ), ) From b4cded14b20b9fd50b3aec7d87a1ae8582f45a22 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 07:48:46 +0100 Subject: [PATCH 27/38] Apply suggestions --- synapse/rest/client/register.py | 6 +++--- synapse/util/caches/__init__.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index f4c8b43622b6..abe4d7e20512 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple from twisted.web.server import Request @@ -763,7 +763,7 @@ async def _create_registration_details( Returns: dictionary for response from /register """ - result: Dict[str, Union[str, int]] = { + result: JsonDict = { "user_id": user_id, "home_server": self.hs.hostname, } @@ -817,7 +817,7 @@ async def _do_guest_registration( user_id, device_id, initial_display_name, is_guest=True ) - result: Dict[str, Union[int, str]] = { + result: JsonDict = { "user_id": user_id, "device_id": device_id, "access_token": access_token, diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index d3d2a1b1f79a..cab1bf0c1537 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -70,7 +70,7 @@ def inc_hits(self) -> None: def inc_misses(self) -> None: self.misses += 1 - def inc_evictions(self, size=1) -> None: + def inc_evictions(self, size: int = 1) -> None: self.evicted_size += size def inc_memory_usage(self, memory: int) -> None: From 6f7fac009aea4545fc9bc9c53d2c7ca915bf5b30 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 07:49:55 +0100 Subject: [PATCH 28/38] Use `cast` to IReactorTime [WANTS REVIEW] `twisted.internet.reactor` is typed to `module` by default. --- synapse/util/async_helpers.py | 3 ++- synapse/util/caches/lrucache.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 152ba40d8aea..d08402cf04d2 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -31,6 +31,7 @@ Set, TypeVar, Union, + cast, ) import attr @@ -268,7 +269,7 @@ def __init__( if not clock: from twisted.internet import reactor - clock = Clock(reactor) # type: ignore[arg-type] + clock = Clock(cast(IReactorTime, reactor)) self._clock = clock self.max_count = max_count diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index c8ebb61cda09..39dce9dd4166 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -35,6 +35,7 @@ from typing_extensions import Literal from twisted.internet import reactor +from twisted.internet.interfaces import IReactorTime from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -341,7 +342,7 @@ def __init__( # Default `clock` to something sensible. Note that we rename it to # `real_clock` so that mypy doesn't think its still `Optional`. if clock is None: - real_clock = Clock(reactor) # type: ignore[arg-type] + real_clock = Clock(cast(IReactorTime, reactor)) else: real_clock = clock From d4afbca0e9da30fde0a50d18714c8aff4ca12ec8 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 07:53:37 +0100 Subject: [PATCH 29/38] Add types and casts to `__exit__` [REVIEW] --- synapse/util/retryutils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 3b76e0b75a16..987a752fd352 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -14,7 +14,7 @@ import logging import random from types import TracebackType -from typing import Any, Optional +from typing import Any, Optional, Type, TypeVar, cast import synapse.logging.context from synapse.api.errors import CodeMessageException @@ -33,6 +33,9 @@ MAX_RETRY_INTERVAL = 2 ** 62 +T = TypeVar("T") + + class NotRetryingDestination(Exception): def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): """Raised by the limiter (and federation client) to indicate that we are @@ -163,7 +166,9 @@ def __init__( def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb: TracebackType) -> None: + def __exit__( + self, exc_type: Optional[Type[T]], exc_val: T, exc_tb: TracebackType + ) -> None: valid_err_code = False if exc_type is None: valid_err_code = True @@ -172,6 +177,7 @@ def __exit__(self, exc_type, exc_val, exc_tb: TracebackType) -> None: # failures; this is mostly so as not to catch defer._DefGen. valid_err_code = True elif issubclass(exc_type, CodeMessageException): + exc_val_cme = cast(CodeMessageException, exc_val) # Some error codes are perfectly fine for some APIs, whereas other # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS @@ -180,11 +186,11 @@ def __exit__(self, exc_type, exc_val, exc_tb: TracebackType) -> None: # won't accept our requests for at least a while. # 429 is us being aggressively rate limited, so lets rate limit # ourselves. - if exc_val.code == 404 and self.backoff_on_404: + if exc_val_cme.code == 404 and self.backoff_on_404: valid_err_code = False - elif exc_val.code in (401, 429): + elif exc_val_cme.code in (401, 429): valid_err_code = False - elif exc_val.code < 500: + elif exc_val_cme.code < 500: valid_err_code = True else: valid_err_code = False From f5cee54e4fcdecb84e75aea19d500b2b99d36f05 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 09:12:46 +0100 Subject: [PATCH 30/38] Fix adherence to Jinja2's interface [REVIEW] --- synapse/rest/synapse/client/new_user_consent.py | 2 +- synapse/rest/synapse/client/pick_username.py | 2 +- synapse/util/templates.py | 7 +++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index fc62a09b7f07..edabf9621aca 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -52,7 +52,7 @@ def template_search_dirs(): yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir - self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config) async def _async_render_GET(self, request: Request) -> None: try: diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index c15b83c387c2..d30b478b9825 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -80,7 +80,7 @@ def template_search_dirs(): yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir - self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config) async def _async_render_GET(self, request: Request) -> None: try: diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 66f8fbb7758a..eb3c8c93705e 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -16,7 +16,7 @@ import time import urllib.parse -from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union import jinja2 @@ -25,7 +25,7 @@ def build_jinja_env( - template_search_directories: Iterable[str], + template_search_directories: Sequence[str], config: "HomeServerConfig", autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None, ) -> jinja2.Environment: @@ -56,8 +56,7 @@ def build_jinja_env( if autoescape is None: autoescape = jinja2.select_autoescape() - # the type signature of this is wrong - loader = jinja2.FileSystemLoader(template_search_directories) # type: ignore[arg-type] + loader = jinja2.FileSystemLoader(template_search_directories) env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters From 12cfb9a069ac7408a38eb0dfbe1569bf2bed8397 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 09:13:56 +0100 Subject: [PATCH 31/38] Annotate `WheelTimer`, notably `bucket_size` --- synapse/util/wheel_timer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 9540022701e7..e108adc4604f 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -19,8 +19,8 @@ class _Entry(Generic[T]): __slots__ = ["end_key", "queue"] - def __init__(self, end_key) -> None: - self.end_key = end_key + def __init__(self, end_key: int) -> None: + self.end_key: int = end_key self.queue: List[T] = [] @@ -29,15 +29,15 @@ class WheelTimer(Generic[T]): expired. """ - def __init__(self, bucket_size=5000): + def __init__(self, bucket_size: int = 5000) -> None: """ Args: - bucket_size (int): Size of buckets in ms. Corresponds roughly to the + bucket_size: Size of buckets in ms. Corresponds roughly to the accuracy of the timer. """ - self.bucket_size = bucket_size + self.bucket_size: int = bucket_size self.entries: List[_Entry[T]] = [] - self.current_tick = 0 + self.current_tick: int = 0 def insert(self, now: int, obj: T, then: int) -> None: """Inserts object into timer. @@ -90,5 +90,5 @@ def fetch(self, now: int) -> List[T]: return ret - def __len__(self): + def __len__(self) -> int: return sum(len(entry.queue) for entry in self.entries) From e69a3d67357ea360686d7c8ca015a2491eff1889 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 09:15:20 +0100 Subject: [PATCH 32/38] Update Newsfile --- changelog.d/10601.misc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/10601.misc b/changelog.d/10601.misc index 8b573ab49a95..1227113ff38e 100644 --- a/changelog.d/10601.misc +++ b/changelog.d/10601.misc @@ -1 +1 @@ -Add type annotations to complete the synapse.util package. +Add type annotations to the synapse.util package. From 9f301aefdad5306fc1087ea049d3873bef2d8b7f Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 09:33:55 +0100 Subject: [PATCH 33/38] Note that code was lifted from CPython --- synapse/util/manhole.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 6382f0df431d..1a47c0832b36 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -152,6 +152,8 @@ def showtraceback(self) -> None: # It's unclear whether there is a reason behind this line. # It conceivably could be because an exception raised in this block # will keep the local frame (containing these local variables) around. + # This was adapted taken from CPython's Lib/code.py; see here: + # https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150 last_tb = ei = None # type: ignore def displayhook(self, obj): From e6618d73c9fcc8244f9e5778ea9936c2dcbe44bd Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 11:42:00 +0100 Subject: [PATCH 34/38] Add more type annotations --- synapse/util/manhole.py | 16 +++++++++------- synapse/util/ratelimitutils.py | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 1a47c0832b36..2068b8399208 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -15,6 +15,7 @@ import inspect import sys import traceback +from typing import Any, Dict, Optional, Union from twisted.conch import manhole_ssh from twisted.conch.insults import insults @@ -22,6 +23,7 @@ from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal from twisted.internet import defer +from twisted.internet.protocol import Factory PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" @@ -61,19 +63,19 @@ -----END RSA PRIVATE KEY-----""" -def manhole(username, password, globals): +def manhole(username: str, password: Union[str, bytes], globals: Dict) -> Factory: """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with the supplied globals. Args: - username(str): The username ssh clients should auth with. - password(str): The password ssh clients should auth with. - globals(dict): The variables to expose in the shell. + username: The username ssh clients should auth with. + password: The password ssh clients should auth with. + globals: The variables to expose in the shell. Returns: - twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` + A factory to pass to ``listenTCP`` """ if not isinstance(password, bytes): password = password.encode("ascii") @@ -108,7 +110,7 @@ def connectionMade(self) -> None: class SynapseManholeInterpreter(ManholeInterpreter): - def showsyntaxerror(self, filename=None): + def showsyntaxerror(self, filename: Optional[str] = None) -> None: """Display the syntax error that just occurred. Overrides the base implementation, ignoring sys.excepthook. We always want @@ -156,7 +158,7 @@ def showtraceback(self) -> None: # https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150 last_tb = ei = None # type: ignore - def displayhook(self, obj): + def displayhook(self, obj: Any) -> None: """ We override the displayhook so that we automatically convert coroutines into Deferreds. (Our superclass' displayhook will take care of the rest, diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 66e7f2e2a78b..dfe628c97e96 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -16,7 +16,7 @@ import contextlib import logging import typing -from typing import DefaultDict, Iterator, List, Set +from typing import Any, DefaultDict, Iterator, List, Set from twisted.internet import defer @@ -150,7 +150,7 @@ def queue_request() -> "defer.Deferred[None]": self.sleeping_requests.add(request_id) - def on_wait_finished(_) -> "defer.Deferred[None]": + def on_wait_finished(_: Any) -> "defer.Deferred[None]": logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() From b1b4f1bdbd099cda53a2dce176820b2b404c686a Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 11:42:16 +0100 Subject: [PATCH 35/38] Enable stricter checking on applicable modules --- mypy.ini | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/mypy.ini b/mypy.ini index cd266265b600..e75b0a2a010d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -88,6 +88,69 @@ files = tests/util/test_itertools.py, tests/util/test_stream_change_cache.py +[mypy-synapse.util.batching_queue] +disallow_untyped_defs = True + +[mypy-synapse.util.caches.dictionary_cache] +disallow_untyped_defs = True + +[mypy-synapse.util.file_consumer] +disallow_untyped_defs = True + +[mypy-synapse.util.frozenutils] +disallow_untyped_defs = True + +[mypy-synapse.util.hash] +disallow_untyped_defs = True + +[mypy-synapse.util.httpresourcetree] +disallow_untyped_defs = True + +[mypy-synapse.util.iterutils] +disallow_untyped_defs = True + +[mypy-synapse.util.linked_list] +disallow_untyped_defs = True + +[mypy-synapse.util.logcontext] +disallow_untyped_defs = True + +[mypy-synapse.util.logformatter] +disallow_untyped_defs = True + +[mypy-synapse.util.macaroons] +disallow_untyped_defs = True + +[mypy-synapse.util.manhole] +disallow_untyped_defs = True + +[mypy-synapse.util.module_loader] +disallow_untyped_defs = True + +[mypy-synapse.util.msisdn] +disallow_untyped_defs = True + +[mypy-synapse.util.ratelimitutils] +disallow_untyped_defs = True + +[mypy-synapse.util.retryutils] +disallow_untyped_defs = True + +[mypy-synapse.util.rlimit] +disallow_untyped_defs = True + +[mypy-synapse.util.stringutils] +disallow_untyped_defs = True + +[mypy-synapse.util.templates] +disallow_untyped_defs = True + +[mypy-synapse.util.threepids] +disallow_untyped_defs = True + +[mypy-synapse.util.wheel_timer] +disallow_untyped_defs = True + [mypy-pymacaroons.*] ignore_missing_imports = True From 8871674bd9916afaea9f0be19dc183945694b2ed Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 15:39:18 +0100 Subject: [PATCH 36/38] Correct types used in `__exit__` --- synapse/util/retryutils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 987a752fd352..e71e456c93f6 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -14,7 +14,7 @@ import logging import random from types import TracebackType -from typing import Any, Optional, Type, TypeVar, cast +from typing import Any, Optional, Type, cast import synapse.logging.context from synapse.api.errors import CodeMessageException @@ -33,9 +33,6 @@ MAX_RETRY_INTERVAL = 2 ** 62 -T = TypeVar("T") - - class NotRetryingDestination(Exception): def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): """Raised by the limiter (and federation client) to indicate that we are @@ -167,7 +164,10 @@ def __enter__(self) -> None: pass def __exit__( - self, exc_type: Optional[Type[T]], exc_val: T, exc_tb: TracebackType + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], ) -> None: valid_err_code = False if exc_type is None: From 20d63a012d1d3930a9ab8fc9b3a510efb645dd75 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 8 Sep 2021 15:57:38 +0100 Subject: [PATCH 37/38] Fix up manhole types after merge [REVIEW, SEE DESC] REVIEW: slight change here in that - if not isinstance(password, bytes): - password = password.encode("ascii") has been removed and made unconditional (since it's defined to be str) --- synapse/util/manhole.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index a5355c878252..f8b2d7bea9b4 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -15,7 +15,7 @@ import inspect import sys import traceback -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional from twisted.conch import manhole_ssh from twisted.conch.insults import insults @@ -25,6 +25,8 @@ from twisted.internet import defer from twisted.internet.protocol import Factory +from synapse.config.server import ManholeConfig + PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" "XALqeK+7385NlLja3DE/DO9mGhnd9+bAy39EKT3sTV6+WXQ4yD0TvEEyUEMtjWkSEm6U32+C" @@ -63,7 +65,7 @@ -----END RSA PRIVATE KEY-----""" -def manhole(settings, globals): +def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with @@ -78,7 +80,7 @@ def manhole(settings, globals): A factory to pass to ``listenTCP`` """ username = settings.username - password = settings.password + password = settings.password.encode("ascii") priv_key = settings.priv_key if priv_key is None: priv_key = Key.fromString(PRIVATE_KEY) @@ -86,9 +88,6 @@ def manhole(settings, globals): if pub_key is None: pub_key = Key.fromString(PUBLIC_KEY) - if not isinstance(password, bytes): - password = password.encode("ascii") - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() @@ -100,8 +99,11 @@ def manhole(settings, globals): ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.privateKeys[b"ssh-rsa"] = priv_key - factory.publicKeys[b"ssh-rsa"] = pub_key + + # conch has the wrong type on these dicts (says bytes to bytes, + # should be bytes to Keys judging by how it's used). + factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment] + factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment] return factory From 19a602e81f1ecdbfb16fd0687390907406addaa0 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 10 Sep 2021 09:33:30 +0100 Subject: [PATCH 38/38] Avoid using evil typecasts --- synapse/util/async_helpers.py | 5 +++-- synapse/util/retryutils.py | 11 +++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index d08402cf04d2..82d918a05fd0 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -31,13 +31,13 @@ Set, TypeVar, Union, - cast, ) import attr from typing_extensions import ContextManager from twisted.internet import defer +from twisted.internet.base import ReactorBase from twisted.internet.defer import CancelledError from twisted.internet.interfaces import IReactorTime from twisted.python import failure @@ -269,7 +269,8 @@ def __init__( if not clock: from twisted.internet import reactor - clock = Clock(cast(IReactorTime, reactor)) + assert isinstance(reactor, ReactorBase) + clock = Clock(reactor) self._clock = clock self.max_count = max_count diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index e71e456c93f6..648d9a95a7eb 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -14,7 +14,7 @@ import logging import random from types import TracebackType -from typing import Any, Optional, Type, cast +from typing import Any, Optional, Type import synapse.logging.context from synapse.api.errors import CodeMessageException @@ -176,8 +176,7 @@ def __exit__( # avoid treating exceptions which don't derive from Exception as # failures; this is mostly so as not to catch defer._DefGen. valid_err_code = True - elif issubclass(exc_type, CodeMessageException): - exc_val_cme = cast(CodeMessageException, exc_val) + elif isinstance(exc_val, CodeMessageException): # Some error codes are perfectly fine for some APIs, whereas other # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS @@ -186,11 +185,11 @@ def __exit__( # won't accept our requests for at least a while. # 429 is us being aggressively rate limited, so lets rate limit # ourselves. - if exc_val_cme.code == 404 and self.backoff_on_404: + if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False - elif exc_val_cme.code in (401, 429): + elif exc_val.code in (401, 429): valid_err_code = False - elif exc_val_cme.code < 500: + elif exc_val.code < 500: valid_err_code = True else: valid_err_code = False