From 59ae3fdfdc77aab5022c7359c7e2e988ed221936 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 17 Oct 2023 21:30:04 -0700 Subject: [PATCH] [CUDA] StableDiffusion XL demo with CUDA EP (#17997) Add CUDA EP to the StableDiffusion XL Demo including: (1) Add fp16 VAE support for CUDA EP. (2) Configuration for each model separately (For example, some models can run with CUDA graph but some models cannot). Some remaining works will boost performance further later: (1) Enable CUDA Graph for Clip2 and UNet. Currently, some part of graph is partitioned to CPU, which blocks CUDA graph. (2) Update GroupNorm CUDA kernel for refiner. Currently, the cuda kernel only supports limited number of channels in refiner so we shall see some gain there if we remove the limitation. Some extra works that are nice to have (thus lower priority): (3) Support denoising_end to ensemble base and refiner. (4) Support classifier free guidance (The idea is from https://www.baseten.co/blog/sdxl-inference-in-under-2-seconds-the-ultimate-guide-to-stable-diffusion-optimiza/). #### Performance on A100-SXM4-80GB Example commands to test an engine built with static shape or dynamic shape: ``` engine_name=ORT_CUDA python demo_txt2img_xl.py --engine $engine_name "some prompt" python demo_txt2img_xl.py --engine $engine_name --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape "some prompt" ``` Engine built with dynamic shape could support different batch size (1 to 4 for TRT; 1 to 16 for CUDA) and image size (256x256 to 1024x1024). Engine built with static shape could only support fixed batch size (1) and image size (1024x1024). The latency (ms) of generating an image of size 1024x1024 (sorted by total latency): Engine | Base (30 Steps)* | Refiner (9 Steps) | Total Latency (ms) -- | -- | -- | -- ORT_TRT (static shape) | 2467 | 1033 | 3501 TRT (static shape) | 2507 | 1048 | 3555 ORT_CUDA (static shape) | 2630 | 1015 | 3645 ORT_CUDA (dynamic shape) | 2639 | 1016 | 3654 TRT (dynamic shape) | 2777 | 1099 | 3876 ORT_TRT (dynamic shape) | 2890 | 1166 | 4057 \* VAE decoder is not used in Base since the output from base is latent, which is consumed by refiner to output image. We can see that ORT_CUDA is faster on dynamic shape, while slower in static shape (The cause is Clip2 and UNet cannot run with CUDA Graph right now, and we will address the issue later). ### Motivation and Context Follow up of https://github.com/microsoft/onnxruntime/pull/17536 --- .../models/stable_diffusion/benchmark.py | 8 +- .../models/stable_diffusion/demo_txt2img.py | 2 +- .../stable_diffusion/demo_txt2img_xl.py | 97 +++++++------ .../models/stable_diffusion/demo_utils.py | 16 +-- .../stable_diffusion/diffusion_models.py | 126 +++++++++++------ .../models/stable_diffusion/engine_builder.py | 19 ++- .../engine_builder_ort_cuda.py | 131 +++++++++++++++--- .../engine_builder_ort_trt.py | 2 +- .../models/stable_diffusion/ort_optimizer.py | 14 +- .../models/stable_diffusion/ort_utils.py | 14 +- .../stable_diffusion/pipeline_img2img_xl.py | 23 +-- .../pipeline_stable_diffusion.py | 82 +++++++---- .../stable_diffusion/pipeline_txt2img.py | 6 +- .../stable_diffusion/pipeline_txt2img_xl.py | 35 +++-- 14 files changed, 379 insertions(+), 196 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index f8fda13a35b93..1f1db914e274b 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -864,7 +864,7 @@ def init_pipeline(pipeline_class, pipeline_info): base_pipeline_info = PipelineInfo(version) demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) - refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True) + refiner_pipeline_info = PipelineInfo(version, is_refiner=True) demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) max_device_memory = max(demo_base.backend.max_device_memory(), demo_refiner.backend.max_device_memory()) @@ -887,7 +887,7 @@ def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): guidance=5.0, warmup=warmup, seed=seed, - return_type="latents", + return_type="latent", ) images, time_refiner = demo_refiner.run( @@ -1037,7 +1037,7 @@ def init_pipeline(pipeline_class, pipeline_info): base_pipeline_info = PipelineInfo(version) demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) - refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True) + refiner_pipeline_info = PipelineInfo(version, is_refiner=True) demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) demo_base.load_resources(image_height, image_width, batch_size) @@ -1053,7 +1053,7 @@ def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): guidance=5.0, warmup=warmup, seed=seed, - return_type="latents", + return_type="latent", ) images, time_refiner = demo_refiner.run( prompt, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index f6e00063a6391..d6de5c45a5210 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -73,7 +73,7 @@ def run_inference(warmup=False): denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, - return_type="images", + return_type="image", ) if not args.disable_cuda_graph: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index c3a2e4e293cc8..efc87a207d130 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -28,11 +28,15 @@ from pipeline_img2img_xl import Img2ImgXLPipeline from pipeline_txt2img_xl import Txt2ImgXLPipeline -if __name__ == "__main__": - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + +def run_demo(): + """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image.""" + args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") + prompt, negative_prompt = repeat_prompt(args) + # Recommend image size as one of those used in training (see Appendix I in https://arxiv.org/pdf/2307.01952.pdf). image_height = args.height image_width = args.width @@ -44,39 +48,32 @@ init_trt_plugins() max_batch_size = 16 - if args.build_dynamic_shape or image_height > 512 or image_width > 512: + if (engine_type in [EngineType.ORT_TRT, EngineType.TRT]) and ( + args.build_dynamic_shape or image_height > 512 or image_width > 512 + ): max_batch_size = 4 batch_size = len(prompt) if batch_size > max_batch_size: - raise ValueError( - f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" - ) + raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.") - base_info = PipelineInfo(args.version, use_vae_in_xl_base=not args.enable_refiner) + # No VAE decoder in base when it outputs latent instead of image. + base_info = PipelineInfo(args.version, use_vae=False) base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) - if args.enable_refiner: - refiner_info = PipelineInfo(args.version, is_sd_xl_refiner=True) - refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) + refiner_info = PipelineInfo(args.version, is_refiner=True) + refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) - if engine_type == EngineType.TRT: - max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) - _, shared_device_memory = cudart.cudaMalloc(max_device_memory) - base.backend.activate_engines(shared_device_memory) - refiner.backend.activate_engines(shared_device_memory) - - base.load_resources(image_height, image_width, batch_size) - refiner.load_resources(image_height, image_width, batch_size) - else: - if engine_type == EngineType.TRT: - max_device_memory = max(base.backend.max_device_memory(), base.backend.max_device_memory()) - _, shared_device_memory = cudart.cudaMalloc(max_device_memory) - base.backend.activate_engines(shared_device_memory) + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + refiner.backend.activate_engines(shared_device_memory) - base.load_resources(image_height, image_width, batch_size) + base.load_resources(image_height, image_width, batch_size) + refiner.load_resources(image_height, image_width, batch_size) - def run_sd_xl_inference(enable_refiner: bool, warmup=False): + def run_base_and_refiner(warmup=False): images, time_base = base.run( prompt, negative_prompt, @@ -86,44 +83,46 @@ def run_sd_xl_inference(enable_refiner: bool, warmup=False): denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, - return_type="latents" if enable_refiner else "images", + return_type="latent", + ) + + images, time_refiner = refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, ) - if enable_refiner: - images, time_refiner = refiner.run( - prompt, - negative_prompt, - images, - image_height, - image_width, - warmup=warmup, - denoising_steps=args.denoising_steps, - guidance=args.guidance, - seed=args.seed, - ) - return images, time_base + time_refiner - else: - return images, time_base + return images, time_base + time_refiner if not args.disable_cuda_graph: # inference once to get cuda graph - images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True) + _, _ = run_base_and_refiner(warmup=True) print("[I] Warming up ..") for _ in range(args.num_warmup_runs): - images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True) + _, _ = run_base_and_refiner(warmup=True) print("[I] Running StableDiffusion XL pipeline") if args.nvtx_profile: cudart.cudaProfilerStart() - images, pipeline_time = run_sd_xl_inference(args.enable_refiner, warmup=False) + _, latency = run_base_and_refiner(warmup=False) if args.nvtx_profile: cudart.cudaProfilerStop() base.teardown() - if args.enable_refiner: - print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("e2e", pipeline_time)) - print("|------------|--------------|") - refiner.teardown() + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("e2e", latency)) + print("|------------|--------------|") + refiner.teardown() + + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + run_demo() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 796e83f70d6e4..3996c8c325be3 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -34,7 +34,7 @@ class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatte def parse_arguments(is_xl: bool, description: str): parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) - engines = ["ORT_TRT", "TRT"] if is_xl else ["ORT_CUDA", "ORT_TRT", "TRT"] + engines = ["ORT_CUDA", "ORT_TRT", "TRT"] parser.add_argument( "--engine", @@ -95,7 +95,7 @@ def parse_arguments(is_xl: bool, description: str): "--denoising-steps", type=int, default=30 if is_xl else 50, - help="Number of denoising steps" + (" in each of base and refiner." if is_xl else "."), + help="Number of denoising steps" + (" in base." if is_xl else "."), ) parser.add_argument( @@ -158,12 +158,6 @@ def parse_arguments(is_xl: bool, description: str): "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources." ) - # Pipeline options - if is_xl: - parser.add_argument( - "--enable-refiner", action="store_true", help="Enable refiner and run both base and refiner pipelines." - ) - args = parser.parse_args() if ( @@ -203,6 +197,7 @@ def repeat_prompt(args): raise ValueError( f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}" ) + if len(args.negative_prompt) == 1: negative_prompt = args.negative_prompt * len(prompt) else: @@ -236,16 +231,11 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si engine_dir=engine_dir, framework_model_dir=framework_model_dir, onnx_dir=onnx_dir, - onnx_opset=args.onnx_opset, opt_image_height=args.height, opt_image_width=args.height, opt_batch_size=batch_size, force_engine_rebuild=args.force_engine_build, device_id=torch.cuda.current_device(), - disable_cuda_graph_models=[ - "clip2", # TODO: Add ArgMax cuda kernel to enable cuda graph for clip2. - "unetxl", - ], ) elif engine_type == EngineType.ORT_TRT: # Build TensorRT EP engines and load pytorch modules diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 7726abb9f9e4d..dc777e26938e4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -24,7 +24,7 @@ import logging import os import tempfile -from typing import List, Optional +from typing import Dict, List, Optional import onnx import onnx_graphsurgeon as gs @@ -82,43 +82,41 @@ def infer_shapes(self): class PipelineInfo: - def __init__( - self, version: str, is_inpaint: bool = False, is_sd_xl_refiner: bool = False, use_vae_in_xl_base=False - ): + def __init__(self, version: str, is_inpaint: bool = False, is_refiner: bool = False, use_vae=False): self.version = version self._is_inpaint = is_inpaint - self._is_sd_xl_refiner = is_sd_xl_refiner - self._use_vae_in_xl_base = use_vae_in_xl_base + self._is_refiner = is_refiner + self._use_vae = use_vae - if is_sd_xl_refiner: - assert self.is_sd_xl() + if is_refiner: + assert self.is_xl() def is_inpaint(self) -> bool: return self._is_inpaint - def is_sd_xl(self) -> bool: + def is_xl(self) -> bool: return "xl" in self.version - def is_sd_xl_base(self) -> bool: - return self.is_sd_xl() and not self._is_sd_xl_refiner + def is_xl_base(self) -> bool: + return self.is_xl() and not self._is_refiner - def is_sd_xl_refiner(self) -> bool: - return self.is_sd_xl() and self._is_sd_xl_refiner + def is_xl_refiner(self) -> bool: + return self.is_xl() and self._is_refiner def use_safetensors(self) -> bool: - return self.is_sd_xl() + return self.is_xl() def stages(self) -> List[str]: - if self.is_sd_xl_base(): - return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae_in_xl_base else []) + if self.is_xl_base(): + return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae else []) - if self.is_sd_xl_refiner(): + if self.is_xl_refiner(): return ["clip2", "unetxl", "vae"] return ["clip", "unet", "vae"] def vae_scaling_factor(self) -> float: - return 0.13025 if self.is_sd_xl() else 0.18215 + return 0.13025 if self.is_xl() else 0.18215 @staticmethod def supported_versions(is_xl: bool): @@ -150,7 +148,7 @@ def name(self) -> str: elif self.version == "2.1-base": return "stabilityai/stable-diffusion-2-1-base" elif self.version == "xl-1.0": - if self.is_sd_xl_refiner(): + if self.is_xl_refiner(): return "stabilityai/stable-diffusion-xl-refiner-1.0" else: return "stabilityai/stable-diffusion-xl-base-1.0" @@ -166,7 +164,7 @@ def clip_embedding_dim(self): return 768 elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): return 1024 - elif self.version in ("xl-1.0") and self.is_sd_xl_base(): + elif self.version in ("xl-1.0") and self.is_xl_base(): return 768 else: raise ValueError(f"Invalid version {self.version}") @@ -182,9 +180,9 @@ def unet_embedding_dim(self): return 768 elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): return 1024 - elif self.version in ("xl-1.0") and self.is_sd_xl_base(): + elif self.version in ("xl-1.0") and self.is_xl_base(): return 2048 - elif self.version in ("xl-1.0") and self.is_sd_xl_refiner(): + elif self.version in ("xl-1.0") and self.is_xl_refiner(): return 1280 else: raise ValueError(f"Invalid version {self.version}") @@ -254,16 +252,16 @@ def from_pretrained(self, model_class, framework_model_dir, hf_token, subfolder, def load_model(self, framework_model_dir: str, hf_token: str, subfolder: str): pass - def get_input_names(self): + def get_input_names(self) -> List[str]: pass - def get_output_names(self): + def get_output_names(self) -> List[str]: pass - def get_dynamic_axes(self): - return None + def get_dynamic_axes(self) -> Dict[str, Dict[int, str]]: + pass - def get_sample_input(self, batch_size, image_height, image_width): + def get_sample_input(self, batch_size, image_height, image_width) -> tuple: pass def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): @@ -293,10 +291,10 @@ def get_profile_id(self, batch_size, image_height, image_width, static_batch, st def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): """For TensorRT""" - return None + pass def get_shape_dict(self, batch_size, image_height, image_width): - return None + pass def fp32_input_output_names(self) -> List[str]: """For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model. @@ -305,9 +303,16 @@ def fp32_input_output_names(self) -> List[str]: """ return [] - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_op_list=None, optimize_by_ort=True): optimizer = self.get_ort_optimizer() - optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16, keep_io_types=self.fp32_input_output_names()) + optimizer.optimize( + input_onnx_path, + optimized_onnx_path, + float16=to_fp16, + keep_io_types=self.fp32_input_output_names(), + fp32_op_list=fp32_op_list, + optimize_by_ort=optimize_by_ort, + ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): onnx_graph = onnx.load(input_onnx_path) @@ -382,7 +387,7 @@ def __init__( max_batch_size=max_batch_size, embedding_dim=embedding_dim if embedding_dim > 0 else pipeline_info.clip_embedding_dim(), ) - self.output_hidden_state = pipeline_info.is_sd_xl() + self.output_hidden_state = pipeline_info.is_xl() # see https://github.com/huggingface/diffusers/pull/5057 for more information of clip_skip. # Clip_skip=1 means that the output of the pre-final layer will be used for computing the prompt embeddings. @@ -466,11 +471,18 @@ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, onnx_model.add_node(cast_node) onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_op_list=None, optimize_by_ort=True): optimizer = self.get_ort_optimizer() + if not self.output_hidden_state: optimizer.optimize( - input_onnx_path, optimized_onnx_path, to_fp16, keep_io_types=[], keep_outputs=["text_embeddings"] + input_onnx_path, + optimized_onnx_path, + float16=to_fp16, + keep_io_types=[], + fp32_op_list=fp32_op_list, + keep_outputs=["text_embeddings"], + optimize_by_ort=optimize_by_ort, ) else: with tempfile.TemporaryDirectory() as tmp_dir: @@ -483,9 +495,11 @@ def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): optimizer.optimize( tmp_model_path, optimized_onnx_path, - to_fp16, + float16=to_fp16, keep_io_types=[], + fp32_op_list=fp32_op_list, keep_outputs=["text_embeddings", "hidden_states"], + optimize_by_ort=optimize_by_ort, ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): @@ -741,27 +755,47 @@ def fp32_input_output_names(self) -> List[str]: # VAE Decoder class VAE(BaseModel): - def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size, + fp16: bool = False, + custom_fp16_vae: Optional[str] = None, + ): super().__init__( pipeline_info, model=model, device=device, + fp16=fp16, max_batch_size=max_batch_size, ) + # For SD XL, need custom trained fp16 model to speed up, and avoid overflow at the same time. + self.custom_fp16_vae = custom_fp16_vae + def load_model(self, framework_model_dir, hf_token: Optional[str] = None, subfolder: str = "vae_decoder"): - model_dir = os.path.join(framework_model_dir, self.pipeline_info.name(), subfolder) + model_name = self.custom_fp16_vae or self.pipeline_info.name() + + model_dir = os.path.join(framework_model_dir, model_name, subfolder) if not os.path.exists(model_dir): - vae = AutoencoderKL.from_pretrained( - self.pipeline_info.name(), - subfolder="vae", - use_safetensors=self.pipeline_info.use_safetensors(), - use_auth_token=hf_token, - ).to(self.device) + if self.custom_fp16_vae: + vae = AutoencoderKL.from_pretrained(self.custom_fp16_vae, torch_dtype=torch.float16).to(self.device) + else: + vae = AutoencoderKL.from_pretrained( + self.pipeline_info.name(), + subfolder="vae", + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + ).to(self.device) vae.save_pretrained(model_dir) else: print(f"Load {self.name} pytorch model from: {model_dir}") - vae = AutoencoderKL.from_pretrained(model_dir).to(self.device) + if self.custom_fp16_vae: + vae = AutoencoderKL.from_pretrained(model_dir, torch_dtype=torch.float16).to(self.device) + else: + vae = AutoencoderKL.from_pretrained(model_dir).to(self.device) vae.forward = vae.decode return vae @@ -809,7 +843,7 @@ def get_sample_input(self, batch_size, image_height, image_width): return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),) def fp32_input_output_names(self) -> List[str]: - return ["latent", "images"] + return [] if self.fp16 else ["latent", "images"] def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"): @@ -819,7 +853,7 @@ def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, su model = CLIPTokenizer.from_pretrained( pipeline_info.name(), subfolder=subfolder, - use_safetensors=pipeline_info.is_sd_xl(), + use_safetensors=pipeline_info.is_xl(), use_auth_token=hf_token, ) model.save_pretrained(tokenizer_dir) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index fdf05ffc799d9..029125c639c09 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -59,7 +59,16 @@ def __init__( self.device = torch.device(device) self.torch_device = torch.device(device, torch.cuda.current_device()) self.stages = pipeline_info.stages() - self.vae_torch_fallback = self.pipeline_info.is_sd_xl() + + # TODO: use custom fp16 for ORT_TRT, and no need to fallback to torch. + self.vae_torch_fallback = self.pipeline_info.is_xl() and engine_type != EngineType.ORT_CUDA + + # For SD XL, use an VAE that modified to run in fp16 precision without generating NaNs. + self.custom_fp16_vae = ( + "madebyollin/sdxl-vae-fp16-fix" + if self.pipeline_info.is_xl() and self.engine_type == EngineType.ORT_CUDA + else None + ) self.models = {} self.engines = {} @@ -130,7 +139,7 @@ def load_models(self, framework_model_dir: str): fp16=export_fp16_unet, max_batch_size=self.max_batch_size, unet_dim=4, - time_dim=(5 if self.pipeline_info.is_sd_xl_refiner() else 6), + time_dim=(5 if self.pipeline_info.is_xl_refiner() else 6), ) # VAE Decoder @@ -140,6 +149,7 @@ def load_models(self, framework_model_dir: str): None, # not loaded yet device=self.torch_device, max_batch_size=self.max_batch_size, + custom_fp16_vae=self.custom_fp16_vae, ) if self.vae_torch_fallback: @@ -156,8 +166,9 @@ def load_resources(self, image_height, image_width, batch_size): def vae_decode(self, latents): if self.vae_torch_fallback: - latents = latents.to(dtype=torch.float32) - self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32) + if not self.custom_fp16_vae: + latents = latents.to(dtype=torch.float32) + self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32) images = self.torch_models["vae"](latents)["sample"] else: images = self.run_engine("vae", {"latent": latents})["images"] diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 936d04e8a1c43..11a39b0decad6 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -7,24 +7,33 @@ import logging import os import shutil +from typing import List, Optional import torch from diffusion_models import PipelineInfo from engine_builder import EngineBuilder, EngineType +from ort_utils import CudaSession import onnxruntime as ort -from onnxruntime.transformers.io_binding_helper import CudaSession logger = logging.getLogger(__name__) class OrtCudaEngine(CudaSession): - def __init__(self, onnx_path, device_id: int = 0, enable_cuda_graph=False, disable_optimization=False): + def __init__( + self, + onnx_path, + device_id: int = 0, + enable_cuda_graph: bool = False, + disable_optimization: bool = False, + ): self.onnx_path = onnx_path self.provider = "CUDAExecutionProvider" self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) + # self.provider_options["enable_skip_layer_norm_strict_mode"] = True session_options = ort.SessionOptions() + # When the model has been optimized by onnxruntime, we can disable optimization to save session creation time. if disable_optimization: session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL @@ -47,6 +56,28 @@ def allocate_buffers(self, shape_dict, device): super().allocate_buffers(shape_dict) +class _ModelConfig: + """ + Configuration of one model (like Clip, UNet etc) on ONNX export and optimization for CUDA provider. + For example, if you want to use fp32 in layer normalization, set the following: + force_fp32_ops=["SkipLayerNormalization", "LayerNormalization"] + """ + + def __init__( + self, + onnx_opset_version: int, + use_cuda_graph: bool, + fp16: bool = True, + force_fp32_ops: Optional[List[str]] = None, + optimize_by_ort: bool = True, + ): + self.onnx_opset_version = onnx_opset_version + self.use_cuda_graph = use_cuda_graph + self.fp16 = fp16 + self.force_fp32_ops = force_fp32_ops + self.optimize_by_ort = optimize_by_ort + + class OrtCudaEngineBuilder(EngineBuilder): def __init__( self, @@ -80,18 +111,59 @@ def __init__( use_cuda_graph=use_cuda_graph, ) + self.model_config = {} + + def _configure( + self, + model_name: str, + onnx_opset_version: int, + use_cuda_graph: bool, + fp16: bool = True, + force_fp32_ops: Optional[List[str]] = None, + optimize_by_ort: bool = True, + ): + self.model_config[model_name] = _ModelConfig( + onnx_opset_version, + use_cuda_graph, + fp16=fp16, + force_fp32_ops=force_fp32_ops, + optimize_by_ort=optimize_by_ort, + ) + + def configure_xl(self, onnx_opset_version: int): + self._configure( + "clip", + onnx_opset_version=onnx_opset_version, + use_cuda_graph=self.use_cuda_graph, + ) + self._configure( + "clip2", + onnx_opset_version=onnx_opset_version, # TODO: ArgMax-12 is not implemented in CUDA + use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph + ) + self._configure( + "unetxl", + onnx_opset_version=onnx_opset_version, + use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph + ) + + self._configure( + "vae", + onnx_opset_version=onnx_opset_version, + use_cuda_graph=self.use_cuda_graph, + ) + def build_engines( self, - engine_dir, - framework_model_dir, - onnx_dir, - onnx_opset, - opt_image_height=512, - opt_image_width=512, - opt_batch_size=1, - force_engine_rebuild=False, - device_id=0, - disable_cuda_graph_models=None, + engine_dir: str, + framework_model_dir: str, + onnx_dir: str, + onnx_opset_version: int = 17, + opt_image_height: int = 512, + opt_image_width: int = 512, + opt_batch_size: int = 1, + force_engine_rebuild: bool = False, + device_id: int = 0, ): self.torch_device = torch.device("cuda", device_id) self.load_models(framework_model_dir) @@ -110,6 +182,13 @@ def build_engines( if not os.path.isdir(onnx_dir): os.makedirs(onnx_dir) + # Add default configuration if missing + if self.pipeline_info.is_xl(): + self.configure_xl(onnx_opset_version) + for model_name in self.models: + if model_name not in self.model_config: + self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph) + # Export models to ONNX for model_name, model_obj in self.models.items(): if model_name == "vae" and self.vae_torch_fallback: @@ -119,8 +198,12 @@ def build_engines( onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) if not os.path.exists(onnx_opt_path): if not os.path.exists(onnx_path): + print("----") logger.info("Exporting model: %s", onnx_path) model = model_obj.load_model(framework_model_dir, self.hf_token) + if model_name == "vae": + model.to(torch.float32) + with torch.inference_mode(): # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern. inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) @@ -130,7 +213,7 @@ def build_engines( inputs, onnx_path, export_params=True, - opset_version=onnx_opset, + opset_version=self.model_config[model_name].onnx_opset_version, do_constant_folding=True, input_names=model_obj.get_input_names(), output_names=model_obj.get_output_names(), @@ -144,8 +227,16 @@ def build_engines( # Run graph optimization and convert to mixed precision (computation in FP16) if not os.path.exists(onnx_opt_path): + print("------") logger.info("Generating optimized model: %s", onnx_opt_path) - model_obj.optimize_ort(onnx_path, onnx_opt_path, to_fp16=True) + + model_obj.optimize_ort( + onnx_path, + onnx_opt_path, + to_fp16=self.model_config[model_name].fp16, + fp32_op_list=self.model_config[model_name].force_fp32_ops, + optimize_by_ort=self.model_config[model_name].optimize_by_ort, + ) else: logger.info("Found cached optimized model: %s", onnx_opt_path) @@ -156,11 +247,15 @@ def build_engines( onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) - use_cuda_graph = self.use_cuda_graph - if self.use_cuda_graph and disable_cuda_graph_models and model_name in disable_cuda_graph_models: - use_cuda_graph = False + use_cuda_graph = self.model_config[model_name].use_cuda_graph + + engine = OrtCudaEngine( + onnx_opt_path, + device_id=device_id, + enable_cuda_graph=use_cuda_graph, + disable_optimization=False, + ) - engine = OrtCudaEngine(onnx_opt_path, device_id=device_id, enable_cuda_graph=use_cuda_graph) logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options) built_engines[model_name] = engine diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py index a6bbd4ee7eeb7..8a39dc2ed63fc 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -12,9 +12,9 @@ from cuda import cudart from diffusion_models import PipelineInfo from engine_builder import EngineBuilder, EngineType +from ort_utils import CudaSession import onnxruntime as ort -from onnxruntime.transformers.io_binding_helper import CudaSession logger = logging.getLogger(__name__) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 24570a6ef62da..2078f8d1a497c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -51,7 +51,16 @@ def optimize_by_ort(self, onnx_model, use_external_data_format=False): model = onnx.load(str(ort_optimized_model_path), load_external_data=True) return self.model_type_class_mapping[self.model_type](model) - def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep_io_types=False, keep_outputs=None): + def optimize( + self, + input_fp32_onnx_path, + optimized_onnx_path, + float16=True, + keep_io_types=False, + fp32_op_list=None, + keep_outputs=None, + optimize_by_ort=True, + ): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") fusion_options = FusionOptions(self.model_type) @@ -76,6 +85,7 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep logger.info("Convert to float16 ...") m.convert_float_to_float16( keep_io_types=keep_io_types, + op_block_list=fp32_op_list, ) use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF @@ -87,7 +97,7 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep # to save session creation time. Another benefit is to inspect the final graph for developing purpose. from onnxruntime import __version__ as ort_version - if version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format: + if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format): m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) m.get_operator_statistics() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py index 5c2145845e757..0afa13a0f4dca 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py @@ -7,16 +7,28 @@ import logging import os import shutil +import sys from typing import Union import torch import onnxruntime as ort -from onnxruntime.transformers.io_binding_helper import CudaSession logger = logging.getLogger(__name__) +def add_transformers_dir_to_path(): + sys.path.append(os.path.dirname(__file__)) + + transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) + if transformers_dir not in sys.path: + sys.path.append(transformers_dir) + + +add_transformers_dir_to_path() +from io_binding_helper import CudaSession # noqa: E402. Walk-around to test locally + + # ----------------------------------------------------------------------------------------------------- # Utilities for CUDA EP # ----------------------------------------------------------------------------------------------------- diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py index 0e2aeb6174666..faa3f8bfaabf1 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py @@ -29,7 +29,7 @@ class Img2ImgXLPipeline(StableDiffusionPipeline): """ - Stable Diffusion Img2Img XL pipeline using NVidia TensorRT. + Stable Diffusion Img2Img XL pipeline. """ def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): @@ -40,7 +40,7 @@ def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): pipeline_info (PipelineInfo): Version and Type of stable diffusion pipeline. """ - assert pipeline_info.is_sd_xl_refiner() + assert pipeline_info.is_xl_refiner() super().__init__(pipeline_info, *args, **kwargs) @@ -73,12 +73,12 @@ def _infer( warmup=False, return_type="image", ): - assert len(prompt) == len(negative_prompt) + assert negative_prompt is None or len(prompt) == len(negative_prompt) - # TODO(tianleiwu): Need we use image_height and image_width for the target size here? - original_size = (1024, 1024) + original_size = (image_height, image_width) crops_coords_top_left = (0, 0) - target_size = (1024, 1024) + target_size = (image_height, image_width) + strength = 0.3 aesthetic_score = 6.0 negative_aesthetic_score = 2.5 @@ -94,6 +94,7 @@ def _infer( # Initialize timesteps timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size) # CLIP text encoder 2 @@ -146,10 +147,10 @@ def _infer( with torch.inference_mode(): # VAE decode latent - if return_type == "latents": - images = latents * self.vae_scaling_factor + if return_type == "latent": + images = latents else: - images = self.decode_latent(latents) + images = self.decode_latent(latents / self.vae_scaling_factor) torch.cuda.synchronize() e2e_toc = time.perf_counter() @@ -172,7 +173,7 @@ def run( guidance=5.0, seed=None, warmup=False, - return_type="images", + return_type="image", ): """ Run the diffusion pipeline. @@ -197,7 +198,7 @@ def run( warmup (bool): Indicate if this is a warmup run. return_type (str): - It can be "latents" or "images". + It can be "latent" or "image". """ if self.is_backend_tensorrt(): diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index 87443c990450b..e28db2b77105a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -120,7 +120,7 @@ def __init__( self.stages = pipeline_info.stages() - self.vae_torch_fallback = self.pipeline_info.is_sd_xl() + self.vae_torch_fallback = self.pipeline_info.is_xl() self.use_cuda_graph = use_cuda_graph @@ -129,6 +129,7 @@ def __init__( self.generator = None self.denoising_steps = None + self.actual_steps = None # backend engine self.engine_type = engine_type @@ -142,12 +143,12 @@ def __init__( raise RuntimeError(f"Backend engine type {engine_type.name} is not supported") # Load text tokenizer - if not self.pipeline_info.is_sd_xl_refiner(): + if not self.pipeline_info.is_xl_refiner(): self.tokenizer = get_tokenizer( self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer" ) - if self.pipeline_info.is_sd_xl(): + if self.pipeline_info.is_xl(): self.tokenizer2 = get_tokenizer( self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer_2" ) @@ -219,7 +220,14 @@ def preprocess_images(self, batch_size, images=()): return tuple(init_images) def encode_prompt( - self, prompt, negative_prompt, encoder="clip", tokenizer=None, pooled_outputs=False, output_hidden_states=False + self, + prompt, + negative_prompt, + encoder="clip", + tokenizer=None, + pooled_outputs=False, + output_hidden_states=False, + force_zeros_for_empty_prompt=False, ): if tokenizer is None: tokenizer = self.tokenizer @@ -247,23 +255,32 @@ def encode_prompt( if output_hidden_states: hidden_states = outputs["hidden_states"].clone() - # Tokenize negative prompt - uncond_input_ids = ( - tokenizer( - negative_prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", + # Note: negative prompt embedding is not needed for SD XL when guidance < 1 + + # For SD XL base, handle force_zeros_for_empty_prompt + is_empty_negative_prompt = all([not i for i in negative_prompt]) + if force_zeros_for_empty_prompt and is_empty_negative_prompt: + uncond_embeddings = torch.zeros_like(text_embeddings) + if output_hidden_states: + uncond_hidden_states = torch.zeros_like(hidden_states) + else: + # Tokenize negative prompt + uncond_input_ids = ( + tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) ) - .input_ids.type(torch.int32) - .to(self.device) - ) - outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) - uncond_embeddings = outputs["text_embeddings"] - if output_hidden_states: - uncond_hidden_states = outputs["hidden_states"] + outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) + uncond_embeddings = outputs["text_embeddings"] + if output_hidden_states: + uncond_hidden_states = outputs["hidden_states"] # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) @@ -292,21 +309,21 @@ def denoise_latent( mask=None, masked_image_latents=None, guidance=7.5, - image_guidance=1.5, add_kwargs=None, ): - assert guidance > 1.0, "Guidance has to be > 1.0" - assert image_guidance > 1.0, "Image guidance has to be > 1.0" + assert guidance > 1.0, "Guidance has to be > 1.0" # TODO: remove this constraint cudart.cudaEventRecord(self.events["denoise-start"], 0) if not isinstance(timesteps, torch.Tensor): timesteps = self.scheduler.timesteps + for step_index, timestep in enumerate(timesteps): if self.nvtx_profile: nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") # Expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input( latent_model_input, step_offset + step_index, timestep ) @@ -322,11 +339,11 @@ def denoise_latent( timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - sample_inp = latent_model_input - timestep_inp = timestep_float - embeddings_inp = text_embeddings - - params = {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp} + params = { + "sample": latent_model_input, + "timestep": timestep_float, + "encoder_hidden_states": text_embeddings, + } if add_kwargs: params.update(add_kwargs) @@ -338,7 +355,7 @@ def denoise_latent( if self.nvtx_profile: nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") - # Perform guidance + # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) @@ -350,8 +367,11 @@ def denoise_latent( if self.nvtx_profile: nvtx.end_range(nvtx_latent_step) - latents = 1.0 / self.vae_scaling_factor * latents cudart.cudaEventRecord(self.events["denoise-stop"], 0) + + # The actual number of steps. It might be different from denoising_steps. + self.actual_steps = len(timesteps) + return latents def encode_image(self, init_image): @@ -394,7 +414,7 @@ def print_summary(self, tic, toc, batch_size, vae_enc=False): ) print( "| {:^10} | {:>9.2f} ms |".format( - "UNet x " + str(self.denoising_steps), + "UNet x " + str(self.actual_steps), cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1], ) ) @@ -403,6 +423,7 @@ def print_summary(self, tic, toc, batch_size, vae_enc=False): "VAE-Dec", cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] ) ) + print("|------------|--------------|") print("| {:^10} | {:>9.2f} ms |".format("Pipeline", (toc - tic) * 1000.0)) print("|------------|--------------|") @@ -413,6 +434,7 @@ def to_pil_image(images): images = ( ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() ) + from PIL import Image return [Image.fromarray(images[i]) for i in range(images.shape[0])] diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py index 82f73e8b3cc61..444b6d9a8ca14 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -52,7 +52,7 @@ def _infer( guidance=7.5, seed=None, warmup=False, - return_type="latents", + return_type="latent", ): assert len(prompt) == len(negative_prompt) batch_size = len(prompt) @@ -100,7 +100,7 @@ def run( guidance=7.5, seed=None, warmup=False, - return_type="images", + return_type="image", ): """ Run the diffusion pipeline. @@ -123,7 +123,7 @@ def run( warmup (bool): Indicate if this is a warmup run. return_type (str): - type of return. The value can be "latents" or "images". + type of return. The value can be "latent" or "image". """ if self.is_backend_tensorrt(): import tensorrt as trt diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py index d8f00ed619354..1b3be143e6ce7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -40,7 +40,7 @@ def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): pipeline_info (PipelineInfo): Version and Type of stable diffusion pipeline. """ - assert pipeline_info.is_sd_xl_base() + assert pipeline_info.is_xl_base() super().__init__(pipeline_info, *args, **kwargs) @@ -59,14 +59,13 @@ def _infer( guidance=5.0, seed=None, warmup=False, - return_type="images", + return_type="image", ): assert len(prompt) == len(negative_prompt) - # TODO(tianleiwu): Need we use image_height and image_width for the target size here? - original_size = (1024, 1024) + original_size = (image_height, image_width) crops_coords_top_left = (0, 0) - target_size = (1024, 1024) + target_size = (image_height, image_width) batch_size = len(prompt) self.set_denoising_steps(denoising_steps) @@ -86,7 +85,12 @@ def _infer( # CLIP text encoder text_embeddings = self.encode_prompt( - prompt, negative_prompt, encoder="clip", tokenizer=self.tokenizer, output_hidden_states=True + prompt, + negative_prompt, + encoder="clip", + tokenizer=self.tokenizer, + output_hidden_states=True, + force_zeros_for_empty_prompt=True, ) # CLIP text encoder 2 text_embeddings2, pooled_embeddings2 = self.encode_prompt( @@ -96,6 +100,7 @@ def _infer( tokenizer=self.tokenizer2, pooled_outputs=True, output_hidden_states=True, + force_zeros_for_empty_prompt=True, ) # Merged text embeddings @@ -112,14 +117,18 @@ def _infer( # UNet denoiser latents = self.denoise_latent( - latents, text_embeddings, denoiser="unetxl", guidance=guidance, add_kwargs=add_kwargs + latents, + text_embeddings, + denoiser="unetxl", + guidance=guidance, + add_kwargs=add_kwargs, ) # VAE decode latent - if return_type == "latents": - images = latents * self.vae_scaling_factor + if return_type == "latent": + images = latents else: - images = self.decode_latent(latents) + images = self.decode_latent(latents / self.vae_scaling_factor) torch.cuda.synchronize() e2e_toc = time.perf_counter() @@ -127,7 +136,7 @@ def _infer( if not warmup: print("SD-XL Base Pipeline") self.print_summary(e2e_tic, e2e_toc, batch_size) - if return_type == "images": + if return_type != "latent": self.save_images(images, "txt2img-xl", prompt) return images, (e2e_toc - e2e_tic) * 1000.0 @@ -142,7 +151,7 @@ def run( guidance=5.0, seed=None, warmup=False, - return_type="images", + return_type="image", ): """ Run the diffusion pipeline. @@ -165,7 +174,7 @@ def run( warmup (bool): Indicate if this is a warmup run. return_type (str): - It can be "latents" or "images". + It can be "latent" or "image". """ if self.is_backend_tensorrt():