Skip to content

Commit

Permalink
feat(send): set / override SEND_URL, SIGN_KEY via callInputs
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Feb 13, 2023
1 parent 2279de1 commit 74b4c53
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 34 deletions.
27 changes: 21 additions & 6 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def inference(all_inputs: dict) -> dict:
call_inputs = all_inputs.get("callInputs", None)
result = {"$meta": {}}

send_opts = {}
if call_inputs.get("SEND_URL", None):
send_opts.update({"SEND_URL": call_inputs.get("SEND_URL")})
if call_inputs.get("SIGN_KEY", None):
send_opts.update({"SIGN_KEY": call_inputs.get("SIGN_KEY")})

if model_inputs == None or call_inputs == None:
return {
"$error": {
Expand Down Expand Up @@ -185,13 +191,17 @@ def inference(all_inputs: dict) -> dict:
checkpoint_config_url=checkpoint_config_url,
hf_model_id=hf_model_id,
model_precision=model_precision,
send_opts=send_opts,
)
# downloaded_models.update({normalized_model_id: True})
clearPipelines()
if model:
model.to("cpu") # Necessary to avoid a memory leak
model = loadModel(
model_id=normalized_model_id, load=True, precision=model_precision
model_id=normalized_model_id,
load=True,
precision=model_precision,
send_opts=send_opts,
)
last_model_id = normalized_model_id
else:
Expand All @@ -207,7 +217,7 @@ def inference(all_inputs: dict) -> dict:
if MODEL_ID == "ALL":
if last_model_id != normalized_model_id:
clearPipelines()
model = loadModel(normalized_model_id)
model = loadModel(normalized_model_id, send_opts=send_opts)
last_model_id = normalized_model_id
else:
if model_id != MODEL_ID and not RUNTIME_DOWNLOADS:
Expand Down Expand Up @@ -295,7 +305,7 @@ def inference(all_inputs: dict) -> dict:
)
)

send("inference", "start", {"startRequestId": startRequestId})
send("inference", "start", {"startRequestId": startRequestId}, send_opts)

# Run patchmatch for inpainting
if call_inputs.get("FILL_MODE", None) == "patchmatch":
Expand Down Expand Up @@ -349,10 +359,14 @@ def inference(all_inputs: dict) -> dict:

torch.set_grad_enabled(True)
result = result | TrainDreamBooth(
normalized_model_id, pipeline, model_inputs, call_inputs
normalized_model_id,
pipeline,
model_inputs,
call_inputs,
send_opts=send_opts,
)
torch.set_grad_enabled(False)
send("inference", "done", {"startRequestId": startRequestId})
send("inference", "done", {"startRequestId": startRequestId}, send_opts)
result.update({"$timings": getTimings()})
return result

Expand All @@ -375,6 +389,7 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
"inference",
"progress",
{"startRequestId": startRequestId, "step": step},
send_opts,
)

with torch.inference_mode():
Expand Down Expand Up @@ -407,7 +422,7 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
image.save(buffered, format="PNG")
images_base64.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))

send("inference", "done", {"startRequestId": startRequestId})
send("inference", "done", {"startRequestId": startRequestId}, send_opts)

# Return the results as a dictionary
if len(images_base64) > 1:
Expand Down
17 changes: 9 additions & 8 deletions api/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)

# i.e. don't run during build
def send(type: str, status: str, payload: dict = {}):
def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
if RUNTIME_DOWNLOADS:
from send import send as _send

_send(type, status, payload)
_send(type, status, payload, send_opts)


