From a95ab837139fca7a4c1dac26862cd7749021f881 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 17:47:05 +0100 Subject: [PATCH 01/14] Cast to postgres types when handling postgres db --- tests/utils.py | 61 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 3059c453d595..1d06398b4803 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ import atexit import os +from typing import TYPE_CHECKING, cast from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions @@ -22,7 +23,7 @@ from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import create_engine +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -48,21 +49,27 @@ # the dbname we will connect to in order to create the base database. POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" +if TYPE_CHECKING: + import psycopg2.extensions -def setupdb(): +def setupdb() -> None: # If we're using PostgreSQL, set up the db once if USE_POSTGRES_FOR_TESTS: # create a PostgresEngine - db_engine = create_engine({"name": "psycopg2", "args": {}}) - + db_engine = cast( + PostgresEngine, create_engine({"name": "psycopg2", "args": {}}) + ) # connect to postgres to create the base database. - db_conn = db_engine.module.connect( - user=POSTGRES_USER, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - password=POSTGRES_PASSWORD, - dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + db_conn = cast( + "psycopg2.extensions.connection", + db_engine.module.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ), ) db_conn.autocommit = True cur = db_conn.cursor() @@ -75,24 +82,30 @@ def setupdb(): db_conn.close() # Set up in the db - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - password=POSTGRES_PASSWORD, - ) - db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") - prepare_database(db_conn, db_engine, None) - db_conn.close() - - def _cleanup(): - db_conn = db_engine.module.connect( + db_conn = cast( + "psycopg2.extensions.connection", + db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER, host=POSTGRES_HOST, port=POSTGRES_PORT, password=POSTGRES_PASSWORD, - dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ), + ) + logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") + prepare_database(logging_conn, db_engine, None) + logging_conn.close() + + def _cleanup() -> None: + db_conn = cast( + "psycopg2.extensions.connection", + db_engine.module.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ), ) db_conn.autocommit = True cur = db_conn.cursor() From 797b2bd9c75e7f58a78bd1fb4143360e801d1316 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 17:57:44 +0100 Subject: [PATCH 02/14] Remove unused method --- tests/utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 1d06398b4803..8f65edea29ec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -269,10 +269,6 @@ def advance_time(self, secs): def advance_time_msec(self, ms): self.advance_time(ms / 1000.0) - def time_bound_deferred(self, d, *args, **kwargs): - # We don't bother timing things out for now. - return d - async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" From 29f721f4b2b7adfe37f8f1bfe6170aa73fb0b4bd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 17:57:53 +0100 Subject: [PATCH 03/14] Easy annotations --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 8f65edea29ec..066645886222 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -206,7 +206,7 @@ def getRawHeaders(name, default=None): class MockClock: now = 1000 - def __init__(self): + def __init__(self) -> None: # list of lists of [absolute_time, callback, expired] in no particular # order self.timers = [] From 575e1c4309d60ef2c2146f81029265731d2ced0b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 18:00:24 +0100 Subject: [PATCH 04/14] Annotate create_room --- tests/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 066645886222..c93fc8990b69 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,6 +22,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context +from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database @@ -270,10 +271,11 @@ def advance_time_msec(self, ms): self.advance_time(ms / 1000.0) -async def create_room(hs, room_id: str, creator_id: str): +async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: """Creates and persist a creation event for the given room""" persistence_store = hs.get_storage_controllers().persistence + assert persistence_store is not None store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() From fb5809aca473ad10817d81d49439498a5b8fb70e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 18:04:56 +0100 Subject: [PATCH 05/14] Use `ParamSpec` to annotate looping_call --- synapse/util/__init__.py | 6 +++++- synapse/util/caches/lrucache.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 6323d452e78b..a90f08dd4c56 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -20,6 +20,7 @@ import attr from frozendict import frozendict from matrix_common.versionstring import get_distribution_version_string +from typing_extensions import ParamSpec from twisted.internet import defer, task from twisted.internet.defer import Deferred @@ -82,6 +83,9 @@ def unwrapFirstError(failure: Failure) -> Failure: return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations +P = ParamSpec("P") + + @attr.s(slots=True) class Clock: """ @@ -110,7 +114,7 @@ def time_msec(self) -> int: return int(self.time() * 1000) def looping_call( - self, f: Callable, msec: float, *args: Any, **kwargs: Any + self, f: Callable[P, object], msec: float, *args: P.args, **kwargs: P.kwargs ) -> LoopingCall: """Call a function repeatedly. diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a3b60578e3c3..8ed5325c5d07 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -109,7 +109,7 @@ def update_last_access(self, clock: Clock) -> None: @wrap_as_background_process("LruCache._expire_old_entries") async def _expire_old_entries( - clock: Clock, expiry_seconds: int, autotune_config: Optional[dict] + clock: Clock, expiry_seconds: float, autotune_config: Optional[dict] ) -> None: """Walks the global cache list to find cache entries that haven't been accessed in the given number of seconds, or if a given memory threshold has been breached. From 39d547bf745574ca296ae52af51c76f52929ced6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 18:11:57 +0100 Subject: [PATCH 06/14] Annotate `default_config` --- tests/utils.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index c93fc8990b69..6a66a938a7a2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,7 +15,9 @@ import atexit import os -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Dict, Union, cast, overload + +from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions @@ -117,7 +119,19 @@ def _cleanup() -> None: atexit.register(_cleanup) -def default_config(name, parse=False): +@overload +def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: + ... + + +@overload +def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: + ... + + +def default_config( + name: str, parse: bool = False +) -> Union[Dict[str, object], HomeServerConfig]: """ Create a reasonable test config. """ From a07c7335cc1023e2872a5fb1049385b52cc9581f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 18:33:07 +0100 Subject: [PATCH 07/14] Track `now` as a float `time_ms` returns an int like the proper Synapse `Clock` --- tests/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 6a66a938a7a2..cae851a0ebb5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -219,7 +219,7 @@ def getRawHeaders(name, default=None): class MockClock: - now = 1000 + now = 1000.0 def __init__(self) -> None: # list of lists of [absolute_time, callback, expired] in no particular @@ -227,11 +227,11 @@ def __init__(self) -> None: self.timers = [] self.loopers = [] - def time(self): + def time(self) -> float: return self.now - def time_msec(self): - return self.time() * 1000 + def time_msec(self) -> int: + return int(self.time() * 1000) def call_later(self, delay, callback, *args, **kwargs): ctx = current_context() @@ -257,7 +257,7 @@ def cancel_call_later(self, timer, ignore_errs=False): self.timers = [t for t in self.timers if t != timer] # For unit testing - def advance_time(self, secs): + def advance_time(self, secs: float) -> None: self.now += secs timers = self.timers @@ -281,7 +281,7 @@ def advance_time(self, secs): func(*args, **kwargs) looped[2] = self.now - def advance_time_msec(self, ms): + def advance_time_msec(self, ms: float) -> None: self.advance_time(ms / 1000.0) From d6ff8bdf96770edd6b65775e70d85110a1e8d67e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 18:35:02 +0100 Subject: [PATCH 08/14] Introduce a `Timer` dataclass --- tests/utils.py | 54 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index cae851a0ebb5..a4df72921d3b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,8 +15,9 @@ import atexit import os -from typing import TYPE_CHECKING, Dict, Union, cast, overload +from typing import TYPE_CHECKING, Callable, Dict, List, ParamSpec, Union, cast, overload +import attr from typing_extensions import Literal from synapse.api.constants import EventTypes @@ -218,13 +219,22 @@ def getRawHeaders(name, default=None): return getRawHeaders +P = ParamSpec("P") + + +@attr.s(slots=True, auto_attribs=True) +class Timer: + absolute_time: float + callback: Callable[[], None] + expired: bool + + class MockClock: now = 1000.0 def __init__(self) -> None: - # list of lists of [absolute_time, callback, expired] in no particular - # order - self.timers = [] + # Timers in no particular order + self.timers: List[Timer] = [] self.loopers = [] def time(self) -> float: @@ -233,27 +243,39 @@ def time(self) -> float: def time_msec(self) -> int: return int(self.time() * 1000) - def call_later(self, delay, callback, *args, **kwargs): + def call_later( + self, + delay: float, + callback: Callable[P, object], + *args: P.args, + **kwargs: P.kwargs, + ) -> Timer: ctx = current_context() - def wrapped_callback(): + def wrapped_callback() -> None: set_current_context(ctx) callback(*args, **kwargs) - t = [self.now + delay, wrapped_callback, False] + t = Timer(self.now + delay, wrapped_callback, False) self.timers.append(t) return t - def looping_call(self, function, interval, *args, **kwargs): + def looping_call( + self, + function: Callable[P, object], + interval: float, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) - def cancel_call_later(self, timer, ignore_errs=False): - if timer[2]: + def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: + if timer.expired: if not ignore_errs: raise Exception("Cannot cancel an expired timer") - timer[2] = True + timer.expired = True self.timers = [t for t in self.timers if t != timer] # For unit testing @@ -264,14 +286,12 @@ def advance_time(self, secs: float) -> None: self.timers = [] for t in timers: - time, callback, expired = t - - if expired: + if t.expired: raise Exception("Timer already expired") - if self.now >= time: - t[2] = True - callback() + if self.now >= t.absolute_time: + t.expired = True + t.callback() else: self.timers.append(t) From 3338364dbe7c04c7d762c69459696951b0dd3af5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 19:32:29 +0100 Subject: [PATCH 09/14] Introduce a Looper type --- tests/utils.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index a4df72921d3b..d3da559200e1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,7 +15,18 @@ import atexit import os -from typing import TYPE_CHECKING, Callable, Dict, List, ParamSpec, Union, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + ParamSpec, + Tuple, + Union, + cast, + overload, +) import attr from typing_extensions import Literal @@ -229,13 +240,23 @@ class Timer: expired: bool +# TODO: Make this generic over a ParamSpec? +@attr.s(slots=True, auto_attribs=True) +class Looper: + func: Callable[..., Any] + interval: float # seconds + last: float + args: Tuple[object, ...] + kwargs: Dict[str, object] + + class MockClock: now = 1000.0 def __init__(self) -> None: # Timers in no particular order self.timers: List[Timer] = [] - self.loopers = [] + self.loopers: List[Looper] = [] def time(self) -> float: return self.now @@ -268,7 +289,9 @@ def looping_call( *args: P.args, **kwargs: P.kwargs, ) -> None: - self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) + # This type-ignore should be redundant once we use a mypy release with + # https://github.com/python/mypy/pull/12668. + self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type] def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: if timer.expired: @@ -296,10 +319,9 @@ def advance_time(self, secs: float) -> None: self.timers.append(t) for looped in self.loopers: - func, interval, last, args, kwargs = looped - if last + interval < self.now: - func(*args, **kwargs) - looped[2] = self.now + if looped.last + looped.interval < self.now: + looped.func(*looped.args, **looped.kwargs) + looped.last = self.now def advance_time_msec(self, ms: float) -> None: self.advance_time(ms / 1000.0) From 395d624e120643d51fd4693346bafce40fe6a48a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 19:45:35 +0100 Subject: [PATCH 10/14] Suppress checking of a mock --- tests/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index d3da559200e1..abdc0eaaddb8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -221,10 +221,15 @@ def default_config( return config_dict -def mock_getRawHeaders(headers=None): +def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def] headers = headers if headers is not None else {} - def getRawHeaders(name, default=None): + def getRawHeaders(name, default=None): # type: ignore[no-untyped-def] + # If the requested header is present, the real twisted function returns + # List[str] if name is a str and List[bytes] if name is a bytes. + # This mock doesn't support that behaviour. + # Fortunately, none of the current callers of mock_getRawHeaders() provide a + # headers dict, so we don't encounter this discrepancy in practice. return headers.get(name, default) return getRawHeaders From 97d72ef11a938d6cbd4fc44677f516d893a9ec19 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 19:45:49 +0100 Subject: [PATCH 11/14] tests.utils is typed --- mypy.ini | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 7973f2ac01d1..6590f0c25f63 100644 --- a/mypy.ini +++ b/mypy.ini @@ -75,7 +75,6 @@ exclude = (?x) |tests/util/test_lrucache.py |tests/util/test_rwlock.py |tests/util/test_wheel_timer.py - |tests/utils.py )$ [mypy-synapse.federation.transport.client] @@ -129,6 +128,9 @@ disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True +[mypy-tests.utils] +disallow_untyped_defs = True + ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. From 92af1306dbd83aab229f44d3e20a8388912896a5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 19:58:14 +0100 Subject: [PATCH 12/14] Changelog --- changelog.d/13028.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/13028.misc diff --git a/changelog.d/13028.misc b/changelog.d/13028.misc new file mode 100644 index 000000000000..4e5f3d8f9104 --- /dev/null +++ b/changelog.d/13028.misc @@ -0,0 +1 @@ +Add type annotations to `tests.utils`. From 61076113c5dc6c434a66b59ee257b5638016d63e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sun, 12 Jun 2022 19:05:39 +0100 Subject: [PATCH 13/14] Whoops, import ParamSpec from typing_extensions --- tests/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index abdc0eaaddb8..7ede75c22ed0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,7 +21,6 @@ Callable, Dict, List, - ParamSpec, Tuple, Union, cast, @@ -29,7 +28,7 @@ ) import attr -from typing_extensions import Literal +from typing_extensions import Literal, ParamSpec from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions From cea7b0a9f45ca8213710bb4b37f33a40383d5bc1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 4 Jul 2022 18:33:00 +0100 Subject: [PATCH 14/14] ditch the psycopg2 casts --- tests/utils.py | 65 ++++++++++++++++---------------------------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index b2177eac7a4e..424cc4c2a0d6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,17 +15,7 @@ import atexit import os -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Tuple, - Union, - cast, - overload, -) +from typing import Any, Callable, Dict, List, Tuple, Union, overload import attr from typing_extensions import Literal, ParamSpec @@ -37,7 +27,7 @@ from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import PostgresEngine, create_engine +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -63,27 +53,20 @@ # the dbname we will connect to in order to create the base database. POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" -if TYPE_CHECKING: - import psycopg2.extensions def setupdb() -> None: # If we're using PostgreSQL, set up the db once if USE_POSTGRES_FOR_TESTS: # create a PostgresEngine - db_engine = cast( - PostgresEngine, create_engine({"name": "psycopg2", "args": {}}) - ) + db_engine = create_engine({"name": "psycopg2", "args": {}}) # connect to postgres to create the base database. - db_conn = cast( - "psycopg2.extensions.connection", - db_engine.module.connect( - user=POSTGRES_USER, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - password=POSTGRES_PASSWORD, - dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, - ), + db_conn = db_engine.module.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_engine.attempt_to_set_autocommit(db_conn, autocommit=True) cur = db_conn.cursor() @@ -96,30 +79,24 @@ def setupdb() -> None: db_conn.close() # Set up in the db - db_conn = cast( - "psycopg2.extensions.connection", - db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - password=POSTGRES_PASSWORD, - ), + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + password=POSTGRES_PASSWORD, ) logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") prepare_database(logging_conn, db_engine, None) logging_conn.close() def _cleanup() -> None: - db_conn = cast( - "psycopg2.extensions.connection", - db_engine.module.connect( - user=POSTGRES_USER, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - password=POSTGRES_PASSWORD, - dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, - ), + db_conn = db_engine.module.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_engine.attempt_to_set_autocommit(db_conn, autocommit=True) cur = db_conn.cursor()