Skip to content

Commit

Permalink
feat(ml): composable ml (#9973)
Browse files Browse the repository at this point in the history
* modularize model classes

* various fixes

* expose port

* change response

* round coordinates

* simplify preload

* update server

* simplify interface

simplify

* update tests

* composable endpoint

* cleanup

fixes

remove unnecessary interface

support text input, cleanup

* ew camelcase

* update server

server fixes

fix typing

* ml fixes

update locustfile

fixes

* cleaner response

* better repo response

* update tests

formatting and typing

rename

* undo compose change

* linting

fix type

actually fix typing

* stricter typing

fix detection-only response

no need for defaultdict

* update spec file

update api

linting

* update e2e

* unnecessary dimension

* remove commented code

* remove duplicate code

* remove unused imports

* add batch dim
  • Loading branch information
mertalev authored Jun 7, 2024
1 parent 7a46f80 commit 2b1b43a
Show file tree
Hide file tree
Showing 39 changed files with 986 additions and 1,003 deletions.
12 changes: 1 addition & 11 deletions machine-learning/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
from uvicorn import Server
from uvicorn.workers import UvicornWorker

from .schemas import ModelType


class PreloadModelData(BaseModel):
clip: str | None
facial_recognition: str | None


class Settings(BaseSettings):
cache_folder: str = "/cache"
cache_folder: Path = Path("/cache")
model_ttl: int = 300
model_ttl_poll_s: int = 10
host: str = "0.0.0.0"
Expand Down Expand Up @@ -55,14 +53,6 @@ def clean_name(model_name: str) -> str:
return model_name.split("/")[-1].translate(_clean_name)


def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
return Path(settings.cache_folder) / model_type.value / clean_name(model_name)


def get_hf_model_name(model_name: str) -> str:
return f"immich-app/{clean_name(model_name)}"


LOG_LEVELS: dict[str, int] = {
"critical": logging.ERROR,
"error": logging.ERROR,
Expand Down
117 changes: 90 additions & 27 deletions machine-learning/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,34 @@
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from functools import partial
from typing import Any, AsyncGenerator, Callable, Iterator
from zipfile import BadZipFile

import orjson
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
from pydantic import ValidationError
from starlette.formparsers import MultiPartParser

from app.models import get_model_deps
from app.models.base import InferenceModel
from app.models.transforms import decode_pil

from .config import PreloadModelData, log, settings
from .models.cache import ModelCache
from .schemas import (
InferenceEntries,
InferenceEntry,
InferenceResponse,
MessageResponse,
ModelIdentity,
ModelTask,
ModelType,
PipelineRequest,
T,
TextResponse,
)

Expand Down Expand Up @@ -63,12 +75,21 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
gc.collect()


async def preload_models(preload_models: PreloadModelData) -> None:
log.info(f"Preloading models: {preload_models}")
if preload_models.clip is not None:
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
if preload_models.facial_recognition is not None:
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
async def preload_models(preload: PreloadModelData) -> None:
log.info(f"Preloading models: {preload}")
if preload.clip is not None:
model = await model_cache.get(preload.clip, ModelType.TEXTUAL, ModelTask.SEARCH)
await load(model)

model = await model_cache.get(preload.clip, ModelType.VISUAL, ModelTask.SEARCH)
await load(model)

if preload.facial_recognition is not None:
model = await model_cache.get(preload.facial_recognition, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
await load(model)

model = await model_cache.get(preload.facial_recognition, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
await load(model)


def update_state() -> Iterator[None]:
Expand All @@ -81,6 +102,27 @@ def update_state() -> Iterator[None]:
active_requests -= 1


def get_entries(entries: str = Form()) -> InferenceEntries:
try:
request: PipelineRequest = orjson.loads(entries)
without_deps: list[InferenceEntry] = []
with_deps: list[InferenceEntry] = []
for task, types in request.items():
for type, entry in types.items():
parsed: InferenceEntry = {
"name": entry["modelName"],
"task": task,
"type": type,
"options": entry.get("options", {}),
}
dep = get_model_deps(parsed["name"], type, task)
(with_deps if dep else without_deps).append(parsed)
return without_deps, with_deps
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
log.error(f"Invalid request format: {e}")
raise HTTPException(422, "Invalid request format.")


app = FastAPI(lifespan=lifespan)


Expand All @@ -96,42 +138,63 @@ def ping() -> str:

@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
model_name: str = Form(alias="modelName"),
model_type: ModelType = Form(alias="modelType"),
options: str = Form(default="{}"),
entries: InferenceEntries = Depends(get_entries),
image: bytes | None = File(default=None),
text: str | None = Form(default=None),
image: UploadFile | None = None,
) -> Any:
if image is not None:
inputs: str | bytes = await image.read()
inputs: Image | str = await run(lambda: decode_pil(image))
elif text is not None:
inputs = text
else:
raise HTTPException(400, "Either image or text must be provided")
try:
kwargs = orjson.loads(options)
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")

model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
model.configure(**kwargs)
outputs = await run(model.predict, inputs)
return ORJSONResponse(outputs)


async def run(func: Callable[..., Any], inputs: Any) -> Any:
response = await run_inference(inputs, entries)
return ORJSONResponse(response)


async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
outputs: dict[ModelIdentity, Any] = {}
response: InferenceResponse = {}

async def _run_inference(entry: InferenceEntry) -> None:
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
inputs = [payload]
for dep in model.depends:
try:
inputs.append(outputs[dep])
except KeyError:
message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
raise HTTPException(400, message)
model = await load(model)
output = await run(model.predict, *inputs, **entry["options"])
outputs[model.identity] = output
response[entry["task"]] = output

without_deps, with_deps = entries
await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
if with_deps:
await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
if isinstance(payload, Image):
response["imageHeight"], response["imageWidth"] = payload.height, payload.width

return response


async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
if thread_pool is None:
return func(inputs)
return await asyncio.get_running_loop().run_in_executor(thread_pool, func, inputs)
return func(*args, **kwargs)
partial_func = partial(func, *args, **kwargs)
return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)


async def load(model: InferenceModel) -> InferenceModel:
if model.loaded:
return model

def _load(model: InferenceModel) -> None:
def _load(model: InferenceModel) -> InferenceModel:
with lock:
model.load()
return model

try:
await run(_load, model)
Expand Down
56 changes: 36 additions & 20 deletions machine-learning/app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
from typing import Any

from app.schemas import ModelType

from .base import InferenceModel
from .clip import MCLIPEncoder, OpenCLIPEncoder
from .constants import is_insightface, is_mclip, is_openclip
from .facial_recognition import FaceRecognizer


def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
match model_type:
case ModelType.CLIP:
if is_openclip(model_name):
return OpenCLIPEncoder(model_name, **model_kwargs)
elif is_mclip(model_name):
return MCLIPEncoder(model_name, **model_kwargs)
case ModelType.FACIAL_RECOGNITION:
if is_insightface(model_name):
return FaceRecognizer(model_name, **model_kwargs)
from app.models.base import InferenceModel
from app.models.clip.textual import MClipTextualEncoder, OpenClipTextualEncoder
from app.models.clip.visual import OpenClipVisualEncoder
from app.schemas import ModelSource, ModelTask, ModelType

from .constants import get_model_source
from .facial_recognition.detection import FaceDetector
from .facial_recognition.recognition import FaceRecognizer


def get_model_class(model_name: str, model_type: ModelType, model_task: ModelTask) -> type[InferenceModel]:
source = get_model_source(model_name)
match source, model_type, model_task:
case ModelSource.OPENCLIP | ModelSource.MCLIP, ModelType.VISUAL, ModelTask.SEARCH:
return OpenClipVisualEncoder

case ModelSource.OPENCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
return OpenClipTextualEncoder

case ModelSource.MCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
return MClipTextualEncoder

case ModelSource.INSIGHTFACE, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION:
return FaceDetector

case ModelSource.INSIGHTFACE, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION:
return FaceRecognizer

case _:
raise ValueError(f"Unknown model type {model_type}")
raise ValueError(f"Unknown model combination: {source}, {model_type}, {model_task}")


def from_model_type(model_name: str, model_type: ModelType, model_task: ModelTask, **kwargs: Any) -> InferenceModel:
return get_model_class(model_name, model_type, model_task)(model_name, **kwargs)


raise ValueError(f"Unknown {model_type} model {model_name}")
def get_model_deps(model_name: str, model_type: ModelType, model_task: ModelTask) -> list[tuple[ModelType, ModelTask]]:
return get_model_class(model_name, model_type, model_task).depends
Loading

0 comments on commit 2b1b43a

Please sign in to comment.