Skip to content

Commit

Permalink
Add Stable Diffusion V2 support to DirectML examples (#405)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Jul 13, 2023
1 parent ca6dfb3 commit 7201b03
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 18 deletions.
8 changes: 5 additions & 3 deletions examples/directml/stable_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Stable Diffusion Optimization with DirectML <!-- omit in toc -->

This sample shows how to optimize [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) to run with ONNX Runtime and DirectML.
This sample shows how to optimize [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4), [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2) to run with ONNX Runtime and DirectML.

Stable Diffusion comprises multiple PyTorch models tied together into a *pipeline*. This Olive sample will convert each PyTorch model to ONNX, and then run the converted ONNX models through the `OrtTransformersOptimization` pass. The transformer optimization pass performs several time-consuming graph transformations that make the models more efficient for inference at runtime. Output models are only guaranteed to be compatible with onnxruntime-directml 1.15.0 or newer.

Expand Down Expand Up @@ -45,8 +45,8 @@ The above command will enumerate the `config_<model_name>.json` files and optimi
The stable diffusion models are large, and the optimization process is resource intensive. It is recommended to run optimization on a system with a minimum of 16GB of memory (preferably 32GB). Expect optimization to take several minutes (especially the U-Net model).

Once the script successfully completes:
- The optimized ONNX pipeline will be stored under `models/optimized/runwayml/stable-diffusion-v1-5`.
- The unoptimized ONNX pipeline (models converted to ONNX, but not run through transformer optimization pass) will be stored under `models/unoptimized/runwayml/stable-diffusion-v1-5`.
- The optimized ONNX pipeline will be stored under `models/optimized/[model_id]` (for example `models/optimized/runwayml/stable-diffusion-v1-5`).
- The unoptimized ONNX pipeline (models converted to ONNX, but not run through transformer optimization pass) will be stored under `models/unoptimized/[model_id]` (for example `models/unoptimized/runwayml/stable-diffusion-v1-5`).

Re-running the script with `--optimize` will delete the output models, but it will *not* delete the Olive cache. Subsequent runs will complete much faster since it will simply be copying previously optimized models; you may use the `--clean_cache` option to start from scratch (not typically used unless you are modifying the scripts, for example).

Expand Down Expand Up @@ -77,6 +77,8 @@ Run `python stable_diffusion.py --help` for additional options. A few particular
- `--model_id <string>` : name of a stable diffusion model ID hosted by huggingface.co. This script has been tested with the following:
- `CompVis/stable-diffusion-v1-4`
- `runwayml/stable-diffusion-v1-5` (default)
- `sayakpaul/sd-model-finetuned-lora-t4`
- `stabilityai/stable-diffusion-2`
- LoRA variants of the above base models may work as well. See [LoRA Models (Experimental)](#lora-models-experimental).
- `--num_inference_steps <int>` : the number of sampling steps per inference. The default value is 50. A lower value (e.g. 20) will speed up inference at the expensive of quality, and a higher value (e.g. 100) may produce higher quality images.
- `--num_images <int>` : the number of images to generate per script invocation (non-interactive UI) or per click of the generate button (interactive UI). The default value is 1.
Expand Down
6 changes: 6 additions & 0 deletions examples/directml/stable_diffusion/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

image_size = 512
59 changes: 47 additions & 12 deletions examples/directml/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import warnings
from pathlib import Path

import config
import onnxruntime as ort
import torch
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline
Expand All @@ -23,7 +24,7 @@


def run_inference_loop(
pipeline, prompt, num_images, batch_size, num_inference_steps, image_callback=None, step_callback=None
pipeline, prompt, num_images, batch_size, image_size, num_inference_steps, image_callback=None, step_callback=None
):
images_saved = 0

Expand All @@ -37,11 +38,13 @@ def update_steps(step, timestep, latents):
[prompt] * batch_size,
num_inference_steps=num_inference_steps,
callback=update_steps if step_callback else None,
height=image_size,
width=image_size,
)
passed_safety_checker = 0

for image_index in range(batch_size):
if not result.nsfw_content_detected[image_index]:
if result.nsfw_content_detected is None or not result.nsfw_content_detected[image_index]:
passed_safety_checker += 1
if images_saved < num_images:
output_path = f"result_{images_saved}.png"
Expand All @@ -54,7 +57,7 @@ def update_steps(step, timestep, latents):
print(f"Inference Batch End ({passed_safety_checker}/{batch_size} images passed the safety checker).")


def run_inference_gui(pipeline, prompt, num_images, batch_size, num_inference_steps):
def run_inference_gui(pipeline, prompt, num_images, batch_size, image_size, num_inference_steps):
def update_progress_bar(total_steps_completed):
progress_bar["value"] = total_steps_completed

Expand All @@ -76,6 +79,7 @@ def on_generate_click():
prompt_textbox.get(),
num_images,
batch_size,
image_size,
num_inference_steps,
image_completed,
update_progress_bar,
Expand All @@ -86,7 +90,6 @@ def on_generate_click():
print("WARNING: interactive UI only supports displaying up to 9 images")
num_images = 9

image_size = 512
image_rows = 1 + (num_images - 1) // 3
image_cols = 2 if num_images == 4 else min(num_images, 3)
min_batches_required = 1 + (num_images - 1) // batch_size
Expand Down Expand Up @@ -127,7 +130,9 @@ def on_generate_click():
window.mainloop()


def run_inference(optimized_model_dir, prompt, num_images, batch_size, num_inference_steps, static_dims, interactive):
def run_inference(
optimized_model_dir, prompt, num_images, batch_size, image_size, num_inference_steps, static_dims, interactive
):
ort.set_default_logger_severity(3)

print("Loading models into ORT session...")
Expand All @@ -140,8 +145,8 @@ def run_inference(optimized_model_dir, prompt, num_images, batch_size, num_infer
# https://github.com/huggingface/diffusers/blob/46c52f9b9607e6ecb29c782c052aea313e6487b7/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L672
sess_options.add_free_dimension_override_by_name("unet_sample_batch", batch_size * 2)
sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess_options.add_free_dimension_override_by_name("unet_sample_height", 64)
sess_options.add_free_dimension_override_by_name("unet_sample_width", 64)
sess_options.add_free_dimension_override_by_name("unet_sample_height", image_size // 8)
sess_options.add_free_dimension_override_by_name("unet_sample_width", image_size // 8)
sess_options.add_free_dimension_override_by_name("unet_time_batch", 1)
sess_options.add_free_dimension_override_by_name("unet_hidden_batch", batch_size * 2)
sess_options.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
Expand All @@ -151,9 +156,9 @@ def run_inference(optimized_model_dir, prompt, num_images, batch_size, num_infer
)

if interactive:
run_inference_gui(pipeline, prompt, num_images, batch_size, num_inference_steps)
run_inference_gui(pipeline, prompt, num_images, batch_size, image_size, num_inference_steps)
else:
run_inference_loop(pipeline, prompt, num_images, batch_size, num_inference_steps)
run_inference_loop(pipeline, prompt, num_images, batch_size, image_size, num_inference_steps)


def optimize(
Expand Down Expand Up @@ -188,7 +193,12 @@ def optimize(

model_info = dict()

for submodel_name in ("text_encoder", "vae_encoder", "vae_decoder", "safety_checker", "unet"):
submodel_names = ["text_encoder", "vae_encoder", "vae_decoder", "unet"]

if pipeline.safety_checker is not None:
submodel_names.append("safety_checker")

for submodel_name in submodel_names:
print(f"\nOptimizing {submodel_name}")

olive_config = None
Expand Down Expand Up @@ -239,14 +249,20 @@ def optimize(
# Save the unoptimized models in a directory structure that the diffusers library can load and run.
# This is optional, and the optimized models can be used directly in a custom pipeline if desired.
print("\nCreating ONNX pipeline...")

if pipeline.safety_checker is not None:
safety_checker = OnnxRuntimeModel.from_pretrained(model_info["safety_checker"]["unoptimized"]["path"].parent)
else:
safety_checker = None

onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(model_info["vae_encoder"]["unoptimized"]["path"].parent),
vae_decoder=OnnxRuntimeModel.from_pretrained(model_info["vae_decoder"]["unoptimized"]["path"].parent),
text_encoder=OnnxRuntimeModel.from_pretrained(model_info["text_encoder"]["unoptimized"]["path"].parent),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(model_info["unet"]["unoptimized"]["path"].parent),
scheduler=pipeline.scheduler,
safety_checker=OnnxRuntimeModel.from_pretrained(model_info["safety_checker"]["unoptimized"]["path"].parent),
safety_checker=safety_checker,
feature_extractor=pipeline.feature_extractor,
requires_safety_checker=True,
)
Expand All @@ -257,7 +273,7 @@ def optimize(
# Create a copy of the unoptimized model directory, then overwrite with optimized models from the olive cache.
print("Copying optimized models...")
shutil.copytree(unoptimized_model_dir, optimized_model_dir, ignore=shutil.ignore_patterns("weights.pb"))
for submodel_name in ("text_encoder", "vae_encoder", "vae_decoder", "safety_checker", "unet"):
for submodel_name in submodel_names:
src_path = model_info[submodel_name]["optimized"]["path"]
dst_path = optimized_model_dir / submodel_name / "model.onnx"
shutil.copyfile(src_path, dst_path)
Expand Down Expand Up @@ -295,6 +311,22 @@ def optimize(
"Use --dynamic_dims to disable static shape optimization."
)

model_to_image_size = {
"CompVis/stable-diffusion-v1-4": 512,
"runwayml/stable-diffusion-v1-5": 512,
"sayakpaul/sd-model-finetuned-lora-t4": 512,
"stabilityai/stable-diffusion-2": 768,
"stabilityai/stable-diffusion-2-base": 768,
"stabilityai/stable-diffusion-2-1": 768,
"stabilityai/stable-diffusion-2-1-base": 768,
}

if args.model_id not in list(model_to_image_size.keys()):
print(
f"WARNING: {args.model_id} is not an officially supported model for this example and may not work as "
+ "expected."
)

if version.parse(ort.__version__) < version.parse("1.15.0"):
print("This script requires onnxruntime-directml 1.15.0 or newer")
exit(1)
Expand All @@ -306,6 +338,8 @@ def optimize(
if args.clean_cache:
shutil.rmtree(script_dir / "cache", ignore_errors=True)

config.image_size = model_to_image_size.get(args.model_id, 512)

if args.optimize or not optimized_model_dir.exists():
# TODO: clean up warning filter (mostly during conversion from torch to ONNX)
with warnings.catch_warnings():
Expand All @@ -323,6 +357,7 @@ def optimize(
args.prompt,
args.num_images,
args.batch_size,
config.image_size,
args.num_inference_steps,
use_static_dims,
args.interactive,
Expand Down
7 changes: 4 additions & 3 deletions examples/directml/stable_diffusion/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import config
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -135,7 +136,7 @@ def unet_inputs(batchsize, torch_dtype):
return {
"sample": torch.rand((batchsize, 4, 64, 64), dtype=torch_dtype),
"timestep": torch.rand((batchsize,), dtype=torch_dtype),
"encoder_hidden_states": torch.rand((batchsize, 77, 768), dtype=torch_dtype),
"encoder_hidden_states": torch.rand((batchsize, 77, config.image_size + 256), dtype=torch_dtype),
"return_dict": False,
}

Expand Down Expand Up @@ -163,7 +164,7 @@ def unet_data_loader(data_dir, batchsize, *args, **kwargs):

def vae_encoder_inputs(batchsize, torch_dtype):
return {
"sample": torch.rand((batchsize, 3, 512, 512), dtype=torch_dtype),
"sample": torch.rand((batchsize, 3, config.image_size, config.image_size), dtype=torch_dtype),
"return_dict": False,
}

Expand Down Expand Up @@ -218,7 +219,7 @@ def vae_decoder_data_loader(data_dir, batchsize, *args, **kwargs):
def safety_checker_inputs(batchsize, torch_dtype):
return {
"clip_input": torch.rand((batchsize, 3, 224, 224), dtype=torch_dtype),
"images": torch.rand((batchsize, 512, 512, 3), dtype=torch_dtype),
"images": torch.rand((batchsize, config.image_size, config.image_size, 3), dtype=torch_dtype),
}


Expand Down

0 comments on commit 7201b03

Please sign in to comment.