diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index cb79e3b15602..b350f57ccb4a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -22,7 +22,7 @@ IncorrectDatabaseSetup, IsolationLevel, ) -from synapse.storage.types import Cursor, DBAPI2Module +from synapse.storage.types import Cursor if TYPE_CHECKING: from synapse.storage.database import LoggingDatabaseConnection @@ -35,9 +35,7 @@ class PostgresEngine( BaseDatabaseEngine[psycopg2.extensions.connection, psycopg2.extensions.cursor] ): def __init__(self, database_config: Mapping[str, Any]): - # Cast: mypy 1.0.0 doesn't seem to think that the module implements the protocol. - # AFAICS this is a false positive. - super().__init__(cast(DBAPI2Module, psycopg2), database_config) + super().__init__(psycopg2, database_config) psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) # Disables passing `bytes` to txn.execute, c.f. #6186. If you do diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 6df470dc41be..28751e89a5a5 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -15,10 +15,10 @@ import sqlite3 import struct import threading -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, List, Mapping, Optional from synapse.storage.engines import BaseDatabaseEngine -from synapse.storage.types import Cursor, DBAPI2Module +from synapse.storage.types import Cursor if TYPE_CHECKING: from synapse.storage.database import LoggingDatabaseConnection @@ -26,9 +26,7 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): def __init__(self, database_config: Mapping[str, Any]): - # Cast: mypy 1.0.0 doesn't seem to think that the module implements the protocol. - # AFAICS this is a false positive. - super().__init__(cast(DBAPI2Module, sqlite3), database_config) + super().__init__(sqlite3, database_config) database = database_config.get("args", {}).get("database") self._is_in_memory = database in ( diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 0031df1e0649..77adaa485775 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import TracebackType -from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + Callable, +) from typing_extensions import Protocol @@ -112,15 +123,35 @@ class DBAPI2Module(Protocol): # extends from this hierarchy. See # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions # https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE - Warning: Type[Exception] - Error: Type[Exception] + # + # Note: rather than + # x: T + # we write + # @property + # def x(self) -> T: ... + # which expresses that the protocol attribute `x` is read-only. The mypy docs + # https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected + # explain why this is necessary for safety. TL;DR: we shouldn't be able to write + # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 . + @property + def Warning(self) -> Type[Exception]: + ... + + @property + def Error(self) -> Type[Exception]: + ... # Errors are divided into `InterfaceError`s (something went wrong in the database # driver) and `DatabaseError`s (something went wrong in the database). These are # both subclasses of `Error`, but we can't currently express this in type # annotations due to https://github.com/python/mypy/issues/8397 - InterfaceError: Type[Exception] - DatabaseError: Type[Exception] + @property + def InterfaceError(self) -> Type[Exception]: + ... + + @property + def DatabaseError(self) -> Type[Exception]: + ... # Everything below is a subclass of `DatabaseError`. @@ -128,7 +159,9 @@ class DBAPI2Module(Protocol): # - An integer was too big for its data type. # - An invalid date time was provided. # - A string contained a null code point. - DataError: Type[Exception] + @property + def DataError(self) -> Type[Exception]: + ... # Roughly: something went wrong in the database, but it's not within the application # programmer's control. Examples: @@ -138,28 +171,45 @@ class DBAPI2Module(Protocol): # - A serialisation failure occurred. # - The database ran out of resources, such as storage, memory, connections, etc. # - The database encountered an error from the operating system. - OperationalError: Type[Exception] + @property + def OperationalError(self) -> Type[Exception]: + ... # Roughly: we've given the database data which breaks a rule we asked it to enforce. # Examples: # - Stop, criminal scum! You violated the foreign key constraint # - Also check constraints, non-null constraints, etc. - IntegrityError: Type[Exception] + @property + def IntegrityError(self) -> Type[Exception]: + ... # Roughly: something went wrong within the database server itself. - InternalError: Type[Exception] + @property + def InternalError(self) -> Type[Exception]: + ... # Roughly: the application did something silly that needs to be fixed. Examples: # - We don't have permissions to do something. # - We tried to create a table with duplicate column names. # - We tried to use a reserved name. # - We referred to a column that doesn't exist. - ProgrammingError: Type[Exception] + @property + def ProgrammingError(self) -> Type[Exception]: + ... # Roughly: we've tried to do something that this database doesn't support. - NotSupportedError: Type[Exception] + @property + def NotSupportedError(self) -> Type[Exception]: + ... - def connect(self, **parameters: object) -> Connection: + # We originally wrote + # def connect(self, *args, **kwargs) -> Connection: ... + # But mypy doesn't seem to like that because sqlite3.connect takes a mandatory + # positional argument. We can't make that part of the signature though, because + # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use + # the following slightly unusual workaround. + @property + def connect(self) -> Callable[..., Connection]: ...