Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Continue converting store to async/await #8042

Merged
merged 11 commits into from
Aug 7, 2020
1 change: 1 addition & 0 deletions changelog.d/8042.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
12 changes: 7 additions & 5 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
cross_signing_key = yield defer.ensureDeferred(
self.get_e2e_cross_signing_key(user, "master")
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
Expand All @@ -150,8 +152,8 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
"device_id": verify_key.version,
}

cross_signing_key = yield self.get_e2e_cross_signing_key(
user, "self_signing"
cross_signing_key = yield defer.ensureDeferred(
self.get_e2e_cross_signing_key(user, "self_signing")
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
Expand Down Expand Up @@ -247,7 +249,7 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
user_id/device_id to update stream_id and the relevent json-encoded
user_id/device_id to update stream_id and the relevant json-encoded
opentracing context

Returns:
Expand Down Expand Up @@ -600,7 +602,7 @@ async def get_all_device_list_changes_for_remotes(
between the requested tokens due to the limit.

The token returned can be used in a subsequent call to this
function to get further updatees.
function to get further updates.

The updates are a list of 2-tuples of stream ID and the row data
"""
Expand Down
49 changes: 24 additions & 25 deletions synapse/storage/databases/main/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,29 @@
# limitations under the License.

from collections import namedtuple
from typing import Optional

from twisted.internet import defer
from typing import Iterable, Optional

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached

RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))


class DirectoryWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
async def get_association_from_room_alias(
self, room_alias: RoomAlias
) -> RoomAliasMapping:
clokep marked this conversation as resolved.
Show resolved Hide resolved
""" Get's the room_id and server list for a given room_alias
clokep marked this conversation as resolved.
Show resolved Hide resolved

Args:
room_alias (RoomAlias)
room_alias: The alias to traverse to an ID.
clokep marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
The room alias mapping or None if no association can be found.
"""
room_id = yield self.db_pool.simple_select_one_onecol(
room_id = await self.db_pool.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
Expand All @@ -48,7 +47,7 @@ def get_association_from_room_alias(self, room_alias):
if not room_id:
return None

servers = yield self.db_pool.simple_select_onecol(
servers = await self.db_pool.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
Expand Down Expand Up @@ -79,18 +78,20 @@ def get_aliases_for_room(self, room_id):


class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
async def create_room_alias_association(
self,
room_alias: RoomAlias,
room_id: str,
servers: Iterable[str],
creator: Optional[str] = None,
) -> None:
""" Creates an association between a room alias and room_id/servers

Args:
room_alias (RoomAlias)
room_id (str)
servers (list)
creator (str): Optional user_id of creator.

Returns:
Deferred
room_alias: The alias to create.
room_id: The target of the alias.
servers:
clokep marked this conversation as resolved.
Show resolved Hide resolved
creator: Optional user_id of creator.
"""

def alias_txn(txn):
Expand Down Expand Up @@ -118,24 +119,22 @@ def alias_txn(txn):
)

try:
ret = yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
return ret

@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.db_pool.runInteraction(
async def delete_room_alias(self, room_alias: RoomAlias) -> str:
room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)

return room_id

def _delete_room_alias_txn(self, txn, room_alias):
def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
Expand Down
30 changes: 14 additions & 16 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@

from canonicaljson import json

from twisted.internet import defer

from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json


class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks
def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
async def update_e2e_room_key(
self, user_id, version, room_id, session_id, room_key
):
"""Replaces the encrypted E2E room key for a given session in a given backup

Args:
Expand All @@ -38,7 +37,7 @@ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
StoreError
"""

yield self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
Expand All @@ -55,8 +54,7 @@ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
desc="update_e2e_room_key",
)

@defer.inlineCallbacks
def add_e2e_room_keys(self, user_id, version, room_keys):
async def add_e2e_room_keys(self, user_id, version, room_keys):
"""Bulk add room keys to a given backup.

Args:
Expand Down Expand Up @@ -89,13 +87,12 @@ def add_e2e_room_keys(self, user_id, version, room_keys):
}
)

yield self.db_pool.simple_insert_many(
await self.db_pool.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)

@trace
@defer.inlineCallbacks
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.

Expand All @@ -110,7 +107,7 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
the backup (or for the specified room)

Returns:
A deferred list of dicts giving the session_data and message metadata for
A list of dicts giving the session_data and message metadata for
these room keys.
"""

Expand All @@ -125,7 +122,7 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
if session_id:
keyvalues["session_id"] = session_id

rows = yield self.db_pool.simple_select_list(
rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
Expand Down Expand Up @@ -243,8 +240,9 @@ def count_e2e_room_keys(self, user_id, version):
)

@trace
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
async def delete_e2e_room_keys(
self, user_id, version, room_id=None, session_id=None
):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.

Expand All @@ -259,7 +257,7 @@ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
the backup (or for the specified room)

Returns:
A deferred of the deletion transaction
The deletion transaction
"""

keyvalues = {"user_id": user_id, "version": int(version)}
Expand All @@ -268,7 +266,7 @@ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
if session_id:
keyvalues["session_id"] = session_id

yield self.db_pool.simple_delete(
await self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)

Expand Down
Loading