Skip to content

Commit

Permalink
fix(addons): async TI download status, LoRA improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Jul 30, 2023
1 parent 0c55bb8 commit de8cfdc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 35 deletions.
65 changes: 36 additions & 29 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def sendStatus():
image_decoder = getFromUrl if is_url else decodeBase64Image

textual_inversions = call_inputs.get("textual_inversions", [])
handle_textual_inversions(textual_inversions, model)
await handle_textual_inversions(textual_inversions, model, status=status)

# Better to use new lora_weights in next section
attn_procs = call_inputs.get("attn_procs", None)
Expand Down Expand Up @@ -376,35 +376,42 @@ def sendStatus():
# and array too in anticipation of multi-LoRA support in diffusers
# tracked at https://github.com/huggingface/diffusers/issues/2613.
lora_weights = call_inputs.get("lora_weights", None)
if lora_weights is not last_lora_weights:
last_lora_weights = lora_weights
if lora_weights:
pipeline.unet.set_attn_processor(CrossAttnProcessor())
storage = Storage(lora_weights, no_raise=True)
if storage:
storage_query_fname = storage.query.get("fname")
if storage_query_fname:
fname = storage_query_fname[0]
lora_weights_joined = json.dumps(lora_weights)
if last_lora_weights != lora_weights_joined:
last_lora_weights = lora_weights_joined
print("Unloading previous LoRA weights")
pipeline.unet.set_attn_processor(CrossAttnProcessor())
# pipeline.unload_lora_weights()

if type(lora_weights) is not list:
lora_weights = [lora_weights]

if len(lora_weights) > 0:
for weights in lora_weights:
storage = Storage(weights, no_raise=True, status=status)
if storage:
storage_query_fname = storage.query.get("fname")
if storage_query_fname:
fname = storage_query_fname[0]
else:
hash = sha256(weights.encode("utf-8")).hexdigest()
fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop()
cache_fname = "lora_weights--" + fname
path = os.path.join(MODELS_DIR, cache_fname)
if not os.path.exists(path):
await asyncio.to_thread(storage.download_file, path)
print("Load lora_weights `" + weights + "` from `" + path + "`")
pipeline.load_lora_weights(
MODELS_DIR, weight_name=cache_fname, local_files_only=True
)
else:
hash = sha256(lora_weights.encode("utf-8")).hexdigest()
fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop()
cache_fname = "lora_weights--" + fname
path = os.path.join(MODELS_DIR, cache_fname)
if not os.path.exists(path):
storage.download_and_extract(path, status=status)
print("Load lora_weights `" + lora_weights + "` from `" + path + "`")
pipeline.load_lora_weights(
MODELS_DIR, weight_name=cache_fname, local_files_only=True
)
else:
print("Loading from huggingface not supported yet: " + lora_weights)
# maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s?
# lora_model_id = "sayakpaul/civitai-light-shadow-lora"
# lora_filename = "light_and_shadow.safetensors"
# pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
else:
print("Clearing attn procs")
pipeline.unet.set_attn_processor(CrossAttnProcessor())
print("Loading from huggingface not supported yet: " + weights)
# maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s?
# lora_model_id = "sayakpaul/civitai-light-shadow-lora"
# lora_filename = "light_and_shadow.safetensors"
# pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
else:
print("No changes to LoRAs since last call")

# TODO, generalize
cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
Expand Down
15 changes: 9 additions & 6 deletions api/lib/textual_inversions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import re
import os
import asyncio
from utils import Storage
from .vars import MODELS_DIR

Expand All @@ -15,7 +16,7 @@

def strMap(str: str):
match = re.search(tokenRe, str)
print(match)
# print(match)
if match:
return match.group("token") or match.group("fname")

Expand All @@ -24,17 +25,17 @@ def extract_tokens_from_list(textual_inversions: list):
return list(map(strMap, textual_inversions))


def handle_textual_inversions(textual_inversions: list, model):
async def handle_textual_inversions(textual_inversions: list, model, status):
global last_textual_inversions
global last_textual_inversion_model
global loaded_textual_inversion_tokens

textual_inversions_str = json.dumps(textual_inversions)
if (
textual_inversions_str is not last_textual_inversions
textual_inversions_str != last_textual_inversions
or model is not last_textual_inversion_model
):
if (model is not last_textual_inversion_model):
if model is not last_textual_inversion_model:
loaded_textual_inversion_tokens = []
last_textual_inversion_model = model
# print({"textual_inversions": textual_inversions})
Expand All @@ -53,7 +54,7 @@ def handle_textual_inversions(textual_inversions: list, model):

last_textual_inversions = textual_inversions_str
for textual_inversion in textual_inversions:
storage = Storage(textual_inversion, no_raise=True)
storage = Storage(textual_inversion, no_raise=True, status=status)
if storage:
storage_query_fname = storage.query.get("fname")
if storage_query_fname:
Expand All @@ -62,7 +63,7 @@ def handle_textual_inversions(textual_inversions: list, model):
fname = textual_inversion.split("/").pop()
path = os.path.join(MODELS_DIR, "textual_inversion--" + fname)
if not os.path.exists(path):
storage.download_file(path)
await asyncio.to_thread(storage.download_file, path)
print("Load textual inversion " + path)
token = storage.query.get("token", None)
if token not in loaded_textual_inversion_tokens:
Expand All @@ -73,3 +74,5 @@ def handle_textual_inversions(textual_inversions: list, model):
else:
print("Load textual inversion " + textual_inversion)
model.load_textual_inversion(textual_inversion)
else:
print("No changes to textual inversions since last call")

0 comments on commit de8cfdc

Please sign in to comment.