Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support stream api #479

Merged
merged 12 commits into from
Aug 29, 2024
18 changes: 7 additions & 11 deletions jupyter_server_proxy/config.py
Original file line number Diff line number Diff line change
@@ -54,7 +54,11 @@ def _make_proxy_handler(sp: ServerProcess):
Create an appropriate handler with given parameters
"""
if sp.command:
cls = SuperviseAndRawSocketHandler if sp.raw_socket_proxy else SuperviseAndProxyHandler
cls = (
SuperviseAndRawSocketHandler
if sp.raw_socket_proxy
else SuperviseAndProxyHandler
)
args = dict(state={})
elif not (sp.port or isinstance(sp.unix_socket, str)):
warn(
@@ -122,13 +126,7 @@ def make_handlers(base_url, server_processes):
handler = _make_proxy_handler(sp)
if not handler:
continue
handlers.append(
(
ujoin(base_url, sp.name, r"(.*)"),
handler,
handler.kwargs
)
)
handlers.append((ujoin(base_url, sp.name, r"(.*)"), handler, handler.kwargs))
handlers.append((ujoin(base_url, sp.name), AddSlashHandler))
return handlers

@@ -159,9 +157,7 @@ def make_server_process(name, server_process_config, serverproxy_config):
"rewrite_response",
tuple(),
),
update_last_activity=server_process_config.get(
"update_last_activity", True
),
update_last_activity=server_process_config.get("update_last_activity", True),
raw_socket_proxy=server_process_config.get("raw_socket_proxy", False),
)

96 changes: 93 additions & 3 deletions jupyter_server_proxy/handlers.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
"""

import os
import re
import socket
from asyncio import Lock
from copy import copy
@@ -287,7 +288,7 @@ def get_client_uri(self, protocol, host, port, proxied_path):

return client_uri

def _build_proxy_request(self, host, port, proxied_path, body):
def _build_proxy_request(self, host, port, proxied_path, body, **extra_opts):
headers = self.proxy_request_headers()

client_uri = self.get_client_uri("http", host, port, proxied_path)
@@ -307,6 +308,7 @@ def _build_proxy_request(self, host, port, proxied_path, body):
decompress_response=False,
headers=headers,
**self.proxy_request_options(),
**extra_opts,
)
return req

@@ -365,7 +367,6 @@ async def proxy(self, host, port, proxied_path):
body = b""
else:
body = None

if self.unix_socket is not None:
# Port points to a Unix domain socket
self.log.debug("Making client for Unix socket %r", self.unix_socket)
@@ -374,8 +375,97 @@ async def proxy(self, host, port, proxied_path):
force_instance=True, resolver=UnixResolver(self.unix_socket)
)
else:
client = httpclient.AsyncHTTPClient()
client = httpclient.AsyncHTTPClient(force_instance=True)
# check if the request is stream request
accept_header = self.request.headers.get("Accept")
if accept_header == "text/event-stream":
return await self._proxy_progressive(host, port, proxied_path, body, client)
else:
return await self._proxy_buffered(host, port, proxied_path, body, client)

async def _proxy_progressive(self, host, port, proxied_path, body, client):
# Proxy in progressive flush mode, whenever chunks are received. Potentially slower but get results quicker for voila
# Set up handlers so we can progressively flush result

headers_raw = []

def dump_headers(headers_raw):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move this function outside _proxy_progressive. headers_raw as used inside this function is an arg, while there's also a local variable with the same name. Moving it outside cleans this up a little.

for line in headers_raw:
r = re.match("^([a-zA-Z0-9\\-_]+)\\s*\\:\\s*([^\r\n]+)[\r\n]*$", line)
if r:
k, v = r.groups([1, 2])
if k not in (
"Content-Length",
"Transfer-Encoding",
"Content-Encoding",
"Connection",
):
# some header appear multiple times, eg 'Set-Cookie'
self.set_header(k, v)
else:
r = re.match(r"^HTTP[^\s]* ([0-9]+)", line)
if r:
status_code = r.group(1)
self.set_status(int(status_code))
headers_raw.clear()

# clear tornado default header
self._headers = httputil.HTTPHeaders()

def header_callback(line):
headers_raw.append(line)

def streaming_callback(chunk):
# record activity at start and end of requests
self._record_activity()
# Do this here, not in header_callback so we can be sure headers are out of the way first
dump_headers(
headers_raw
) # array will be empty if this was already called before
self.write(chunk)
self.flush()

# Now make the request

req = self._build_proxy_request(
host,
port,
proxied_path,
body,
streaming_callback=streaming_callback,
header_callback=header_callback,
)

# no timeout for stream api
req.request_timeout = 7200
req.connect_timeout = 600

try:
response = await client.fetch(req, raise_error=False)
except httpclient.HTTPError as err:
if err.code == 599:
self._record_activity()
self.set_status(599)
self.write(str(err))
return
else:
raise

# For all non http errors...
if response.error and type(response.error) is not httpclient.HTTPError:
self.set_status(500)
self.write(str(response.error))
else:
self.set_status(
response.code, response.reason
) # Should already have been set

dump_headers(headers_raw) # Should already have been emptied

if response.body: # Likewise, should already be chunked out and flushed
self.write(response.body)

async def _proxy_buffered(self, host, port, proxied_path, body, client):
req = self._build_proxy_request(host, port, proxied_path, body)

