Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Split registration store
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Mar 2, 2018
1 parent 1a6c7cd commit fafa3e7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 72 deletions.
18 changes: 3 additions & 15 deletions synapse/replication/slave/storage/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,8 @@
# limitations under the License.

from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.registration import RegistrationStore
from synapse.storage.registration import RegistrationWorkerStore


class SlavedRegistrationStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedRegistrationStore, self).__init__(db_conn, hs)

# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
]

_query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]
class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
pass
118 changes: 61 additions & 57 deletions synapse/storage/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,70 @@

from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks


class RegistrationStore(background_updates.BackgroundUpdateStore):
class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)

@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
)

@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)

defer.returnValue(res if res else False)

def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)

txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]

return None


class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):

def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
Expand Down Expand Up @@ -187,18 +247,6 @@ def _register(
)
txn.call_after(self.is_guest.invalidate, (user_id,))

@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)

def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
Expand Down Expand Up @@ -304,34 +352,6 @@ def f(txn):

return self.runInteraction("delete_access_token", f)

@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
)

@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)

defer.returnValue(res if res else False)

@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self._simple_select_one_onecol(
Expand All @@ -344,22 +364,6 @@ def is_guest(self, user_id):

defer.returnValue(res if res else False)

def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)

txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]

return None

@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
Expand Down

0 comments on commit fafa3e7

Please sign in to comment.