Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #928 from matrix-org/rav/refactor_login
Browse files Browse the repository at this point in the history
Refactor login flow
  • Loading branch information
richvdh authored Jul 18, 2016
2 parents fca90b3 + dcfd71a commit 93efcb8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 65 deletions.
106 changes: 59 additions & 47 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def get_session_data(self, session_id, key, default=None):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)

@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
Expand All @@ -240,11 +239,7 @@ def _check_password_auth(self, authdict, _):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()

if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)

defer.returnValue(user_id)
return self._check_password(user_id, password)

@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
Expand Down Expand Up @@ -348,67 +343,66 @@ def _get_session_info(self, session_id):

return self.sessions[session_id]

@defer.inlineCallbacks
def login_with_password(self, user_id, password):
def validate_password_login(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): User ID
user_id (str): complete @user:id
password (str): Password
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem storing the token.
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""

if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)

logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
return self._check_password(user_id, password)

@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id):
"""
Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
machanism (e.g. CAS), and the user_id converted to the canonical case.
Args:
user_id (str): User ID
user_id (str): canonical User ID
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)

logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
defer.returnValue((access_token, refresh_token))

@defer.inlineCallbacks
def does_user_exist(self, user_id):
def check_user_exists(self, user_id):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
(str) user_id: complete @user:id
Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches
"""
try:
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
res = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(res[0])
except LoginError:
defer.returnValue(False)
defer.returnValue(None)

@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
Expand Down Expand Up @@ -438,27 +432,45 @@ def _find_user_id_and_pwd_hash(self, user_id):

@defer.inlineCallbacks
def _check_password(self, user_id, password):
"""
"""Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns:
True if the user_id successfully authenticated
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
defer.returnValue(True)
defer.returnValue(user_id)

valid_local_password = yield self._check_local_password(user_id, password)
if valid_local_password:
defer.returnValue(True)

defer.returnValue(False)
result = yield self._check_local_password(user_id, password)
defer.returnValue(result)

@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(self.validate_hash(password, password_hash))
except LoginError:
defer.returnValue(False)
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are
multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
result = self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)

@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
Expand Down Expand Up @@ -570,7 +582,7 @@ def _check_ldap_password(self, user_id, password):
)

# check for existing account, if none exists, create one
if not (yield self.does_user_exist(user_id)):
if not (yield self.check_user_exists(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
Expand Down
41 changes: 23 additions & 18 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,13 @@ def do_password_login(self, login_submission):
).to_string()

auth_handler = self.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id = yield auth_handler.validate_password_login(
user_id=user_id,
password=login_submission["password"])

password=login_submission["password"],
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
Expand All @@ -165,7 +168,7 @@ def do_token_login(self, login_submission):
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
user_id, access_token, refresh_token = (
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
Expand Down Expand Up @@ -196,13 +199,15 @@ def do_cas_login(self, cas_response_body):

user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id
)
)
result = {
"user_id": user_id, # may have changed
"user_id": registered_user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
Expand Down Expand Up @@ -245,13 +250,13 @@ def do_jwt_login(self, login_submission):

user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(registered_user_id)
)
result = {
"user_id": user_id, # may have changed
"user_id": registered_user_id,
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
Expand Down Expand Up @@ -414,13 +419,13 @@ def handle_cas_response(self, request, cas_response_body, client_redirect_url):

user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)

login_token = auth_handler.generate_short_term_login_token(user_id)
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
Expand Down

0 comments on commit 93efcb8

Please sign in to comment.