Skip to content

Commit

Permalink
improve shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
mertalev committed Dec 13, 2023
1 parent 679b22f commit ce803e7
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 48 deletions.
29 changes: 25 additions & 4 deletions machine-learning/app/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging
import os
from pathlib import Path
from socket import socket
import sys

import gunicorn
import starlette
from pydantic import BaseSettings
from rich.console import Console
from rich.logging import RichHandler

from gunicorn.arbiter import Arbiter
from uvicorn import Server
from uvicorn.workers import UvicornWorker

from .schemas import ModelType


Expand Down Expand Up @@ -69,10 +74,26 @@ def get_hf_model_name(model_name: str) -> str:
class CustomRichHandler(RichHandler):
def __init__(self) -> None:
console = Console(color_system="standard", no_color=log_settings.no_color)
super().__init__(
show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[gunicorn, starlette]
)
super().__init__(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])


log = logging.getLogger("gunicorn.access")
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))


# patches this issue https://github.com/encode/uvicorn/discussions/1803
class CustomUvicornServer(Server):
async def shutdown(self, sockets: list[socket] | None = None) -> None:
for sock in sockets or []:
sock.close()
await super().shutdown()


class CustomUvicornWorker(UvicornWorker):
async def _serve(self) -> None:
self.config.app = self.wsgi
server = CustomUvicornServer(config=self.config)
self._install_sigquit_handler()
await server.serve(sockets=self.sockets)
if not server.started:
sys.exit(Arbiter.WORKER_BOOT_ERROR)
9 changes: 4 additions & 5 deletions machine-learning/app/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from pathlib import Path
from typing import Any, Iterator
from unittest import mock

Expand All @@ -8,7 +7,7 @@
from fastapi.testclient import TestClient
from PIL import Image

from .main import app, init_state
from .main import app
from .schemas import ndarray_f32


Expand All @@ -29,9 +28,9 @@ def mock_get_model() -> Iterator[mock.Mock]:


@pytest.fixture(scope="session")
def deployed_app() -> TestClient:
init_state()
return TestClient(app)
def deployed_app() -> Iterator[TestClient]:
with TestClient(app) as client:
yield client


@pytest.fixture(scope="session")
Expand Down
85 changes: 48 additions & 37 deletions machine-learning/app/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import gc
import os
import signal
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, Iterator
from zipfile import BadZipFile

import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from starlette.formparsers import MultiPartParser
Expand All @@ -27,27 +28,47 @@
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
app = FastAPI()

model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
last_called: float | None = None

def init_state() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)

@app.on_event("startup")
def startup() -> None:
global thread_pool
log.info(
(
"Created in-memory cache with unloading "
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
)
)
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
app.state.lock = threading.Lock()
app.state.last_called = None
thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")


@app.on_event("startup")
async def startup_event() -> None:
init_state()
@app.on_event("shutdown")
def shutdown() -> None:
log.handlers.clear()
for model in model_cache.cache._cache.values():
del model
if thread_pool is not None:
thread_pool.shutdown()
gc.collect()


def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
last_called = time.time()
try:
yield
finally:
active_requests -= 1


@app.get("/", response_model=MessageResponse)
Expand All @@ -60,7 +81,7 @@ def ping() -> str:
return "pong"


@app.post("/predict")
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
model_name: str = Form(alias="modelName"),
model_type: ModelType = Form(alias="modelType"),
Expand All @@ -79,33 +100,32 @@ async def predict(
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")

model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
model = await load(await model_cache.get(model_name, model_type, **kwargs))
model.configure(**kwargs)
outputs = await run(model, inputs)
return ORJSONResponse(outputs)


async def run(model: InferenceModel, inputs: Any) -> Any:
app.state.last_called = time.time()
if app.state.thread_pool is None:
if thread_pool is None:
return model.predict(inputs)
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
return await asyncio.get_running_loop().run_in_executor(thread_pool, model.predict, inputs)


async def load(model: InferenceModel) -> InferenceModel:
if model.loaded:
return model

def _load() -> None:
with app.state.lock:
with lock:
model.load()

loop = asyncio.get_running_loop()
try:
if app.state.thread_pool is None:
if thread_pool is None:
model.load()
else:
await loop.run_in_executor(app.state.thread_pool, _load)
await loop.run_in_executor(thread_pool, _load)
return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warn(
Expand All @@ -115,32 +135,23 @@ def _load() -> None:
)
)
model.clear_cache()
if app.state.thread_pool is None:
if thread_pool is None:
model.load()
else:
await loop.run_in_executor(app.state.thread_pool, _load)
await loop.run_in_executor(thread_pool, _load)
return model


async def idle_shutdown_task() -> None:
while True:
log.debug("Checking for inactivity...")
if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl:
if (
last_called is not None
and not active_requests
and not lock.locked()
and time.time() - last_called > settings.model_ttl
):
log.info("Shutting down due to inactivity.")
loop = asyncio.get_running_loop()
for task in asyncio.all_tasks(loop):
if task is not asyncio.current_task():
try:
task.cancel()
except asyncio.CancelledError:
pass
sys.stderr.close()
sys.stdout.close()
sys.stdout = sys.stderr = open(os.devnull, "w")
try:
await app.state.model_cache.cache.clear()
gc.collect()
loop.stop()
except asyncio.CancelledError:
pass
os.kill(os.getpid(), signal.SIGINT)
break
await asyncio.sleep(settings.model_ttl_poll_s)
6 changes: 4 additions & 2 deletions machine-learning/start.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#!/usr/bin/env sh

export LD_PRELOAD="/usr/lib/$(arch)-linux-gnu/libmimalloc.so.2"
export LD_BIND_NOW=1

: "${MACHINE_LEARNING_HOST:=0.0.0.0}"
: "${MACHINE_LEARNING_PORT:=3003}"
: "${MACHINE_LEARNING_WORKERS:=1}"
: "${MACHINE_LEARNING_WORKER_TIMEOUT:=120}"

gunicorn app.main:app \
-k uvicorn.workers.UvicornWorker \
-k app.config.CustomUvicornWorker \
-w $MACHINE_LEARNING_WORKERS \
-b $MACHINE_LEARNING_HOST:$MACHINE_LEARNING_PORT \
-t $MACHINE_LEARNING_WORKER_TIMEOUT \
--log-config-json log_conf.json
--log-config-json log_conf.json \
--graceful-timeout 0

0 comments on commit ce803e7

Please sign in to comment.