Skip to content

Commit

Permalink
fix(app): async fixes for download, train_dreambooth
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Apr 25, 2023
1 parent d1cd39e commit 0dcbd16
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 30 deletions.
27 changes: 15 additions & 12 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ def init():
global dummy_safety_checker
global always_normalize_model_id

send(
"init",
"start",
{
"device": device_name,
"hostname": os.getenv("HOSTNAME"),
"model_id": MODEL_ID,
"diffusers": __version__,
},
asyncio.run(
send(
"init",
"start",
{
"device": device_name,
"hostname": os.getenv("HOSTNAME"),
"model_id": MODEL_ID,
"diffusers": __version__,
},
)
)

dummy_safety_checker = DummySafetyChecker()
Expand All @@ -96,7 +98,7 @@ def init():
else:
model = None

send("init", "done")
asyncio.run(send("init", "done"))


def decodeBase64Image(imageStr: str, name: str) -> PIL.Image:
Expand Down Expand Up @@ -213,7 +215,7 @@ def sendStatus():
# }
# }
normalized_model_id = hf_model_id or model_id
download_model(
await download_model(
model_id=model_id,
model_url=model_url,
model_revision=model_revision,
Expand Down Expand Up @@ -426,7 +428,8 @@ def sendStatus():
normalized_model_id = model_dir

torch.set_grad_enabled(True)
result = result | TrainDreamBooth(
result = result | await asyncio.to_thread(
TrainDreamBooth,
normalized_model_id,
pipeline,
model_inputs,
Expand Down
37 changes: 20 additions & 17 deletions api/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from convert_to_diffusers import main as convert_to_diffusers
from download_checkpoint import main as download_checkpoint
from status import status
import asyncio

USE_DREAMBOOTH = os.environ.get("USE_DREAMBOOTH")
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
Expand All @@ -23,11 +24,11 @@


# i.e. don't run during build
def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
async def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
if RUNTIME_DOWNLOADS:
from send import send as _send

_send(type, status, payload, send_opts)
await _send(type, status, payload, send_opts)


def normalize_model_id(model_id: str, model_revision):
Expand All @@ -37,7 +38,7 @@ def normalize_model_id(model_id: str, model_revision):
return normalized_model_id


def download_model(
async def download_model(
model_url=None,
model_id=None,
model_revision=None,
Expand Down Expand Up @@ -109,14 +110,14 @@ def download_model(
# This would be quicker to just model.to(device) afterwards, but
# this conveniently logs all the timings (and doesn't happen often)
print("download")
send("download", "start", {}, send_opts)
await send("download", "start", {}, send_opts)
model = loadModel(
hf_model_id,
False,
precision=model_precision,
revision=model_revision,
) # download
send("download", "done", {}, send_opts)
await send("download", "done", {}, send_opts)

print("load")
model = loadModel(
Expand All @@ -127,19 +128,19 @@ def download_model(
model.save_pretrained(dir, safe_serialization=True)

# This is all duped from train_dreambooth, need to refactor TODO XXX
send("compress", "start", {}, send_opts)
await send("compress", "start", {}, send_opts)
subprocess.run(
f"tar cvf - -C {dir} . | zstd -o {model_file}",
shell=True,
check=True, # TODO, rather don't raise and return an error in JSON
)

send("compress", "done", {}, send_opts)
await send("compress", "done", {}, send_opts)
subprocess.run(["ls", "-l", model_file])

send("upload", "start", {}, send_opts)
await send("upload", "start", {}, send_opts)
upload_result = storage.upload_file(model_file, filename)
send("upload", "done", {}, send_opts)
await send("upload", "done", {}, send_opts)
print(upload_result)
os.remove(model_file)

Expand Down Expand Up @@ -185,12 +186,14 @@ def download_model(


if __name__ == "__main__":
download_model(
model_url=os.environ.get("MODEL_URL"),
model_id=os.environ.get("MODEL_ID"),
hf_model_id=os.environ.get("HF_MODEL_ID"),
model_revision=os.environ.get("MODEL_REVISION"),
model_precision=os.environ.get("MODEL_PRECISION"),
checkpoint_url=os.environ.get("CHECKPOINT_URL"),
checkpoint_config_url=os.environ.get("CHECKPOINT_CONFIG_URL"),
asyncio.run(
download_model(
model_url=os.environ.get("MODEL_URL"),
model_id=os.environ.get("MODEL_ID"),
hf_model_id=os.environ.get("HF_MODEL_ID"),
model_revision=os.environ.get("MODEL_REVISION"),
model_precision=os.environ.get("MODEL_PRECISION"),
checkpoint_url=os.environ.get("CHECKPOINT_URL"),
checkpoint_config_url=os.environ.get("CHECKPOINT_CONFIG_URL"),
)
)
7 changes: 6 additions & 1 deletion api/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,22 @@
from transformers import AutoTokenizer, PretrainedConfig

# DDA
from send import send, get_now
from send import send as _send
from utils import Storage
import subprocess
import re
import shutil
import asyncio

# Our original code in docker-diffusers-api:

HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")


def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
asyncio.run((_send(type, status, payload, send_opts)))


def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs, send_opts):
# required inputs: instance_images instance_prompt

Expand Down

0 comments on commit 0dcbd16

Please sign in to comment.