Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Factor SSO success handling out of CAS login
Browse files Browse the repository at this point in the history
This is mostly factoring out the post-CAS-login code to somewhere we can reuse
it for other SSO flows, but it also fixes the userid mapping while we're at it.
  • Loading branch information
richvdh committed Dec 5, 2018
1 parent e8d9846 commit c79afcb
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 32 deletions.
1 change: 1 addition & 0 deletions changelog.d/4264.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix CAS login when username is not valid in an MXID
13 changes: 11 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,10 @@ def check_user_exists(self, user_id):
insensitively, but return None if there are multiple inexact matches.
Args:
(str) user_id: complete @user:id
(unicode|bytes) user_id: complete @user:id
Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
defer.Deferred: (unicode) canonical_user_id, or None if zero or
multiple matches
"""
res = yield self._find_user_id_and_pwd_hash(user_id)
Expand Down Expand Up @@ -954,6 +954,15 @@ def generate_access_token(self, user_id, extra_caveats=None):
return macaroon.serialize()

def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
"""
Args:
user_id (unicode):
duration_in_ms (int):
Returns:
unicode
"""
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
Expand Down
105 changes: 76 additions & 29 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@

from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_string,
)
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn

from .base import ClientV1RestServlet, client_path_patterns
Expand Down Expand Up @@ -421,17 +425,15 @@ def __init__(self, hs):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
self.macaroon_gen = hs.get_macaroon_generator()
self._sso_auth_handler = SSOAuthHandler(hs)

@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = request.args[b"redirectUrl"][0]
client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": request.args[b"ticket"][0].decode('ascii'),
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
}
try:
Expand All @@ -443,7 +445,6 @@ def on_GET(self, request):
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)

@defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)

Expand All @@ -459,28 +460,9 @@ def handle_cas_response(self, request, cas_response_body, client_redirect_url):
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

user_id = UserID(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)

login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url,
)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
finish_request(request)

def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
return urllib.parse.urlunparse(url_parts)

def parse_cas_response(self, cas_response_body):
user = None
Expand Down Expand Up @@ -515,6 +497,71 @@ def parse_cas_response(self, cas_response_body):
return user, attributes


class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
service
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_handlers().registration_handler
self._macaroon_gen = hs.get_macaroon_generator()

@defer.inlineCallbacks
def on_successful_auth(
self, username, request, client_redirect_url,
):
"""Called once the user has successfully authenticated with the SSO.
Registers the user if necessary, and then returns a redirect (with
a login token) to the client.
Args:
username (unicode|bytes): the remote user id. We'll map this onto
something sane for a MXID localpath.
request (SynapseRequest): the incoming request from the browser. We'll
respond to it with a redirect.
client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process.
Returns:
Deferred[none]: Completes once we have handled the request.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = yield self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self._registration_handler.register(
localpart=localpart,
generate_token=False,
)
)

login_token = self._macaroon_gen.generate_short_term_login_token(
registered_user_id
)
redirect_url = self._add_login_token_to_redirect_url(
client_redirect_url, login_token
)
request.redirect(redirect_url)
finish_request(request)

@staticmethod
def _add_login_token_to_redirect_url(url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)


def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled:
Expand Down
66 changes: 66 additions & 0 deletions synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import string
from collections import namedtuple

Expand Down Expand Up @@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)


UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")

# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
# localpart.
#
# It works by:
# * building a string containing the allowed characters (excluding '=')
# * escaping every special character with a backslash (to stop '-' being interpreted as a
# range operator)
# * wrapping it in a '[^...]' regex
# * converting the whole lot to a 'bytes' sequence, so that we can use it to match
# bytes rather than strings
#
NON_MXID_CHARACTER_PATTERN = re.compile(
("[^%s]" % (
re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
)).encode("ascii"),
)


def map_username_to_mxid_localpart(username, case_sensitive=False):
"""Map a username onto a string suitable for a MXID
This follows the algorithm laid out at
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
Args:
username (unicode|bytes): username to be mapped
case_sensitive (bool): true if TEST and test should be mapped
onto different mxids
Returns:
unicode: string suitable for a mxid localpart
"""
if not isinstance(username, bytes):
username = username.encode('utf-8')

# first we sort out upper-case characters
if case_sensitive:
def f1(m):
return b"_" + m.group().lower()

username = UPPER_CASE_PATTERN.sub(f1, username)
else:
username = username.lower()

# then we sort out non-ascii characters
def f2(m):
g = m.group()[0]
if isinstance(g, str):
# on python 2, we need to do a ord(). On python 3, the
# byte itself will do.
g = ord(g)
return b"=%02x" % (g,)

username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)

# we also do the =-escaping to mxids starting with an underscore.
username = re.sub(b'^_', b'=5f', username)

# we should now only have ascii bytes left, so can decode back to a
# unicode.
return username.decode('ascii')


class StreamToken(
namedtuple("Token", (
"room_key",
Expand Down
31 changes: 30 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomAlias, UserID
from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart

from tests import unittest
from tests.utils import TestHomeServer
Expand Down Expand Up @@ -79,3 +79,32 @@ def test_validate(self):
except SynapseError as exc:
self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode)


class MapUsernameTestCase(unittest.TestCase):
def testPassThrough(self):
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")

def testUpperCase(self):
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)

def testSymbols(self):
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"),
"test=3d=24=3f_1234",
)

def testLeadingUnderscore(self):
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")

def testNonAscii(self):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
self.assertEqual(
map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
"t=c3=aast",
)

0 comments on commit c79afcb

Please sign in to comment.