Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert appservice, group server, profile and more databases to async (
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Aug 12, 2020
1 parent 9d1e494 commit a3a59ba
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 116 deletions.
1 change: 1 addition & 0 deletions changelog.d/8066.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
34 changes: 13 additions & 21 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from canonicaljson import json

from twisted.internet import defer

from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json
Expand Down Expand Up @@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
@defer.inlineCallbacks
def get_appservices_by_state(self, state):
async def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
Args:
state(ApplicationServiceState): The state to filter on.
Returns:
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
A list of ApplicationServices, which may be empty.
"""
results = yield self.db_pool.simple_select_list(
results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
Expand All @@ -147,16 +143,15 @@ def get_appservices_by_state(self, state):
services.append(service)
return services

@defer.inlineCallbacks
def get_appservice_state(self, service):
async def get_appservice_state(self, service):
"""Get the application service state.
Args:
service(ApplicationService): The service whose state to set.
Returns:
A Deferred which resolves to ApplicationServiceState.
An ApplicationServiceState.
"""
result = yield self.db_pool.simple_select_one(
result = await self.db_pool.simple_select_one(
"application_services_state",
{"as_id": service.id},
["state"],
Expand Down Expand Up @@ -270,16 +265,14 @@ def _complete_appservice_txn(txn):
"complete_appservice_txn", _complete_appservice_txn
)

@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
async def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
service.
Args:
service(ApplicationService): The app service to get the oldest txn.
Returns:
A Deferred which resolves to an AppServiceTransaction or
None.
An AppServiceTransaction or None.
"""

def _get_oldest_unsent_txn(txn):
Expand All @@ -298,7 +291,7 @@ def _get_oldest_unsent_txn(txn):

return entry

entry = yield self.db_pool.runInteraction(
entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)

Expand All @@ -307,7 +300,7 @@ def _get_oldest_unsent_txn(txn):

event_ids = db_to_json(entry["event_ids"])

events = yield self.get_events_as_list(event_ids)
events = await self.get_events_as_list(event_ids)

return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)

Expand All @@ -332,8 +325,7 @@ def set_appservice_last_pos_txn(txn):
"set_appservice_last_pos", set_appservice_last_pos_txn
)

@defer.inlineCallbacks
def get_new_events_for_appservice(self, current_id, limit):
async def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""

def get_new_events_for_appservice_txn(txn):
Expand All @@ -357,11 +349,11 @@ def get_new_events_for_appservice_txn(txn):

return upper_bound, [row[1] for row in rows]

upper_bound, event_ids = yield self.db_pool.runInteraction(
upper_bound, event_ids = await self.db_pool.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)

events = yield self.get_events_as_list(event_ids)
events = await self.get_events_as_list(event_ids)

return upper_bound, events

Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@

from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached


class FilteringStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id):
@cached(num_args=2)
async def get_user_filter(self, user_localpart, filter_id):
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
int(filter_id)
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)

def_json = yield self.db_pool.simple_select_one_onecol(
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
Expand Down
86 changes: 39 additions & 47 deletions synapse/storage/databases/main/group_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

from twisted.internet import defer
from typing import List, Optional, Tuple

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.types import JsonDict
from synapse.util import json_encoder

# The category ID for the "default" category. We don't store as null in the
Expand Down Expand Up @@ -210,9 +209,8 @@ def _get_rooms_for_summary_txn(txn):
"get_rooms_for_summary", _get_rooms_for_summary_txn
)

@defer.inlineCallbacks
def get_group_categories(self, group_id):
rows = yield self.db_pool.simple_select_list(
async def get_group_categories(self, group_id):
rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"),
Expand All @@ -227,9 +225,8 @@ def get_group_categories(self, group_id):
for row in rows
}

