Skip to content

Commit

Permalink
feat(loras): use load_lora_weights (works with A1111 files too)
Browse files Browse the repository at this point in the history
UNFINISHED.  Initial implementation works but needs more testing.
Also, let's from the get-go support an array of LoRAs (for when
diffusers allows multi loras in a future release).
  • Loading branch information
gadicc committed Jun 16, 2023
1 parent 4fe13ef commit 7a64846
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
43 changes: 43 additions & 0 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def truncateInputs(inputs: dict):

last_xformers_memory_efficient_attention = {}
last_attn_procs = None
last_lora_weights = None


# Inference is ran for every server call
Expand All @@ -143,6 +144,7 @@ async def inference(all_inputs: dict, response) -> dict:
global last_xformers_memory_efficient_attention
global always_normalize_model_id
global last_attn_procs
global last_lora_weights

clearSession()

Expand Down Expand Up @@ -244,6 +246,8 @@ def sendStatus():
"loadModel", "done", {"startRequestId": startRequestId}, send_opts
)
last_model_id = normalized_model_id
last_attn_procs = None
last_lora_weights = None
else:
if always_normalize_model_id:
normalized_model_id = always_normalize_model_id
Expand Down Expand Up @@ -312,8 +316,13 @@ def sendStatus():
is_url = call_inputs.get("is_url", False)
image_decoder = getFromUrl if is_url else decodeBase64Image

# Better to use new lora_weights in next section
attn_procs = call_inputs.get("attn_procs", None)
if attn_procs is not last_attn_procs:
print(
"[DEPRECATED] Using `attn_procs` for LoRAs is deprecated. "
+ "Please use `lora_weights` instead."
)
last_attn_procs = attn_procs
if attn_procs:
storage = Storage(attn_procs, no_raise=True)
Expand Down Expand Up @@ -344,6 +353,40 @@ def sendStatus():
print("Clearing attn procs")
pipeline.unet.set_attn_processor(CrossAttnProcessor())

# Currently we only support a single string, but we should allow
# and array too in anticipation of multi-LoRA support in diffusers
# tracked at https://github.com/huggingface/diffusers/issues/2613.
lora_weights = call_inputs.get("lora_weights", None)
if lora_weights is not last_lora_weights:
last_lora_weights = lora_weights
if lora_weights:
pipeline.unet.set_attn_processor(CrossAttnProcessor())
storage = Storage(lora_weights, no_raise=True)
if storage:
storage_query_fname = storage.query.get("fname")
if storage_query_fname:
fname = storage_query_fname[0]
else:
hash = sha256(lora_weights.encode("utf-8")).hexdigest()
fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop()
cache_fname = "lora_weights--" + fname
path = os.path.join(MODELS_DIR, cache_fname)
if not os.path.exists(path):
storage.download_and_extract(path)
print("Load lora_weights `" + lora_weights + "` from `" + path + "`")
pipeline.load_lora_weights(
MODELS_DIR, weight_name=cache_fname, local_files_only=True
)
else:
print("Loading from huggingface not supported yet: " + lora_weights)
# maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s?
# lora_model_id = "sayakpaul/civitai-light-shadow-lora"
# lora_filename = "light_and_shadow.safetensors"
# pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
else:
print("Clearing attn procs")
pipeline.unet.set_attn_processor(CrossAttnProcessor())

# TODO, generalize
cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
if isinstance(cross_attention_kwargs, str):
Expand Down
96 changes: 96 additions & 0 deletions tests/integration/test_loras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import sys
import os
from .lib import getMinio, getDDA
from test import runTest


class TestLoRAs:
def setup_class(self):
print("setup_class")
# self.minio = minio = getMinio("global")

self.dda = dda = getDDA(
# minio=minio
stream_logs=True,
)
print(dda)

self.TEST_ARGS = {"test_url": dda.url}

def teardown_class(self):
print("teardown_class")
# self.minio.stop() - leave global up
self.dda.stop()

if False:

def test_lora_hf_download(self):
"""
Download user/repo from HuggingFace.
"""
# fp32 model is obviously bigger
result = runTest(
"txt2img",
self.TEST_ARGS,
{
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
"MODEL_REVISION": "fp16",
"MODEL_PRECISION": "fp16",
"attn_procs": "patrickvonplaten/lora_dreambooth_dog_example",
},
{
"num_inference_steps": 1,
"prompt": "A picture of a sks dog in a bucket",
"seed": 1,
"cross_attention_kwargs": {"scale": 0.5},
},
)

assert result["image_base64"]

if False:

def test_lora_http_download_pytorch_bin(self):
"""
Download pytroch_lora_weights.bin directly.
"""
result = runTest(
"txt2img",
self.TEST_ARGS,
{
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
"MODEL_REVISION": "fp16",
"MODEL_PRECISION": "fp16",
"attn_procs": "https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin",
},
{
"num_inference_steps": 1,
"prompt": "A picture of a sks dog in a bucket",
"seed": 1,
"cross_attention_kwargs": {"scale": 0.5},
},
)

assert result["image_base64"]

# These formats are not supported by diffusers yet :(
def test_lora_http_download_civitai_safetensors(self):
result = runTest(
"txt2img",
self.TEST_ARGS,
{
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
"MODEL_REVISION": "fp16",
"MODEL_PRECISION": "fp16",
# https://civitai.com/models/5373/makima-chainsaw-man-lora
"lora_weights": "https://civitai.com/api/download/models/6244#fname=makima_offset.safetensors",
"safety_checker": False,
},
{
"num_inference_steps": 1,
"prompt": "masterpiece, (photorealistic:1.4), best quality, beautiful lighting, (ulzzang-6500:0.5), makima \(chainsaw man\), (red hair)+(long braided hair)+(bangs), yellow eyes, golden eyes, ((ringed eyes)), (white shirt), (necktie), RAW photo, 8k uhd, film grain",
"seed": 1,
},
)

assert result["image_base64"]

0 comments on commit 7a64846

Please sign in to comment.