self.log.debug(f"Proxying request to {req.url}")
27 changes: 20 additions & 7 deletions jupyter_server_proxy/rawsocket.py
Original file line number Diff line number Diff line change
@@ -9,15 +9,17 @@

import asyncio

from .handlers import NamedLocalProxyHandler, SuperviseAndProxyHandler
from tornado import web

from .handlers import NamedLocalProxyHandler, SuperviseAndProxyHandler


class RawSocketProtocol(asyncio.Protocol):
"""
A protocol handler for the proxied stream connection.
Sends any received blocks directly as websocket messages.
"""

def __init__(self, handler):
self.handler = handler

@@ -30,14 +32,18 @@ def data_received(self, data):

def connection_lost(self, exc):
"Close the websocket connection."
self.handler.log.info(f"Raw websocket {self.handler.name} connection lost: {exc}")
self.handler.log.info(
f"Raw websocket {self.handler.name} connection lost: {exc}"
)
self.handler.close()


class RawSocketHandler(NamedLocalProxyHandler):
"""
HTTP handler that proxies websocket connections into a backend stream.
All other HTTP requests return 405.
"""

def _create_ws_connection(self, proto: asyncio.BaseProtocol):
"Create the appropriate backend asyncio connection"
loop = asyncio.get_running_loop()
@@ -46,17 +52,21 @@ def _create_ws_connection(self, proto: asyncio.BaseProtocol):
return loop.create_unix_connection(proto, self.unix_socket)
else:
self.log.info(f"RawSocket {self.name} connecting to port {self.port}")
return loop.create_connection(proto, 'localhost', self.port)
return loop.create_connection(proto, "localhost", self.port)

async def proxy(self, port, path):
raise web.HTTPError(405, "this raw_socket_proxy backend only supports websocket connections")
raise web.HTTPError(
405, "this raw_socket_proxy backend only supports websocket connections"
)

async def proxy_open(self, host, port, proxied_path=""):
"""
Open the backend connection. host and port are ignored (as they are in
the parent for unix sockets) since they are always passed known values.
"""
transp, proto = await self._create_ws_connection(lambda: RawSocketProtocol(self))
transp, proto = await self._create_ws_connection(
lambda: RawSocketProtocol(self)
)
self.ws_transp = transp
self.ws_proto = proto
self._record_activity()
@@ -66,8 +76,10 @@ def on_message(self, message):
"Send websocket messages as stream writes, encoding if necessary."
self._record_activity()
if isinstance(message, str):
message = message.encode('utf-8')
self.ws_transp.write(message) # buffered non-blocking. should block (needs new enough tornado)
message = message.encode("utf-8")
self.ws_transp.write(
message
) # buffered non-blocking. should block (needs new enough tornado)

def on_ping(self, message):
"No-op"
@@ -79,6 +91,7 @@ def on_close(self):
if hasattr(self, "ws_transp"):
self.ws_transp.close()


class SuperviseAndRawSocketHandler(SuperviseAndProxyHandler, RawSocketHandler):
async def _http_ready_func(self, p):
# not really HTTP here, just try an empty connection
36 changes: 36 additions & 0 deletions tests/resources/eventstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio

import tornado.escape
import tornado.ioloop
import tornado.options
import tornado.web
import tornado.websocket
from tornado.options import define, options


class Application(tornado.web.Application):
def __init__(self):
handlers = [
(r"/stream/(\d+)", StreamHandler),
]
super().__init__(handlers)


class StreamHandler(tornado.web.RequestHandler):
async def get(self, seconds):
for i in range(int(seconds)):
await asyncio.sleep(0.5)
self.write(f"data: {i}\n\n")
await self.flush()


def main():
define("port", default=8888, help="run on the given port", type=int)
options.parse_command_line()
app = Application()
app.listen(options.port)
tornado.ioloop.IOLoop.current().start()


if __name__ == "__main__":
main()
13 changes: 8 additions & 5 deletions tests/resources/jupyter_server_config.py
Original file line number Diff line number Diff line change
@@ -42,10 +42,10 @@ def cats_only(response, path):
response.code = 403
response.body = b"dogs not allowed"


def my_env():
return {
"MYVAR": "String with escaped {{var}}"
}
return {"MYVAR": "String with escaped {{var}}"}


c.ServerProxy.servers = {
"python-http": {
@@ -79,6 +79,9 @@ def my_env():
"X-Custom-Header": "pytest-23456",
},
},
"python-eventstream": {
"command": [sys.executable, "./tests/resources/eventstream.py", "--port={port}"]
},
"python-unix-socket-true": {
"command": [
sys.executable,
@@ -129,12 +132,12 @@ def my_env():
"python-proxyto54321-no-command": {"port": 54321},
"python-rawsocket-tcp": {
"command": [sys.executable, "./tests/resources/rawsocket.py", "{port}"],
"raw_socket_proxy": True
"raw_socket_proxy": True,
},
"python-rawsocket-unix": {
"command": [sys.executable, "./tests/resources/rawsocket.py", "{unix_socket}"],
"unix_socket": True,
"raw_socket_proxy": True
"raw_socket_proxy": True,
},
}

3 changes: 1 addition & 2 deletions tests/resources/rawsocket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

import os
import socket
import sys

@@ -11,7 +10,7 @@
try:
port = int(where)
family = socket.AF_INET
addr = ('localhost', port)
addr = ("localhost", port)
except ValueError:
family = socket.AF_UNIX
addr = where
Loading