Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #2727 from matrix-org/rav/refactor_ui_auth_return
Browse files Browse the repository at this point in the history
Refactor UI auth implementation
  • Loading branch information
richvdh authored Dec 5, 2017
2 parents 58ebdb0 + d5f9fb0 commit aa6ecf0
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 48 deletions.
16 changes: 16 additions & 0 deletions synapse/api/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
pass


class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete
(This indicates we should return a 401 with 'result' as the body)
Attributes:
result (dict): the server response to the request, which should be
passed back to the client
"""
def __init__(self, result):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete",
)
self.result = result


class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
Expand Down
46 changes: 29 additions & 17 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.api.errors import (
AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
SynapseError,
)
from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor
Expand Down Expand Up @@ -95,26 +98,36 @@ def check_auth(self, flows, clientdict, clientip):
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
If no auth flows have been completed successfully, raises an
InteractiveAuthIncompleteError. To handle this, you can use
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
decorator.
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns:
A tuple of (authed, dict, dict, session_id) where authed is true if
the client has successfully completed an auth flow. If it is true
the first dict contains the authenticated credentials of each stage.
defer.Deferred[dict, dict, str]: a deferred tuple of
(creds, params, session_id).
If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.
'creds' contains the authenticated credentials of each stage.
In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).
'params' contains the parameters for this request (which may
have been given only in a previous call).
session_id is the ID of this session, either passed in by the client
or assigned by the call to check_auth
'session_id' is the ID of this session, either passed in by the
client or assigned by this call
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
all the stages in any of the permitted flows.
"""

authdict = None
Expand Down Expand Up @@ -142,11 +155,8 @@ def check_auth(self, flows, clientdict, clientip):
clientdict = session['clientdict']

if not authdict:
defer.returnValue(
(
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session),
)

if 'creds' not in session:
Expand Down Expand Up @@ -190,12 +200,14 @@ def check_auth(self, flows, clientdict, clientip):
"Auth completed with creds: %r. Client dict has keys: %r",
creds, clientdict.keys()
)
defer.returnValue((True, creds, clientdict, session['id']))
defer.returnValue((creds, clientdict, session['id']))

ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id']))
raise InteractiveAuthIncompleteError(
ret,
)

@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
Expand Down
41 changes: 38 additions & 3 deletions synapse/rest/client/v2_alpha/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

"""This module contains base REST classes for constructing client v1 servlets.
"""

from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
import logging
import re

import logging
from twisted.internet import defer

from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'],
filter_timeline_limit)


def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns a deferred (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
# ...
yield self.auth_handler.check_auth
"""
def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs)
res.addErrback(_catch_incomplete_interactive_auth)
return res
return wrapped


def _catch_incomplete_interactive_auth(f):
"""helper for interactive_auth_handler
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
Args:
f (failure.Failure):
"""
f.trap(InteractiveAuthIncompleteError)
return 401, f.value.result
14 changes: 5 additions & 9 deletions synapse/rest/client/v2_alpha/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,21 +100,19 @@ def __init__(self, hs):
self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()

@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()

body = parse_json_object_from_request(request)

authed, result, params, _ = yield self.auth_handler.check_auth([
result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY],
[LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request))

if not authed:
defer.returnValue((401, result))

user_id = None
requester = None

Expand Down Expand Up @@ -168,6 +166,7 @@ def __init__(self, hs):
self.auth_handler = hs.get_auth_handler()
self._deactivate_account_handler = hs.get_deactivate_account_handler()

@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
Expand All @@ -186,13 +185,10 @@ def on_POST(self, request):
)
defer.returnValue((200, {}))

authed, result, params, _ = yield self.auth_handler.check_auth([
result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))

if not authed:
defer.returnValue((401, result))

if LoginType.PASSWORD in result:
user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in
Expand Down
14 changes: 5 additions & 9 deletions synapse/rest/client/v2_alpha/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from synapse.api import constants, errors
from synapse.http import servlet
from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,6 +60,7 @@ def __init__(self, hs):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()

@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
try:
Expand All @@ -77,13 +78,10 @@ def on_POST(self, request):
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)

authed, result, params, _ = yield self.auth_handler.check_auth([
result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))

if not authed:
defer.returnValue((401, result))

requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices(
requester.user.to_string(),
Expand Down Expand Up @@ -115,6 +113,7 @@ def on_GET(self, request, device_id):
)
defer.returnValue((200, device))

@interactive_auth_handler
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
Expand All @@ -130,13 +129,10 @@ def on_DELETE(self, request, device_id):
else:
raise

authed, result, params, _ = yield self.auth_handler.check_auth([
result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))

if not authed:
defer.returnValue((401, result))

# check that the UI auth matched the access token
user_id = result[constants.LoginType.PASSWORD]
if user_id != requester.user.to_string():
Expand Down
9 changes: 3 additions & 6 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from synapse.util.msisdn import phone_number_to_msisdn

from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler

import logging
import hmac
Expand Down Expand Up @@ -176,6 +176,7 @@ def __init__(self, hs):
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()

@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
Expand Down Expand Up @@ -325,14 +326,10 @@ def on_POST(self, request):
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])

authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)

if not authed:
defer.returnValue((401, auth_result))
return

if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
Expand Down
11 changes: 7 additions & 4 deletions tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from twisted.python import failure

from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
from twisted.internet import defer
from mock import Mock
from tests import unittest
Expand All @@ -24,7 +26,7 @@ def setUp(self):
side_effect=lambda x: self.appservice)
)

self.auth_result = (False, None, None, None)
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
Expand Down Expand Up @@ -86,6 +88,7 @@ def test_POST_appservice_registration_invalid(self):
self.request.args = {
"access_token": "i_am_an_app_service"
}

self.request_data = json.dumps({
"username": "kermit"
})
Expand Down Expand Up @@ -120,7 +123,7 @@ def test_POST_user_valid(self):
"device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)
Expand Down Expand Up @@ -150,7 +153,7 @@ def test_POST_disabled_registration(self):
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)
Expand Down

0 comments on commit aa6ecf0

Please sign in to comment.