def normalize_model_id(model_id: str, model_revision):
Expand All @@ -43,6 +43,7 @@ def download_model(
checkpoint_config_url=None,
hf_model_id=None,
model_precision=None,
send_opts={},
):
print(
"download_model",
Expand Down Expand Up @@ -104,14 +105,14 @@ def download_model(
# This would be quicker to just model.to(device) afterwards, but
# this conveniently logs all the timings (and doesn't happen often)
print("download")
send("download", "start", {})
send("download", "start", {}, send_opts)
model = loadModel(
hf_model_id,
False,
precision=model_precision,
revision=model_revision,
) # download
send("download", "done", {})
send("download", "done", {}, send_opts)

print("load")
model = loadModel(
Expand All @@ -122,19 +123,19 @@ def download_model(
model.save_pretrained(dir, safe_serialization=True)

# This is all duped from train_dreambooth, need to refactor TODO XXX
send("compress", "start", {})
send("compress", "start", {}, send_opts)
subprocess.run(
f"tar cvf - -C {dir} . | zstd -o {model_file}",
shell=True,
check=True, # TODO, rather don't raise and return an error in JSON
)

send("compress", "done")
send("compress", "done", {}, send_opts)
subprocess.run(["ls", "-l", model_file])

send("upload", "start", {})
send("upload", "start", {}, send_opts)
upload_result = storage.upload_file(model_file, filename)
send("upload", "done")
send("upload", "done", {}, send_opts)
print(upload_result)
os.remove(model_file)

Expand Down
2 changes: 1 addition & 1 deletion api/loadModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
]


def loadModel(model_id: str, load=True, precision=None, revision=None):
def loadModel(model_id: str, load=True, precision=None, revision=None, send_opts={}):
torch_dtype = torch_dtype_from_precision(precision)
if revision == "":
revision = None
Expand Down
18 changes: 10 additions & 8 deletions api/send.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def get_now():
return round(time.time() * 1000)


send_url = os.getenv("SEND_URL")
if send_url == "":
send_url = None
SEND_URL = os.getenv("SEND_URL")
if SEND_URL == "":
SEND_URL = None

sign_key = os.getenv("SIGN_KEY")
if sign_key == "":
sign_key = None
SIGN_KEY = os.getenv("SIGN_KEY", "")
if SIGN_KEY == "":
SIGN_KEY = None

futureSession = FuturesSession()

Expand Down Expand Up @@ -70,8 +70,10 @@ def getTimings():
return timings


def send(type: str, status: str, payload: dict = {}):
def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
now = get_now()
send_url = opts.get("SEND_URL", SEND_URL)
sign_key = opts.get("SIGN_KEY", SIGN_KEY)

if status == "start":
session.update({type: {"start": now, "last_time": now}})
Expand All @@ -90,7 +92,7 @@ def send(type: str, status: str, payload: dict = {}):
"payload": payload,
}

if send_url:
if send_url and sign_key:
input = json.dumps(data, separators=(",", ":")) + sign_key
sig = hashlib.md5(input.encode("utf-8")).hexdigest()
data["sig"] = sig
Expand Down
22 changes: 11 additions & 11 deletions api/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")


def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs):
def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs, send_opts):
# required inputs: instance_images instance_prompt

params = {
Expand Down Expand Up @@ -166,7 +166,7 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs):

subprocess.run(["ls", "-l", args.instance_data_dir])

result = result | main(args, pipeline)
result = result | main(args, pipeline, send_opts=send_opts)

dest_url = call_inputs.get("dest_url")
if dest_url:
Expand All @@ -179,7 +179,7 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs):
print(filename)

# fp16 model timings: zip 1m20s, tar+zstd 4s and a tiny bit smaller!
send("compress", "start", {})
send("compress", "start", {}, send_opts)

# TODO, steaming upload (turns out docker disk write is super slow)
subprocess.run(
Expand All @@ -188,12 +188,12 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs):
check=True, # TODO, rather don't raise and return an error in JSON
)

send("compress", "done")
send("compress", "done", {}, send_opts)
subprocess.run(["ls", "-l", filename])

send("upload", "start", {})
send("upload", "start", {}, send_opts)
upload_result = storage.upload_file(filename, filename)
send("upload", "done")
send("upload", "done", {}, send_opts)
print(upload_result)
os.remove(filename)

Expand Down Expand Up @@ -379,7 +379,7 @@ def get_full_repo_name(
return f"{organization}/{model_id}"


def main(args, init_pipeline):
def main(args, init_pipeline, send_opts):
logging_dir = Path(args.output_dir, args.logging_dir)

accelerator = Accelerator(
Expand Down Expand Up @@ -777,7 +777,7 @@ def main(args, init_pipeline):
progress_bar.set_description("Steps")

# DDA
send("training", "start", {})
send("training", "start", {}, send_opts)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
Expand Down Expand Up @@ -892,7 +892,7 @@ def main(args, init_pipeline):

# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
send("training", "done") # DDA
send("training", "done", {}, send_opts) # DDA
if accelerator.is_main_process:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
Expand All @@ -905,7 +905,7 @@ def main(args, init_pipeline):

if args.push_to_hub:
# DDA
send("upload", "start", {})
send("upload", "start", {}, send_opts)

repo.push_to_hub(
commit_message="End of training",
Expand All @@ -917,7 +917,7 @@ def main(args, init_pipeline):
)

# DDA
send("upload", "done")
send("upload", "done", {}, send_opts)

accelerator.end_training()

Expand Down

0 comments on commit 74b4c53

Please sign in to comment.