Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve datastore type hints #14678

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14678.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints in datastores.
23 changes: 21 additions & 2 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import traceback
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand Down Expand Up @@ -53,7 +54,12 @@
run_in_background,
)
from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_conn,
)
from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
Expand Down Expand Up @@ -94,6 +100,9 @@
from synapse.types import ISynapseReactor
from synapse.util import SYNAPSE_VERSION, Clock

if TYPE_CHECKING:
from synapse.server import HomeServer

# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
Expand Down Expand Up @@ -238,8 +247,18 @@ class Store(
PusherBackgroundUpdatesStore,
PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore,
RelationsWorkerStore,
):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

# This is a bit repetitive, but avoids dynamically setting attributes.
self.relations = RelationsWorkerStore(database, db_conn, hs)

def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion synapse/api/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ async def _check_event_relations(
# The event IDs to check, mypy doesn't understand the isinstance check.
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
await self._store.relations.events_have_relations(
event_ids, self.related_by_senders, self.related_by_rel_types
)
)
Expand Down
4 changes: 3 additions & 1 deletion synapse/app/admin_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class AdminCmdStore(
ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore,
RoomMemberWorkerStore,
RelationsWorkerStore,
EventFederationWorkerStore,
EventPushActionsWorkerStore,
StateGroupWorkerStore,
Expand All @@ -101,6 +100,9 @@ def __init__(
# should refactor it to take a `Clock` directly.
self.clock = hs.get_clock()

# This is a bit repetitive, but avoids dynamically setting attributes.
self.relations = RelationsWorkerStore(database, db_conn, hs)


class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdStore # type: ignore
Expand Down
13 changes: 12 additions & 1 deletion synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.appservice import (
ApplicationServiceTransactionWorkerStore,
Expand Down Expand Up @@ -132,7 +133,6 @@ class GenericWorkerStore(
ServerMetricsStore,
PusherWorkerStore,
RoomMemberWorkerStore,
RelationsWorkerStore,
EventFederationWorkerStore,
EventPushActionsWorkerStore,
StateGroupWorkerStore,
Expand All @@ -152,6 +152,17 @@ class GenericWorkerStore(
server_name: str
config: HomeServerConfig

def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

# This is a bit repetitive, but avoids dynamically setting attributes.
self.relations = RelationsWorkerStore(database, db_conn, hs)


class GenericWorkerServer(HomeServer):
DATASTORE_CLASS = GenericWorkerStore # type: ignore
Expand Down
8 changes: 5 additions & 3 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,9 @@ async def _validate_event_relation(self, event: EventBase) -> None:
else:
# There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine.
if not await self.store.event_is_target_of_relation(relation.parent_id):
if not await self.store.relations.event_is_target_of_relation(
relation.parent_id
):
# Otherwise, the client can't know about the parent event!
raise SynapseError(400, "Can't send relation to unknown event")

Expand All @@ -1377,7 +1379,7 @@ async def _validate_event_relation(self, event: EventBase) -> None:
if len(aggregation_key) > 500:
raise SynapseError(400, "Aggregation key is too long")

already_exists = await self.store.has_user_annotated_event(
already_exists = await self.store.relations.has_user_annotated_event(
relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
Expand All @@ -1389,7 +1391,7 @@ async def _validate_event_relation(self, event: EventBase) -> None:

# Don't attempt to start a thread if the parent event is a relation.
elif relation.rel_type == RelationTypes.THREAD:
if await self.store.event_includes_relation(relation.parent_id):
if await self.store.relations.event_includes_relation(relation.parent_id):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
)
Expand Down
32 changes: 21 additions & 11 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ async def get_relations(
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
related_events, next_token = await self._main_store.get_relations_for_event(
(
related_events,
next_token,
) = await self._main_store.relations.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
Expand Down Expand Up @@ -211,7 +214,7 @@ async def redact_events_related_to(
ShadowBanError if the requester is shadow-banned
"""
related_event_ids = (
await self._main_store.get_all_relations_for_event_with_types(
await self._main_store.relations.get_all_relations_for_event_with_types(
event_id, relation_types
)
)
Expand Down Expand Up @@ -250,7 +253,9 @@ async def get_references_for_events(
A map of event IDs to a list related events.
"""

related_events = await self._main_store.get_references_for_events(event_ids)
related_events = await self._main_store.relations.get_references_for_events(
event_ids
)

# Avoid additional logic if there are no ignored users.
if not ignored_users:
Expand Down Expand Up @@ -304,7 +309,7 @@ async def _get_threads_for_events(
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]

# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
summaries = await self._main_store.relations.get_thread_summaries(event_ids)

# Limit fetching whether the requester has participated in a thread to
# events which are thread roots.
Expand All @@ -320,7 +325,7 @@ async def _get_threads_for_events(
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
await self._main_store.relations.get_threads_participated(
[
event_id
for event_id in thread_event_ids
Expand All @@ -331,8 +336,10 @@ async def _get_threads_for_events(
)

# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_threaded_messages_per_user(
thread_event_ids, ignored_users
ignored_results = (
await self._main_store.relations.get_threaded_messages_per_user(
thread_event_ids, ignored_users
)
)

# A map of event ID to the thread aggregation.
Expand Down Expand Up @@ -361,7 +368,10 @@ async def _get_threads_for_events(
continue

# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
(
potential_events,
_,
) = await self._main_store.relations.get_relations_for_event(
event_id,
event,
room_id,
Expand Down Expand Up @@ -498,7 +508,7 @@ async def _fetch_edits() -> None:
Note that there is no use in limiting edits by ignored users since the
parent event should be ignored in the first place if the user is ignored.
"""
edits = await self._main_store.get_applicable_edits(
edits = await self._main_store.relations.get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
Expand Down Expand Up @@ -553,7 +563,7 @@ async def get_threads(
# Note that ignored users are not passed into get_threads
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_batch = await self._main_store.get_threads(
thread_roots, next_batch = await self._main_store.relations.get_threads(
room_id=room_id, limit=limit, from_token=from_token
)

Expand All @@ -565,7 +575,7 @@ async def get_threads(
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
await self._main_store.relations.get_threads_participated(
[eid for eid, p in participated.items() if not p],
user_id,
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ async def _action_for_event_by_user(
else:
# Since the event has not yet been persisted we check whether
# the parent is part of a thread.
thread_id = await self.store.get_thread_id(relation.parent_id)
thread_id = await self.store.relations.get_thread_id(relation.parent_id)

related_events = await self._related_events(event)

Expand Down
8 changes: 6 additions & 2 deletions synapse/rest/client/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,15 @@ async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool:
# If the receipt is on the main timeline, it is enough to check whether
# the event is directly related to a thread.
if thread_id == MAIN_TIMELINE:
return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
return MAIN_TIMELINE == await self._main_store.relations.get_thread_id(
event_id
)

# Otherwise, check if the event is directly part of a thread, or is the
# root message (or related to the root message) of a thread.
return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
return thread_id == await self._main_store.relations.get_thread_id_for_receipts(
event_id
)


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
Expand Down
18 changes: 16 additions & 2 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def _invalidate_state_caches(
self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))

def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
self,
cache_name: str,
key: Optional[Collection[Any]],
store_name: Optional[str] = None,
) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
Expand All @@ -132,10 +135,21 @@ def _attempt_to_invalidate_cache(
cache_name
key: Entry to invalidate. If None then invalidates the entire
cache.
store_name: The name of the store, leave as None for stores which
have not yet been split out.
"""

# First get the store.
store = self
if store_name is not None:
try:
store = getattr(self, store_name)
except AttributeError:
pass
Comment on lines +145 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be ignoring store_names that refer to missing fields here?


# Then attempt to find the cache on that store.
try:
cache = getattr(self, cache_name)
cache = getattr(store, cache_name)
except AttributeError:
# Check if an externally defined module cache has been registered
cache = self.external_cached_functions.get(cache_name)
Expand Down
4 changes: 3 additions & 1 deletion synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ class DataStore(
UserErasureStore,
MonthlyActiveUsersWorkerStore,
StatsStore,
RelationsStore,
CensorEventsStore,
UIAuthStore,
EventForwardExtremitiesStore,
Expand All @@ -141,6 +140,9 @@ def __init__(

super().__init__(database, db_conn, hs)

# This is a bit repetitive, but avoids dynamically setting attributes.
self.relations = RelationsStore(database, db_conn, hs)

async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table.

Expand Down
36 changes: 26 additions & 10 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,16 @@ def _invalidate_caches_for_event(
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
# Caches which might leak edits must be invalidated for the event being
# redacted.
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
self._attempt_to_invalidate_cache(
"get_relations_for_event", (redacts,), "relations"
)
self._attempt_to_invalidate_cache(
"get_applicable_edit", (redacts,), "relations"
)
self._attempt_to_invalidate_cache("get_thread_id", (redacts,), "relations")
self._attempt_to_invalidate_cache(
"get_thread_id_for_receipts", (redacts,), "relations"
)

if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) # type: ignore[attr-defined]
Expand All @@ -264,12 +270,22 @@ def _invalidate_caches_for_event(
self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))

if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
self._attempt_to_invalidate_cache("get_threads", (room_id,))
self._attempt_to_invalidate_cache(
"get_relations_for_event", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache(
"get_references_for_event", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache(
"get_applicable_edit", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache(
"get_thread_summary", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache(
"get_thread_participated", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache("get_threads", (room_id,), "relations")

async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
Expand Down
Loading