Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #295 from stevenhammerton/sh-cas-auth
Browse files Browse the repository at this point in the history
Provide ability to login using CAS
  • Loading branch information
erikjohnston committed Oct 10, 2015
2 parents ce19fc0 + 95f7661 commit 347aa3c
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 21 deletions.
39 changes: 39 additions & 0 deletions synapse/config/cas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ._base import Config


class CasConfig(Config):
"""Cas Configuration
cas_server_url: URL of CAS server
"""

def read_config(self, config):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = True
self.cas_server_url = cas_config["server_url"]
else:
self.cas_enabled = False
self.cas_server_url = None

def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#cas_config:
# server_url: "https://cas-server.com"
"""
3 changes: 2 additions & 1 deletion synapse/config/homeserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
from .appservice import AppServiceConfig
from .key import KeyConfig
from .saml2 import SAML2Config
from .cas import CasConfig


class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig,
AppServiceConfig, KeyConfig, SAML2Config, ):
AppServiceConfig, KeyConfig, SAML2Config, CasConfig):
pass


Expand Down
32 changes: 32 additions & 0 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,38 @@ def login_with_password(self, user_id, password):
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))

@defer.inlineCallbacks
def login_with_cas_user_id(self, user_id):
"""
Authenticates the user with the given user ID,
intended to have been captured from a CAS response
Args:
user_id (str): User ID
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)

logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))

@defer.inlineCallbacks
def does_user_exist(self, user_id):
try:
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
except LoginError:
defer.returnValue(False)

@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case
Expand Down
55 changes: 36 additions & 19 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,40 @@ def get_json(self, uri, args={}):
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
body = yield self.get_raw(uri, args)
defer.returnValue(json.loads(body))

@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
""" Puts some json to the given URI.
Args:
uri (str): The URI to request, not including query parameters
json_body (dict): The JSON to put in the HTTP body,
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
On a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)

json_str = encode_canonical_json(json_body)

response = yield self.request(
"GET",
"PUT",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
b"User-Agent": [self.version_string],
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
)

body = yield preserve_context_over_fn(readBody, response)
Expand All @@ -183,46 +207,39 @@ def get_json(self, uri, args={}):
raise CodeMessageException(response.code, body)

@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
""" Puts some json to the given URI.
def get_raw(self, uri, args={}):
""" Gets raw text from the given URI.
Args:
uri (str): The URI to request, not including query parameters
json_body (dict): The JSON to put in the HTTP body,
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
HTTP body at text.
Raises:
On a non-2xx HTTP response.
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)

json_str = encode_canonical_json(json_body)

response = yield self.request(
"PUT",
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
b"User-Agent": [self.version_string],
})
)

body = yield preserve_context_over_fn(readBody, response)

if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
defer.returnValue(body)
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)


Expand Down
76 changes: 75 additions & 1 deletion synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from twisted.internet import defer

from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern

Expand All @@ -27,6 +28,8 @@
from saml2 import config
from saml2.client import Saml2Client

import xml.etree.ElementTree as ET


logger = logging.getLogger(__name__)

Expand All @@ -35,16 +38,23 @@ class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"

def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled

self.cas_server_url = hs.config.cas_server_url
self.servername = hs.config.server_name

def on_GET(self, request):
flows = [{"type": LoginRestServlet.PASS_TYPE}]
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.CAS_TYPE})
return (200, {"flows": flows})

def on_OPTIONS(self, request):
Expand All @@ -67,6 +77,19 @@ def on_POST(self, request):
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
}
defer.returnValue((200, result))
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
else:
raise SynapseError(400, "Bad login type.")
except KeyError:
Expand Down Expand Up @@ -100,6 +123,44 @@ def do_password_login(self, login_submission):

defer.returnValue((200, result))

@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.login_with_cas_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}

else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}

defer.returnValue((200, result))

raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)


class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$")
Expand Down Expand Up @@ -174,6 +235,17 @@ def on_POST(self, request):
defer.returnValue((200, {"status": "not_authenticated"}))


class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas")

def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url

def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})


def _parse_json(request):
try:
content = json.loads(request.content.read())
Expand All @@ -188,4 +260,6 @@ def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

0 comments on commit 347aa3c

Please sign in to comment.