diff --git a/src/tests/web/server_test.py b/src/tests/web/server_test.py index f0f1ffe7..9a8c39de 100644 --- a/src/tests/web/server_test.py +++ b/src/tests/web/server_test.py @@ -6,6 +6,7 @@ from unittest.mock import patch import requests +from parameterized import parameterized from tornado.ioloop import IOLoop from auth.authorization import Authorizer, ANY_USER, EmptyGroupProvider @@ -88,14 +89,16 @@ def test_get_scripts(self): {'name': 's3', 'group': None}], response['scripts']) - # Disabled for now - # def test_redirect_honors_protocol_header(self): - # self.start_server(12345, '127.0.0.1') - # - # response = requests.get('http://127.0.0.1:12345/', - # allow_redirects=False, - # headers={'X-Forwarded-Proto': 'https'}) - # self.assertRegex(response.headers['Location'], '^https') + @parameterized.expand([ + ('X-Forwarded-Proto',), + ('X-Scheme',)]) + def test_redirect_honors_protocol_header(self, header): + self.start_server(12345, '127.0.0.1') + + response = requests.get('http://127.0.0.1:12345/', + allow_redirects=False, + headers={header: 'https'}) + self.assertRegex(response.headers['Location'], '^https') def request(self, method, url): response = requests.request(method, url) diff --git a/src/web/server.py b/src/web/server.py index ae00f87c..1ae5f9c5 100755 --- a/src/web/server.py +++ b/src/web/server.py @@ -13,6 +13,7 @@ import tornado.escape import tornado.httpserver as httpserver import tornado.ioloop +import tornado.routing import tornado.web import tornado.websocket @@ -42,6 +43,7 @@ from web.streaming_form_reader import StreamingFormReader from web.web_auth_utils import check_authorization from web.web_utils import wrap_to_server_event, identify_user, inject_user, get_user +from web.xheader_app_wrapper import autoapply_xheaders BYTES_IN_MB = 1024 * 1024 @@ -788,6 +790,7 @@ def init(server_config: ServerConfig, } application = tornado.web.Application(handlers, **settings) + autoapply_xheaders(application) application.auth = auth diff --git a/src/web/xheader_app_wrapper.py b/src/web/xheader_app_wrapper.py new file mode 100644 index 00000000..ffd9a140 --- /dev/null +++ b/src/web/xheader_app_wrapper.py @@ -0,0 +1,37 @@ +import types + + +def _start_request_decorator(func): + def wrapper(self, *args, **kwargs): + delegate = func(*args, **kwargs) + + _decorate(delegate, 'headers_received', _headers_received_decorator) + + return delegate + + return wrapper + + +def _headers_received_decorator(func): + def wrapper(self, start_line, headers, *args, **kwargs): + proto_header = headers.get('X-Scheme', headers.get('X-Forwarded-Proto')) + + if proto_header: + # use only the last proto entry if there is more than one + proto_header = proto_header.split(',')[-1].strip() + if proto_header in ('http', 'https'): + self.request_conn.context.protocol = proto_header + + return func(start_line, headers, *args, **kwargs) + + return wrapper + + +def _decorate(obj, method_name, decorator): + original_method = getattr(obj, method_name) + new_method = types.MethodType(decorator(original_method), obj) + setattr(obj, method_name, new_method) + + +def autoapply_xheaders(application): + _decorate(application, 'start_request', _start_request_decorator)