Skip to content

Commit

Permalink
fix(build-download): support regular HF download not just cloud cache
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Jan 5, 2023
1 parent 8248ba0 commit 52edf6b
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@


torch.set_grad_enabled(False)
always_normalize_model_id = None


class DummySafetyChecker:
Expand Down Expand Up @@ -71,7 +72,6 @@ def init():
last_model_id = None

if not RUNTIME_DOWNLOADS:
# Uh doesn't this break non-cached images? TODO... IMAGE_CACHE
normalized_model_id = normalize_model_id(MODEL_ID, MODEL_REVISION)
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
if os.path.isdir(model_dir):
Expand All @@ -80,7 +80,7 @@ def init():
normalized_model_id = MODEL_ID

model = loadModel(
model_id = model_dir,
model_id=always_normalize_model_id or MODEL_ID,
load=True,
precision=MODEL_PRECISION,
revision=MODEL_REVISION,
Expand Down Expand Up @@ -187,7 +187,9 @@ def inference(all_inputs: dict) -> dict:
clearPipelines()
if model:
model.to("cpu") # Necessary to avoid a memory leak
model = loadModel(model_id=normalized_model_id, load=True, precision=model_precision)
model = loadModel(
model_id=normalized_model_id, load=True, precision=model_precision
)
last_model_id = normalized_model_id
else:
if always_normalize_model_id:
Expand Down

0 comments on commit 52edf6b

Please sign in to comment.