Skip to content

Commit

Permalink
feat(attn_procs): initial URL work (see notes)
Browse files Browse the repository at this point in the history
TODO:

* test S3 for single files (should work, http works)
* test archives (totally untested, all archive code is new)
  • Loading branch information
gadicc committed Feb 20, 2023
1 parent ee2d835 commit 6348836
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 11 deletions.
19 changes: 17 additions & 2 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import traceback
from precision import MODEL_REVISION, MODEL_PRECISION
from device import device, device_id, device_name
from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor
from diffusers.models.cross_attention import CrossAttnProcessor
from utils import Storage
from hashlib import sha256


RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
Expand Down Expand Up @@ -278,6 +281,18 @@ def inference(all_inputs: dict) -> dict:
if attn_procs is not last_attn_procs:
last_attn_procs = attn_procs
if attn_procs:
storage = Storage(attn_procs, no_raise=True)
if storage:
fname = storage.url.split("/").pop()
hash = sha256(attn_procs.encode("utf-8")).hexdigest()
if True:
# TODO, way to specify explicit name
path = os.path.join(
MODELS_DIR, "attn_proc--url_" + hash[:7] + "--" + fname
)
attn_procs = path
if not os.path.exists(path):
storage.download_and_extract(path)
print("Load attn_procs " + attn_procs)
pipeline.unet.load_attn_procs(attn_procs)
else:
Expand All @@ -286,7 +301,7 @@ def inference(all_inputs: dict) -> dict:

# TODO, generalize
cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
if cross_attention_kwargs:
if isinstance(cross_attention_kwargs, str):
model_inputs["cross_attention_kwargs"] = json.loads(cross_attention_kwargs)

# Parse out your arguments
Expand Down
50 changes: 41 additions & 9 deletions tests/integration/test_attn_procs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,31 @@


class TestAttnProcs:
def test_hf_download(self):
"""
Make sure when switching models we release VRAM afterwards.
"""
dda = getDDA(
def setup_class(self):
print("setup_class")
# self.minio = minio = getMinio("global")

self.dda = dda = getDDA(
# minio=minio
stream_logs=True,
)
print(dda)

TEST_ARGS = {"test_url": dda.url}
self.TEST_ARGS = {"test_url": dda.url}

mem_usage = list()
def teardown_class(self):
print("teardown_class")
# self.minio.stop() - leave global up
self.dda.stop()

def test_hf_download(self):
"""
Download user/repo from HuggingFace.
"""
# fp32 model is obviously bigger
result = runTest(
"txt2img",
TEST_ARGS,
self.TEST_ARGS,
{
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
"MODEL_REVISION": "fp16",
Expand All @@ -37,4 +45,28 @@ def test_hf_download(self):
)

assert result["image_base64"]
dda.stop()

def test_http_download_diffusers_archive(self):
"""
Download user/repo from HuggingFace.
"""

# fp32 model is obviously bigger
result = runTest(
"txt2img",
self.TEST_ARGS,
{
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
"MODEL_REVISION": "fp16",
"MODEL_PRECISION": "fp16",
"attn_procs": "https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin",
},
{
"num_inference_steps": 1,
"prompt": "A picture of a sks dog in a bucket",
"seed": 1,
"cross_attention_kwargs": {"scale": 0.5},
},
)

assert result["image_base64"]

0 comments on commit 6348836

Please sign in to comment.