Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor worker class #2651

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions backend/src/server_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from events import EventQueue
from gpu import get_nvidia_helper
from server_config import ServerConfig
from server_process_helper import ExecutorServer
from server_process_helper import WorkerServer


class AppContext:
Expand Down Expand Up @@ -51,7 +51,7 @@ def filter(self, record): # noqa: ANN001
)


executor_server: ExecutorServer = ExecutorServer()
worker: WorkerServer = WorkerServer()

setup_task = None

Expand All @@ -60,39 +60,39 @@ def filter(self, record): # noqa: ANN001

@app.route("/nodes")
async def nodes(request: Request):
resp = await executor_server.proxy_request(request)
resp = await worker.proxy_request(request)
return resp


@app.route("/run", methods=["POST"])
async def run(request: Request):
return await executor_server.proxy_request(request, timeout=None)
return await worker.proxy_request(request, timeout=None)


@app.route("/run/individual", methods=["POST"])
async def run_individual(request: Request):
logger.info("Running individual")
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/clear-cache/individual", methods=["POST"])
async def clear_cache_individual(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/pause", methods=["POST"])
async def pause(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/resume", methods=["POST"])
async def resume(request: Request):
return await executor_server.proxy_request(request, timeout=None)
return await worker.proxy_request(request, timeout=None)


@app.route("/kill", methods=["POST"])
async def kill(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/python-info", methods=["GET"])
Expand Down Expand Up @@ -131,13 +131,13 @@ async def system_usage(_request: Request):

@app.route("/packages", methods=["GET"])
async def get_packages(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/installed-dependencies", methods=["GET"])
async def get_installed_dependencies(request: Request):
installed_deps: dict[str, str] = {}
packages = await executor_server.get_packages()
packages = await worker.get_packages()
for package in packages:
for pkg_dep in package.dependencies:
installed_version = installed_packages.get(pkg_dep.pypi_name, None)
Expand All @@ -149,7 +149,7 @@ async def get_installed_dependencies(request: Request):

@app.route("/features")
async def get_features(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.get("/sse")
Expand All @@ -158,7 +158,7 @@ async def sse(request: Request):
response = await request.respond(headers=headers, content_type="text/event-stream")
while True:
try:
async for data in executor_server.get_sse(request):
async for data in worker.get_sse(request):
if response is not None:
await response.send(data)
except Exception:
Expand Down Expand Up @@ -196,7 +196,7 @@ async def install_deps(dependencies: list[api.Dependency]):
]
await install_dependencies(dep_info, update_progress_cb, logger)

packages = await executor_server.get_packages()
packages = await worker.get_packages()

logger.info("Checking dependencies...")

Expand Down Expand Up @@ -228,7 +228,7 @@ async def install_deps(dependencies: list[api.Dependency]):
if config.close_after_start:
flags.append("--close-after-start")

await executor_server.restart(flags)
await worker.restart(flags)
except Exception as ex:
logger.error(f"Error installing dependencies: {ex}", exc_info=True)
if config.close_after_start:
Expand Down Expand Up @@ -277,7 +277,7 @@ async def update_progress(
await update_progress("Loading Nodes...", 1.0, None)

# Wait to send backend-ready until nodes are loaded
await executor_server.wait_for_backend_ready()
await worker.wait_for_ready()

await setup_queue.put_and_wait(
{
Expand Down Expand Up @@ -305,26 +305,26 @@ async def close_server(sanic_app: Sanic):
except Exception as ex:
logger.error(f"Error waiting for server to start: {ex}")

await executor_server.stop()
await worker.stop()
sanic_app.stop()


@app.after_server_stop
async def after_server_stop(_sanic_app: Sanic, _loop: asyncio.AbstractEventLoop):
await executor_server.stop()
await worker.stop()
logger.info("Server closed.")


@app.after_server_start
async def after_server_start(sanic_app: Sanic, loop: asyncio.AbstractEventLoop):
global setup_task
await executor_server.start()
await worker.start()

# initialize the queues
ctx = AppContext.get(sanic_app)
ctx.setup_queue = EventQueue()

await executor_server.wait_for_server_start()
await worker.wait_for_ready()

# start the setup task
setup_task = loop.create_task(setup(sanic_app, loop))
Expand Down
132 changes: 65 additions & 67 deletions backend/src/server_process_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import subprocess
import sys
import threading
import time
from typing import Iterable

import aiohttp
from sanic import HTTPResponse, Request
Expand All @@ -14,104 +16,102 @@
from api import Package


def find_free_port():
def _find_free_port():
with socket.socket() as s:
s.bind(("", 0)) # Bind to a free port provided by the host.
return s.getsockname()[1] # Return the port number assigned.


class ExecutorServerWorker:
def __init__(self, port: int, flags: list[str] | None = None):
self.process = None
self.stop_event = threading.Event()
self.finished_starting = False
def _port_in_use(port: int):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("127.0.0.1", port)) == 0

self.port = port
self.flags = flags or []

def start_process(self):
class _WorkerProcess:
def __init__(self, flags: list[str]):
server_file = os.path.join(os.path.dirname(__file__), "server.py")
python_location = sys.executable
self.process = subprocess.Popen(
[python_location, server_file, str(self.port), *self.flags],

self._process = subprocess.Popen(
[python_location, server_file, *flags],
shell=False,
stdin=None,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
self._stop_event = threading.Event()

# Create a separate thread to read and print the output of the subprocess
threading.Thread(
target=self._read_output, daemon=True, name="output reader"
target=self._read_output,
daemon=True,
name="output reader",
).start()

def stop_process(self):
if self.process:
self.stop_event.set()
self.process.terminate()
self.process.kill()
def close(self):
self._stop_event.set()
self._process.terminate()
self._process.kill()

def _read_output(self):
if self.process is None or self.process.stdout is None:
if self._process.stdout is None:
return
for line in self.process.stdout:
if self.stop_event.is_set():
for line in self._process.stdout:
if self._stop_event.is_set():
break
if not self.finished_starting:
if "Starting worker" in line.decode():
self.finished_starting = True
print(line.decode().strip())


class ExecutorServer:
def __init__(self, flags: list[str] | None = None):
self.flags = flags

self.server_process = None
class WorkerServer:
def __init__(self):
self._process = None

self.port = find_free_port()
self.base_url = f"http://127.0.0.1:{self.port}"
self.session = None
self._port = _find_free_port()
self._base_url = f"http://127.0.0.1:{self._port}"
self._session = None

self.backend_ready = False

async def start(self, flags: list[str] | None = None):
del self.server_process
self.server_process = ExecutorServerWorker(self.port, flags or self.flags)
self.server_process.start_process()
self.session = aiohttp.ClientSession(base_url=self.base_url)
await self.wait_for_server_start()
await self.session.get("/nodes", timeout=None)
self.backend_ready = True
return self
async def start(self, flags: Iterable[str] = []):
logger.info("Starting worker process...")
self._process = _WorkerProcess([str(self._port), *flags])
self._session = aiohttp.ClientSession(base_url=self._base_url)
await self.wait_for_ready()
logger.info("Worker process started")

async def stop(self):
if self.server_process:
self.server_process.stop_process()
if self.session:
await self.session.close()
if self._process:
self._process.close()
if self._session:
await self._session.close()
logger.info("Worker process stopped")

async def restart(self, flags: list[str] | None = None):
async def restart(self, flags: Iterable[str] = []):
await self.stop()
await self.start(flags)

async def wait_for_server_start(self):
while (
self.server_process is None
or self.server_process.finished_starting is False
):
await asyncio.sleep(0.1)
async def wait_for_ready(self, timeout: float = 300):
start = time.time()
while time.time() - start < timeout:
if (
self._process is not None
and self._session is not None
and _port_in_use(self._port)
):
try:
await self._session.get("/nodes", timeout=5)
return
except Exception:
pass

async def wait_for_backend_ready(self):
while not self.backend_ready:
await asyncio.sleep(0.1)

raise TimeoutError("Server did not start in time")

async def proxy_request(self, request: Request, timeout: int | None = 300):
assert self.session is not None
await self.wait_for_server_start()
await self.wait_for_backend_ready()
await self.wait_for_ready()
assert self._session is not None
if request.route is None:
raise ValueError("Route not found")
async with self.session.request(
async with self._session.request(
request.method,
f"/{request.route.path}",
headers=request.headers,
Expand All @@ -129,10 +129,9 @@ async def proxy_request(self, request: Request, timeout: int | None = 300):
)

async def get_sse(self, request: Request):
assert self.session is not None
await self.wait_for_server_start()
await self.wait_for_backend_ready()
async with self.session.request(
await self.wait_for_ready()
assert self._session is not None
async with self._session.request(
request.method,
"/sse",
headers=request.headers,
Expand All @@ -143,11 +142,10 @@ async def get_sse(self, request: Request):
yield data

async def get_packages(self):
await self.wait_for_server_start()
await self.wait_for_backend_ready()
assert self.session is not None
await self.wait_for_ready()
assert self._session is not None
logger.debug("Fetching packages...")
packages_resp = await self.session.get(
packages_resp = await self._session.get(
"/packages", params={"hideInternal": "false"}
)
packages_json = await packages_resp.json()
Expand Down
Loading