diff --git a/tests/test_auth_plugins.py b/tests/test_auth_plugins.py new file mode 100644 index 00000000..4b3bfb5c --- /dev/null +++ b/tests/test_auth_plugins.py @@ -0,0 +1,28 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +""" Unit tests for Authentication plugins""" + +from websockify.auth_plugins import BasicHTTPAuth, AuthenticationError +import unittest + + +class BasicHTTPAuthTestCase(unittest.TestCase): + + def setUp(self): + self.plugin = BasicHTTPAuth('Aladdin:open sesame') + + def test_no_auth(self): + headers = {} + self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + + def test_invalid_password(self): + headers = {'Authorization': 'Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0'} + self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + + def test_valid_password(self): + headers = {'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='} + self.plugin.authenticate(headers, 'localhost', '1234') + + def test_garbage_auth(self): + headers = {'Authorization': 'Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx'} + self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py index 924d5de2..93f13855 100644 --- a/websockify/auth_plugins.py +++ b/websockify/auth_plugins.py @@ -30,12 +30,13 @@ def __init__(self, expected, actual): class BasicHTTPAuth(object): + """Verifies Basic Auth headers. Specify src as username:password""" + def __init__(self, src=None): self.src = src def authenticate(self, headers, target_host, target_port): import base64 - auth_header = headers.get('Authorization') if auth_header: if not auth_header.startswith('Basic '): @@ -46,18 +47,24 @@ def authenticate(self, headers, target_host, target_port): except TypeError: raise AuthenticationError(response_code=403) - user_pass = user_pass_raw.split(':', 1) + try: + # http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication + user_pass_as_text = user_pass_raw.decode('ISO-8859-1') + except UnicodeDecodeError: + raise AuthenticationError(response_code=403) + + user_pass = user_pass_as_text.split(':', 1) if len(user_pass) != 2: raise AuthenticationError(response_code=403) - if not self.validate_creds: + if not self.validate_creds(*user_pass): raise AuthenticationError(response_code=403) else: raise AuthenticationError(response_code=401, response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) - def validate_creds(username, password): + def validate_creds(self, username, password): if '%s:%s' % (username, password) == self.src: return True else: