diff --git a/api/app.py b/api/app.py index 4673736..090d369 100644 --- a/api/app.py +++ b/api/app.py @@ -333,7 +333,7 @@ def sendStatus(): image_decoder = getFromUrl if is_url else decodeBase64Image textual_inversions = call_inputs.get("textual_inversions", []) - handle_textual_inversions(textual_inversions, model) + await handle_textual_inversions(textual_inversions, model, status=status) # Better to use new lora_weights in next section attn_procs = call_inputs.get("attn_procs", None) @@ -376,35 +376,42 @@ def sendStatus(): # and array too in anticipation of multi-LoRA support in diffusers # tracked at https://github.com/huggingface/diffusers/issues/2613. lora_weights = call_inputs.get("lora_weights", None) - if lora_weights is not last_lora_weights: - last_lora_weights = lora_weights - if lora_weights: - pipeline.unet.set_attn_processor(CrossAttnProcessor()) - storage = Storage(lora_weights, no_raise=True) - if storage: - storage_query_fname = storage.query.get("fname") - if storage_query_fname: - fname = storage_query_fname[0] + lora_weights_joined = json.dumps(lora_weights) + if last_lora_weights != lora_weights_joined: + last_lora_weights = lora_weights_joined + print("Unloading previous LoRA weights") + pipeline.unet.set_attn_processor(CrossAttnProcessor()) + # pipeline.unload_lora_weights() + + if type(lora_weights) is not list: + lora_weights = [lora_weights] + + if len(lora_weights) > 0: + for weights in lora_weights: + storage = Storage(weights, no_raise=True, status=status) + if storage: + storage_query_fname = storage.query.get("fname") + if storage_query_fname: + fname = storage_query_fname[0] + else: + hash = sha256(weights.encode("utf-8")).hexdigest() + fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop() + cache_fname = "lora_weights--" + fname + path = os.path.join(MODELS_DIR, cache_fname) + if not os.path.exists(path): + await asyncio.to_thread(storage.download_file, path) + print("Load lora_weights `" + weights + "` from `" + path + "`") + pipeline.load_lora_weights( + MODELS_DIR, weight_name=cache_fname, local_files_only=True + ) else: - hash = sha256(lora_weights.encode("utf-8")).hexdigest() - fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop() - cache_fname = "lora_weights--" + fname - path = os.path.join(MODELS_DIR, cache_fname) - if not os.path.exists(path): - storage.download_and_extract(path, status=status) - print("Load lora_weights `" + lora_weights + "` from `" + path + "`") - pipeline.load_lora_weights( - MODELS_DIR, weight_name=cache_fname, local_files_only=True - ) - else: - print("Loading from huggingface not supported yet: " + lora_weights) - # maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s? - # lora_model_id = "sayakpaul/civitai-light-shadow-lora" - # lora_filename = "light_and_shadow.safetensors" - # pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) - else: - print("Clearing attn procs") - pipeline.unet.set_attn_processor(CrossAttnProcessor()) + print("Loading from huggingface not supported yet: " + weights) + # maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s? + # lora_model_id = "sayakpaul/civitai-light-shadow-lora" + # lora_filename = "light_and_shadow.safetensors" + # pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) + else: + print("No changes to LoRAs since last call") # TODO, generalize cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None) diff --git a/api/lib/textual_inversions.py b/api/lib/textual_inversions.py index 2e2ef6a..3a06a0c 100644 --- a/api/lib/textual_inversions.py +++ b/api/lib/textual_inversions.py @@ -1,6 +1,7 @@ import json import re import os +import asyncio from utils import Storage from .vars import MODELS_DIR @@ -15,7 +16,7 @@ def strMap(str: str): match = re.search(tokenRe, str) - print(match) + # print(match) if match: return match.group("token") or match.group("fname") @@ -24,17 +25,17 @@ def extract_tokens_from_list(textual_inversions: list): return list(map(strMap, textual_inversions)) -def handle_textual_inversions(textual_inversions: list, model): +async def handle_textual_inversions(textual_inversions: list, model, status): global last_textual_inversions global last_textual_inversion_model global loaded_textual_inversion_tokens textual_inversions_str = json.dumps(textual_inversions) if ( - textual_inversions_str is not last_textual_inversions + textual_inversions_str != last_textual_inversions or model is not last_textual_inversion_model ): - if (model is not last_textual_inversion_model): + if model is not last_textual_inversion_model: loaded_textual_inversion_tokens = [] last_textual_inversion_model = model # print({"textual_inversions": textual_inversions}) @@ -53,7 +54,7 @@ def handle_textual_inversions(textual_inversions: list, model): last_textual_inversions = textual_inversions_str for textual_inversion in textual_inversions: - storage = Storage(textual_inversion, no_raise=True) + storage = Storage(textual_inversion, no_raise=True, status=status) if storage: storage_query_fname = storage.query.get("fname") if storage_query_fname: @@ -62,7 +63,7 @@ def handle_textual_inversions(textual_inversions: list, model): fname = textual_inversion.split("/").pop() path = os.path.join(MODELS_DIR, "textual_inversion--" + fname) if not os.path.exists(path): - storage.download_file(path) + await asyncio.to_thread(storage.download_file, path) print("Load textual inversion " + path) token = storage.query.get("token", None) if token not in loaded_textual_inversion_tokens: @@ -73,3 +74,5 @@ def handle_textual_inversions(textual_inversions: list, model): else: print("Load textual inversion " + textual_inversion) model.load_textual_inversion(textual_inversion) + else: + print("No changes to textual inversions since last call")