-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathmain.py
111 lines (80 loc) · 3.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
import os
from contextlib import asynccontextmanager
from app.routes import health
from fastapi import FastAPI
from fastapi.routing import APIRoute
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
config_logging()
app.include_router(health.router)
pipeline = os.environ["PIPELINE"]
model_id = os.environ["MODEL_ID"]
app.pipeline = load_pipeline(pipeline, model_id)
app.include_router(load_route(pipeline))
use_route_names_as_operation_ids(app)
logger.info(f"Started up with pipeline {app.pipeline}")
yield
logger.info("Shutting down")
def load_pipeline(pipeline: str, model_id: str) -> any:
match pipeline:
case "text-to-image":
from app.pipelines.text_to_image import TextToImagePipeline
return TextToImagePipeline(model_id)
case "image-to-image":
from app.pipelines.image_to_image import ImageToImagePipeline
return ImageToImagePipeline(model_id)
case "image-to-video":
from app.pipelines.image_to_video import ImageToVideoPipeline
return ImageToVideoPipeline(model_id)
case "audio-to-text":
from app.pipelines.audio_to_text import AudioToTextPipeline
return AudioToTextPipeline(model_id)
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
from app.pipelines.upscale import UpscalePipeline
return UpscalePipeline(model_id)
case "segment-anything-2":
from app.pipelines.segment_anything_2 import SegmentAnything2Pipeline
return SegmentAnything2Pipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
)
def load_route(pipeline: str) -> any:
match pipeline:
case "text-to-image":
from app.routes import text_to_image
return text_to_image.router
case "image-to-image":
from app.routes import image_to_image
return image_to_image.router
case "image-to-video":
from app.routes import image_to_video
return image_to_video.router
case "audio-to-text":
from app.routes import audio_to_text
return audio_to_text.router
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
from app.routes import upscale
return upscale.router
case "segment-anything-2":
from app.routes import segment_anything_2
return segment_anything_2.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")
def config_logging():
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
force=True,
)
def use_route_names_as_operation_ids(app: FastAPI) -> None:
for route in app.routes:
if isinstance(route, APIRoute):
route.operation_id = route.name
app = FastAPI(lifespan=lifespan)