Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert account data, device inbox, and censor events databases to async/await #8063

Merged
merged 4 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8063.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
77 changes: 43 additions & 34 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@

import abc
import logging
from typing import List, Tuple
from typing import List, Optional, Tuple

from twisted.internet import defer

from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,13 +98,15 @@ def get_account_data_for_user_txn(txn):
"get_account_data_for_user", get_account_data_for_user_txn
)

@cachedInlineCallbacks(num_args=2, max_entries=5000)
def get_global_account_data_by_type_for_user(self, data_type, user_id):
@cached(num_args=2, max_entries=5000)
async def get_global_account_data_by_type_for_user(
self, data_type: str, user_id: str
) -> Optional[JsonDict]:
"""
Returns:
Deferred: A dict
The account data.
"""
result = yield self.db_pool.simple_select_one_onecol(
result = await self.db_pool.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
Expand Down Expand Up @@ -280,9 +283,11 @@ def get_updated_account_data_for_user_txn(txn):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)

@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
@cached(num_args=2, cache_context=True, max_entries=5000)
async def is_ignored_by(
self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
) -> bool:
ignored_account_data = await self.get_global_account_data_by_type_for_user(
"m.ignored_user_list",
ignorer_user_id,
on_invalidate=cache_context.invalidate,
Expand All @@ -307,32 +312,35 @@ def __init__(self, database: DatabasePool, db_conn, hs):

super(AccountDataStore, self).__init__(database, db_conn, hs)

def get_max_account_data_stream_id(self):
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream

Returns:
A deferred int.
The maximum stream ID.
"""
return self._account_data_id_gen.get_current_token()

@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.

Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
account_data_type(str): The type of account_data to add.
content(dict): A json object to associate with the tag.
user_id: The user to add a tag for.
room_id: The room to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.

Returns:
A deferred that completes once the account_data has been added.
The maximum stream ID.
"""
content_json = json_encoder.encode(content)

with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
yield self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
Expand All @@ -350,7 +358,7 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content)
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
yield self._update_max_stream_id(next_id)
await self._update_max_stream_id(next_id)

self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
Expand All @@ -359,26 +367,28 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content)
(user_id, room_id, account_data_type), content
)

result = self._account_data_id_gen.get_current_token()
return result
return self._account_data_id_gen.get_current_token()

@defer.inlineCallbacks
def add_account_data_for_user(self, user_id, account_data_type, content):
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.

Args:
user_id(str): The user to add a tag for.
account_data_type(str): The type of account_data to add.
content(dict): A json object to associate with the tag.
user_id: The user to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.

Returns:
A deferred that completes once the account_data has been added.
The maximum stream ID.
"""
content_json = json_encoder.encode(content)

with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
yield self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
Expand All @@ -396,22 +406,21 @@ def add_account_data_for_user(self, user_id, account_data_type, content):
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
yield self._update_max_stream_id(next_id)
await self._update_max_stream_id(next_id)

self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
(account_data_type, user_id)
)

result = self._account_data_id_gen.get_current_token()
return result
return self._account_data_id_gen.get_current_token()

def _update_max_stream_id(self, next_id):
def _update_max_stream_id(self, next_id: int):
"""Update the max stream_id

Args:
next_id(int): The the revision to advance to.
next_id: The the revision to advance to.
"""

# Note: This is only here for backwards compat to allow admins to
Expand Down
11 changes: 4 additions & 7 deletions synapse/storage/databases/main/censor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import logging
from typing import TYPE_CHECKING

from twisted.internet import defer

from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -148,17 +146,16 @@ def _censor_event_txn(self, txn, event_id, pruned_json):
updatevalues={"json": pruned_json},
)

@defer.inlineCallbacks
def expire_event(self, event_id):
async def expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that has expired, and delete its associated
expiry timestamp. If the event can't be retrieved, delete its associated
timestamp so we don't try to expire it again in the future.

Args:
event_id (str): The ID of the event to delete.
event_id: The ID of the event to delete.
"""
# Try to retrieve the event's content from the database or the event cache.
event = yield self.get_event(event_id)
event = await self.get_event(event_id)

def delete_expired_event_txn(txn):
# Delete the expiry timestamp associated with this event from the database.
Expand Down Expand Up @@ -193,7 +190,7 @@ def delete_expired_event_txn(txn):
txn, "_get_event_cache", (event.event_id,)
)

yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_expired_event", delete_expired_event_txn
)

Expand Down
Loading