Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Detail typing of Connection and Cursor to further match PEP 249 #9299

Merged
merged 9 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9299.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update the `Cursor` type hints to better match PEP 249.
14 changes: 9 additions & 5 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def close(self) -> None:
def commit(self) -> None:
self.conn.commit()

def rollback(self, *args, **kwargs) -> None:
self.conn.rollback(*args, **kwargs)
def rollback(self) -> None:
self.conn.rollback()

def __enter__(self) -> "Connection":
self.conn.__enter__()
Expand Down Expand Up @@ -244,12 +244,15 @@ def call_on_exception(
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))

def fetchone(self) -> Optional[Tuple]:
return self.txn.fetchone()

def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
return self.txn.fetchmany(size=size)

def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()

def fetchone(self) -> Tuple:
return self.txn.fetchone()

def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()

Expand Down Expand Up @@ -754,6 +757,7 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
Returns:
A list of dicts where the key is the column header.
"""
assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor]
return results
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/prepare_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,9 @@ def _get_or_create_schema_state(

txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None

if current_version:
if row is not None:
current_version = int(row[0])
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
Expand Down
37 changes: 29 additions & 8 deletions synapse/storage/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, List, Optional, Tuple
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union

from typing_extensions import Protocol

"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""

_Parameters = Union[Sequence[Any], Mapping[str, Any]]


class Cursor(Protocol):
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
...

def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
...

def fetchall(self) -> List[Tuple]:
def fetchone(self) -> Optional[Tuple]:
...

def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
...

def fetchone(self) -> Tuple:
def fetchall(self) -> List[Tuple]:
...

@property
def description(self) -> Any:
return None
def description(
self,
) -> Optional[
Sequence[
# Note that this is an approximate typing based on sqlite3 and other
# drivers, and may not be entirely accurate.
Tuple[
str,
Optional[Any],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
]
]
]:
...

@property
def rowcount(self) -> int:
Expand All @@ -59,7 +80,7 @@ def close(self) -> None:
def commit(self) -> None:
...

def rollback(self, *args, **kwargs) -> None:
def rollback(self) -> None:
...

def __enter__(self) -> "Connection":
Expand Down
8 changes: 6 additions & 2 deletions synapse/storage/util/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def __init__(self, sequence_name: str):

def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]
fetch_res = txn.fetchone()
assert fetch_res is not None
return fetch_res[0]

def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
Expand Down Expand Up @@ -147,7 +149,9 @@ def check_consistency(
txn.execute(
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
last_value, is_called = txn.fetchone()
fetch_res = txn.fetchone()
assert fetch_res is not None
last_value, is_called = fetch_res

# If we have an associated stream check the stream_positions table.
max_in_stream_positions = None
Expand Down