From 4d9cd5b18f2aa79d296f6d40f53811fa212e229d Mon Sep 17 00:00:00 2001 From: Erik Sundell Date: Wed, 21 Feb 2024 19:58:43 +0100 Subject: [PATCH] Improve websocket subprotocol request to backend server --- jupyter_server_proxy/handlers.py | 33 +++++++++++++++++++++++++------- tests/resources/websocket.py | 6 ++++++ tests/test_proxies.py | 6 ++++-- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/jupyter_server_proxy/handlers.py b/jupyter_server_proxy/handlers.py index 2890dc3a..4f084339 100644 --- a/jupyter_server_proxy/handlers.py +++ b/jupyter_server_proxy/handlers.py @@ -116,7 +116,6 @@ def __init__(self, *args, **kwargs): "rewrite_response", tuple(), ) - self.subprotocols = None super().__init__(*args, **kwargs) # Support/use jupyter_server config arguments allow_origin and allow_origin_pat @@ -489,11 +488,14 @@ async def start_websocket_connection(): self.log.info(f"Trying to establish websocket connection to {client_uri}") self._record_activity() request = httpclient.HTTPRequest(url=client_uri, headers=headers) + subprotocols = ( + [self.selected_subprotocol] if self.selected_subprotocol else None + ) self.ws = await pingable_ws_connect( request=request, on_message_callback=message_cb, on_ping_callback=ping_cb, - subprotocols=self.subprotocols, + subprotocols=subprotocols, resolver=resolver, ) self._record_activity() @@ -531,12 +533,29 @@ def check_xsrf_cookie(self): """ def select_subprotocol(self, subprotocols): - """Select a single Sec-WebSocket-Protocol during handshake.""" - self.subprotocols = subprotocols - if isinstance(subprotocols, list) and subprotocols: - self.log.debug(f"Client sent subprotocols: {subprotocols}") + """ + Select a single Sec-WebSocket-Protocol during handshake. + + Note that this subprotocol selection should really be delegated to the + server we proxy to, but we don't! For this to happen, we would need to + delay accepting the handshake with the client until we have successfully + handshaked with the server. + + Overrides `tornado.websocket.WebSocketHandler.select_subprotocol` that + includes an informative docstring: + https://github.com/tornadoweb/tornado/blob/v6.4.0/tornado/websocket.py#L337-L360. + """ + if subprotocols: + # Tornado 5.0 doesn't pass an empty list, but a list with a an empty + # string element. + if subprotocols[0] == "": + return None + self.log.debug( + f"Client sent subprotocols: {subprotocols}, selecting the first" + ) + # TODO: warn if we select one out of multiple! return subprotocols[0] - return super().select_subprotocol(subprotocols) + return None class LocalProxyHandler(ProxyHandler): diff --git a/tests/resources/websocket.py b/tests/resources/websocket.py index dda24d7c..6d2eb413 100644 --- a/tests/resources/websocket.py +++ b/tests/resources/websocket.py @@ -54,16 +54,22 @@ def get(self): class EchoWebSocket(tornado.websocket.WebSocketHandler): + """Echoes back received messages.""" + def on_message(self, message): self.write_message(message) class HeadersWebSocket(tornado.websocket.WebSocketHandler): + """Echoes back incoming request headers.""" + def on_message(self, message): self.write_message(json.dumps(dict(self.request.headers))) class SubprotocolWebSocket(tornado.websocket.WebSocketHandler): + """Echoes back incoming requested subprotocols.""" + def __init__(self, *args, **kwargs): self._subprotocols = None super().__init__(*args, **kwargs) diff --git a/tests/test_proxies.py b/tests/test_proxies.py index 5605b4d1..8079d47d 100644 --- a/tests/test_proxies.py +++ b/tests/test_proxies.py @@ -378,7 +378,7 @@ async def _websocket_subprotocols(a_server_port_and_token: Tuple[int, str]) -> N conn = await websocket_connect(url, subprotocols=["protocol_1", "protocol_2"]) await conn.write_message("Hello, world!") msg = await conn.read_message() - assert json.loads(msg) == ["protocol_1", "protocol_2"] + assert json.loads(msg) == ["protocol_1"] def test_server_proxy_websocket_subprotocols( @@ -410,7 +410,9 @@ def test_bad_server_proxy_url( assert "X-ProxyContextPath" not in r.headers -def test_callable_environment_formatting(a_server_port_and_token: Tuple[int, str]) -> None: +def test_callable_environment_formatting( + a_server_port_and_token: Tuple[int, str] +) -> None: PORT, TOKEN = a_server_port_and_token r = request_get(PORT, "/python-http-callable-env/test", TOKEN) assert r.code == 200