diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 41f566b6487a..5ddb58a8a2ca 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -31,6 +31,7 @@ List, Optional, Tuple, + Type, TypeVar, cast, overload, @@ -41,6 +42,7 @@ from typing_extensions import Concatenate, Literal, ParamSpec from twisted.enterprise import adbapi +from twisted.internet.interfaces import IReactorCore from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -92,7 +94,9 @@ def make_pool( - reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + reactor: IReactorCore, + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, ) -> adbapi.ConnectionPool: """Get the connection pool for the database.""" @@ -101,7 +105,7 @@ def make_pool( db_args = dict(db_config.config.get("args", {})) db_args.setdefault("cp_reconnect", True) - def _on_new_connection(conn): + def _on_new_connection(conn: Connection) -> None: # Ensure we have a logging context so we can correctly track queries, # etc. with LoggingContext("db.on_new_connection"): @@ -157,7 +161,11 @@ class LoggingDatabaseConnection: default_txn_name: str def cursor( - self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + self, + *, + txn_name: Optional[str] = None, + after_callbacks: Optional[List["_CallbackListEntry"]] = None, + exception_callbacks: Optional[List["_CallbackListEntry"]] = None, ) -> "LoggingTransaction": if not txn_name: txn_name = self.default_txn_name @@ -183,11 +191,16 @@ def __enter__(self) -> "LoggingDatabaseConnection": self.conn.__enter__() return self - def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> Optional[bool]: return self.conn.__exit__(exc_type, exc_value, traceback) # Proxy through any unknown lookups to the DB conn class. - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self.conn, name) @@ -391,17 +404,22 @@ def close(self) -> None: def __enter__(self) -> "LoggingTransaction": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: self.close() class PerformanceCounters: - def __init__(self): - self.current_counters = {} - self.previous_counters = {} + def __init__(self) -> None: + self.current_counters: Dict[str, Tuple[int, float]] = {} + self.previous_counters: Dict[str, Tuple[int, float]] = {} def update(self, key: str, duration_secs: float) -> None: - count, cum_time = self.current_counters.get(key, (0, 0)) + count, cum_time = self.current_counters.get(key, (0, 0.0)) count += 1 cum_time += duration_secs self.current_counters[key] = (count, cum_time) @@ -527,7 +545,7 @@ async def _check_safe_to_upsert(self) -> None: def start_profiling(self) -> None: self._previous_loop_ts = monotonic_time() - def loop(): + def loop() -> None: curr = self._current_txn_total_time prev = self._previous_txn_total_time self._previous_txn_total_time = curr @@ -1186,7 +1204,7 @@ def simple_upsert_txn_emulated( if lock: self.engine.lock_table(txn, table) - def _getwhere(key): + def _getwhere(key: str) -> str: # If the value we're passing in is None (aka NULL), we need to use # IS, not =, as NULL = NULL equals NULL (False). if keyvalues[key] is None: @@ -2258,7 +2276,7 @@ async def simple_search_list( term: Optional[str], col: str, retcols: Collection[str], - desc="simple_search_list", + desc: str = "simple_search_list", ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts.