Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type tests.utils #13028

Merged
merged 15 commits into from
Jul 5, 2022
1 change: 1 addition & 0 deletions changelog.d/13028.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `tests.utils`.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True

[mypy-tests.utils]
disallow_untyped_defs = True


;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available.
Expand Down
6 changes: 5 additions & 1 deletion synapse/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import attr
from frozendict import frozendict
from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import ParamSpec

from twisted.internet import defer, task
from twisted.internet.defer import Deferred
Expand Down Expand Up @@ -82,6 +83,9 @@ def unwrapFirstError(failure: Failure) -> Failure:
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations


P = ParamSpec("P")


@attr.s(slots=True)
class Clock:
"""
Expand Down Expand Up @@ -110,7 +114,7 @@ def time_msec(self) -> int:
return int(self.time() * 1000)

def looping_call(
self, f: Callable, msec: float, *args: Any, **kwargs: Any
self, f: Callable[P, object], msec: float, *args: P.args, **kwargs: P.kwargs
) -> LoopingCall:
"""Call a function repeatedly.

Expand Down
2 changes: 1 addition & 1 deletion synapse/util/caches/lrucache.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def update_last_access(self, clock: Clock) -> None:

@wrap_as_background_process("LruCache._expire_old_entries")
async def _expire_old_entries(
clock: Clock, expiry_seconds: int, autotune_config: Optional[dict]
clock: Clock, expiry_seconds: float, autotune_config: Optional[dict]
) -> None:
"""Walks the global cache list to find cache entries that haven't been
accessed in the given number of seconds, or if a given memory threshold has been breached.
Expand Down
136 changes: 92 additions & 44 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@

import atexit
import os
from typing import Any, Callable, Dict, List, Tuple, Union, overload

import attr
from typing_extensions import Literal, ParamSpec

from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
Expand Down Expand Up @@ -50,12 +55,11 @@
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"


def setupdb():
def setupdb() -> None:
# If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS:
# create a PostgresEngine
db_engine = create_engine({"name": "psycopg2", "args": {}})

# connect to postgres to create the base database.
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
Expand All @@ -82,11 +86,11 @@ def setupdb():
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(logging_conn, db_engine, None)
logging_conn.close()

def _cleanup():
def _cleanup() -> None:
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
Expand All @@ -103,7 +107,19 @@ def _cleanup():
atexit.register(_cleanup)


def default_config(name, parse=False):
@overload
def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]:
...


@overload
def default_config(name: str, parse: Literal[True]) -> HomeServerConfig:
...


def default_config(
name: str, parse: bool = False
) -> Union[Dict[str, object], HomeServerConfig]:
"""
Create a reasonable test config.
"""
Expand Down Expand Up @@ -181,90 +197,122 @@ def default_config(name, parse=False):
return config_dict


def mock_getRawHeaders(headers=None):
def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def]
headers = headers if headers is not None else {}

def getRawHeaders(name, default=None):
def getRawHeaders(name, default=None): # type: ignore[no-untyped-def]
# If the requested header is present, the real twisted function returns
# List[str] if name is a str and List[bytes] if name is a bytes.
# This mock doesn't support that behaviour.
# Fortunately, none of the current callers of mock_getRawHeaders() provide a
# headers dict, so we don't encounter this discrepancy in practice.
return headers.get(name, default)

return getRawHeaders


P = ParamSpec("P")


@attr.s(slots=True, auto_attribs=True)
class Timer:
absolute_time: float
callback: Callable[[], None]
expired: bool


# TODO: Make this generic over a ParamSpec?
@attr.s(slots=True, auto_attribs=True)
class Looper:
func: Callable[..., Any]
interval: float # seconds
last: float
args: Tuple[object, ...]
kwargs: Dict[str, object]


class MockClock:
now = 1000
now = 1000.0

def __init__(self):
# list of lists of [absolute_time, callback, expired] in no particular
# order
self.timers = []
self.loopers = []
def __init__(self) -> None:
# Timers in no particular order
self.timers: List[Timer] = []
self.loopers: List[Looper] = []

def time(self):
def time(self) -> float:
return self.now

def time_msec(self):
return self.time() * 1000
def time_msec(self) -> int:
return int(self.time() * 1000)

def call_later(self, delay, callback, *args, **kwargs):
def call_later(
self,
delay: float,
callback: Callable[P, object],
*args: P.args,
**kwargs: P.kwargs,
) -> Timer:
ctx = current_context()

def wrapped_callback():
def wrapped_callback() -> None:
set_current_context(ctx)
callback(*args, **kwargs)

t = [self.now + delay, wrapped_callback, False]
t = Timer(self.now + delay, wrapped_callback, False)
self.timers.append(t)

return t

def looping_call(self, function, interval, *args, **kwargs):
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])

def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
def looping_call(
self,
function: Callable[P, object],
interval: float,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
# This type-ignore should be redundant once we use a mypy release with
# https://github.com/python/mypy/pull/12668.
self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type]

def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
if timer.expired:
if not ignore_errs:
raise Exception("Cannot cancel an expired timer")

timer[2] = True
timer.expired = True
self.timers = [t for t in self.timers if t != timer]

# For unit testing
def advance_time(self, secs):
def advance_time(self, secs: float) -> None:
self.now += secs

timers = self.timers
self.timers = []

for t in timers:
time, callback, expired = t

if expired:
if t.expired:
raise Exception("Timer already expired")

if self.now >= time:
t[2] = True
callback()
if self.now >= t.absolute_time:
t.expired = True
t.callback()
else:
self.timers.append(t)

for looped in self.loopers:
func, interval, last, args, kwargs = looped
if last + interval < self.now:
func(*args, **kwargs)
looped[2] = self.now
if looped.last + looped.interval < self.now:
looped.func(*looped.args, **looped.kwargs)
looped.last = self.now

def advance_time_msec(self, ms):
def advance_time_msec(self, ms: float) -> None:
self.advance_time(ms / 1000.0)

def time_bound_deferred(self, d, *args, **kwargs):
# We don't bother timing things out for now.
return d


async def create_room(hs, room_id: str, creator_id: str):
async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
"""Creates and persist a creation event for the given room"""

persistence_store = hs.get_storage_controllers().persistence
assert persistence_store is not None
store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
Expand Down