Skip to content

Commit

Permalink
Merge pull request #219 from andrewheberle/patch-2
Browse files Browse the repository at this point in the history
Add option to retrieve username from HTTP header
  • Loading branch information
bugy authored May 22, 2019
2 parents 5172394 + 92b3d96 commit 409f766
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
7 changes: 6 additions & 1 deletion src/auth/identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ class IpBasedIdentification(Identification):
COOKIE_KEY = 'client_id_token'
EMPTY_TOKEN = (None, None)

def __init__(self, trusted_ips) -> None:
def __init__(self, trusted_ips, user_header_name) -> None:
self._trusted_ips = set(trusted_ips)
self._user_header_name = user_header_name

def identify(self, request_handler):
remote_ip = request_handler.request.remote_ip
Expand All @@ -43,6 +44,10 @@ def identify(self, request_handler):
if new_trusted:
if request_handler.get_cookie(self.COOKIE_KEY):
request_handler.clear_cookie(self.COOKIE_KEY)
if self._user_header_name:
user_header = request_handler.request.headers.get(self._user_header_name, None)
if user_header:
return user_header
return self._resolve_ip(request_handler)

(client_id, days_remaining) = self._read_client_token(request_handler)
Expand Down
4 changes: 4 additions & 0 deletions src/model/server_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self) -> None:
self.admin_users = []
self.max_request_size_mb = None
self.callbacks_config = None
self.user_header_name = None

def get_port(self):
return self.port
Expand Down Expand Up @@ -88,9 +89,11 @@ def from_json(conf_path, temp_folder):
if access_config:
allowed_users = access_config.get('allowed_users')
user_groups = model_helper.read_dict(access_config, 'groups')
user_header_name = access_config.get('user_header_name')
else:
allowed_users = None
user_groups = {}
user_header_name = None

auth_config = json_object.get('auth')
if auth_config:
Expand Down Expand Up @@ -119,6 +122,7 @@ def from_json(conf_path, temp_folder):
config.logging_config = parse_logging_config(json_object)
config.user_groups = user_groups
config.admin_users = admin_users
config.user_header_name = user_header_name

config.max_request_size_mb = read_int_from_config('max_request_size', json_object, default=10)

Expand Down
37 changes: 20 additions & 17 deletions src/tests/ip_idenfication_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
COOKIE_KEY = 'client_id_token'


def mock_request_handler(ip=None, x_forwarded_for=None, x_real_ip=None, saved_token=None):
def mock_request_handler(ip=None, x_forwarded_for=None, x_real_ip=None, saved_token=None, user_header_name=None, user_header_name_value=None):
handler_mock = mock_object()

