Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add some type hints to datastore #12255

Merged
merged 6 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12255.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints for storage.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,11 @@ exclude = (?x)
|synapse/_scripts/update_synapse_database.py

|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/schema/

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat surprised this is necessary. What does mypy think the type of hs.hostname is? (E.g. try adding reveal_type(hs.hostname) and running mypy.) Similarly for self.config: HomeServerConfig below.

I'm not actually sure what the best practice is here---restate the annotation or let the type propagate? Maybe @matrix-org/synapse-core has opinions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a bit surprised by this too -- I'd double check that server_name is properly typed on HomeServer first! 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(And that we're importing the proper HomeServer object?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do not set this here, you will get: synapse/storage/databases/main/__init__.py:84: error: Cannot determine type of "server_name" in base class "MediaRepositoryStore" [misc]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better or other solution? Or # type: ignore[misc]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's what I think is going on here:

  • mypy knows that Homeserver.hostname is a str
  • when mypy processes self.server_name = hs.hostname it looks at the superclass to see if self.server_name is already exists, and if so checks that str is compatible with whatever type the super().server_name is
  • the parent classes are MediaRepositoryBackgroundUpdateStore and SQLBaseStore, neither of which has a server_name attribute. Mypy doesn't like this.
  • I think by writing self.server_name: str we are explicitly telling mypy that we want to create a new attribute on the child class?

All in all, I think what you've done (self.server_name: str) is the right choice.


async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
Expand Down
37 changes: 25 additions & 12 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
Optional,
Set,
Tuple,
cast,
)

from twisted.internet import defer

from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
Expand All @@ -38,7 +37,11 @@
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -58,6 +61,7 @@ def __init__(
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()
self._receipts_id_gen: AbstractStreamIdTracker

if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
Expand Down Expand Up @@ -161,7 +165,7 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
return txn.fetchall()
return cast(List[Tuple[str, str, int, int]], txn.fetchall())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the schema, events.stream_ordering is nullable. Are we sure the return type annotation is correct here? Should the final int maybe be Optional[int] instead?

(It could just be a sloppy schema?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
Expand Down Expand Up @@ -257,7 +261,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if not rows:
return []

content = {}
content: JsonDict = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
Expand Down Expand Up @@ -305,7 +309,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
"_get_linearized_receipts_for_rooms", f
)

results = {}
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
Expand Down Expand Up @@ -370,7 +374,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
"get_linearized_receipts_for_all_rooms", f
)

results = {}
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
Expand Down Expand Up @@ -399,7 +403,7 @@ async def get_users_sent_receipts_between(
"""

if last_id == current_id:
return defer.succeed([])
return []

def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
Expand Down Expand Up @@ -453,7 +457,10 @@ def get_all_updated_receipts_txn(
"""
txn.execute(sql, (last_id, current_id, limit))

updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
updates = cast(
List[Tuple[int, list]],
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
)

limited = False
upper_bound = current_id
Expand Down Expand Up @@ -496,7 +503,13 @@ def invalidate_caches_for_receipt(
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
Expand Down Expand Up @@ -584,7 +597,7 @@ def insert_linearized_receipt_txn(
)

if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)

Expand Down Expand Up @@ -637,7 +650,7 @@ def graph_to_linear(txn: LoggingTransaction) -> str:
"insert_receipt_conv", graph_to_linear
)

async with self._receipts_id_gen.get_next() as stream_id:
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
DatabasePool,
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self.config = hs.config
self.config: HomeServerConfig = hs.config

# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self.config = hs.config
self.config: HomeServerConfig = hs.config

async def store_room(
self,
Expand Down
28 changes: 14 additions & 14 deletions synapse/storage/databases/main/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import re
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set

import attr

Expand Down Expand Up @@ -74,7 +74,7 @@ def store_search_entries_txn(
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)

args = (
args1 = (
(
entry.event_id,
entry.room_id,
Expand All @@ -86,14 +86,14 @@ def store_search_entries_txn(
for entry in entries
)

txn.execute_batch(sql, args)
txn.execute_batch(sql, args1)

elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = (
args2 = (
(
entry.event_id,
entry.room_id,
Expand All @@ -102,7 +102,7 @@ def store_search_entries_txn(
)
for entry in entries
)
txn.execute_batch(sql, args)
txn.execute_batch(sql, args2)

else:
# This should be unreachable.
Expand Down Expand Up @@ -274,7 +274,7 @@ def create_index(conn):
# have an event_search_fts_idx; unfortunately postgres 9.4
# doesn't support CREATE INDEX IF EXISTS so we just catch the
# exception and ignore it.
import psycopg2
import psycopg2 # type: ignore

try:
c.execute(
Expand Down Expand Up @@ -427,7 +427,7 @@ async def search_msgs(

search_query = _parse_query(self.database_engine, search_term)

args = []
args: List[Any] = []

# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
Expand Down Expand Up @@ -496,7 +496,7 @@ async def search_msgs(

# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list(
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
Expand Down Expand Up @@ -530,7 +530,7 @@ async def search_rooms(
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
limit,
limit: int,
pagination_token: Optional[str] = None,
) -> JsonDict:
"""Performs a full text search over events with given keys.
Expand All @@ -549,7 +549,7 @@ async def search_rooms(

search_query = _parse_query(self.database_engine, search_term)

args = []
args: List[Any] = []

# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
Expand All @@ -573,9 +573,9 @@ async def search_rooms(

if pagination_token:
try:
origin_server_ts, stream = pagination_token.split(",")
origin_server_ts = int(origin_server_ts)
stream = int(stream)
origin_server_ts_str, stream_str = pagination_token.split(",")
origin_server_ts = int(origin_server_ts_str)
stream = int(stream_str)
except Exception:
raise SynapseError(400, "Invalid pagination token")

Expand Down Expand Up @@ -654,7 +654,7 @@ async def search_rooms(

# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list(
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
Expand Down
24 changes: 15 additions & 9 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
from typing import TYPE_CHECKING, Iterable, Optional, Set
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
Expand All @@ -29,7 +29,7 @@
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.types import JsonDict, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList

Expand Down Expand Up @@ -241,7 +241,9 @@ async def get_filtered_current_state_ids(
# We delegate to the cached version
return await self.get_current_state_ids(room_id)

def _get_filtered_current_state_ids_txn(txn):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {}
sql = """
SELECT type, state_key, event_id FROM current_state_events
Expand Down Expand Up @@ -281,11 +283,11 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:

event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
return
return None
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

event = await self.get_event(event_id, allow_none=True)
if not event:
return
return None

return event.content.get("canonical_alias")

Expand All @@ -304,7 +306,7 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
list_name="event_ids",
num_args=1,
)
async def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
Expand Down Expand Up @@ -355,7 +357,7 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
self.server_name: str = hs.hostname

self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
Expand All @@ -375,15 +377,19 @@ def __init__(
self._background_remove_left_rooms,
)

async def _background_remove_left_rooms(self, progress, batch_size):
async def _background_remove_left_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to delete rows from `current_state_events` and
`event_forward_extremities` tables of rooms that the server is no
longer joined to.
"""

last_room_id = progress.get("last_room_id", "")

def _background_remove_left_rooms_txn(txn):
def _background_remove_left_rooms_txn(
txn: LoggingTransaction,
) -> Tuple[bool, Set[str]]:
# get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats.stats_enabled

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
) -> None:
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
self.server_name: str = hs.hostname

self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables",
Expand Down