Skip to content

Commit

Permalink
feat(downloads): allow separate MODEL_REVISION and MODEL_PRECISION
Browse files Browse the repository at this point in the history
TODO: allow same for builds
  • Loading branch information
gadicc committed Jan 4, 2023
1 parent adaa7f6 commit 6edc821
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
4 changes: 3 additions & 1 deletion api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def inference(all_inputs: dict) -> dict:

if RUNTIME_DOWNLOADS:
hf_model_id = call_inputs.get("HF_MODEL_ID", None)
model_revision = call_inputs.get("MODEL_REVISION", None)
model_precision = call_inputs.get("MODEL_PRECISION", None)
checkpoint_url = call_inputs.get("CHECKPOINT_URL", None)
checkpoint_config_url = call_inputs.get("CHECKPOINT_CONFIG_URL", None)
Expand All @@ -171,10 +172,11 @@ def inference(all_inputs: dict) -> dict:
download_model(
model_id=model_id,
model_url=model_url,
model_revision=model_precision,
model_revision=model_revision or model_precision,
checkpoint_url=checkpoint_url,
checkpoint_config_url=checkpoint_config_url,
hf_model_id=hf_model_id,
model_precision=model_precision,
)
# downloaded_models.update({normalized_model_id: True})
clearPipelines()
Expand Down
11 changes: 6 additions & 5 deletions api/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def download_model(
checkpoint_url=None,
checkpoint_config_url=None,
hf_model_id=None,
model_precision=None,
):
print(
"download_model",
Expand All @@ -56,11 +57,11 @@ def download_model(
)
url = model_url or MODEL_URL
hf_model_id = hf_model_id or model_id
revision = model_revision or revision_from_precision()
model_revision = model_revision or revision_from_precision()
normalized_model_id = id

if url != "":
normalized_model_id = normalize_model_id(model_id, model_revision)
normalized_model_id = normalize_model_id(model_id, model_precision)
print({"normalized_model_id": normalized_model_id})
filename = url.split("/").pop()
if not filename:
Expand Down Expand Up @@ -97,16 +98,16 @@ def download_model(
)
else:
print("Does not exist, let's try find it on huggingface")
print("precision = ", {"model_revision": model_revision})
print({"model_precision": model_precision, "model_revision": model_revision})
# This would be quicker to just model.to("cuda") afterwards, but
# this conveniently logs all the timings (and doesn't happen often)
print("download")
send("download", "start", {})
model = loadModel(hf_model_id, False, precision=model_revision) # download
model = loadModel(hf_model_id, False, precision=model_precision, revision=model_revision) # download
send("download", "done", {})

print("load")
model = loadModel(hf_model_id, True, precision=model_revision) # load
model = loadModel(hf_model_id, True, precision=model_precision, revision=model_revision) # load
# dir = "models--" + model_id.replace("/", "--") + "--dda"
dir = os.path.join(MODELS_DIR, normalized_model_id)
model.save_pretrained(dir, safe_serialization=True)
Expand Down
6 changes: 3 additions & 3 deletions api/loadModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
]


def loadModel(model_id: str, load=True, precision=None):
print("loadModel", {"model_id": model_id, "load": load, "precision": precision})
revision = revision_from_precision(precision)
def loadModel(model_id: str, load=True, precision=None, revision=None):
revision = revision or revision_from_precision(precision)
torch_dtype = torch_dtype_from_precision(precision)
print("loadModel", {"model_id": model_id, "load": load, "precision": precision, "revision": revision})
print(
("Loading" if load else "Downloading")
+ " model: "
Expand Down

0 comments on commit 6edc821

Please sign in to comment.