Skip to content

Commit

Permalink
feat: streaming updates python client [MD-246] (#8778)
Browse files Browse the repository at this point in the history
This includes a new code generator.

Our existing code generation is proto -> openapi -> tweaks -> python,
with an alternate proto -> go path for the server code.

This is more direct, it's go -> python; the go code is the source of
truth and we can write the generator in go to use go's ast package.
  • Loading branch information
rb-determined-ai authored Mar 1, 2024
1 parent 2dfc4f2 commit 592a566
Show file tree
Hide file tree
Showing 14 changed files with 1,551 additions and 92 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ proto/pkg/**/* -diff -merge linguist-generated=true
master/pkg/schemas/expconf/zgen_* -diff -merge linguist-generated=true
webui/react/src/services/api-ts-sdk/**/* -diff -merge linguist-generated=true
harness/determined/common/api/bindings.py -diff -merge linguist-generated=true
harness/determined/common/streams/wire.py -diff -merge linguist-generated=true
docs/swagger-ui/swagger-ui*js* -diff -merge
docs/swagger-ui/swagger-ui-main* diff merge
1 change: 1 addition & 0 deletions harness/.flake8
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ exclude =
build,
dist,
_gen.py,
wire.py,
tests/experiment/fixtures/ancient-checkpoints/

# We ignore F401 in __init__.py because it is expected for there to be
Expand Down
83 changes: 5 additions & 78 deletions harness/determined/cli/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
import os
import socket
import socketserver
import ssl
import sys
import threading
import time
import urllib.request
from dataclasses import dataclass
from typing import Iterator, List, Optional
from urllib import parse

import lomond

from determined.common import api
from determined.common import api, detlomond
from determined.common.api import bindings, certs


Expand All @@ -27,32 +24,6 @@ class ListenerConfig:
local_addr: str = "0.0.0.0"


class CustomSSLWebsocketSession(lomond.session.WebsocketSession): # type: ignore
"""
A session class that allows for the TLS verification mode of a WebSocket connection to be
configured.
"""

def __init__(self, socket: lomond.WebSocket, cert: Optional[certs.Cert]) -> None:
super().__init__(socket)
self.ctx = ssl.create_default_context()

self.cert_name = cert.name if cert else None

bundle = cert.bundle if cert else None
if bundle is False:
self.ctx.check_hostname = False
self.ctx.verify_mode = ssl.CERT_NONE
return

if bundle is not None:
assert isinstance(bundle, str)
self.ctx.load_verify_locations(cafile=bundle)

def _wrap_socket(self, sock: socket.SocketType, host: str) -> socket.SocketType:
return self.ctx.wrap_socket(sock, server_hostname=self.cert_name or host)


def copy_to_websocket(
ws: lomond.WebSocket, f: io.RawIOBase, ready_sem: threading.Semaphore
) -> None:
Expand Down Expand Up @@ -92,10 +63,7 @@ def copy_from_websocket(
cert: Optional[certs.Cert],
) -> None:
try:
for event in ws.connect(
ping_rate=0,
session_class=lambda socket: CustomSSLWebsocketSession(socket, cert),
):
for event in ws.connect(ping_rate=0):
if isinstance(event, lomond.events.Binary):
f.write(event.data)
elif isinstance(event, lomond.events.Ready):
Expand All @@ -118,10 +86,7 @@ def copy_from_websocket2(
cert: Optional[certs.Cert],
) -> None:
try:
for event in ws.connect(
ping_rate=0,
session_class=lambda socket: CustomSSLWebsocketSession(socket, cert),
):
for event in ws.connect(ping_rate=0):
if isinstance(event, lomond.events.Binary):
f.send(event.data)
elif isinstance(event, lomond.events.Ready):
Expand All @@ -139,36 +104,8 @@ def copy_from_websocket2(
f.close()


def maybe_upgrade_ws_scheme(master_address: str) -> str:
parsed = parse.urlparse(master_address)
if parsed.scheme == "https":
return parsed._replace(scheme="wss").geturl()
elif parsed.scheme == "http":
return parsed._replace(scheme="ws").geturl()
else:
return master_address


def http_connect_tunnel(sess: api.BaseSession, service: str) -> None:
parsed_master = parse.urlparse(sess.master)
assert parsed_master.hostname is not None, f"Failed to parse master address: {sess.master}"

# The "lomond.WebSocket()" function does not honor the "no_proxy" or
# "NO_PROXY" environment variables. To work around that, we check if
# the hostname is in the "no_proxy" or "NO_PROXY" environment variables
# ourselves using the "proxy_bypass()" function, which checks the "no_proxy"
# and "NO_PROXY" environment variables, and returns True if the host does
# not require a proxy server. The "lomond.WebSocket()" function will disable
# the proxy if the "proxies" parameter is an empty dictionary. Otherwise,
# if the "proxies" parameter is "None", it will honor the "HTTP_PROXY" and
# "HTTPS_PROXY" environment variables. When the "proxies" parameter is not
# specified, the default value is "None".
proxies = {} if urllib.request.proxy_bypass(parsed_master.hostname) else None # type: ignore

url = f"{sess.master}/proxy/{service}/"
ws = lomond.WebSocket(maybe_upgrade_ws_scheme(url), proxies=proxies)
if isinstance(sess, api.Session):
ws.add_header(b"Authorization", f"Bearer {sess.token}".encode())
ws = detlomond.WebSocket(sess, f"proxy/{service}/")

# We can't send data to the WebSocket before the connection becomes ready, which takes a bit of
# time; this semaphore lets the sending thread wait for that to happen.
Expand Down Expand Up @@ -197,20 +134,10 @@ def _http_tunnel_listener(
sess: api.BaseSession,
tunnel: ListenerConfig,
) -> socketserver.ThreadingTCPServer:
parsed_master = parse.urlparse(sess.master)
assert parsed_master.hostname is not None, f"Failed to parse master address: {sess.master}"

url = f"{sess.master}/proxy/{tunnel.service_id}/"

class TunnelHandler(socketserver.BaseRequestHandler):
def handle(self) -> None:
proxies = (
{} if urllib.request.proxy_bypass(parsed_master.hostname) else None # type: ignore
)
ws = detlomond.WebSocket(sess, f"proxy/{tunnel.service_id}/")

ws = lomond.WebSocket(maybe_upgrade_ws_scheme(url), proxies=proxies)
if isinstance(sess, api.Session):
ws.add_header(b"Authorization", f"Bearer {sess.token}".encode())
# We can't send data to the WebSocket before the connection becomes ready,
# which takes a bit of time; this semaphore lets the sending thread
# wait for that to happen.
Expand Down
84 changes: 84 additions & 0 deletions harness/determined/common/detlomond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
detlomond contains helpers for using the lomond websocket library with Determined.
"""

import socket
import ssl
from typing import Any, Optional
from urllib import parse, request

import lomond

from determined.common import api
from determined.common.api import certs


class CustomSSLWebsocketSession(lomond.session.WebsocketSession): # type: ignore
"""
A session class that allows for the TLS verification mode of a WebSocket connection to be
configured.
"""

def __init__(self, socket: lomond.WebSocket, cert: Optional[certs.Cert]) -> None:
super().__init__(socket)
self.ctx = ssl.create_default_context()

self.cert_name = cert.name if cert else None

bundle = cert.bundle if cert else None
if bundle is False:
self.ctx.check_hostname = False
self.ctx.verify_mode = ssl.CERT_NONE
return

if bundle is not None:
assert isinstance(bundle, str)
self.ctx.load_verify_locations(cafile=bundle)

def _wrap_socket(self, sock: socket.SocketType, host: str) -> socket.SocketType:
return self.ctx.wrap_socket(sock, server_hostname=self.cert_name or host)


class WebSocket(lomond.WebSocket): # type: ignore
"""
WebSocket extends lomond.WebSocket with Determined-specific features:
- support for NO_PROXY
- our custom TLS verification
- automatic authentication
"""

def __init__(self, sess: api.BaseSession, path: str, **kwargs: Any):
# The "lomond.WebSocket()" function does not honor the "no_proxy" or
# "NO_PROXY" environment variables. To work around that, we check if
# the hostname is in the "no_proxy" or "NO_PROXY" environment variables
# ourselves using the "proxy_bypass()" function, which checks the "no_proxy"
# and "NO_PROXY" environment variables, and returns True if the host does
# not require a proxy server. The "lomond.WebSocket()" function will disable
# the proxy if the "proxies" parameter is an empty dictionary. Otherwise,
# if the "proxies" parameter is "None", it will honor the "HTTP_PROXY" and
# "HTTPS_PROXY" environment variables. When the "proxies" parameter is not
# specified, the default value is "None".
parsed = parse.urlparse(sess.master)
proxies = {} if request.proxy_bypass(parsed.hostname) else None # type: ignore

# Prepare a session_class for the eventual .connect() method.
self._default_session_class = lambda socket: CustomSSLWebsocketSession(socket, sess.cert)

# Replace http with ws for a ws:// or wss:// url
assert sess.master.startswith("http"), f"unable to convert non-http url ({sess.master})"
baseurl = sess.master[4:]
super().__init__(f"ws{baseurl}/{path}", proxies=proxies, **kwargs)

# Possibly include authorization headers.
if isinstance(sess, api.Session):
self.add_header(b"Authorization", f"Bearer {sess.token}".encode())

def connect(
self,
session_class: Optional[lomond.session.WebsocketSession] = None,
*args: Any,
**kwargs: Any,
) -> Any:
session_class = session_class or self._default_session_class
return super().connect(session_class, *args, **kwargs)
9 changes: 9 additions & 0 deletions harness/determined/common/streams/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from determined.common.streams import wire
from determined.common.streams._util import range_encoded_keys
from determined.common.streams._client import (
StreamWebSocket,
LomondStreamWebSocket,
Stream,
Sync,
ProjectSpec,
)
Loading

0 comments on commit 592a566

Please sign in to comment.