handler_mock.application = mock_object()
Expand All @@ -24,6 +24,9 @@ def mock_request_handler(ip=None, x_forwarded_for=None, x_real_ip=None, saved_to

if x_real_ip:
handler_mock.request.headers['X-Real-IP'] = x_real_ip

if user_header_name and user_header_name_value:
handler_mock.request.headers[user_header_name] = user_header_name_value

cookies = {COOKIE_KEY: saved_token}

Expand Down Expand Up @@ -51,22 +54,22 @@ def clear_cookie(key):
class IpIdentificationTest(unittest.TestCase):

def test_localhost_ip_trusted_identification(self):
identification = IpBasedIdentification(['127.0.0.1'])
identification = IpBasedIdentification(['127.0.0.1'], None)
id = identification.identify(mock_request_handler(ip='127.0.0.1'))
self.assertEqual('127.0.0.1', id)

def test_some_ip_trusted_identification(self):
identification = IpBasedIdentification(['192.168.21.13'])
identification = IpBasedIdentification(['192.168.21.13'], None)
id = identification.identify(mock_request_handler(ip='192.168.21.13'))
self.assertEqual('192.168.21.13', id)

def test_ip_untrusted_identification(self):
identification = IpBasedIdentification([])
identification = IpBasedIdentification([], None)
id = identification.identify(mock_request_handler(ip='192.168.21.13'))
self.assertNotEqual('192.168.21.13', id)

def test_ip_untrusted_identification_for_different_connections(self):
identification = IpBasedIdentification([])
identification = IpBasedIdentification([], None)

ids = set()
for _ in range(0, 100):
Expand All @@ -75,22 +78,22 @@ def test_ip_untrusted_identification_for_different_connections(self):
self.assertEqual(100, len(ids))

def test_ip_untrusted_identification_same_connection(self):
identification = IpBasedIdentification([])
identification = IpBasedIdentification([], None)

request_handler = mock_request_handler(ip='192.168.21.13')
id1 = identification.identify(request_handler)
id2 = identification.identify(request_handler)
self.assertEqual(id1, id2)

def test_proxied_ip_behind_trusted(self):
identification = IpBasedIdentification(['127.0.0.1'])
identification = IpBasedIdentification(['127.0.0.1'], None)

request_handler = mock_request_handler(ip='127.0.0.1', x_forwarded_for='192.168.21.13')
id = identification.identify(request_handler)
self.assertEqual('192.168.21.13', id)

def test_proxied_ip_behind_untrusted(self):
identification = IpBasedIdentification([])
identification = IpBasedIdentification([], None)

request_handler = mock_request_handler(ip='127.0.0.1', x_forwarded_for='192.168.21.13')
id = identification.identify(request_handler)
Expand All @@ -100,9 +103,9 @@ def test_proxied_ip_behind_untrusted(self):
def test_change_to_trusted(self):
request_handler = mock_request_handler(ip='192.168.21.13')

old_id = IpBasedIdentification([]).identify(request_handler)
old_id = IpBasedIdentification([], None).identify(request_handler)

trusted_identification = IpBasedIdentification(['192.168.21.13'])
trusted_identification = IpBasedIdentification(['192.168.21.13'], None)
new_id = trusted_identification.identify(request_handler)

self.assertNotEqual(old_id, new_id)
Expand All @@ -112,10 +115,10 @@ def test_change_to_trusted(self):
def test_change_to_untrusted(self):
request_handler = mock_request_handler(ip='192.168.21.13')

trusted_identification = IpBasedIdentification(['192.168.21.13'])
trusted_identification = IpBasedIdentification(['192.168.21.13'], None)
old_id = trusted_identification.identify(request_handler)

new_id = IpBasedIdentification([]).identify(request_handler)
new_id = IpBasedIdentification([], None).identify(request_handler)

self.assertNotEqual(old_id, new_id)
self.assertNotEqual(new_id, '192.168.21.13')
Expand All @@ -124,7 +127,7 @@ def test_change_to_untrusted(self):
def test_no_cookie_change_for_same_user(self):
request_handler = mock_request_handler(ip='192.168.21.13')

identification = IpBasedIdentification([])
identification = IpBasedIdentification([], None)

identification.identify(request_handler)
cookie1 = request_handler.get_cookie(COOKIE_KEY)
Expand All @@ -136,7 +139,7 @@ def test_no_cookie_change_for_same_user(self):
def test_refresh_old_cookie_with_same_id(self):
request_handler = mock_request_handler(ip='192.168.21.13')

identification = IpBasedIdentification([])
identification = IpBasedIdentification([], None)

id = '1234567'
token_expiry = str(date_utils.get_current_millis() + date_utils.days_to_ms(2))
Expand All @@ -153,7 +156,7 @@ def test_broken_token_structure(self):
request_handler = mock_request_handler(ip='192.168.21.13')
request_handler.set_secure_cookie(COOKIE_KEY, 'something')

IpBasedIdentification([]).identify(request_handler)
IpBasedIdentification([], None).identify(request_handler)

new_token = request_handler.get_cookie(COOKIE_KEY)

Expand All @@ -163,7 +166,7 @@ def test_broken_token_timestamp(self):
request_handler = mock_request_handler(ip='192.168.21.13')
request_handler.set_secure_cookie(COOKIE_KEY, 'something&hello')

id = IpBasedIdentification([]).identify(request_handler)
id = IpBasedIdentification([], None).identify(request_handler)

new_token = request_handler.get_cookie(COOKIE_KEY)

Expand All @@ -174,7 +177,7 @@ def test_old_token_timestamp(self):
request_handler = mock_request_handler(ip='192.168.21.13')
request_handler.set_secure_cookie(COOKIE_KEY, 'something&100000')

id = IpBasedIdentification([]).identify(request_handler)
id = IpBasedIdentification([], None).identify(request_handler)

new_token = request_handler.get_cookie(COOKIE_KEY)

Expand Down
2 changes: 1 addition & 1 deletion src/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def init(server_config: ServerConfig,
if auth.is_enabled():
identification = AuthBasedIdentification(auth)
else:
identification = IpBasedIdentification(server_config.trusted_ips)
identification = IpBasedIdentification(server_config.trusted_ips, server_config.user_header_name)

downloads_folder = file_download_feature.get_result_files_folder()

Expand Down

0 comments on commit 409f766

Please sign in to comment.