Skip to content

Commit

Permalink
feat(build): separate MODEL_REVISION, MODEL_PRECISION, HF_MODEL_ID
Browse files Browse the repository at this point in the history
The above already worked for runtime downloads, but can now be used
as build-args to download the image at build time and include in your
image.

SOME VERY IMPORTANT NOTES:

1) MODEL_REVISION no longer defaults to MODEL_PRECISION, you need to
specify it separately (however, still defaults to "fp16" in Dockerfile).
You'll get a warning if you specify MODEL_PRECISION without _REVISION,
to help in the most common case of the old behaviour.

2) build-arg PRECISION still works but has been deprecated for
MODEL_PRECISION (which was already the call-arg name).

3) normalized_model_id uses MODEL_REVISION, so for the early birds
already using the "cloud cache" on S3, your filenames might no
longer match in some cases.
  • Loading branch information
gadicc committed Jan 4, 2023
1 parent 0f37a4e commit fa9dd16
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 40 deletions.
17 changes: 11 additions & 6 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import requests
from download import download_model, normalize_model_id
import traceback
from precision import PRECISION
from precision import MODEL_REVISION, MODEL_PRECISION

RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
Expand Down Expand Up @@ -72,14 +72,19 @@ def init():

if not RUNTIME_DOWNLOADS:
# Uh doesn't this break non-cached images? TODO... IMAGE_CACHE
normalized_model_id = normalize_model_id(MODEL_ID, PRECISION)
normalized_model_id = normalize_model_id(MODEL_ID, MODEL_REVISION)
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
if os.path.isdir(model_dir):
always_normalize_model_id = model_dir
else:
normalized_model_id = MODEL_ID

model = loadModel(model_dir, True, PRECISION)
model = loadModel(
model_id = model_dir,
load=True,
precision=MODEL_PRECISION,
revision=MODEL_REVISION,
)
else:
model = None

Expand Down Expand Up @@ -156,7 +161,7 @@ def inference(all_inputs: dict) -> dict:
model_precision = call_inputs.get("MODEL_PRECISION", None)
checkpoint_url = call_inputs.get("CHECKPOINT_URL", None)
checkpoint_config_url = call_inputs.get("CHECKPOINT_CONFIG_URL", None)
normalized_model_id = normalize_model_id(model_id, model_precision)
normalized_model_id = normalize_model_id(model_id, model_revision)
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
if last_model_id != normalized_model_id:
# if not downloaded_models.get(normalized_model_id, None):
Expand All @@ -172,7 +177,7 @@ def inference(all_inputs: dict) -> dict:
download_model(
model_id=model_id,
model_url=model_url,
model_revision=model_revision or model_precision,
model_revision=model_revision,
checkpoint_url=checkpoint_url,
checkpoint_config_url=checkpoint_config_url,
hf_model_id=hf_model_id,
Expand All @@ -182,7 +187,7 @@ def inference(all_inputs: dict) -> dict:
clearPipelines()
if model:
model.to("cpu") # Necessary to avoid a memory leak
model = loadModel(normalized_model_id, True, model_precision)
model = loadModel(model_id=normalized_model_id, load=True, precision=model_precision)
last_model_id = normalized_model_id
else:
if always_normalize_model_id:
Expand Down
54 changes: 34 additions & 20 deletions api/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
from loadModel import loadModel, MODEL_IDS
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from precision import PRECISION, revision_from_precision, torch_dtype_from_precision
from utils import Storage
import subprocess
from pathlib import Path
import shutil
from convert_to_diffusers import main as convert_to_diffusers
from download_checkpoint import main as download_checkpoint

MODEL_ID = os.environ.get("MODEL_ID")
MODEL_URL = os.environ.get("MODEL_URL")
USE_DREAMBOOTH = os.environ.get("USE_DREAMBOOTH")
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
RUNTIME_DOWNLOADS = os.environ.get("RUNTIME_DOWNLOADS")

HOME = os.path.expanduser("~")
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -55,19 +53,17 @@ def download_model(
"hf_model_id": hf_model_id,
},
)
url = model_url or MODEL_URL
hf_model_id = hf_model_id or model_id
model_revision = model_revision or revision_from_precision()
normalized_model_id = id
normalized_model_id = model_id

if url != "":
normalized_model_id = normalize_model_id(model_id, model_precision)
if model_url != "":
normalized_model_id = normalize_model_id(model_id, model_revision)
print({"normalized_model_id": normalized_model_id})
filename = url.split("/").pop()
filename = model_url.split("/").pop()
if not filename:
filename = normalized_model_id + ".tar.zst"
model_file = os.path.join(MODELS_DIR, filename)
storage = Storage(url, default_path=normalized_model_id + ".tar.zst")
storage = Storage(model_url, default_path=normalized_model_id + ".tar.zst")
exists = storage.file_exists()
if exists:
storage.download_file(model_file)
Expand Down Expand Up @@ -98,16 +94,28 @@ def download_model(
)
else:
print("Does not exist, let's try find it on huggingface")
print({"model_precision": model_precision, "model_revision": model_revision})
print(
{
"model_precision": model_precision,
"model_revision": model_revision,
}
)
# This would be quicker to just model.to("cuda") afterwards, but
# this conveniently logs all the timings (and doesn't happen often)
print("download")
send("download", "start", {})
model = loadModel(hf_model_id, False, precision=model_precision, revision=model_revision) # download
model = loadModel(
hf_model_id,
False,
precision=model_precision,
revision=model_revision,
) # download
send("download", "done", {})