@defer.inlineCallbacks
def get_group_category(self, group_id, category_id):
category = yield self.db_pool.simple_select_one(
async def get_group_category(self, group_id, category_id):
category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"),
Expand All @@ -240,9 +237,8 @@ def get_group_category(self, group_id, category_id):

return category

@defer.inlineCallbacks
def get_group_roles(self, group_id):
rows = yield self.db_pool.simple_select_list(
async def get_group_roles(self, group_id):
rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"),
Expand All @@ -257,9 +253,8 @@ def get_group_roles(self, group_id):
for row in rows
}

@defer.inlineCallbacks
def get_group_role(self, group_id, role_id):
role = yield self.db_pool.simple_select_one(
async def get_group_role(self, group_id, role_id):
role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"),
Expand Down Expand Up @@ -448,12 +443,11 @@ def _get_attestations_need_renewals_txn(txn):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)

@defer.inlineCallbacks
def get_remote_attestation(self, group_id, user_id):
async def get_remote_attestation(self, group_id, user_id):
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
row = yield self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"),
Expand Down Expand Up @@ -499,13 +493,13 @@ def _get_all_groups_for_user_txn(txn):
"get_all_groups_for_user", _get_all_groups_for_user_txn
)

def get_groups_changes_for_user(self, user_id, from_token, to_token):
async def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
user_id, from_token
)
if not has_changed:
return defer.succeed([])
return []

def _get_groups_changes_for_user_txn(txn):
sql = """
Expand All @@ -525,7 +519,7 @@ def _get_groups_changes_for_user_txn(txn):
for group_id, membership, gtype, content_json in txn
]

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)

Expand Down Expand Up @@ -1087,31 +1081,31 @@ def update_group_publicity(self, group_id, user_id, publicise):
desc="update_group_publicity",
)

@defer.inlineCallbacks
def register_user_group_membership(
async def register_user_group_membership(
self,
group_id,
user_id,
membership,
is_admin=False,
content={},
local_attestation=None,
remote_attestation=None,
is_publicised=False,
):
group_id: str,
user_id: str,
membership: str,
is_admin: bool = False,
content: JsonDict = {},
local_attestation: Optional[dict] = None,
remote_attestation: Optional[dict] = None,
is_publicised: bool = False,
) -> int:
"""Registers that a local user is a member of a (local or remote) group.
Args:
group_id (str)
user_id (str)
membership (str)
is_admin (bool)
content (dict): Content of the membership, e.g. includes the inviter
group_id: The group the member is being added to.
user_id: THe user ID to add to the group.
membership: The type of group membership.
is_admin: Whether the user should be added as a group admin.
content: Content of the membership, e.g. includes the inviter
if the user has been invited.
local_attestation (dict): If remote group then store the fact that we
local_attestation: If remote group then store the fact that we
have given out an attestation, else None.
remote_attestation (dict): If remote group then store the remote
remote_attestation: If remote group then store the remote
attestation from the group, else None.
is_publicised: Whether this should be publicised.
"""

def _register_user_group_membership_txn(txn, next_id):
Expand Down Expand Up @@ -1188,18 +1182,17 @@ def _register_user_group_membership_txn(txn, next_id):
return next_id

with self._group_updates_id_gen.get_next() as next_id:
res = yield self.db_pool.runInteraction(
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
next_id,
)
return res

@defer.inlineCallbacks
def create_group(
async def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
):
yield self.db_pool.simple_insert(
) -> None:
await self.db_pool.simple_insert(
table="groups",
values={
"group_id": group_id,
Expand All @@ -1212,9 +1205,8 @@ def create_group(
desc="create_group",
)

@defer.inlineCallbacks
def update_group_profile(self, group_id, profile):
yield self.db_pool.simple_update_one(
async def update_group_profile(self, group_id, profile):
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
Expand Down
7 changes: 2 additions & 5 deletions synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,20 @@

from typing import List, Tuple

from twisted.internet import defer

from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter


class PresenceStore(SQLBaseStore):
@defer.inlineCallbacks
def update_presence(self, presence_states):
async def update_presence(self, presence_states):
stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)

with stream_ordering_manager as stream_orderings:
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
stream_orderings,
Expand Down
Loading

0 comments on commit a3a59ba

Please sign in to comment.