Skip to content

Commit

Permalink
Improve websocket subprotocol request to backend server
Browse files Browse the repository at this point in the history
  • Loading branch information
consideRatio committed Feb 21, 2024
1 parent 141dbf0 commit 4d9cd5b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
33 changes: 26 additions & 7 deletions jupyter_server_proxy/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(self, *args, **kwargs):
"rewrite_response",
tuple(),
)
self.subprotocols = None
super().__init__(*args, **kwargs)

# Support/use jupyter_server config arguments allow_origin and allow_origin_pat
Expand Down Expand Up @@ -489,11 +488,14 @@ async def start_websocket_connection():
self.log.info(f"Trying to establish websocket connection to {client_uri}")
self._record_activity()
request = httpclient.HTTPRequest(url=client_uri, headers=headers)
subprotocols = (
[self.selected_subprotocol] if self.selected_subprotocol else None
)
self.ws = await pingable_ws_connect(
request=request,
on_message_callback=message_cb,
on_ping_callback=ping_cb,
subprotocols=self.subprotocols,
subprotocols=subprotocols,
resolver=resolver,
)
self._record_activity()
Expand Down Expand Up @@ -531,12 +533,29 @@ def check_xsrf_cookie(self):
"""

def select_subprotocol(self, subprotocols):
"""Select a single Sec-WebSocket-Protocol during handshake."""
self.subprotocols = subprotocols
if isinstance(subprotocols, list) and subprotocols:
self.log.debug(f"Client sent subprotocols: {subprotocols}")
"""
Select a single Sec-WebSocket-Protocol during handshake.
Note that this subprotocol selection should really be delegated to the
server we proxy to, but we don't! For this to happen, we would need to
delay accepting the handshake with the client until we have successfully
handshaked with the server.
Overrides `tornado.websocket.WebSocketHandler.select_subprotocol` that
includes an informative docstring:
https://github.com/tornadoweb/tornado/blob/v6.4.0/tornado/websocket.py#L337-L360.
"""
if subprotocols:
# Tornado 5.0 doesn't pass an empty list, but a list with a an empty
# string element.
if subprotocols[0] == "":
return None
self.log.debug(
f"Client sent subprotocols: {subprotocols}, selecting the first"
)
# TODO: warn if we select one out of multiple!
return subprotocols[0]
return super().select_subprotocol(subprotocols)
return None


class LocalProxyHandler(ProxyHandler):
Expand Down
6 changes: 6 additions & 0 deletions tests/resources/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,22 @@ def get(self):


class EchoWebSocket(tornado.websocket.WebSocketHandler):
"""Echoes back received messages."""

def on_message(self, message):
self.write_message(message)


class HeadersWebSocket(tornado.websocket.WebSocketHandler):
"""Echoes back incoming request headers."""

def on_message(self, message):
self.write_message(json.dumps(dict(self.request.headers)))


class SubprotocolWebSocket(tornado.websocket.WebSocketHandler):
"""Echoes back incoming requested subprotocols."""

def __init__(self, *args, **kwargs):
self._subprotocols = None
super().__init__(*args, **kwargs)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ async def _websocket_subprotocols(a_server_port_and_token: Tuple[int, str]) -> N
conn = await websocket_connect(url, subprotocols=["protocol_1", "protocol_2"])
await conn.write_message("Hello, world!")
msg = await conn.read_message()
assert json.loads(msg) == ["protocol_1", "protocol_2"]
assert json.loads(msg) == ["protocol_1"]


def test_server_proxy_websocket_subprotocols(
Expand Down Expand Up @@ -410,7 +410,9 @@ def test_bad_server_proxy_url(
assert "X-ProxyContextPath" not in r.headers


def test_callable_environment_formatting(a_server_port_and_token: Tuple[int, str]) -> None:
def test_callable_environment_formatting(
a_server_port_and_token: Tuple[int, str]
) -> None:
PORT, TOKEN = a_server_port_and_token
r = request_get(PORT, "/python-http-callable-env/test", TOKEN)
assert r.code == 200

0 comments on commit 4d9cd5b

Please sign in to comment.