Skip to content

Commit

Permalink
fix(misc): fix failing tests, pipeline init in rare circumstances
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Sep 4, 2023
1 parent 3f1f980 commit 9338648
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
13 changes: 9 additions & 4 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
import skimage
import skimage.measure
from getScheduler import getScheduler, SCHEDULERS
from getPipeline import getPipelineForModel, listAvailablePipelines, clearPipelines
from getPipeline import (
getPipelineClass,
getPipelineForModel,
listAvailablePipelines,
clearPipelines,
)
import re
import requests
from download import download_model, normalize_model_id
Expand Down Expand Up @@ -228,7 +233,7 @@ def sendStatus():
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
pipeline_name = call_inputs.get("PIPELINE", None)
if pipeline_name:
pipeline_class = getattr(diffusers_pipelines, pipeline_name)
pipeline_class = getPipelineClass(pipeline_name)
if last_model_id != normalized_model_id:
# if not downloaded_models.get(normalized_model_id, None):
if not os.path.isdir(model_dir):
Expand All @@ -250,7 +255,7 @@ def sendStatus():
hf_model_id=hf_model_id,
model_precision=model_precision,
send_opts=send_opts,
pipeline_class=pipeline_class,
pipeline_class=pipeline_class if pipeline_name else None,
)
# downloaded_models.update({normalized_model_id: True})
clearPipelines()
Expand All @@ -267,7 +272,7 @@ def sendStatus():
precision=model_precision,
revision=model_revision,
send_opts=send_opts,
pipeline_class=pipeline_class,
pipeline_class=pipeline_class if pipeline_name else None,
)
await send(
"loadModel", "done", {"startRequestId": startRequestId}, send_opts
Expand Down
7 changes: 7 additions & 0 deletions api/getPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def clearPipelines():
_pipelines = {}


def getPipelineClass(pipeline_name: str):
if hasattr(diffusers_pipelines, pipeline_name):
return getattr(diffusers_pipelines, pipeline_name)
elif pipeline_name in availableCommunityPipelines():
return DiffusionPipeline


def getPipelineForModel(
pipeline_name: str, model, model_id, model_revision, model_precision
):
Expand Down
13 changes: 9 additions & 4 deletions api/loadModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def loadModel(
precision=None,
revision=None,
send_opts={},
pipeline_class=AutoPipelineForText2Image,
pipeline_class=None,
):
torch_dtype = torch_dtype_from_precision(precision)
if revision == "":
Expand All @@ -44,18 +44,23 @@ def loadModel(
"load": load,
"precision": precision,
"revision": revision,
"pipeline_class": pipeline_class,
},
)

if not pipeline_class:
pipeline_class = AutoPipelineForText2Image

pipeline = pipeline_class if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
print("pipeline", pipeline_class)

print(
("Loading" if load else "Downloading")
+ " model: "
+ model_id
+ (f" ({revision})" if revision else "")
)

pipeline = pipeline_class if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
print("pipeline", pipeline_class)

scheduler = getScheduler(model_id, DEFAULT_SCHEDULER, not load)

model_dir = os.path.join(MODELS_DIR, model_id)
Expand Down

0 comments on commit 9338648

Please sign in to comment.