print("load")
model = loadModel(hf_model_id, True, precision=model_precision, revision=model_revision) # load
model = loadModel(
hf_model_id, True, precision=model_precision, revision=model_revision
) # load
# dir = "models--" + model_id.replace("/", "--") + "--dda"
dir = os.path.join(MODELS_DIR, normalized_model_id)
model.save_pretrained(dir, safe_serialization=True)
Expand Down Expand Up @@ -137,12 +145,12 @@ def download_model(
return

# do a dry run of loading the huggingface model, which will download weights at build time
# For local dev & preview deploys, download all the models (terrible for serverless deploys)
if MODEL_ID == "ALL":
for MODEL_I in MODEL_IDS:
loadModel(MODEL_I, False, precision=model_revision)
else:
loadModel(normalized_model_id, False, precision=model_revision)
loadModel(
model_id=normalized_model_id,
load=False,
precision=model_precision,
revision=model_revision,
)

# if USE_DREAMBOOTH:
# Actually we can re-use these from the above loaded model
Expand All @@ -164,4 +172,10 @@ def download_model(


if __name__ == "__main__":
download_model("", MODEL_ID, PRECISION)
download_model(
model_url=os.environ.get("MODEL_URL"),
model_id=os.environ.get("MODEL_ID"),
hf_model_id=os.environ.get("HF_MODEL_ID"),
model_revision=os.environ.get("MODEL_REVISION"),
model_precision=os.environ.get("MODEL_PRECISION"),
)
5 changes: 2 additions & 3 deletions api/getPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
DiffusionPipeline,
pipelines as diffusers_pipelines,
)
from precision import revision, torch_dtype

HOME = os.path.expanduser("~")
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
Expand Down Expand Up @@ -83,8 +82,8 @@ def getPipelineForModel(pipeline_name: str, model, model_id):

pipeline = DiffusionPipeline.from_pretrained(
model_dir or model_id,
revision=revision,
torch_dtype=torch_dtype,
# revision=revision,
# torch_dtype=torch_dtype,
custom_pipeline="./diffusers/examples/community/" + pipeline_name + ".py",
local_files_only=True,
**model.components,
Expand Down
13 changes: 10 additions & 3 deletions api/loadModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from diffusers import pipelines as _pipelines, StableDiffusionPipeline
from getScheduler import getScheduler, DEFAULT_SCHEDULER
from precision import revision_from_precision, torch_dtype_from_precision
from precision import torch_dtype_from_precision
import time

HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
Expand All @@ -25,9 +25,16 @@


def loadModel(model_id: str, load=True, precision=None, revision=None):
revision = revision or revision_from_precision(precision)
torch_dtype = torch_dtype_from_precision(precision)
print("loadModel", {"model_id": model_id, "load": load, "precision": precision, "revision": revision})
print(
"loadModel",
{
"model_id": model_id,
"load": load,
"precision": precision,
"revision": revision,
},
)
print(
("Loading" if load else "Downloading")
+ " model: "
Expand Down
27 changes: 21 additions & 6 deletions api/precision.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import os
import torch

PRECISION = os.getenv("PRECISION")
DEPRECATED_PRECISION = os.getenv("PRECISION")
MODEL_PRECISION = os.getenv("MODEL_PRECISION") or DEPRECATED_PRECISION
MODEL_REVISION = os.getenv("MODEL_REVISION")

revision = None if PRECISION == "" else PRECISION
torch_dtype = None if PRECISION == "" else torch.float16
if DEPRECATED_PRECISION:
print("Warning: PRECISION variable been deprecated and renamed MODEL_PRECISION")
print("Your setup still works but in a future release, this will throw an error")

if MODEL_PRECISION and not MODEL_REVISION:
print("Warning: we no longer default to MODEL_REVISION=MODEL_PRECISION, please")
print(f'explicitly set MODEL_REVISION="{MODEL_PRECISION}" if that\'s what you')
print("want.")

def revision_from_precision(precision=PRECISION):
return precision if precision else None

def revision_from_precision(precision=MODEL_PRECISION):
# return precision if precision else None
raise Exception("revision_from_precision no longer supported")

def torch_dtype_from_precision(precision=PRECISION):

def torch_dtype_from_precision(precision=MODEL_PRECISION):
if precision == "fp16":
return torch.float16
return None


def torch_dtype_from_precision(precision=MODEL_PRECISION):
if precision == "fp16":
return torch.float16
return None
3 changes: 1 addition & 2 deletions api/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from transformers import CLIPTextModel, CLIPTokenizer

# DDA
from precision import revision, torch_dtype
from send import send, get_now
from utils import Storage
import subprocess
Expand All @@ -55,7 +54,7 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs):
params = {
# Defaults
"pretrained_model_name_or_path": model_id, # DDA, TODO
"revision": revision, # DDA, was: None
"revision": None,
"tokenizer_name": None,
"instance_data_dir": "instance_data_dir", # DDA TODO
"class_data_dir": "class_data_dir", # DDA, was: None,
Expand Down

0 comments on commit fa9dd16

Please sign in to comment.