From 7201b039075bbd62b83bb7a90964026f6fe34ce6 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 12 Jul 2023 19:35:54 -0700 Subject: [PATCH] Add Stable Diffusion V2 support to DirectML examples (#405) --- examples/directml/stable_diffusion/README.md | 8 ++- examples/directml/stable_diffusion/config.py | 6 ++ .../stable_diffusion/stable_diffusion.py | 59 +++++++++++++++---- .../directml/stable_diffusion/user_script.py | 7 ++- 4 files changed, 62 insertions(+), 18 deletions(-) create mode 100644 examples/directml/stable_diffusion/config.py diff --git a/examples/directml/stable_diffusion/README.md b/examples/directml/stable_diffusion/README.md index e2efb1de8..efeeb8fa5 100644 --- a/examples/directml/stable_diffusion/README.md +++ b/examples/directml/stable_diffusion/README.md @@ -1,6 +1,6 @@ # Stable Diffusion Optimization with DirectML -This sample shows how to optimize [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) to run with ONNX Runtime and DirectML. +This sample shows how to optimize [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4), [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2) to run with ONNX Runtime and DirectML. Stable Diffusion comprises multiple PyTorch models tied together into a *pipeline*. This Olive sample will convert each PyTorch model to ONNX, and then run the converted ONNX models through the `OrtTransformersOptimization` pass. The transformer optimization pass performs several time-consuming graph transformations that make the models more efficient for inference at runtime. Output models are only guaranteed to be compatible with onnxruntime-directml 1.15.0 or newer. @@ -45,8 +45,8 @@ The above command will enumerate the `config_.json` files and optimi The stable diffusion models are large, and the optimization process is resource intensive. It is recommended to run optimization on a system with a minimum of 16GB of memory (preferably 32GB). Expect optimization to take several minutes (especially the U-Net model). Once the script successfully completes: -- The optimized ONNX pipeline will be stored under `models/optimized/runwayml/stable-diffusion-v1-5`. -- The unoptimized ONNX pipeline (models converted to ONNX, but not run through transformer optimization pass) will be stored under `models/unoptimized/runwayml/stable-diffusion-v1-5`. +- The optimized ONNX pipeline will be stored under `models/optimized/[model_id]` (for example `models/optimized/runwayml/stable-diffusion-v1-5`). +- The unoptimized ONNX pipeline (models converted to ONNX, but not run through transformer optimization pass) will be stored under `models/unoptimized/[model_id]` (for example `models/unoptimized/runwayml/stable-diffusion-v1-5`). Re-running the script with `--optimize` will delete the output models, but it will *not* delete the Olive cache. Subsequent runs will complete much faster since it will simply be copying previously optimized models; you may use the `--clean_cache` option to start from scratch (not typically used unless you are modifying the scripts, for example). @@ -77,6 +77,8 @@ Run `python stable_diffusion.py --help` for additional options. A few particular - `--model_id ` : name of a stable diffusion model ID hosted by huggingface.co. This script has been tested with the following: - `CompVis/stable-diffusion-v1-4` - `runwayml/stable-diffusion-v1-5` (default) + - `sayakpaul/sd-model-finetuned-lora-t4` + - `stabilityai/stable-diffusion-2` - LoRA variants of the above base models may work as well. See [LoRA Models (Experimental)](#lora-models-experimental). - `--num_inference_steps ` : the number of sampling steps per inference. The default value is 50. A lower value (e.g. 20) will speed up inference at the expensive of quality, and a higher value (e.g. 100) may produce higher quality images. - `--num_images ` : the number of images to generate per script invocation (non-interactive UI) or per click of the generate button (interactive UI). The default value is 1. diff --git a/examples/directml/stable_diffusion/config.py b/examples/directml/stable_diffusion/config.py new file mode 100644 index 000000000..c548b1c82 --- /dev/null +++ b/examples/directml/stable_diffusion/config.py @@ -0,0 +1,6 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +image_size = 512 diff --git a/examples/directml/stable_diffusion/stable_diffusion.py b/examples/directml/stable_diffusion/stable_diffusion.py index b59c508f5..5f3b6a0ad 100644 --- a/examples/directml/stable_diffusion/stable_diffusion.py +++ b/examples/directml/stable_diffusion/stable_diffusion.py @@ -11,6 +11,7 @@ import warnings from pathlib import Path +import config import onnxruntime as ort import torch from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline @@ -23,7 +24,7 @@ def run_inference_loop( - pipeline, prompt, num_images, batch_size, num_inference_steps, image_callback=None, step_callback=None + pipeline, prompt, num_images, batch_size, image_size, num_inference_steps, image_callback=None, step_callback=None ): images_saved = 0 @@ -37,11 +38,13 @@ def update_steps(step, timestep, latents): [prompt] * batch_size, num_inference_steps=num_inference_steps, callback=update_steps if step_callback else None, + height=image_size, + width=image_size, ) passed_safety_checker = 0 for image_index in range(batch_size): - if not result.nsfw_content_detected[image_index]: + if result.nsfw_content_detected is None or not result.nsfw_content_detected[image_index]: passed_safety_checker += 1 if images_saved < num_images: output_path = f"result_{images_saved}.png" @@ -54,7 +57,7 @@ def update_steps(step, timestep, latents): print(f"Inference Batch End ({passed_safety_checker}/{batch_size} images passed the safety checker).") -def run_inference_gui(pipeline, prompt, num_images, batch_size, num_inference_steps): +def run_inference_gui(pipeline, prompt, num_images, batch_size, image_size, num_inference_steps): def update_progress_bar(total_steps_completed): progress_bar["value"] = total_steps_completed @@ -76,6 +79,7 @@ def on_generate_click(): prompt_textbox.get(), num_images, batch_size, + image_size, num_inference_steps, image_completed, update_progress_bar, @@ -86,7 +90,6 @@ def on_generate_click(): print("WARNING: interactive UI only supports displaying up to 9 images") num_images = 9 - image_size = 512 image_rows = 1 + (num_images - 1) // 3 image_cols = 2 if num_images == 4 else min(num_images, 3) min_batches_required = 1 + (num_images - 1) // batch_size @@ -127,7 +130,9 @@ def on_generate_click(): window.mainloop() -def run_inference(optimized_model_dir, prompt, num_images, batch_size, num_inference_steps, static_dims, interactive): +def run_inference( + optimized_model_dir, prompt, num_images, batch_size, image_size, num_inference_steps, static_dims, interactive +): ort.set_default_logger_severity(3) print("Loading models into ORT session...") @@ -140,8 +145,8 @@ def run_inference(optimized_model_dir, prompt, num_images, batch_size, num_infer # https://github.com/huggingface/diffusers/blob/46c52f9b9607e6ecb29c782c052aea313e6487b7/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L672 sess_options.add_free_dimension_override_by_name("unet_sample_batch", batch_size * 2) sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4) - sess_options.add_free_dimension_override_by_name("unet_sample_height", 64) - sess_options.add_free_dimension_override_by_name("unet_sample_width", 64) + sess_options.add_free_dimension_override_by_name("unet_sample_height", image_size // 8) + sess_options.add_free_dimension_override_by_name("unet_sample_width", image_size // 8) sess_options.add_free_dimension_override_by_name("unet_time_batch", 1) sess_options.add_free_dimension_override_by_name("unet_hidden_batch", batch_size * 2) sess_options.add_free_dimension_override_by_name("unet_hidden_sequence", 77) @@ -151,9 +156,9 @@ def run_inference(optimized_model_dir, prompt, num_images, batch_size, num_infer ) if interactive: - run_inference_gui(pipeline, prompt, num_images, batch_size, num_inference_steps) + run_inference_gui(pipeline, prompt, num_images, batch_size, image_size, num_inference_steps) else: - run_inference_loop(pipeline, prompt, num_images, batch_size, num_inference_steps) + run_inference_loop(pipeline, prompt, num_images, batch_size, image_size, num_inference_steps) def optimize( @@ -188,7 +193,12 @@ def optimize( model_info = dict() - for submodel_name in ("text_encoder", "vae_encoder", "vae_decoder", "safety_checker", "unet"): + submodel_names = ["text_encoder", "vae_encoder", "vae_decoder", "unet"] + + if pipeline.safety_checker is not None: + submodel_names.append("safety_checker") + + for submodel_name in submodel_names: print(f"\nOptimizing {submodel_name}") olive_config = None @@ -239,6 +249,12 @@ def optimize( # Save the unoptimized models in a directory structure that the diffusers library can load and run. # This is optional, and the optimized models can be used directly in a custom pipeline if desired. print("\nCreating ONNX pipeline...") + + if pipeline.safety_checker is not None: + safety_checker = OnnxRuntimeModel.from_pretrained(model_info["safety_checker"]["unoptimized"]["path"].parent) + else: + safety_checker = None + onnx_pipeline = OnnxStableDiffusionPipeline( vae_encoder=OnnxRuntimeModel.from_pretrained(model_info["vae_encoder"]["unoptimized"]["path"].parent), vae_decoder=OnnxRuntimeModel.from_pretrained(model_info["vae_decoder"]["unoptimized"]["path"].parent), @@ -246,7 +262,7 @@ def optimize( tokenizer=pipeline.tokenizer, unet=OnnxRuntimeModel.from_pretrained(model_info["unet"]["unoptimized"]["path"].parent), scheduler=pipeline.scheduler, - safety_checker=OnnxRuntimeModel.from_pretrained(model_info["safety_checker"]["unoptimized"]["path"].parent), + safety_checker=safety_checker, feature_extractor=pipeline.feature_extractor, requires_safety_checker=True, ) @@ -257,7 +273,7 @@ def optimize( # Create a copy of the unoptimized model directory, then overwrite with optimized models from the olive cache. print("Copying optimized models...") shutil.copytree(unoptimized_model_dir, optimized_model_dir, ignore=shutil.ignore_patterns("weights.pb")) - for submodel_name in ("text_encoder", "vae_encoder", "vae_decoder", "safety_checker", "unet"): + for submodel_name in submodel_names: src_path = model_info[submodel_name]["optimized"]["path"] dst_path = optimized_model_dir / submodel_name / "model.onnx" shutil.copyfile(src_path, dst_path) @@ -295,6 +311,22 @@ def optimize( "Use --dynamic_dims to disable static shape optimization." ) + model_to_image_size = { + "CompVis/stable-diffusion-v1-4": 512, + "runwayml/stable-diffusion-v1-5": 512, + "sayakpaul/sd-model-finetuned-lora-t4": 512, + "stabilityai/stable-diffusion-2": 768, + "stabilityai/stable-diffusion-2-base": 768, + "stabilityai/stable-diffusion-2-1": 768, + "stabilityai/stable-diffusion-2-1-base": 768, + } + + if args.model_id not in list(model_to_image_size.keys()): + print( + f"WARNING: {args.model_id} is not an officially supported model for this example and may not work as " + + "expected." + ) + if version.parse(ort.__version__) < version.parse("1.15.0"): print("This script requires onnxruntime-directml 1.15.0 or newer") exit(1) @@ -306,6 +338,8 @@ def optimize( if args.clean_cache: shutil.rmtree(script_dir / "cache", ignore_errors=True) + config.image_size = model_to_image_size.get(args.model_id, 512) + if args.optimize or not optimized_model_dir.exists(): # TODO: clean up warning filter (mostly during conversion from torch to ONNX) with warnings.catch_warnings(): @@ -323,6 +357,7 @@ def optimize( args.prompt, args.num_images, args.batch_size, + config.image_size, args.num_inference_steps, use_static_dims, args.interactive, diff --git a/examples/directml/stable_diffusion/user_script.py b/examples/directml/stable_diffusion/user_script.py index a013f0594..ce188a832 100644 --- a/examples/directml/stable_diffusion/user_script.py +++ b/examples/directml/stable_diffusion/user_script.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import config import torch from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -135,7 +136,7 @@ def unet_inputs(batchsize, torch_dtype): return { "sample": torch.rand((batchsize, 4, 64, 64), dtype=torch_dtype), "timestep": torch.rand((batchsize,), dtype=torch_dtype), - "encoder_hidden_states": torch.rand((batchsize, 77, 768), dtype=torch_dtype), + "encoder_hidden_states": torch.rand((batchsize, 77, config.image_size + 256), dtype=torch_dtype), "return_dict": False, } @@ -163,7 +164,7 @@ def unet_data_loader(data_dir, batchsize, *args, **kwargs): def vae_encoder_inputs(batchsize, torch_dtype): return { - "sample": torch.rand((batchsize, 3, 512, 512), dtype=torch_dtype), + "sample": torch.rand((batchsize, 3, config.image_size, config.image_size), dtype=torch_dtype), "return_dict": False, } @@ -218,7 +219,7 @@ def vae_decoder_data_loader(data_dir, batchsize, *args, **kwargs): def safety_checker_inputs(batchsize, torch_dtype): return { "clip_input": torch.rand((batchsize, 3, 224, 224), dtype=torch_dtype), - "images": torch.rand((batchsize, 512, 512, 3), dtype=torch_dtype), + "images": torch.rand((batchsize, config.image_size, config.image_size, 3), dtype=torch_dtype), }