Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Split registration store #2929

Merged
merged 1 commit into from
Mar 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions synapse/replication/slave/storage/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,8 @@
# limitations under the License.

from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.registration import RegistrationStore
from synapse.storage.registration import RegistrationWorkerStore


class SlavedRegistrationStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedRegistrationStore, self).__init__(db_conn, hs)

# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
]

_query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]
class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
pass
118 changes: 61 additions & 57 deletions synapse/storage/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,70 @@

from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks


class RegistrationStore(background_updates.BackgroundUpdateStore):
class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)

@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
)

@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)

defer.returnValue(res if res else False)

def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)

txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]

return None


class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):

def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
Expand Down Expand Up @@ -187,18 +247,6 @@ def _register(
)
txn.call_after(self.is_guest.invalidate, (user_id,))

@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)

def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
Expand Down Expand Up @@ -304,34 +352,6 @@ def f(txn):

return self.runInteraction("delete_access_token", f)

@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
)

@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)

defer.returnValue(res if res else False)

@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self._simple_select_one_onecol(
Expand All @@ -344,22 +364,6 @@ def is_guest(self, user_id):

defer.returnValue(res if res else False)

def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)

txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]

return None

@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
Expand Down