From 6fb7369663509587129054f539ebdc12272a1db9 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 22 Nov 2024 13:59:23 -0800 Subject: [PATCH 01/26] initial --- .../models/stable_diffusion/README.md | 12 ++++++ .../models/stable_diffusion/benchmark.py | 21 ++++------ .../stable_diffusion/optimize_pipeline.py | 42 ++++++++++++++----- 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index edef0d3ee5453..8587a92c53025 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -212,6 +212,14 @@ pip install optimum diffusers onnx onnxruntime-gpu optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sd_xl_base_onnx ``` +SD3 and Flux requires transformers >= 4.45, and optimum > 1.23.3: +``` +git clone https://github.com/huggingface/optimum +pip install -e . +optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers sd3_onnx_fp32 +optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-medium-diffusers sd3.5_onnx_fp32 +``` + ### Optimize ONNX Pipeline Example to optimize the exported float32 ONNX models, and save to float16 models: @@ -230,6 +238,10 @@ For SDXL model, it is recommended to use a machine with 48 GB or more memory to python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 ``` +For SD3 model: +``` +python optimize_pipeline.py -i sd3_onnx_fp32 -o sd3_onnx_fp16 --float16 +``` ### Run Benchmark The benchmark.py script will run a warm-up prompt twice, and measure the peak GPU memory usage in these two runs, then record them as first_run_memory_MB and second_run_memory_MB. Then it will run 5 runs to get average latency (in seconds), and output the results to benchmark_result.csv. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 0708d57f040f8..af24e5c817062 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -22,6 +22,11 @@ "2.0": "stabilityai/stable-diffusion-2", "2.1": "stabilityai/stable-diffusion-2-1", "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0", + "3.0": "stabilityai/stable-diffusion-3-medium-diffusers", + # "3.5": "stabilityai/stable-diffusion-3.5-medium", + # "3.5-large": "stabilityai/stable-diffusion-3.5-large", + # "flux.1-schnell": "black-forest-labs/FLUX.1-schnell", + # "flux.1-dev": "black-forest-labs/FLUX.1-dev", } PROVIDERS = { @@ -322,22 +327,10 @@ def get_optimum_ort_pipeline( disable_safety_checker: bool = True, use_io_binding: bool = False, ): - from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline + from optimum.onnxruntime import ORTPipelineForText2Image, ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline if directory is not None and os.path.exists(directory): - if "xl" in model_name: - pipeline = ORTStableDiffusionXLPipeline.from_pretrained( - directory, - provider=provider, - session_options=None, - use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification. - ) - else: - pipeline = ORTStableDiffusionPipeline.from_pretrained( - directory, - provider=provider, - use_io_binding=use_io_binding, - ) + pipeline = ORTPipelineForText2Image.from_pretrained(directory, provider=provider, use_io_binding=use_io_binding) elif "xl" in model_name: pipeline = ORTStableDiffusionXLPipeline.from_pretrained( model_name, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index ffcfd6d9fd7e0..dabce6bab48b5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -27,6 +27,7 @@ import coloredlogs import onnx from fusion_options import FusionOptions +from onnx_model_bert import BertOnnxModel from onnx_model_clip import ClipOnnxModel from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel @@ -46,9 +47,20 @@ def has_external_data(onnx_model_path): return False +def _get_model_list(source_dir: Path): + is_xl = (source_dir / "text_encoder_2").exists() + is_sd3 = (source_dir / "text_encoder_3").exists() + model_list_sd3 = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"] + model_list_sdxl = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"] + model_list_sd = ["text_encoder", "unet", "vae_encoder", "vae_decoder"] + model_list = model_list_sd3 if is_sd3 else (model_list_sdxl if is_xl else model_list_sd) + return model_list + + def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, + model_list: List[str], use_external_data_format: Optional[bool], float16: bool, force_fp32_ops: List[str], @@ -60,6 +72,7 @@ def _optimize_sd_pipeline( Args: source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. + model_list (List[str]): list of directory names with onnx model. use_external_data_format (Optional[bool]): use external data format. float16 (bool): use half precision force_fp32_ops(List[str]): operators that are forced to run in float32. @@ -70,18 +83,21 @@ def _optimize_sd_pipeline( RuntimeError: output onnx model path existed """ model_type_mapping = { + "transformer": "mmdit", "unet": "unet", "vae_encoder": "vae", "vae_decoder": "vae", "text_encoder": "clip", "text_encoder_2": "clip", "safety_checker": "unet", + "text_encoder_3": "clip", } model_type_class_mapping = { "unet": UnetOnnxModel, "vae": VaeOnnxModel, "clip": ClipOnnxModel, + "mmdit": BertOnnxModel, # TODO: have a new class for DiT } force_fp32_operators = { @@ -91,10 +107,10 @@ def _optimize_sd_pipeline( "text_encoder": [], "text_encoder_2": [], "safety_checker": [], + "text_encoder_3": [], + "transformer": [], } - is_xl = (source_dir / "text_encoder_2").exists() - if force_fp32_ops: for fp32_operator in force_fp32_ops: parts = fp32_operator.split(":") @@ -108,8 +124,8 @@ def _optimize_sd_pipeline( for name, model_type in model_type_mapping.items(): onnx_model_path = source_dir / name / "model.onnx" if not os.path.exists(onnx_model_path): - if name != "safety_checker": - logger.info("input onnx model does not exist: %s", onnx_model_path) + if name != "safety_checker" and name in model_list: + logger.warning("input onnx model does not exist: %s", onnx_model_path) # some model are optional so we do not raise error here. continue @@ -122,7 +138,7 @@ def _optimize_sd_pipeline( use_external_data_format = has_external_data(onnx_model_path) # Graph fusion before fp16 conversion, otherwise they cannot be fused later. - logger.info(f"Optimize {onnx_model_path}...") + logger.info("Optimize %s ...", onnx_model_path) args.model_type = model_type fusion_options = FusionOptions.parse(args) @@ -147,6 +163,7 @@ def _optimize_sd_pipeline( if float16: # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32. + is_xl = (source_dir / "text_encoder_2").exists() if is_xl and name == "vae_decoder": logger.info("Skip converting %s to float16 to avoid NaN", name) else: @@ -181,17 +198,18 @@ def _optimize_sd_pipeline( logger.info("*" * 20) -def _copy_extra_directory(source_dir: Path, target_dir: Path): +def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: List[str]): """Copy extra directory that does not have onnx model Args: source_dir (Path): source directory target_dir (Path): target directory + model_list (List[str]): list of directory names with onnx model. Raises: RuntimeError: source path does not exist """ - extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "feature_extractor"] + extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"] for name in extra_dirs: source_path = source_dir / name @@ -213,8 +231,7 @@ def _copy_extra_directory(source_dir: Path, target_dir: Path): logger.info("%s => %s", source_path, target_path) # Some directory are optional - onnx_model_dirs = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder", "safety_checker"] - for onnx_model_dir in onnx_model_dirs: + for onnx_model_dir in model_list: source_path = source_dir / onnx_model_dir / "config.json" target_path = target_dir / onnx_model_dir / "config.json" if source_path.exists(): @@ -236,17 +253,20 @@ def optimize_stable_diffusion_pipeline( if overwrite: shutil.rmtree(output_dir, ignore_errors=True) else: - raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.") + raise RuntimeError(f"output directory existed:{output_dir}. Add --overwrite to empty the directory.") source_dir = Path(input_dir) target_dir = Path(output_dir) target_dir.mkdir(parents=True, exist_ok=True) - _copy_extra_directory(source_dir, target_dir) + model_list = _get_model_list(source_dir) + + _copy_extra_directory(source_dir, target_dir, model_list) _optimize_sd_pipeline( source_dir, target_dir, + model_list, use_external_data_format, float16, args.force_fp32_ops, From 9b2dcc0dadcd94344ee1743d1b092c9381e2ef8d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 2 Dec 2024 23:29:59 +0000 Subject: [PATCH 02/26] sd3.x and flux --- .../models/stable_diffusion/README.md | 51 ++++-- .../models/stable_diffusion/benchmark.py | 28 +-- .../stable_diffusion/optimize_pipeline.py | 4 +- .../tools/transformers/onnx_model_mmdit.py | 169 ++++++++++++++++++ .../python/tools/transformers/optimizer.py | 3 + 5 files changed, 215 insertions(+), 40 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/onnx_model_mmdit.py diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 8587a92c53025..1230cbffa7815 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -203,47 +203,55 @@ This step will export stable diffusion 1.5 to ONNX model in float32 using script ``` curl https://raw.githubusercontent.com/huggingface/diffusers/v0.15.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py > convert_sd_onnx.py -python convert_sd_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ./sd_v1_5/fp32 +python convert_sd_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ./sd1.5_onnx/fp32 ``` For SDXL, use optimum to export the model: ``` pip install optimum diffusers onnx onnxruntime-gpu -optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sd_xl_base_onnx +optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sdxl_onnx/fp32 ``` -SD3 and Flux requires transformers >= 4.45, and optimum > 1.23.3: +#### Stable Diffusion 3.x and Flux 1.0 + +Stable Diffusion 3.x and Flux 1.0 requires transformers >= 4.45, and optimum > 1.23.3: ``` git clone https://github.com/huggingface/optimum +cd optimum pip install -e . -optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers sd3_onnx_fp32 -optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-medium-diffusers sd3.5_onnx_fp32 + +optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers ./sd3_onnx/fp32 +optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-medium ./sd3.5_medium_onnx/fp32 +optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-large ./sd3.5_large_onnx/fp32 +optimum-cli export onnx --model black-forest-labs/FLUX.1-schnell ./flux1_schnell_onnx/fp32 +optimum-cli export onnx --model black-forest-labs/FLUX.1-dev ./flux1_dev_onnx/fp32 ``` ### Optimize ONNX Pipeline -Example to optimize the exported float32 ONNX models, and save to float16 models: +Example to optimize the exported float32 ONNX models, then save to float16 models: ``` -python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_v1_5/fp32 -o ./sd_v1_5/fp16 --float16 +python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd1.5_onnx/fp32 -o ./sd1.5_onnx/fp16 --float16 ``` -In all examples below, we run the scripts in source code directory. You can get source code like the following: +You can also run the script in source code directory like the following: ``` git clone https://github.com/microsoft/onnxruntime cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion + +python optimize_pipeline.py -i ./sdxl_onnx/fp32 -o ./sdxl_onnx/fp16 --float16 +python optimize_pipeline.py -i ./sd3_onnx/fp32 -o ./sd3_onnx/fp16 --float16 +python optimize_pipeline.py -i ./sd3.5_medium_onnx/fp32 -o ./sd3.5_medium_onnx/fp16 --float16 +python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp16 --float16 +python optimize_pipeline.py -i ./flux1_dev_onnx/fp32 -o ./flux1_dev_onnx/fp16 --float16 ``` For SDXL model, it is recommended to use a machine with 48 GB or more memory to optimize. -``` -python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 -``` -For SD3 model: -``` -python optimize_pipeline.py -i sd3_onnx_fp32 -o sd3_onnx_fp16 --float16 -``` ### Run Benchmark +#### Run Benchmark with Optimum + The benchmark.py script will run a warm-up prompt twice, and measure the peak GPU memory usage in these two runs, then record them as first_run_memory_MB and second_run_memory_MB. Then it will run 5 runs to get average latency (in seconds), and output the results to benchmark_result.csv. Note that the first run might need more time and memory: For example, cuDNN convolution algorithm search or model compile happens in the first run. @@ -257,15 +265,15 @@ Before running benchmark on PyTorch, you need to be logged in via `huggingface-c Example to benchmark the optimized pipeline of stable diffusion 1.5 with batch size 1 on CUDA EP: ``` -python benchmark.py -p ./sd_v1_5/fp16 -b 1 -v 1.5 +python benchmark.py -p ./sd1.5_onnx/fp16 -b 1 -v 1.5 python benchmark.py -b 1 -v 1.5 ``` For the first command, '-p' specifies a directory of optimized ONNX pipeline as generated by optimize_pipeline.py. -For the second command without '-p', we will use OnnxruntimeCudaStableDiffusionPipeline to export and optimize ONNX models for clip, unet and vae decoder. +For the second command without '-p', we will use ORTPipelineForText2Image to export and optimize ONNX models for clip, unet and vae decoder. On ROCm EP, use the following command instead: ``` -python benchmark.py -p ./sd_v1_5/fp16 -b 1 --tuning --provider rocm -v 1.5 +python benchmark.py -p ./sd1.5_onnx/fp16 -b 1 --tuning --provider rocm -v 1.5 ``` For ROCm EP, you can substitute `python benchmark.py` with `python -m onnxruntime.transformers.models.stable_diffusion.benchmark` since @@ -275,6 +283,13 @@ For ROCm EP, the `--tuning` is mandatory because we heavily rely on tuning to fi The default parameters are stable diffusion version=1.5, height=512, width=512, steps=50, batch_count=5. Run `python benchmark.py --help` for more information. +#### Stable Diffusion 3.x and Flux 1.0 +Example of benchmark with optimum using CUDA provider on stable diffusion 3.5: +``` +python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp32 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp16 +``` + ### Run Benchmark with xFormers Run PyTorch 1.13.1+cu117 with xFormers like the following diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index af24e5c817062..617d1ee461851 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -22,11 +22,11 @@ "2.0": "stabilityai/stable-diffusion-2", "2.1": "stabilityai/stable-diffusion-2-1", "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0", - "3.0": "stabilityai/stable-diffusion-3-medium-diffusers", - # "3.5": "stabilityai/stable-diffusion-3.5-medium", - # "3.5-large": "stabilityai/stable-diffusion-3.5-large", - # "flux.1-schnell": "black-forest-labs/FLUX.1-schnell", - # "flux.1-dev": "black-forest-labs/FLUX.1-dev", + "3.0M": "stabilityai/stable-diffusion-3-medium-diffusers", + "3.5M": "stabilityai/stable-diffusion-3.5-medium", + "3.5L": "stabilityai/stable-diffusion-3.5-large", + "Flux.1S": "black-forest-labs/FLUX.1-schnell", + "Flux.1D": "black-forest-labs/FLUX.1-dev", } PROVIDERS = { @@ -327,21 +327,12 @@ def get_optimum_ort_pipeline( disable_safety_checker: bool = True, use_io_binding: bool = False, ): - from optimum.onnxruntime import ORTPipelineForText2Image, ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline + from optimum.onnxruntime import ORTPipelineForText2Image if directory is not None and os.path.exists(directory): pipeline = ORTPipelineForText2Image.from_pretrained(directory, provider=provider, use_io_binding=use_io_binding) - elif "xl" in model_name: - pipeline = ORTStableDiffusionXLPipeline.from_pretrained( - model_name, - export=True, - provider=provider, - session_options=None, - use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification. - ) - pipeline.save_pretrained(directory) else: - pipeline = ORTStableDiffusionPipeline.from_pretrained( + pipeline = ORTPipelineForText2Image.from_pretrained( model_name, export=True, provider=provider, @@ -369,10 +360,7 @@ def run_optimum_ort_pipeline( memory_monitor_type, use_num_images_per_prompt=False, ): - from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline - - assert isinstance(pipe, (ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline)) - + print("Pipeline type", type(pipe)) prompts, negative_prompt = example_prompts() def warmup(): diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index dabce6bab48b5..b31f1bd2c98b7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -27,7 +27,7 @@ import coloredlogs import onnx from fusion_options import FusionOptions -from onnx_model_bert import BertOnnxModel +from onnx_model_mmdit import MmditOnnxModel from onnx_model_clip import ClipOnnxModel from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel @@ -97,7 +97,7 @@ def _optimize_sd_pipeline( "unet": UnetOnnxModel, "vae": VaeOnnxModel, "clip": ClipOnnxModel, - "mmdit": BertOnnxModel, # TODO: have a new class for DiT + "mmdit": MmditOnnxModel, } force_fp32_operators = { diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py new file mode 100644 index 0000000000000..2f162ea9e7868 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -0,0 +1,169 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import logging +from typing import Optional + +from fusion_attention_unet import FusionAttentionUnet +from fusion_bias_add import FusionBiasAdd +from fusion_options import FusionOptions +from import_utils import is_installed +from onnx import ModelProto +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class MmditOnnxModel(BertOnnxModel): + def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): + """Initialize Multimodal Diffusion Transformer (MMDiT) ONNX Model. + + Args: + model (ModelProto): the ONNX model + num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically). + hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically). + """ + assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) + + super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + + def preprocess(self): + self.remove_useless_div() + + def postprocess(self): + self.prune_graph() + self.remove_unused_constant() + + def remove_useless_div(self): + """Remove Div by 1""" + div_nodes = [node for node in self.nodes() if node.op_type == "Div"] + + nodes_to_remove = [] + for div in div_nodes: + if self.find_constant_input(div, 1.0) == 1: + nodes_to_remove.append(div) + + for node in nodes_to_remove: + self.replace_input_of_all_nodes(node.output[0], node.input[0]) + + if nodes_to_remove: + self.remove_nodes(nodes_to_remove) + logger.info("Removed %d Div nodes", len(nodes_to_remove)) + + def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): + # Self Attention + self_attention_fusion = FusionAttentionUnet( + self, + self.hidden_size, + self.num_heads, + is_cross_attention=False, + enable_packed_qkv=False, + enable_packed_kv=False, + ) + self_attention_fusion.apply() + + # Cross Attention + cross_attention_fusion = FusionAttentionUnet( + self, + self.hidden_size, + self.num_heads, + is_cross_attention=True, + enable_packed_qkv=False, + enable_packed_kv=False, + ) + cross_attention_fusion.apply() + + def fuse_bias_add(self): + fusion = FusionBiasAdd(self) + fusion.apply() + + def optimize(self, options: Optional[FusionOptions] = None): + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 18 + progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): + if (options is not None) and not options.enable_shape_inference: + self.disable_shape_inference() + + self.utils.remove_identity_nodes() + if progress_bar: + progress_bar.update(1) + + # Remove cast nodes that having same data type of input and output based on symbolic shape inference. + self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_layer_norm: + self.fuse_layer_norm() + if progress_bar: + progress_bar.update(1) + + self.preprocess() + if progress_bar: + progress_bar.update(1) + + self.fuse_reshape() + if progress_bar: + progress_bar.update(1) + + + if (options is None) or options.enable_attention: + self.fuse_multi_head_attention(options) + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_skip_layer_norm: + self.fuse_skip_layer_norm() + if progress_bar: + progress_bar.update(1) + + self.fuse_shape() + if progress_bar: + progress_bar.update(1) + + # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. + self.utils.remove_useless_reshape_nodes() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_bias_skip_layer_norm: + # Fuse SkipLayerNormalization and Add Bias before it. + self.fuse_add_bias_skip_layer_norm() + if progress_bar: + progress_bar.update(1) + + self.postprocess() + if progress_bar: + progress_bar.update(1) + + logger.info(f"opset version: {self.get_opset_version()}") + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "MultiHeadAttention", + "LayerNormalization", + "SkipLayerNormalization", + ] + + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators:{op_count}") + return op_count diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 933bd785dc00d..4564015b9c665 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -41,6 +41,8 @@ from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel +from onnx_model_mmdit import MmditOnnxModel + from onnx_utils import extract_raw_data_from_model, has_external_data import onnxruntime @@ -66,6 +68,7 @@ "unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), + "mmdit": (MmditOnnxModel, "pytorch", 1), } From 7f925cef2f2ba7a8a4e88509705bdccde95952b4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 5 Dec 2024 23:52:40 +0000 Subject: [PATCH 03/26] update FastGelu and RMSNorm fusions --- .../tools/transformers/fusion_fastgelu.py | 122 ++++++++++++++ .../tools/transformers/fusion_layernorm.py | 24 ++- .../fusion_simplified_layernorm.py | 152 +++++++----------- .../stable_diffusion/optimize_pipeline.py | 2 +- .../tools/transformers/onnx_model_mmdit.py | 9 +- .../python/tools/transformers/optimizer.py | 3 +- 6 files changed, 201 insertions(+), 111 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_fastgelu.py b/onnxruntime/python/tools/transformers/fusion_fastgelu.py index a9f46585faad7..e2bb8027c8608 100644 --- a/onnxruntime/python/tools/transformers/fusion_fastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_fastgelu.py @@ -26,6 +26,9 @@ def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict): if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node): return + if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node): + return + def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]: """ Fuse Gelu with tanh into one node: @@ -358,3 +361,122 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name return True + + def fuse_4(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: + """ + This pattern is from stable diffusion 3.5 model. + Fuse Gelu with tanh into one node: + +-----------------+------------------+ + | | | + | v v + [root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul --> + | (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5) + | | + +-------------------------------------------------------------------------+ + Note that constant input for Add and Mul could be first or second input. + """ + if tanh_node.output[0] not in input_name_to_nodes: + return + + children = input_name_to_nodes[tanh_node.output[0]] + if len(children) != 1 or children[0].op_type != "Add": + return + add_after_tanh = children[0] + + if not self.model.has_constant_input(add_after_tanh, 1.0): + return + + if add_after_tanh.output[0] not in input_name_to_nodes: + return + children = input_name_to_nodes[add_after_tanh.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return + mul_after_tanh = children[0] + + if mul_after_tanh.output[0] not in input_name_to_nodes: + return + children = input_name_to_nodes[mul_after_tanh.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return + mul_half = children[0] + if not self.model.has_constant_input(mul_half, 0.5): + return + + root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1] + + mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node) + if mul_before_tanh is None: + return + + i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001) + if i < 0: + return + + add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node) + if add_before_tanh is None: + return + + if add_before_tanh.input[0] == root_input: + another = 1 + elif add_before_tanh.input[1] == root_input: + another = 0 + else: + return + + mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node) + if mul_after_pow is None: + return + + i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001) + if i < 0: + return + + mul = self.model.match_parent(mul_after_pow, "Mul", 0 if i == 1 else 1, output_name_to_node) + if mul is None: + return + + if mul.input[0] == root_input: + another = 1 + elif mul.input[1] == root_input: + another = 0 + else: + return + + mul2 = self.model.match_parent(mul, "Mul", another, output_name_to_node) + if mul2 is None: + return + + if mul2.input[0] != root_input or mul2.input[1] != root_input: + return + + subgraph_nodes = [ + mul2, + mul, + mul_after_pow, + add_before_tanh, + mul_before_tanh, + tanh_node, + add_after_tanh, + mul_after_tanh, + mul_half, + ] + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + [mul_half.output[0]], + input_name_to_nodes, + output_name_to_node, + ): + return + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = helper.make_node( + "FastGelu", + inputs=[root_input], + outputs=mul_half.output, + name=self.model.create_node_name("FastGelu"), + ) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + self.node_name_to_graph_name[fused_node.name] = self.this_graph_name + return True diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index aac05a7f01325..fc49d6ab98752 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -56,18 +56,20 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): for child in children: # Check if Sub --> Div exists div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) - - # Check if Sub --> Cast --> Div - div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[]) - if div_node_1 is not None: div_node = div_node_1 - elif div_node_2 is not None: - div_node = div_node_2[-1] + break + else: + # Check if Sub --> Cast --> Div + div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[]) + if div_node_2 is not None: + div_node = div_node_2[-1] + break + if div_node is None: return - path_id, parent_nodes, _ = self.model.match_parent_paths( + _path_id, parent_nodes, _ = self.model.match_parent_paths( div_node, [ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), @@ -75,7 +77,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): ], output_name_to_node, ) - if path_id < 0: + if parent_nodes is None: return sub_node = parent_nodes[-1] @@ -92,10 +94,14 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if self.model.find_constant_input(pow_node, 2.0) != 1: return + if div_node.output[0] not in input_name_to_nodes: + return temp_node = input_name_to_nodes[div_node.output[0]][0] if temp_node.op_type == "Cast": # Div --> Cast --> Mul subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes + if temp_node.output[0] not in input_name_to_nodes: + return mul_node = input_name_to_nodes[temp_node.output[0]][0] else: # Div --> Mul @@ -103,6 +109,8 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if mul_node.op_type != "Mul": return + if mul_node.output[0] not in input_name_to_nodes: + return last_add_node = input_name_to_nodes[mul_node.output[0]][0] if last_add_node.op_type != "Add": return diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py index a872b8c2075bc..8cb49734f10e0 100644 --- a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -18,134 +18,90 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): return sim_ln_nodes = None - # SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary): - # DD = Pow(D, 2) + # RMSNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary): + # DD = Pow(D, 2) or DD = Mul(D, D) # Var = ReduceMean(DD) # VarEps = Add(Var, epsilon) # StdDev = Sqrt(VarEps) # InvStdDev = Div(1, StdDev) # Normalized = Mul(D, InvStdDev) # NormalizedScaled = Mul(Normalized, Scale) - - # SimplifiedLayerNorm - # +-------------------------------------------------------+ - # | | - # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul - # | - # node - sim_ln_nodes_1 = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], - [1, 1, 1, 0, 0, 0, 0], - ) - # SimplifiedLayerNorm - # +-------------------------------------------------------+ - # | | - # Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul - # | - # node - sim_ln_nodes_2 = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], - [1, 1, 1, 0, 0, 0, 0], - ) - - # For LLaMA from Microsoft custom export: - # sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1 # - # SimplifiedLayerNorm - # +-------------------------------------------------------+ - # | | - # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul - # | - # node - sim_ln_nodes_3 = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], - [0, 1, 1, 0, 0, 0, 0], - ) - - # sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3 + # (root_input) ---------------------------------------+ + # | | + # v v + # Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A=1) (A/B=scale) # - # SimplifiedLayerNorm - # +-----------------------------------------------+ - # | | - # graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul - # | - # node - sim_ln_nodes_4 = self.model.match_parent_path( + # (root_input) ---------------------------------------+ + # | | | + # v v v + # Mul --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A=1) (A/B=scale) + + return_indice = [] + sim_ln_nodes = self.model.match_parent_path( node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"], - [0, 1, 1, 0, 0, 0], + ["Mul", "Div", "Sqrt", "Add", "ReduceMean"], + [None, 1, 1, 0, None], + output_name_to_node=output_name_to_node, + return_indice=return_indice, ) - # For Gemma from Microsoft custom export, which has a Multiply after the Gather: - # - # SimplifiedLayerNorm - # +-------------------------------------------------------+ - # | | - # Mul --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul - # | - # node - sim_ln_nodes_5 = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Mul"], - [1, 1, 1, 0, 0, 0, 0], - ) + if sim_ln_nodes is None: + return - add_node, pow_node = None, None - if sim_ln_nodes_1 is not None: - sim_ln_nodes = sim_ln_nodes_1 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-2] - elif sim_ln_nodes_2 is not None: - sim_ln_nodes = sim_ln_nodes_2 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-2] - elif sim_ln_nodes_3 is not None: - sim_ln_nodes = sim_ln_nodes_3 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-2] - elif sim_ln_nodes_4 is not None: - sim_ln_nodes = sim_ln_nodes_4 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-1] - # Verify that parent input to Pow node is graph_input - if pow_node.input[0] not in self.model.get_graphs_input_names(): + mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes + + pow_or_mul_node = self.model.get_parent(reduce_mean_node, 0, output_name_to_node) + if pow_or_mul_node is None or pow_or_mul_node.op_type not in ["Pow", "Mul"]: + return + + if pow_or_mul_node.op_type == "Pow": + if self.model.find_constant_input(pow_or_mul_node, 2.0) != 1: return - elif sim_ln_nodes_5 is not None: - sim_ln_nodes = sim_ln_nodes_5 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-2] else: + assert pow_or_mul_node.op_type == "Mul" + if pow_or_mul_node[0] != pow_or_mul_node[1]: + return + + root_input = pow_or_mul_node.input[0] + if root_input != mul_node.input[0]: return - layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0 - starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4 + if not self.model.has_constant_input(div_node, 1.0): + return - if self.model.find_constant_input(pow_node, 2.0) != 1: + _i, epsilon = self.model.get_constant_input(add_node) + if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4: + logger.warning(f"epsilon value is not expected: {epsilon}") return - root_input = pow_node.input[0] - if root_input != sim_ln_nodes[0].input[0]: + # ReduceMean must have keepdims == 1 + keepdims = self.model.get_node_attribute(reduce_mean_node, "keepdims") + if not keepdims: return - i, add_weight = self.model.get_constant_input(add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.warning(f"epsilon value is not expected: {add_weight}") + # ReduceMean axes must refer only to the last dimension. + # Axes became an input in opset 18. Before then, axes was an attribute. + axes = self.model.get_node_attribute(reduce_mean_node, "axes") + if (not axes) and len(reduce_mean_node.input) > 1: + axes = self.model.get_constant_value(reduce_mean_node.input[1]) + # Make sure only one axis as required by SimplifiedLayerNormalization spec. + if not axes or len(axes) != 1: return - self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes) + self.nodes_to_remove.extend(sim_ln_nodes) self.nodes_to_remove.append(node) normalize_node = helper.make_node( "SimplifiedLayerNormalization", - inputs=[root_input, node.input[layernorm_weight_index]], + inputs=[root_input, node.input[1 - return_indice[0]]], outputs=[node.output[0]], name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) - normalize_node.attribute.extend([helper.make_attribute("axis", -1)]) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) + normalize_node.attribute.extend([helper.make_attribute("axis", axes[0])]) normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)]) self.nodes_to_add.append(normalize_node) self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index b31f1bd2c98b7..313c3b304a258 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -27,8 +27,8 @@ import coloredlogs import onnx from fusion_options import FusionOptions -from onnx_model_mmdit import MmditOnnxModel from onnx_model_clip import ClipOnnxModel +from onnx_model_mmdit import MmditOnnxModel from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel from optimizer import optimize_by_onnxruntime, optimize_model diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py index 2f162ea9e7868..dcf796b309822 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -85,7 +85,7 @@ def optimize(self, options: Optional[FusionOptions] = None): from tqdm.contrib.logging import logging_redirect_tqdm with logging_redirect_tqdm(): - steps = 18 + steps = 12 progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion") self._optimize(options, progress_bar) else: @@ -107,6 +107,12 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + self.fuse_simplified_layer_norm() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_gelu: + self.fuse_gelu() if progress_bar: progress_bar.update(1) @@ -118,7 +124,6 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if progress_bar: progress_bar.update(1) - if (options is None) or options.enable_attention: self.fuse_multi_head_attention(options) if progress_bar: diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 4564015b9c665..4c9f84a7d1181 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -35,14 +35,13 @@ from onnx_model_clip import ClipOnnxModel from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel +from onnx_model_mmdit import MmditOnnxModel from onnx_model_phi import PhiOnnxModel from onnx_model_sam2 import Sam2OnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel -from onnx_model_mmdit import MmditOnnxModel - from onnx_utils import extract_raw_data_from_model, has_external_data import onnxruntime From cf259e1b63ede68a11445795f10d621a288b711a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 6 Dec 2024 00:30:21 +0000 Subject: [PATCH 04/26] support Reciprocal in RMSNorm fusion --- .../fusion_simplified_layernorm.py | 75 ++++++++++++------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py index 8cb49734f10e0..daf511981d08a 100644 --- a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -18,27 +18,27 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): return sim_ln_nodes = None - # RMSNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary): - # DD = Pow(D, 2) or DD = Mul(D, D) - # Var = ReduceMean(DD) - # VarEps = Add(Var, epsilon) - # StdDev = Sqrt(VarEps) - # InvStdDev = Div(1, StdDev) - # Normalized = Mul(D, InvStdDev) - # NormalizedScaled = Mul(Normalized, Scale) + # RMSNorm formula: + # S = Pow(X, 2) or S = Mul(X, X) + # MS = ReduceMean(S) + # MSEps = Add(MS, epsilon) + # RMS = Sqrt(MSEps) + # InvRMS = Div(1, RMS) or InvRMS = Reciprocal(RMS) + # Normalized = Mul(D, InvRMS) + # Y = Mul(Normalized, Scale) # - # (root_input) ---------------------------------------+ - # | | - # v v - # Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul (node) - # (B=2) (A/B=eps) (A=1) (A/B=scale) + # (root_input) ----------------------------------------+ + # | | + # v v + # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A=1) (A/B=scale) + # + # (root_input) ----------------------------------------+ + # | | | + # v v v + # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A=1) (A/B=scale) # - # (root_input) ---------------------------------------+ - # | | | - # v v v - # Mul --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul (node) - # (B=2) (A/B=eps) (A=1) (A/B=scale) - return_indice = [] sim_ln_nodes = self.model.match_parent_path( node, @@ -48,10 +48,35 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): return_indice=return_indice, ) - if sim_ln_nodes is None: - return - - mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes + if sim_ln_nodes: + mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes + if not self.model.has_constant_input(div_node, 1.0): + return + else: + # Div(1, RMS) can also be represented as Reciprocal(RMS) like + # + # (root_input) -----------------------------------------------+ + # | | + # v v + # Pow --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A/B=scale) + # + # (root_input) -----------------------------------------------+ + # | | | + # v v v + # Mul --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A/B=scale) + # + sim_ln_nodes = self.model.match_parent_path( + node, + ["Mul", "Reciprocal", "Sqrt", "Add", "ReduceMean"], + [None, 1, 0, 0, None], + output_name_to_node=output_name_to_node, + return_indice=return_indice, + ) + if sim_ln_nodes is None: + return + mul_node, _reciprocal_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes pow_or_mul_node = self.model.get_parent(reduce_mean_node, 0, output_name_to_node) if pow_or_mul_node is None or pow_or_mul_node.op_type not in ["Pow", "Mul"]: @@ -69,9 +94,6 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if root_input != mul_node.input[0]: return - if not self.model.has_constant_input(div_node, 1.0): - return - _i, epsilon = self.model.get_constant_input(add_node) if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4: logger.warning(f"epsilon value is not expected: {epsilon}") @@ -92,6 +114,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): return self.nodes_to_remove.extend(sim_ln_nodes) + self.nodes_to_remove.append(pow_or_mul_node) self.nodes_to_remove.append(node) normalize_node = helper.make_node( From b38f12eb86ba832470bd3e63125ecb660a5027f8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 13 Dec 2024 22:21:52 +0000 Subject: [PATCH 05/26] match_child_path interface change --- .../tools/transformers/fusion_layernorm.py | 33 +++++----- .../python/tools/transformers/onnx_model.py | 65 ++++++++++++------- 2 files changed, 59 insertions(+), 39 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index fc49d6ab98752..8c99fe3e1d4c6 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -13,8 +13,9 @@ class FusionLayerNormalization(Fusion): - def __init__(self, model: OnnxModel): + def __init__(self, model: OnnxModel, check_constant_and_dimension:bool=True): super().__init__(model, "LayerNormalization", "ReduceMean") + self.check_constant_and_dimension = check_constant_and_dimension def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ @@ -23,9 +24,9 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): | | | v [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add - (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ - | | - +-----------------------------------------------+ + (axis=2 or -1) | (Y=2) (axis=2 or -1) (B=E-6 or E-12) ^ + | | + +-------------------------------------------------+ It also handles cases of duplicated sub nodes exported from older version of PyTorch: +----------------------+ @@ -61,7 +62,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): break else: # Check if Sub --> Cast --> Div - div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[]) + div_node_2 = self.model.match_child_path(child, ["Cast", "Div"]) if div_node_2 is not None: div_node = div_node_2[-1] break @@ -84,10 +85,10 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if sub_node not in children: return - second_add_node = parent_nodes[1] - i, add_weight = self.model.get_constant_input(second_add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}") + add_eps_node = parent_nodes[1] + i, epsilon = self.model.get_constant_input(add_eps_node) + if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4: + logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {epsilon}") return pow_node = parent_nodes[3] @@ -131,11 +132,11 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] - if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"): + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"): return bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] - if not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"): + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"): return self.nodes_to_remove.extend(subgraph_nodes) @@ -146,7 +147,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): outputs=[last_add_node.output[0]], name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) self.nodes_to_add.append(normalize_node) self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name @@ -226,9 +227,9 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if sub != sub_node: return - i, add_weight = self.model.get_constant_input(second_add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}") + i, epsilon = self.model.get_constant_input(second_add_node) + if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4: + logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {epsilon}") return axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes") @@ -294,7 +295,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): outputs=[layernorm_node_name + "_out_nhwc"], name=layernorm_node_name, ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) self.nodes_to_add.append(transpose_input) self.nodes_to_add.append(normalize_node) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index fe80a08829263..fa898a750178f 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -233,15 +233,23 @@ def get_nodes_by_op_type(self, op_type): nodes.append(node) return nodes - def get_children(self, node, input_name_to_nodes=None): + def get_children(self, node, input_name_to_nodes=None, output_index=None): if input_name_to_nodes is None: input_name_to_nodes = self.input_name_to_nodes() children = [] - for output in node.output: - if output in input_name_to_nodes: - for node in input_name_to_nodes[output]: - children.append(node) # noqa: PERF402 + if output_index is not None: + if output_index < len(node.output): + output = node.output[output_index] + if output in input_name_to_nodes: + for node in input_name_to_nodes[output]: + children.append(node) + else: + for output in node.output: + if output in input_name_to_nodes: + for node in input_name_to_nodes[output]: + children.append(node) # noqa: PERF402 + return children def get_parents(self, node, output_name_to_node=None): @@ -436,48 +444,59 @@ def match_child_path( self, node, child_op_types, - child_output_index=None, - return_indice=None, + edges:Optional[List[Tuple[int, int]]]=None, + input_name_to_nodes=None, exclude=[], # noqa: B006 ): """ Find a sequence of input edges based on constraints on parent op_type and index. - When input_index is None, we will find the first parent node based on constraints, - and return_indice will be appended the corresponding input index. + Note that we use greedy approach and only consider the first matched child, so it has chance to miss matching. Args: node (str): current node name. child_op_types (str): constraint of child node op_type of each input edge. - child_output_index (list): constraint of input index of each input edge. None means no constraint. - return_indice (list): a list to append the input index - When there is no constraint on input index of an edge. + edges (list): each edge is represented by two integers: output index of parent node, input index of child node. + None means no constraint. + exclude(list): list of nodes that are excluded (not allowed to match as child). Returns: children: a list of matched children node. """ - if child_output_index is not None: - assert len(child_output_index) == len(child_op_types) + if edges is not None: + assert len(edges) == len(child_op_types) + for edge in edges: + assert isinstance(edge, tuple) and len(edge) == 2 and isinstance(edge[0], int) and isinstance(edge[1], int) + + if input_name_to_nodes is None: + input_name_to_nodes = self.input_name_to_nodes() current_node = node matched_children = [] for i, op_type in enumerate(child_op_types): matched_child = None - node_children = self.get_children(current_node) - for child_i, child in enumerate(node_children): + + if edges is None: + children_nodes = self.get_children(current_node, input_name_to_nodes=input_name_to_nodes) + else: + children_nodes = self.get_children(current_node, input_name_to_nodes=input_name_to_nodes, output_index=edges[i][0]) + + for child in children_nodes: if child.op_type == op_type and child not in exclude: - if child_output_index is not None and child_output_index[i] != child_i: - logger.debug( - f"Failed to match index={i} child_output_index={child_output_index[i]} op_type={op_type}", - stack_info=True, - ) - return None + if edges is not None and child.input[edges[i][1]] != current_node.output[edges[i][0]]: + continue + + # Here we use greedy approach and only consider the first matched child. + # TODO: match recursively if we encounter cases that the correct child is not the first matched. matched_child = child + break + if matched_child is None: - logger.debug(f"Failed to match child op_type={op_type}", stack_info=True) + logger.debug(f"Failed to match child {i} op_type={op_type}", stack_info=True) return None matched_children.append(matched_child) current_node = matched_child + return matched_children def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True): From a58b68cd015f183704d5aad13ae4e804d1b00e20 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 13 Dec 2024 22:22:08 +0000 Subject: [PATCH 06/26] clean up --- .../tools/transformers/fusion_attention.py | 39 ------------------- .../fusion_simplified_layernorm.py | 2 +- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index a9ff623fb6967..efc0441b0cf4a 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -399,45 +399,6 @@ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str): self.node_name_to_graph_name[gather_k_name] = self.this_graph_name self.node_name_to_graph_name[gather_v_name] = self.this_graph_name - def transpose_kv(self, past_k: str, past_v: str): - """Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H) - - Args: - past_k (str): name of past K value of shape (B,N,P,H) - past_v (str): name of past V value of shape (B,N,P,H) - - Returns: - past_k_transpose (str): name of past K value of shape (B,P,N,H) - past_v_transpose (str): name of past V value of shape (B,P,N,H) - """ - past_k_transpose = (past_k + "_transposed").replace(".", "_") - past_v_transpose = (past_v + "_transposed").replace(".", "_") - transpose_k_name = self.model.create_node_name("Transpose") - transpose_v_name = self.model.create_node_name("Transpose") - - transpose_k = helper.make_node( - "Transpose", - inputs=[past_k], - outputs=[past_k_transpose], - name=transpose_k_name, - perm=[0, 2, 1, 3], - ) - transpose_v = helper.make_node( - "Transpose", - inputs=[past_v], - outputs=[past_v_transpose], - name=transpose_v_name, - perm=[0, 2, 1, 3], - ) - - # Add reshape nodes to graph - self.nodes_to_add.append(transpose_k) - self.nodes_to_add.append(transpose_v) - self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name - self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name - - return past_k_transpose, past_v_transpose - def create_combined_qkv_bias( self, q_add: NodeProto, diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py index daf511981d08a..ca7ff6462b9ff 100644 --- a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -121,7 +121,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): "SimplifiedLayerNormalization", inputs=[root_input, node.input[1 - return_indice[0]]], outputs=[node.output[0]], - name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), + name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="RMSNorm"), ) normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) normalize_node.attribute.extend([helper.make_attribute("axis", axes[0])]) From c7317cbd797b50172776b7ae4f3c805e21ded00f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 13 Dec 2024 22:22:27 +0000 Subject: [PATCH 07/26] MHA fusion for MMDit --- .../tools/transformers/fusion_mha_mmdit.py | 386 ++++++++++++++++++ .../tools/transformers/onnx_model_mmdit.py | 39 +- .../python/tools/transformers/optimizer.py | 2 +- 3 files changed, 401 insertions(+), 26 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/fusion_mha_mmdit.py diff --git a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py new file mode 100644 index 0000000000000..c27c333930c60 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py @@ -0,0 +1,386 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Optional, Tuple + +import numpy as np +from fusion_base import Fusion +from fusion_utils import FusionUtils +from onnx import NodeProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + +class FusionMultiHeadAttentionMMDit(Fusion): + """ + Fuse MultiHeadAttention for Multimodal Diffusion Transformer (MMDiT). + """ + + def __init__(self, model: OnnxModel): + super().__init__(model, fused_op_type = "MultiHeadAttention", search_op_types = ["Softmax"]) + + def get_num_heads(self, node: NodeProto, output_name_to_node) -> int: + """ + Detect num_heads and hidden_size from Concat node in the following subgraph: + MatMul .. [-1] [24] .. + | | | / / + Add<1536> Concat + | / + Reshape + | + Transpose(perm=0,2,1,3) + | + SimplifiedLayerNorm -- scale<64> + | + (node) + + The num_heads can be read directly from the third input of Concat node. + + Here we deduce num_heads=hidden_size/head_size from the following two nodes: + The hidden_size can be read from the bias input of Add node + The head_size can be read from the scale input of SimplifiedLayerNormalization node + """ + k_proj_nodes = self.model.match_parent_path( + node, + ["SimplifiedLayerNormalization", "Transpose", "Reshape", "Add"], + [0, 0, 0, 0], + output_name_to_node=output_name_to_node) + + num_heads = 0 + if k_proj_nodes: + simplified_layernorm, _transpose, _reshape, add = k_proj_nodes + _i, bias = self.model.get_constant_input(add) + + hidden_size = 0 + if isinstance(bias, np.ndarray) and len(bias.shape) == 1: + hidden_size = bias.shape[0] + + weight = self.model.get_constant_value(simplified_layernorm.input[1]) + if isinstance(weight, np.ndarray) and len(weight.shape) == 1: + head_size = weight.shape[0] + if (hidden_size % head_size) == 0: + num_heads = hidden_size // head_size + + return num_heads + + def reshape_to_3d(self, input_name: str, output_name: str) -> str: + # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator. + new_dims_name = "bsnh_to_bsd_reshape_dims" + new_dims = self.model.get_initializer(new_dims_name) + if new_dims is None: + new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name) + self.model.add_initializer(new_dims, self.this_graph_name) + reshape_q = helper.make_node( + "Reshape", + inputs=[input_name, new_dims_name], + outputs=[output_name], + name=self.model.create_node_name("Reshape"), + ) + self.nodes_to_add.append(reshape_q) + self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name + return reshape_q.output[0] + + def get_num_heads_with_concat(self, transpose_k: NodeProto, output_name_to_node) -> int: + """ + Detect num_heads and hidden_size from Concat node in the following subgraph: + + / | + MatMul MatMul .. [-1] [24] .. + | | | | / / + Add Add<1536> Concat + | | / + Reshape Reshape + | | + Transpose Transpose(perm=0,2,1,3) + | | + SimplifiedLayerNorm SimplifiedLayerNorm -- scale<64> + | / + Concat + | + Transpose(perm=0,1,3,2) + """ + nodes = self.model.match_parent_path( + transpose_k, + ["Concat"], + [0], + output_name_to_node=output_name_to_node) + + return self.get_num_heads(nodes[0], output_name_to_node) if nodes else 0 + + def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + """ + Before: + MatMul .. [-1] [24] .. + | | | / / + Add Concat + | / + Reshape + | + Transpose(perm=0,2,1,3) + | + SimplifiedLayerNorm + | + Mul + + After: + MatMul .. [-1] [24] .. + | | | / / + Add Concat + | / + Reshape + | + SimplifiedLayerNorm + | + Reshape (shape=[0, 0, -1]) + """ + + path = self.model.match_parent_path( + mul_q, + ["SimplifiedLayerNormalization", "Transpose"], + [0, 0], + ) + if path is None: + return None + sln_a, transpose_a = path + + if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]): + return None + + # Update the graph + sln_a.input[0] = transpose_a.input[0] + sln_output = sln_a.output[0] + sln_a.output[0] = sln_output + "_BSNH" + + return self.reshape_to_3d(sln_a.output[0], sln_output + "_BSD") + + + + def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + """ + Before: + MatMul MatMul .. [-1] [24] .. + | | | | / / + Add Concat Add Concat + | / | / + Reshape Reshape + | | + Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3) + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=2) + | + Mul + + After: + MatMul MatMul .. [-1] [24] .. + | | | | / / + Add Concat Add Concat + | / | / + Reshape Reshape + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=1) + | + Reshape (shape=[0, 0, -1]) + """ + + path = self.model.match_parent_path( + mul_q, + ["Concat", "SimplifiedLayerNormalization", "Transpose"], + [0, 0, 0], + ) + if path is None: + return None + concat, sln_a, transpose_a = path + + if len(concat.input) != 2: + return None + + path = self.model.match_parent_path( + concat, + ["SimplifiedLayerNormalization", "Transpose"], + [1, 0], + + ) + if path is None: + return None + sln_b, transpose_b = path + + if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]): + return None + + if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]): + return None + + # Update the graph + sln_a.input[0] = transpose_a.input[0] + sln_b.input[0] = transpose_b.input[0] + + new_concat_node = helper.make_node( + "Concat", + inputs=[sln_a.output[0], sln_b.output[0]], + outputs=[concat.output[0] + "_BSNH"], + name=self.model.create_node_name("Concat"), + axis=1, + ) + self.nodes_to_add.append(new_concat_node) + self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name + + return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD") + + + def create_multihead_attention_node( + self, + q: str, + k: str, + v: str, + output: str, + num_heads: int, + ) -> NodeProto: + """ + Create a MultiHeadAttention node. + + Args: + q (str): name of q + k (str): name of k + v (str): name of v + output (str): output name of MHA + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + + Returns: + NodeProto: the node created. + """ + + assert num_heads > 0 + + # Add inputs for MHA: Query, Key, Value (Proj_Bias, Mask, Attention_Bias, Past_K, Past_V are optional) + mha_inputs = [q, k, v] + + # Add outputs for MHA (Present_K, Present_V are optional) + mha_outputs = [output] + + mha_node = helper.make_node( + "MultiHeadAttention", + inputs=mha_inputs, + outputs=mha_outputs, + name=self.model.create_node_name("MultiHeadAttention"), + ) + + mha_node.domain = "com.microsoft" + mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + # No mask is used in MMDit model, so we need not set the optional mask_filter_value attribute. + return mha_node + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + assert node.op_type == "Softmax" + softmax = node + + # Softmax output shall not be graph output. + if self.model.find_graph_output(softmax.output[0]): + return + + nodes = self.model.match_child_path(softmax, + ["MatMul", "Transpose", "Reshape"], + [(0, 0), (0, 0), (0, 0)], + input_name_to_nodes) + if nodes is None: + return + + matmul_s_v, transpose_out, reshape_out = nodes + if not FusionUtils.check_node_attribute(transpose_out, "perm", [0, 2, 1, 3]): + return + + q_nodes = self.model.match_parent_path( + softmax, + ["MatMul", "Mul", "Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape"], + [0, 0, 1, 0, 1, 0, 0, 0], + ) + + if q_nodes is None: + return + + matmul_qk, mul_q, sqrt_q_2, div_q, sqrt_q, _, _, shape_q = q_nodes + + if mul_q.input[0] != shape_q.input[0]: + return + + k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose"], [1, 0]) + if k_nodes is None: + return + + mul_k, transpose_k = k_nodes + k = transpose_k.input[0] + if not FusionUtils.check_node_attribute(transpose_k, "perm", [0, 1, 3, 2]): + return + + k_scale_nodes = self.model.match_parent_path(mul_k, ["Sqrt", "Div"], [1, 0]) + if k_scale_nodes is None: + return + if k_scale_nodes[0].input[0] != sqrt_q_2.input[0]: + return + + v = matmul_s_v.input[1] + + # Here we sanity check the v path to make sure it is in the expected BNSH format. + concat = self.model.match_parent(matmul_s_v, "Concat", input_index=1, output_name_to_node=output_name_to_node) + if concat is not None: + # Match v path like: + # -- Transpose (perm=[0,2,1,3]) ----+ + # | + # v + # -- Transpose (perm=[0,2,1,3]) -> Concat -> (v) + transpose_1 = self.model.match_parent(concat, "Transpose", input_index=0, output_name_to_node=output_name_to_node) + if transpose_1 is None: + return + if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]): + return + + transpose_2 = self.model.match_parent(concat, "Transpose", input_index=1, output_name_to_node=output_name_to_node) + if transpose_2 is None: + return + if not FusionUtils.check_node_attribute(transpose_2, "perm", [0, 2, 1, 3]): + return + else: + # Match v path like: + # -- Transpose (perm=[0,2,1,3]) -> (v) + transpose_1 = self.model.match_parent(matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node) + if transpose_1 is None: + return + if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]): + return + + if concat is not None: + num_heads = self.get_num_heads_with_concat(transpose_k, output_name_to_node) + else: + num_heads = self.get_num_heads(transpose_k, output_name_to_node) + if num_heads <= 0: + return + + # Q is in BNSH format, we need to adjust it to BSD format due to limitation of MHA op. + if concat is not None: + query = self.adjust_query_from_bnsh_to_bsd(mul_q, output_name_to_node) + else: + query = self.adjust_query_from_bnsh_to_bsd_no_concat(mul_q, output_name_to_node) + + if query is None: + return + + new_node = self.create_multihead_attention_node( + q=query, + k=k, + v=v, + output=reshape_out.output[0], + num_heads=num_heads, + ) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([matmul_s_v, transpose_out, reshape_out]) + + # Use prune graph to remove nodes + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py index dcf796b309822..8014011f394fb 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -6,7 +6,8 @@ import logging from typing import Optional -from fusion_attention_unet import FusionAttentionUnet +from fusion_layernorm import FusionLayerNormalization +from fusion_mha_mmdit import FusionMultiHeadAttentionMMDit from fusion_bias_add import FusionBiasAdd from fusion_options import FusionOptions from import_utils import is_installed @@ -52,34 +53,22 @@ def remove_useless_div(self): self.remove_nodes(nodes_to_remove) logger.info("Removed %d Div nodes", len(nodes_to_remove)) - def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): - # Self Attention - self_attention_fusion = FusionAttentionUnet( - self, - self.hidden_size, - self.num_heads, - is_cross_attention=False, - enable_packed_qkv=False, - enable_packed_kv=False, - ) - self_attention_fusion.apply() - - # Cross Attention - cross_attention_fusion = FusionAttentionUnet( - self, - self.hidden_size, - self.num_heads, - is_cross_attention=True, - enable_packed_qkv=False, - enable_packed_kv=False, - ) - cross_attention_fusion.apply() + def fuse_layer_norm(self): + # TODO: set check_constant_and_dimension=True when LayerNormalization supports broadcasting scale and bias. + fusion = FusionLayerNormalization(self, check_constant_and_dimension=False) + fusion.apply() + + def fuse_multi_head_attention(self): + fusion = FusionMultiHeadAttentionMMDit(self) + fusion.apply() def fuse_bias_add(self): fusion = FusionBiasAdd(self) fusion.apply() - def optimize(self, options: Optional[FusionOptions] = None): + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + assert not add_dynamic_axes + if is_installed("tqdm"): import tqdm from tqdm.contrib.logging import logging_redirect_tqdm @@ -125,7 +114,7 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): progress_bar.update(1) if (options is None) or options.enable_attention: - self.fuse_multi_head_attention(options) + self.fuse_multi_head_attention() if progress_bar: progress_bar.update(1) diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 4c9f84a7d1181..b924a19ebd8a8 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -239,7 +239,7 @@ def optimize_by_fusion( Returns: object of an optimizer class. """ - if model_type not in ["bert", "swin", "unet", "vae", "clip", "sam2"] and (num_heads == 0 or hidden_size == 0): + if model_type not in ["bert", "swin", "unet", "vae", "clip", "sam2", "mmdit"] and (num_heads == 0 or hidden_size == 0): logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") if model_type not in MODEL_TYPES: From 2f5b9b9c357b9233abe7781ab7e0a5c25c9ed491 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 15 Dec 2024 01:15:36 +0000 Subject: [PATCH 08/26] cuda layernorm support broadcast --- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 1 + .../core/providers/cuda/nn/layer_norm.cc | 33 +++++-- .../core/providers/cuda/nn/layer_norm_impl.cu | 17 ++-- .../core/providers/cuda/nn/layer_norm_impl.h | 1 + .../tools/transformers/fusion_layernorm.py | 10 ++- .../tools/transformers/fusion_mha_mmdit.py | 89 +++++++++---------- .../python/tools/transformers/onnx_model.py | 18 ++-- .../tools/transformers/onnx_model_mmdit.py | 77 ++++------------ .../python/tools/transformers/optimizer.py | 4 +- 9 files changed, 121 insertions(+), 129 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 3299bc2cb11de..91e8577df487b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -101,6 +101,7 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (double)epsilon_, // epsilon reinterpret_cast(gamma->Data()), // gamma (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta + 0, // broadcast stride for gamma/beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm.cc b/onnxruntime/core/providers/cuda/nn/layer_norm.cc index 7dd10f9c2960c..c430ffe5aa97d 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/layer_norm.cc @@ -44,19 +44,36 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast(bias->Data()); const TensorShape& x_shape = X->Shape(); - const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); + auto x_num_dims = x_shape.NumDimensions(); + const int64_t axis = HandleNegativeAxis(axis_, x_num_dims); int n1 = gsl::narrow(x_shape.SizeToDimension(axis)); int n2 = gsl::narrow(x_shape.SizeFromDimension(axis)); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias_data) ? bias->Shape().Size() : 0; + + int broadcast = 0; if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Size of X.shape()[axis:] == ", n2, - ". Size of scale and bias (if provided) must match this " - "and the size must not be 1. Got scale size of ", - scale_size, " and bias size of ", bias_size); + // Handle a special case for MMDit where scale and bias need broadcast. + // X shape is (B, S, D), scale and bias shape is (B, 1, D), and we store S as broadcast stride. + if (x_num_dims == 3 && axis == 2 && n2 > 1 && + scale->Shape().NumDimensions() == x_num_dims && + scale->Shape().GetDims()[0] == x_shape.GetDims()[0] && + scale->Shape().GetDims()[1] == 1 && + scale->Shape().GetDims()[2] == x_shape.GetDims()[2] && + bias->Shape().NumDimensions() == x_num_dims && + bias->Shape().GetDims()[0] == x_shape.GetDims()[0] && + bias->Shape().GetDims()[1] == 1 && + bias->Shape().GetDims()[2] == x_shape.GetDims()[2]) { + broadcast = static_cast(x_shape.GetDims()[1]); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", n2, + ". Size of scale and bias (if provided) must match this " + "and the size must not be 1. Got scale size of ", + scale_size, " and bias size of ", bias_size); + } } // Outputs @@ -65,7 +82,7 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con // Mean and variance std::vector mean_inv_std_var_dim; - for (int i = 0; i < static_cast(x_shape.NumDimensions()); ++i) { + for (int i = 0; i < static_cast(x_num_dims); ++i) { if (i < axis) { mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]); } else { @@ -94,7 +111,7 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con } HostApplyLayerNorm(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, - X_data, n1, n2, epsilon_, scale_data, bias_data); + X_data, n1, n2, epsilon_, scale_data, bias_data, broadcast); CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index b9e8b45307079..c21943649775b 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -334,6 +334,7 @@ __global__ void cuApplyLayerNorm( const U epsilon, const V* __restrict__ gamma, const V* __restrict__ beta, + int broadcast, const T* __restrict__ skip, const T* __restrict__ bias, T* __restrict__ skip_input_bias_add_output) { @@ -366,8 +367,13 @@ __global__ void cuApplyLayerNorm( curr += static_cast(skip_vals[i]); } - U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1; - U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0; + // onnx operator LayerNormalization support broadcast. + // gamma and beta should be unidirectional broadcastable to tensor x. + // Here we support a special case for transformer models that x is (B, S, D) and gamma/beta is (B, 1, D) + int index = (broadcast > 0) ? ((i1 / broadcast) * n2 + i) : i; + U gamma_i = (gamma != nullptr) ? (U)gamma[index] : (U)1; + U beta_i = (beta != nullptr) ? (U)beta[index] : (U)0; + if (simplified) { ovals[i] = static_cast(gamma_i * c_inv_std_dev * curr); } else { @@ -409,6 +415,7 @@ void HostApplyLayerNorm( double epsilon, const V* gamma, const V* beta, + int broadcast, const T* skip, const T* bias, T* skip_input_bias_add_output) { @@ -442,15 +449,15 @@ void HostApplyLayerNorm( input, n1, n2, U(epsilon), - gamma, beta, + gamma, beta, broadcast, skip, bias, skip_input_bias_add_output); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ - double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output); + double epsilon, const V* gamma, const V* beta, int broadcast, \ + const T* skip, const T* bias, T* skip_input_bias_add_output); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index e3952eefae35d..3ba895e8829b6 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -41,6 +41,7 @@ void HostApplyLayerNorm( double epsilon, const V* gamma, const V* beta, + int broadcast = 0, // broadcast stride for gamma/beta const T* skip = nullptr, const T* bias = nullptr, T* skip_input_bias_add_output = nullptr); diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index 8c99fe3e1d4c6..d1e30351564a9 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -13,7 +13,7 @@ class FusionLayerNormalization(Fusion): - def __init__(self, model: OnnxModel, check_constant_and_dimension:bool=True): + def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True): super().__init__(model, "LayerNormalization", "ReduceMean") self.check_constant_and_dimension = check_constant_and_dimension @@ -132,11 +132,15 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] - if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"): + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( + weight_input, 1, "layernorm weight" + ): return bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] - if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"): + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( + bias_input, 1, "layernorm bias" + ): return self.nodes_to_remove.extend(subgraph_nodes) diff --git a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py index c27c333930c60..65c98349f5cbb 100644 --- a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py +++ b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Optional, Tuple +from typing import Optional import numpy as np from fusion_base import Fusion @@ -13,13 +13,14 @@ logger = getLogger(__name__) + class FusionMultiHeadAttentionMMDit(Fusion): """ Fuse MultiHeadAttention for Multimodal Diffusion Transformer (MMDiT). """ def __init__(self, model: OnnxModel): - super().__init__(model, fused_op_type = "MultiHeadAttention", search_op_types = ["Softmax"]) + super().__init__(model, fused_op_type="MultiHeadAttention", search_op_types=["Softmax"]) def get_num_heads(self, node: NodeProto, output_name_to_node) -> int: """ @@ -46,7 +47,8 @@ def get_num_heads(self, node: NodeProto, output_name_to_node) -> int: node, ["SimplifiedLayerNormalization", "Transpose", "Reshape", "Add"], [0, 0, 0, 0], - output_name_to_node=output_name_to_node) + output_name_to_node=output_name_to_node, + ) num_heads = 0 if k_proj_nodes: @@ -101,11 +103,7 @@ def get_num_heads_with_concat(self, transpose_k: NodeProto, output_name_to_node) | Transpose(perm=0,1,3,2) """ - nodes = self.model.match_parent_path( - transpose_k, - ["Concat"], - [0], - output_name_to_node=output_name_to_node) + nodes = self.model.match_parent_path(transpose_k, ["Concat"], [0], output_name_to_node=output_name_to_node) return self.get_num_heads(nodes[0], output_name_to_node) if nodes else 0 @@ -155,37 +153,35 @@ def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_ return self.reshape_to_3d(sln_a.output[0], sln_output + "_BSD") - - def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: """ - Before: - MatMul MatMul .. [-1] [24] .. - | | | | / / - Add Concat Add Concat - | / | / - Reshape Reshape - | | - Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3) - | | - SimplifiedLayerNorm SimplifiedLayerNorm - | / - Concat(axis=2) - | - Mul - - After: - MatMul MatMul .. [-1] [24] .. - | | | | / / - Add Concat Add Concat - | / | / - Reshape Reshape - | | + Before: + MatMul MatMul .. [-1] [24] .. + | | | | / / + Add Concat Add Concat + | / | / + Reshape Reshape + | | + Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3) + | | SimplifiedLayerNorm SimplifiedLayerNorm - | / - Concat(axis=1) - | - Reshape (shape=[0, 0, -1]) + | / + Concat(axis=2) + | + Mul + + After: + MatMul MatMul .. [-1] [24] .. + | | | | / / + Add Concat Add Concat + | / | / + Reshape Reshape + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=1) + | + Reshape (shape=[0, 0, -1]) """ path = self.model.match_parent_path( @@ -204,7 +200,6 @@ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) - concat, ["SimplifiedLayerNormalization", "Transpose"], [1, 0], - ) if path is None: return None @@ -232,7 +227,6 @@ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) - return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD") - def create_multihead_attention_node( self, q: str, @@ -284,10 +278,9 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if self.model.find_graph_output(softmax.output[0]): return - nodes = self.model.match_child_path(softmax, - ["MatMul", "Transpose", "Reshape"], - [(0, 0), (0, 0), (0, 0)], - input_name_to_nodes) + nodes = self.model.match_child_path( + softmax, ["MatMul", "Transpose", "Reshape"], [(0, 0), (0, 0), (0, 0)], input_name_to_nodes + ) if nodes is None: return @@ -334,13 +327,17 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # | # v # -- Transpose (perm=[0,2,1,3]) -> Concat -> (v) - transpose_1 = self.model.match_parent(concat, "Transpose", input_index=0, output_name_to_node=output_name_to_node) + transpose_1 = self.model.match_parent( + concat, "Transpose", input_index=0, output_name_to_node=output_name_to_node + ) if transpose_1 is None: return if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]): return - transpose_2 = self.model.match_parent(concat, "Transpose", input_index=1, output_name_to_node=output_name_to_node) + transpose_2 = self.model.match_parent( + concat, "Transpose", input_index=1, output_name_to_node=output_name_to_node + ) if transpose_2 is None: return if not FusionUtils.check_node_attribute(transpose_2, "perm", [0, 2, 1, 3]): @@ -348,7 +345,9 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): else: # Match v path like: # -- Transpose (perm=[0,2,1,3]) -> (v) - transpose_1 = self.model.match_parent(matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node) + transpose_1 = self.model.match_parent( + matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node + ) if transpose_1 is None: return if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]): diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index fa898a750178f..2a6f9c3d758db 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -242,13 +242,11 @@ def get_children(self, node, input_name_to_nodes=None, output_index=None): if output_index < len(node.output): output = node.output[output_index] if output in input_name_to_nodes: - for node in input_name_to_nodes[output]: - children.append(node) + children = list(input_name_to_nodes[output]) else: for output in node.output: if output in input_name_to_nodes: - for node in input_name_to_nodes[output]: - children.append(node) # noqa: PERF402 + children.extend(input_name_to_nodes[output]) return children @@ -444,7 +442,7 @@ def match_child_path( self, node, child_op_types, - edges:Optional[List[Tuple[int, int]]]=None, + edges: Optional[List[Tuple[int, int]]] = None, input_name_to_nodes=None, exclude=[], # noqa: B006 ): @@ -465,10 +463,12 @@ def match_child_path( if edges is not None: assert len(edges) == len(child_op_types) for edge in edges: - assert isinstance(edge, tuple) and len(edge) == 2 and isinstance(edge[0], int) and isinstance(edge[1], int) + assert ( + isinstance(edge, tuple) and len(edge) == 2 and isinstance(edge[0], int) and isinstance(edge[1], int) + ) if input_name_to_nodes is None: - input_name_to_nodes = self.input_name_to_nodes() + input_name_to_nodes = self.input_name_to_nodes() current_node = node matched_children = [] @@ -478,7 +478,9 @@ def match_child_path( if edges is None: children_nodes = self.get_children(current_node, input_name_to_nodes=input_name_to_nodes) else: - children_nodes = self.get_children(current_node, input_name_to_nodes=input_name_to_nodes, output_index=edges[i][0]) + children_nodes = self.get_children( + current_node, input_name_to_nodes=input_name_to_nodes, output_index=edges[i][0] + ) for child in children_nodes: if child.op_type == op_type and child not in exclude: diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py index 8014011f394fb..7593450f7dd74 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -8,7 +8,6 @@ from fusion_layernorm import FusionLayerNormalization from fusion_mha_mmdit import FusionMultiHeadAttentionMMDit -from fusion_bias_add import FusionBiasAdd from fusion_options import FusionOptions from import_utils import is_installed from onnx import ModelProto @@ -30,42 +29,23 @@ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) - def preprocess(self): - self.remove_useless_div() - def postprocess(self): self.prune_graph() self.remove_unused_constant() - def remove_useless_div(self): - """Remove Div by 1""" - div_nodes = [node for node in self.nodes() if node.op_type == "Div"] - - nodes_to_remove = [] - for div in div_nodes: - if self.find_constant_input(div, 1.0) == 1: - nodes_to_remove.append(div) - - for node in nodes_to_remove: - self.replace_input_of_all_nodes(node.output[0], node.input[0]) - - if nodes_to_remove: - self.remove_nodes(nodes_to_remove) - logger.info("Removed %d Div nodes", len(nodes_to_remove)) - def fuse_layer_norm(self): - # TODO: set check_constant_and_dimension=True when LayerNormalization supports broadcasting scale and bias. - fusion = FusionLayerNormalization(self, check_constant_and_dimension=False) + layernorm_support_broadcast = True + logger.warning( + "The optimized model requires LayerNormalization with broadcast support. " + "Please use onnxruntime-gpu>=1.21 for inference." + ) + fusion = FusionLayerNormalization(self, check_constant_and_dimension=not layernorm_support_broadcast) fusion.apply() def fuse_multi_head_attention(self): fusion = FusionMultiHeadAttentionMMDit(self) fusion.apply() - def fuse_bias_add(self): - fusion = FusionBiasAdd(self) - fusion.apply() - def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): assert not add_dynamic_axes @@ -74,7 +54,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo from tqdm.contrib.logging import logging_redirect_tqdm with logging_redirect_tqdm(): - steps = 12 + steps = 5 progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion") self._optimize(options, progress_bar) else: @@ -85,10 +65,6 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() - self.utils.remove_identity_nodes() - if progress_bar: - progress_bar.update(1) - # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() if progress_bar: @@ -105,38 +81,19 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if progress_bar: progress_bar.update(1) - self.preprocess() - if progress_bar: - progress_bar.update(1) - - self.fuse_reshape() - if progress_bar: - progress_bar.update(1) - if (options is None) or options.enable_attention: self.fuse_multi_head_attention() if progress_bar: progress_bar.update(1) - if (options is None) or options.enable_skip_layer_norm: - self.fuse_skip_layer_norm() - if progress_bar: - progress_bar.update(1) - - self.fuse_shape() - if progress_bar: - progress_bar.update(1) - - # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. - self.utils.remove_useless_reshape_nodes() - if progress_bar: - progress_bar.update(1) - - if (options is None) or options.enable_bias_skip_layer_norm: - # Fuse SkipLayerNormalization and Add Bias before it. - self.fuse_add_bias_skip_layer_norm() - if progress_bar: - progress_bar.update(1) + # TODO: SkipLayerNormalization does not support broadcast yet. + # if (options is None) or options.enable_skip_layer_norm: + # self.fuse_skip_layer_norm() + # if (options is None) or options.enable_bias_skip_layer_norm: + # # Fuse SkipLayerNormalization and Add Bias before it. + # self.fuse_add_bias_skip_layer_norm() + # if progress_bar: + # progress_bar.update(1) self.postprocess() if progress_bar: @@ -150,9 +107,11 @@ def get_fused_operator_statistics(self): """ op_count = {} ops = [ + "FastGelu", "MultiHeadAttention", "LayerNormalization", - "SkipLayerNormalization", + # "SkipLayerNormalization", + "SimplifiedLayerNormalization", ] for op in ops: diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index b924a19ebd8a8..33737a7d34998 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -239,7 +239,9 @@ def optimize_by_fusion( Returns: object of an optimizer class. """ - if model_type not in ["bert", "swin", "unet", "vae", "clip", "sam2", "mmdit"] and (num_heads == 0 or hidden_size == 0): + if model_type not in ["bert", "swin", "unet", "vae", "clip", "sam2", "mmdit"] and ( + num_heads == 0 or hidden_size == 0 + ): logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") if model_type not in MODEL_TYPES: From 699a64cf6c56e0e4f45697567a62c90a0adee1b1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 15 Dec 2024 01:16:08 +0000 Subject: [PATCH 09/26] force fuse layernorm --- .../tools/transformers/fusion_layernorm.py | 126 ++++++++++-------- .../tools/transformers/onnx_model_mmdit.py | 7 +- 2 files changed, 74 insertions(+), 59 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index d1e30351564a9..277bd0799cf16 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -13,9 +13,10 @@ class FusionLayerNormalization(Fusion): - def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True): + def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True, force: bool = False): super().__init__(model, "LayerNormalization", "ReduceMean") self.check_constant_and_dimension = check_constant_and_dimension + self.force = force def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ @@ -97,63 +98,74 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if div_node.output[0] not in input_name_to_nodes: return - temp_node = input_name_to_nodes[div_node.output[0]][0] - if temp_node.op_type == "Cast": - # Div --> Cast --> Mul - subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes - if temp_node.output[0] not in input_name_to_nodes: - return - mul_node = input_name_to_nodes[temp_node.output[0]][0] - else: - # Div --> Mul - mul_node = temp_node - if mul_node.op_type != "Mul": - return - - if mul_node.output[0] not in input_name_to_nodes: - return - last_add_node = input_name_to_nodes[mul_node.output[0]][0] - if last_add_node.op_type != "Add": - return - - subgraph_nodes.append(node) - subgraph_nodes.extend(children) - subgraph_nodes.extend(parent_nodes[:-1]) - - subgraph_nodes.extend([last_add_node, mul_node, div_node]) - if not self.model.is_safe_to_fuse_nodes( - subgraph_nodes, - last_add_node.output, - input_name_to_nodes, - output_name_to_node, - ): - logger.debug("It is not safe to fuse LayerNormalization node. Skip") - return - node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node - weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] - if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( - weight_input, 1, "layernorm weight" - ): - return - - bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] - if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( - bias_input, 1, "layernorm bias" - ): - return - - self.nodes_to_remove.extend(subgraph_nodes) - - normalize_node = helper.make_node( - "LayerNormalization", - inputs=[node.input[0], weight_input, bias_input], - outputs=[last_add_node.output[0]], - name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), - ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) - self.nodes_to_add.append(normalize_node) - self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + # In MMDit model, Div might have two Mul+Add children paths. + div_children = input_name_to_nodes[div_node.output[0]] + for temp_node in div_children: + if temp_node.op_type == "Cast": + # Div --> Cast --> Mul + subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes + if temp_node.output[0] not in input_name_to_nodes: + continue + mul_node = input_name_to_nodes[temp_node.output[0]][0] + else: + # Div --> Mul + mul_node = temp_node + if mul_node.op_type != "Mul": + continue + + if mul_node.output[0] not in input_name_to_nodes: + continue + last_add_node = input_name_to_nodes[mul_node.output[0]][0] + if last_add_node.op_type != "Add": + continue + + subgraph_nodes.append(node) + subgraph_nodes.extend(children) + subgraph_nodes.extend(parent_nodes[:-1]) + + subgraph_nodes.extend([last_add_node, mul_node, div_node]) + + node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node + weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( + weight_input, 1, "layernorm weight" + ): + continue + + bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( + bias_input, 1, "layernorm bias" + ): + continue + + layer_norm_output = last_add_node.output[0] + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + last_add_node.output, + input_name_to_nodes, + output_name_to_node, + ): + # If it is not safe to fuse, somce computation may be duplicated if we force to fuse it. + # It it unknown that force fusion might bring performance gain/loss. + # User need test performance impact to see whether forcing fusion can help. + if self.force: + self.prune_graph = True + else: + logger.debug("It is not safe to fuse LayerNormalization node. Skip") + continue + else: + self.nodes_to_remove.extend(subgraph_nodes) + + normalize_node = helper.make_node( + "LayerNormalization", + inputs=[node.input[0], weight_input, bias_input], + outputs=[layer_norm_output], + name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name class FusionLayerNormalizationNCHW(Fusion): diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py index 7593450f7dd74..9e30130d033e7 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -39,7 +39,9 @@ def fuse_layer_norm(self): "The optimized model requires LayerNormalization with broadcast support. " "Please use onnxruntime-gpu>=1.21 for inference." ) - fusion = FusionLayerNormalization(self, check_constant_and_dimension=not layernorm_support_broadcast) + fusion = FusionLayerNormalization( + self, check_constant_and_dimension=not layernorm_support_broadcast, force=True + ) fusion.apply() def fuse_multi_head_attention(self): @@ -88,7 +90,8 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): # TODO: SkipLayerNormalization does not support broadcast yet. # if (options is None) or options.enable_skip_layer_norm: - # self.fuse_skip_layer_norm() + # self.fuse_skip_simplified_layer_norm() + # self.fuse_skip_layer_norm() # if (options is None) or options.enable_bias_skip_layer_norm: # # Fuse SkipLayerNormalization and Add Bias before it. # self.fuse_add_bias_skip_layer_norm() From c1d0160069b97e50526561b3366b7adc0402ed38 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 15 Dec 2024 04:15:00 +0000 Subject: [PATCH 10/26] refactoring --- .../models/stable_diffusion/optimize_pipeline.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 313c3b304a258..7d4d00010a28e 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -47,9 +47,17 @@ def has_external_data(onnx_model_path): return False +def is_sd_3(source_dir: Path): + return (source_dir / "text_encoder_3").exists() + + +def is_sdxl(source_dir: Path): + return (source_dir / "text_encoder_2").exists() and not (source_dir / "text_encoder_3").exists() + + def _get_model_list(source_dir: Path): - is_xl = (source_dir / "text_encoder_2").exists() - is_sd3 = (source_dir / "text_encoder_3").exists() + is_xl = is_sdxl(source_dir) + is_sd3 = is_sd_3(source_dir) model_list_sd3 = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"] model_list_sdxl = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"] model_list_sd = ["text_encoder", "unet", "vae_encoder", "vae_decoder"] @@ -163,8 +171,7 @@ def _optimize_sd_pipeline( if float16: # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32. - is_xl = (source_dir / "text_encoder_2").exists() - if is_xl and name == "vae_decoder": + if is_sdxl(source_dir) and name == "vae_decoder": logger.info("Skip converting %s to float16 to avoid NaN", name) else: logger.info("Convert %s to float16 ...", name) From 1b9ea543cbfda7ddc0933589ef3fed36ad538aa8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 19 Dec 2024 19:17:56 +0000 Subject: [PATCH 11/26] mha fusion for flux --- .../tools/transformers/fusion_mha_mmdit.py | 156 ++++++++++-------- .../models/stable_diffusion/README.md | 4 +- .../models/stable_diffusion/benchmark.py | 21 ++- .../stable_diffusion/optimize_pipeline.py | 58 +++++-- 4 files changed, 152 insertions(+), 87 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py index 65c98349f5cbb..ef81ce6a3a5a5 100644 --- a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py +++ b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py @@ -22,69 +22,41 @@ class FusionMultiHeadAttentionMMDit(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, fused_op_type="MultiHeadAttention", search_op_types=["Softmax"]) - def get_num_heads(self, node: NodeProto, output_name_to_node) -> int: - """ - Detect num_heads and hidden_size from Concat node in the following subgraph: - MatMul .. [-1] [24] .. - | | | / / - Add<1536> Concat - | / - Reshape - | - Transpose(perm=0,2,1,3) - | - SimplifiedLayerNorm -- scale<64> - | - (node) - The num_heads can be read directly from the third input of Concat node. + def get_num_heads(self, v_node: NodeProto, output_name_to_node, input_index=0) -> int: + """ + Detect num_heads and hidden_size from Concat node in the value subgraph for Flux: - Here we deduce num_heads=hidden_size/head_size from the following two nodes: - The hidden_size can be read from the bias input of Add node - The head_size can be read from the scale input of SimplifiedLayerNormalization node + | | + MatMul MatMul .. [-1] [24] .. + | | | | / / + Add Add Concat(axis=0) + | | / + Reshape Reshape + | | + Transpose(perm=0,1,3,2) Transpose(perm=0,1,3,2) + | | + Concat (axis=2) """ - k_proj_nodes = self.model.match_parent_path( - node, - ["SimplifiedLayerNormalization", "Transpose", "Reshape", "Add"], - [0, 0, 0, 0], - output_name_to_node=output_name_to_node, - ) + nodes = self.model.match_parent_path(v_node, ["Transpose", "Reshape", "Concat"], [input_index, 0, 1], + output_name_to_node=output_name_to_node) + if nodes is None: + return 0 - num_heads = 0 - if k_proj_nodes: - simplified_layernorm, _transpose, _reshape, add = k_proj_nodes - _i, bias = self.model.get_constant_input(add) + concat_shape = nodes[-1] + if len(concat_shape.input) != 4: + return 0 - hidden_size = 0 - if isinstance(bias, np.ndarray) and len(bias.shape) == 1: - hidden_size = bias.shape[0] + value = self.model.get_constant_value(concat_shape.input[2]) + if value is None: + return 0 - weight = self.model.get_constant_value(simplified_layernorm.input[1]) - if isinstance(weight, np.ndarray) and len(weight.shape) == 1: - head_size = weight.shape[0] - if (hidden_size % head_size) == 0: - num_heads = hidden_size // head_size + if len(value.shape) != 1: + return 0 - return num_heads + return int(value[0]) - def reshape_to_3d(self, input_name: str, output_name: str) -> str: - # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator. - new_dims_name = "bsnh_to_bsd_reshape_dims" - new_dims = self.model.get_initializer(new_dims_name) - if new_dims is None: - new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name) - self.model.add_initializer(new_dims, self.this_graph_name) - reshape_q = helper.make_node( - "Reshape", - inputs=[input_name, new_dims_name], - outputs=[output_name], - name=self.model.create_node_name("Reshape"), - ) - self.nodes_to_add.append(reshape_q) - self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name - return reshape_q.output[0] - - def get_num_heads_with_concat(self, transpose_k: NodeProto, output_name_to_node) -> int: + def get_num_heads_from_k(self, transpose_k: NodeProto, output_name_to_node, has_concat:bool) -> int: """ Detect num_heads and hidden_size from Concat node in the following subgraph: @@ -103,9 +75,33 @@ def get_num_heads_with_concat(self, transpose_k: NodeProto, output_name_to_node) | Transpose(perm=0,1,3,2) """ - nodes = self.model.match_parent_path(transpose_k, ["Concat"], [0], output_name_to_node=output_name_to_node) + if has_concat: + nodes = self.model.match_parent_path(transpose_k, ["Concat", "SimplifiedLayerNormalization"], [0, 1], output_name_to_node=output_name_to_node) + if nodes: + return self.get_num_heads(nodes[1], output_name_to_node) + + nodes = self.model.match_parent_path(transpose_k, ["SimplifiedLayerNormalization"], [0], output_name_to_node=output_name_to_node) + if nodes: + return self.get_num_heads(nodes[0], output_name_to_node) - return self.get_num_heads(nodes[0], output_name_to_node) if nodes else 0 + return 0 + + def reshape_to_3d(self, input_name: str, output_name: str) -> str: + # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator. + new_dims_name = "bsnh_to_bsd_reshape_dims" + new_dims = self.model.get_initializer(new_dims_name) + if new_dims is None: + new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name) + self.model.add_initializer(new_dims, self.this_graph_name) + reshape_q = helper.make_node( + "Reshape", + inputs=[input_name, new_dims_name], + outputs=[output_name], + name=self.model.create_node_name("Reshape"), + ) + self.nodes_to_add.append(reshape_q) + self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name + return reshape_q.output[0] def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: """ @@ -211,6 +207,9 @@ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) - if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]): return None + if not FusionUtils.check_node_attribute(concat, "axis", 2): + return None + # Update the graph sln_a.input[0] = transpose_a.input[0] sln_b.input[0] = transpose_b.input[0] @@ -227,6 +226,19 @@ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) - return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD") + def transpose_reshape_bnsh_to_bsd(self, q: str, output_name_to_node) -> Optional[str]: + transpose_q = helper.make_node( + "Transpose", + [q], + [q + "_BSNH"], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_BNSH_to_BSNH"), + perm=[0, 2, 1, 3], + ) + self.nodes_to_add.append(transpose_q) + self.node_name_to_graph_name[transpose_q.name] = self.this_graph_name + + return self.reshape_to_3d(q + "_BSNH", q + "_BSD") + def create_multihead_attention_node( self, q: str, @@ -299,7 +311,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): matmul_qk, mul_q, sqrt_q_2, div_q, sqrt_q, _, _, shape_q = q_nodes - if mul_q.input[0] != shape_q.input[0]: + q_bnsh = mul_q.input[0] + if q_bnsh != shape_q.input[0]: return k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose"], [1, 0]) @@ -320,15 +333,15 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): v = matmul_s_v.input[1] # Here we sanity check the v path to make sure it is in the expected BNSH format. - concat = self.model.match_parent(matmul_s_v, "Concat", input_index=1, output_name_to_node=output_name_to_node) - if concat is not None: + concat_v = self.model.match_parent(matmul_s_v, "Concat", input_index=1, output_name_to_node=output_name_to_node) + if concat_v is not None: # Match v path like: # -- Transpose (perm=[0,2,1,3]) ----+ # | # v # -- Transpose (perm=[0,2,1,3]) -> Concat -> (v) transpose_1 = self.model.match_parent( - concat, "Transpose", input_index=0, output_name_to_node=output_name_to_node + concat_v, "Transpose", input_index=0, output_name_to_node=output_name_to_node ) if transpose_1 is None: return @@ -336,7 +349,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return transpose_2 = self.model.match_parent( - concat, "Transpose", input_index=1, output_name_to_node=output_name_to_node + concat_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node ) if transpose_2 is None: return @@ -353,21 +366,24 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]): return - if concat is not None: - num_heads = self.get_num_heads_with_concat(transpose_k, output_name_to_node) - else: - num_heads = self.get_num_heads(transpose_k, output_name_to_node) - if num_heads <= 0: - return + # Match patterns for Flux. + num_heads = self.get_num_heads(concat_v, output_name_to_node) if concat_v else \ + self.get_num_heads(matmul_s_v, output_name_to_node, input_index=1) + + if num_heads == 0: + # Match patterns for Stable Diffusion 3.5. + num_heads = self.get_num_heads_from_k(transpose_k, output_name_to_node, concat_v is not None) + if num_heads <= 0: + return # Q is in BNSH format, we need to adjust it to BSD format due to limitation of MHA op. - if concat is not None: + if concat_v is not None: query = self.adjust_query_from_bnsh_to_bsd(mul_q, output_name_to_node) else: query = self.adjust_query_from_bnsh_to_bsd_no_concat(mul_q, output_name_to_node) if query is None: - return + query = self.transpose_reshape_bnsh_to_bsd(q_bnsh, output_name_to_node) new_node = self.create_multihead_attention_node( q=query, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 1230cbffa7815..7cd5f273c556e 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -284,10 +284,12 @@ For ROCm EP, the `--tuning` is mandatory because we heavily rely on tuning to fi The default parameters are stable diffusion version=1.5, height=512, width=512, steps=50, batch_count=5. Run `python benchmark.py --help` for more information. #### Stable Diffusion 3.x and Flux 1.0 -Example of benchmark with optimum using CUDA provider on stable diffusion 3.5: +Example of benchmark with optimum using CUDA provider on stable diffusion 3.5 medium and Flux 1.0: ``` python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp32 python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v Flux.1D -p flux1_dev_onnx/fp16 ``` ### Run Benchmark with xFormers diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 617d1ee461851..063a8cd4637ef 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -361,18 +361,29 @@ def run_optimum_ort_pipeline( use_num_images_per_prompt=False, ): print("Pipeline type", type(pipe)) + from optimum.onnxruntime.modeling_diffusion import ORTFluxPipeline + is_flux = isinstance(pipe, ORTFluxPipeline) + prompts, negative_prompt = example_prompts() + def get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux): + negative_prompt_kwargs = {"negative_prompt": negative} if use_num_images_per_prompt else {"negative_prompt": [negative] * batch_size} + # Flux does not support negative prompt + if is_flux: + negative_prompt_kwargs = {} + return negative_prompt_kwargs + def warmup(): prompt, negative = warmup_prompts() + extra_kwargs = get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux) if use_num_images_per_prompt: pipe( prompt=prompt, height=height, width=width, num_inference_steps=steps, - negative_prompt=negative, num_images_per_prompt=batch_count, + **extra_kwargs ) else: pipe( @@ -380,7 +391,7 @@ def warmup(): height=height, width=width, num_inference_steps=steps, - negative_prompt=[negative] * batch_size, + **extra_kwargs ) # Run warm up, and measure GPU memory of two runs. @@ -390,6 +401,8 @@ def warmup(): warmup() + extra_kwargs = get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux) + latency_list = [] for i, prompt in enumerate(prompts): if i >= num_prompts: @@ -401,8 +414,8 @@ def warmup(): height=height, width=width, num_inference_steps=steps, - negative_prompt=negative_prompt, num_images_per_prompt=batch_size, + **extra_kwargs ).images else: images = pipe( @@ -410,7 +423,7 @@ def warmup(): height=height, width=width, num_inference_steps=steps, - negative_prompt=[negative_prompt] * batch_size, + **extra_kwargs ).images inference_end = time.time() latency = inference_end - inference_start diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 7d4d00010a28e..feb3b27611e6c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -52,22 +52,54 @@ def is_sd_3(source_dir: Path): def is_sdxl(source_dir: Path): - return (source_dir / "text_encoder_2").exists() and not (source_dir / "text_encoder_3").exists() + return ( + (source_dir / "text_encoder_2").exists() + and not (source_dir / "text_encoder_3").exists() + and not (source_dir / "transformer").exists() + ) + + +def is_flux(source_dir: Path): + return ( + (source_dir / "text_encoder_2").exists() + and not (source_dir / "text_encoder_3").exists() + and (source_dir / "transformer").exists() + ) + + +def _classify_pipeline_type(source_dir: Path): + # May also check _class_name in model_index.json like `StableDiffusion3Pipeline` or `FluxPipeline` etc to classify. + if is_sd_3(source_dir): + return "sd3" + + if is_flux(source_dir): + return "flux" + + if is_sdxl(source_dir): + return "sdxl" + + # sd 1.x and 2.x + return "sd" + + +def _get_model_list(pipeline_type: str): + if pipeline_type == "sd3": + return ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"] + + if pipeline_type == "flux": + return ["text_encoder", "text_encoder_2", "transformer", "vae_encoder", "vae_decoder"] + if pipeline_type == "sdxl": + return ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"] -def _get_model_list(source_dir: Path): - is_xl = is_sdxl(source_dir) - is_sd3 = is_sd_3(source_dir) - model_list_sd3 = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"] - model_list_sdxl = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"] - model_list_sd = ["text_encoder", "unet", "vae_encoder", "vae_decoder"] - model_list = model_list_sd3 if is_sd3 else (model_list_sdxl if is_xl else model_list_sd) - return model_list + assert pipeline_type == "sd" + return ["text_encoder", "unet", "vae_encoder", "vae_decoder"] def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, + pipeline_type: str, model_list: List[str], use_external_data_format: Optional[bool], float16: bool, @@ -96,9 +128,9 @@ def _optimize_sd_pipeline( "vae_encoder": "vae", "vae_decoder": "vae", "text_encoder": "clip", - "text_encoder_2": "clip", + "text_encoder_2": "clip" if pipeline_type != "flux" else "bert", "safety_checker": "unet", - "text_encoder_3": "clip", + "text_encoder_3": "clip", # t5? } model_type_class_mapping = { @@ -266,13 +298,15 @@ def optimize_stable_diffusion_pipeline( target_dir = Path(output_dir) target_dir.mkdir(parents=True, exist_ok=True) - model_list = _get_model_list(source_dir) + pipeline_type = _classify_pipeline_type(source_dir) + model_list = _get_model_list(pipeline_type) _copy_extra_directory(source_dir, target_dir, model_list) _optimize_sd_pipeline( source_dir, target_dir, + pipeline_type, model_list, use_external_data_format, float16, From 5528276b53cd72ab95c092a04056df32e5e564b2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Dec 2024 01:22:29 +0000 Subject: [PATCH 12/26] remove transpose for query --- .../tools/transformers/fusion_mha_mmdit.py | 407 +++++++++++++++--- .../python/tools/transformers/fusion_utils.py | 13 + .../models/stable_diffusion/benchmark.py | 37 +- .../tools/transformers/onnx_model_mmdit.py | 3 - 4 files changed, 369 insertions(+), 91 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py index ef81ce6a3a5a5..dcad55c13eb49 100644 --- a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py +++ b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py @@ -3,12 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Optional +from typing import Dict, Optional import numpy as np from fusion_base import Fusion from fusion_utils import FusionUtils -from onnx import NodeProto, helper, numpy_helper +from onnx import NodeProto, TensorProto, helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -21,25 +21,25 @@ class FusionMultiHeadAttentionMMDit(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, fused_op_type="MultiHeadAttention", search_op_types=["Softmax"]) + self.unsqueeze_update_map = {} - - def get_num_heads(self, v_node: NodeProto, output_name_to_node, input_index=0) -> int: + def get_num_heads(self, start_node: NodeProto, output_name_to_node, input_index=0) -> int: """ - Detect num_heads and hidden_size from Concat node in the value subgraph for Flux: - - | | - MatMul MatMul .. [-1] [24] .. - | | | | / / - Add Add Concat(axis=0) - | | / - Reshape Reshape - | | - Transpose(perm=0,1,3,2) Transpose(perm=0,1,3,2) - | | - Concat (axis=2) + Detect num_heads from Reshape & Transpose of q/k/v for both Stable Diffusion 3.x and Flux 1.x: + + MatMul .. [-1] [24] .. + | | | / / + Add Concat(axis=0) + | / + Reshape + | + Transpose(perm=0,1,3,2) + | + (start_node) """ - nodes = self.model.match_parent_path(v_node, ["Transpose", "Reshape", "Concat"], [input_index, 0, 1], - output_name_to_node=output_name_to_node) + nodes = self.model.match_parent_path( + start_node, ["Transpose", "Reshape", "Concat"], [input_index, 0, 1], output_name_to_node=output_name_to_node + ) if nodes is None: return 0 @@ -56,38 +56,66 @@ def get_num_heads(self, v_node: NodeProto, output_name_to_node, input_index=0) - return int(value[0]) - def get_num_heads_from_k(self, transpose_k: NodeProto, output_name_to_node, has_concat:bool) -> int: + def get_num_heads_from_k(self, transpose_k: NodeProto, output_name_to_node, concat_before_transpose: bool) -> int: """ - Detect num_heads and hidden_size from Concat node in the following subgraph: - - / | - MatMul MatMul .. [-1] [24] .. - | | | | / / - Add Add<1536> Concat - | | / - Reshape Reshape - | | - Transpose Transpose(perm=0,2,1,3) - | | - SimplifiedLayerNorm SimplifiedLayerNorm -- scale<64> - | / - Concat - | - Transpose(perm=0,1,3,2) + Detect num_heads from subgraph like the following (num_heads=24 in this example): + MatMu .. [-1] [24] .. + | | | / / + Add Concat + | / + Reshape + | + Transpose(perm=0,2,1,3) + | + SimplifiedLayerNormalization + | + Transpose(perm=0,1,3,2) + + Another variant is to an extra Concat node to join two symmetrical subgraphs: + + | | + MatMul MatMul .. [-1] [24] .. + | | | | / / + Add Concat Add Concat + | / | / + Reshape Reshape + | | + Transpose Transpose(perm=0,2,1,3) + | | + SimplifiedLayerNormalization SimplifiedLayerNormalization + | / + Concat + | + Transpose(perm=0,1,3,2) + + Both patterns are used in stable diffusion 3.5 model. """ - if has_concat: - nodes = self.model.match_parent_path(transpose_k, ["Concat", "SimplifiedLayerNormalization"], [0, 1], output_name_to_node=output_name_to_node) + if concat_before_transpose: + nodes = self.model.match_parent_path( + transpose_k, ["Concat", "SimplifiedLayerNormalization"], [0, 1], output_name_to_node=output_name_to_node + ) if nodes: return self.get_num_heads(nodes[1], output_name_to_node) - - nodes = self.model.match_parent_path(transpose_k, ["SimplifiedLayerNormalization"], [0], output_name_to_node=output_name_to_node) - if nodes: - return self.get_num_heads(nodes[0], output_name_to_node) + else: + nodes = self.model.match_parent_path( + transpose_k, ["SimplifiedLayerNormalization"], [0], output_name_to_node=output_name_to_node + ) + if nodes: + return self.get_num_heads(nodes[0], output_name_to_node) return 0 def reshape_to_3d(self, input_name: str, output_name: str) -> str: - # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator. + """Add a Reshape node to convert 4D BxSxNxH to 3D BxSxD. + + Args: + input_name (str): input name for the 4D tensor of shape BxSxNxH. + output_name (str): output name for the 3D tensor of shape BxSxD, where D = N * H. + + Returns: + str: the output name + """ + new_dims_name = "bsnh_to_bsd_reshape_dims" new_dims = self.model.get_initializer(new_dims_name) if new_dims is None: @@ -105,10 +133,12 @@ def reshape_to_3d(self, input_name: str, output_name: str) -> str: def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: """ + MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format. + Before: - MatMul .. [-1] [24] .. - | | | / / - Add Concat + MatMul + | + Add Concat | / Reshape | @@ -119,9 +149,9 @@ def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_ Mul After: - MatMul .. [-1] [24] .. - | | | / / - Add Concat + MatMul + | + Add Concat | / Reshape | @@ -151,9 +181,11 @@ def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: """ + MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format. + Before: - MatMul MatMul .. [-1] [24] .. - | | | | / / + MatMul MatMul + | | Add Concat Add Concat | / | / Reshape Reshape @@ -167,17 +199,17 @@ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) - Mul After: - MatMul MatMul .. [-1] [24] .. - | | | | / / - Add Concat Add Concat - | / | / - Reshape Reshape - | | - SimplifiedLayerNorm SimplifiedLayerNorm - | / - Concat(axis=1) - | - Reshape (shape=[0, 0, -1]) + MatMul MatMul + | | + Add Concat Add Concat + | / | / + Reshape Reshape + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=1) + | + Reshape (shape=[0, 0, -1]) """ path = self.model.match_parent_path( @@ -226,14 +258,238 @@ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) - return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD") + def update_unsqueeze_axes_1_to_2(self, unsqueeze: NodeProto) -> str: + updated_unsqueeze_output = self.unsqueeze_update_map.get(unsqueeze.name) + if updated_unsqueeze_output is None: + if len(unsqueeze.input) == 1: + new_node = helper.make_node( + "Unsqueeze", + inputs=unsqueeze.input, + outputs=[unsqueeze.output[0] + "_BSNH"], + name=self.model.create_node_name("Unsqueeze"), + axes=[2], + ) + else: + initializer_name = "unsqueeze_axes_2" + if self.model.get_initializer(initializer_name) is None: + unsqueeze_axes_2 = helper.make_tensor( + name=initializer_name, + data_type=TensorProto.INT64, + dims=[1], # Shape of the tensor + vals=[2], # Tensor values + ) + self.model.add_initializer(unsqueeze_axes_2, self.this_graph_name) + + new_node = helper.make_node( + "Unsqueeze", + inputs=[unsqueeze.input[0], initializer_name], + outputs=[unsqueeze.output[0] + "_BSNH"], + name=self.model.create_node_name("Unsqueeze"), + ) + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + updated_unsqueeze_output = new_node.output[0] + self.unsqueeze_update_map[unsqueeze.name] = updated_unsqueeze_output + + return updated_unsqueeze_output + + def update_unsqueeze_axes(self, add: NodeProto, output_name_to_node: Dict[str, NodeProto]) -> bool: + """ + Update axes of Unsqueeze from [1] to [2] in the following pattern: + Unsqueeze Unsqueeze + (axes=[0]) (axes=[0]) + | | + Unsqueeze Unsqueeze + ... (axes=[1]) ... (axes=[1]) + | / | / + Mul Mul + | / + Add + Args: + add (NodeProto): the Add node + output_name_to_node (Dict[str, NodeProto]): mapping from output name to node + + Returns: + bool: True if the pattern is matched and updated successfully, False otherwise. + """ + if len(add.input) != 2: + return False + + # Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively. + nodes_b = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [1, 1, 0], output_name_to_node) + if nodes_b is None: + return False + + fusion_utils = FusionUtils(self.model) + axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[1]) + if axes_1 is None or axes_1 != [1]: + return False + + axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[2]) + if axes_0 is None or axes_0 != [0]: + return False + + # Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively. + nodes_a = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [0, 1, 0], output_name_to_node) + if nodes_a is None: + return False + + axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[1]) + if axes_1 is None or axes_1 != [1]: + return False + + axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[2]) + if axes_0 is None or axes_0 != [0]: + return False + + nodes_a[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_a[1]) + nodes_b[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_b[1]) + return True + + def adjust_flux_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + """ + Adjust graph to change query format from BNSH to BSD for Flux model. + Note that the graph pattern is complex, and we only do a shallow match here. + + Before: + | | + Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3) + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=2) + | + Mul Mul + | / + Add + | + Mul + + After (Transpose nods are removed, and a Reshape is added): + + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=1) + | + Mul Mul + | / + Add + | + Reshape (shape=[0, 0, -1]) + """ + + path = self.model.match_parent_path( + mul_q, + ["Add", "Mul", "Concat", "SimplifiedLayerNormalization", "Transpose"], + [0, 0, 0, 0, 0], + ) + if path is None: + return None + add, _mul_a, concat, sln_a, transpose_a = path + + if len(concat.input) != 2: + return None + + path = self.model.match_parent_path( + concat, + ["SimplifiedLayerNormalization", "Transpose"], + [1, 0], + ) + if path is None: + return None + sln_b, transpose_b = path + + if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]): + return None + + if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]): + return None + + if not FusionUtils.check_node_attribute(concat, "axis", 2): + return None + + # Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH. + if not self.update_unsqueeze_axes(add, output_name_to_node): + return None + + # Update the graph + sln_a.input[0] = transpose_a.input[0] + sln_b.input[0] = transpose_b.input[0] + + new_concat_node = helper.make_node( + "Concat", + inputs=[sln_a.output[0], sln_b.output[0]], + outputs=[concat.output[0] + "_BSNH"], + name=self.model.create_node_name("Concat"), + axis=1, + ) + self.nodes_to_add.append(new_concat_node) + self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name + self.model.replace_input_of_all_nodes(concat.output[0], new_concat_node.output[0]) + + return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD") + + def adjust_flux_single_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + """ + Adjust graph to change query format from BNSH to BSD for Flux model. + Note that the graph pattern is complex, and we only do a shallow match here. + + Before: + | + Transpose(perm=0,2,1,3) + | + SimplifiedLayerNorm + | + Mul Mul + | / + Add + | + Mul + + After (Transpose is removed, and a Reshape is added): + + | + SimplifiedLayerNorm + | + Mul Mul + | / + Add + | + Reshape (shape=[0, 0, -1]) + """ + + path = self.model.match_parent_path( + mul_q, + ["Add", "Mul", "SimplifiedLayerNormalization", "Transpose"], + [0, 0, 0, 0], + ) + if path is None: + return None + add, _mul_a, sln_a, transpose_a = path + + if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]): + return None + + # Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH. + if not self.update_unsqueeze_axes(add, output_name_to_node): + return None + + # Update the graph + sln_a.input[0] = transpose_a.input[0] + add.output[0] = add.output[0] + "_BSNH" + + return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD") + def transpose_reshape_bnsh_to_bsd(self, q: str, output_name_to_node) -> Optional[str]: transpose_q = helper.make_node( - "Transpose", - [q], - [q + "_BSNH"], - name=self.model.create_node_name("Transpose", name_prefix="Transpose_BNSH_to_BSNH"), - perm=[0, 2, 1, 3], - ) + "Transpose", + [q], + [q + "_BSNH"], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_BNSH_to_BSNH"), + perm=[0, 2, 1, 3], + ) self.nodes_to_add.append(transpose_q) self.node_name_to_graph_name[transpose_q.name] = self.this_graph_name @@ -367,8 +623,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return # Match patterns for Flux. - num_heads = self.get_num_heads(concat_v, output_name_to_node) if concat_v else \ - self.get_num_heads(matmul_s_v, output_name_to_node, input_index=1) + num_heads = ( + self.get_num_heads(concat_v, output_name_to_node) + if concat_v + else self.get_num_heads(matmul_s_v, output_name_to_node, input_index=1) + ) if num_heads == 0: # Match patterns for Stable Diffusion 3.5. @@ -377,13 +636,21 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return # Q is in BNSH format, we need to adjust it to BSD format due to limitation of MHA op. + # TODO: MHA op support BNSH format to reduce the effort in fusion. if concat_v is not None: query = self.adjust_query_from_bnsh_to_bsd(mul_q, output_name_to_node) else: query = self.adjust_query_from_bnsh_to_bsd_no_concat(mul_q, output_name_to_node) if query is None: - query = self.transpose_reshape_bnsh_to_bsd(q_bnsh, output_name_to_node) + query = self.adjust_flux_query_from_bnsh_to_bsd(mul_q, output_name_to_node) + if query is None: + query = self.adjust_flux_single_query_from_bnsh_to_bsd(mul_q, output_name_to_node) + if query is None: + # fallback to use Transpose and Add to adjust query from BNSH to BSD + # This is more general approach. + # However, it might be slower if the extra Transpose node cannot be removed by ORT optimizer. + query = self.transpose_reshape_bnsh_to_bsd(q_bnsh, output_name_to_node) new_node = self.create_multihead_attention_node( q=query, diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index dbd9e828198ca..3084b84278994 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -127,6 +127,19 @@ def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_i return parent_can_be_removed + def get_squeeze_or_unsqueeze_axes(self, node: NodeProto) -> Optional[ndarray]: + assert node.op_type in ["Squeeze", "Unsqueeze"] + + # For opset >= 13, axes is an input instead of an attribute. + if len(node.input) > 1: + return self.model.get_constant_value(node.input[1]) + + axes = None + for attr in node.attribute: + if attr.name == "axes": + axes = helper.get_attribute_value(attr) + return axes + @staticmethod def check_node_attribute(node, attribute_name: str, expected_value, default_value=None): """Verify that a node has expected value for an attribute. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 063a8cd4637ef..f6e49f484db38 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -362,15 +362,23 @@ def run_optimum_ort_pipeline( ): print("Pipeline type", type(pipe)) from optimum.onnxruntime.modeling_diffusion import ORTFluxPipeline + is_flux = isinstance(pipe, ORTFluxPipeline) prompts, negative_prompt = example_prompts() - def get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux): - negative_prompt_kwargs = {"negative_prompt": negative} if use_num_images_per_prompt else {"negative_prompt": [negative] * batch_size} + def get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux) -> dict: # Flux does not support negative prompt - if is_flux: - negative_prompt_kwargs = {} + negative_prompt_kwargs = ( + ( + {"negative_prompt": negative} + if use_num_images_per_prompt + else {"negative_prompt": [negative] * batch_size} + ) + if not is_flux + else {} + ) + return negative_prompt_kwargs def warmup(): @@ -383,16 +391,10 @@ def warmup(): width=width, num_inference_steps=steps, num_images_per_prompt=batch_count, - **extra_kwargs + **extra_kwargs, ) else: - pipe( - prompt=[prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, - **extra_kwargs - ) + pipe(prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs) # Run warm up, and measure GPU memory of two runs. # The first run has algo search for cuDNN/MIOpen, so it might need more memory. @@ -402,6 +404,9 @@ def warmup(): warmup() extra_kwargs = get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux) + # Fix the random seed so that we can inspect the output quality easily. + if torch.cuda.is_available(): + extra_kwargs["generator"] = torch.Generator(device="cuda").manual_seed(123) latency_list = [] for i, prompt in enumerate(prompts): @@ -415,15 +420,11 @@ def warmup(): width=width, num_inference_steps=steps, num_images_per_prompt=batch_size, - **extra_kwargs + **extra_kwargs, ).images else: images = pipe( - prompt=[prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, - **extra_kwargs + prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs ).images inference_end = time.time() latency = inference_end - inference_start diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py index 9e30130d033e7..80d408e671979 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -26,7 +26,6 @@ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically). """ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) - super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) def postprocess(self): @@ -95,8 +94,6 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): # if (options is None) or options.enable_bias_skip_layer_norm: # # Fuse SkipLayerNormalization and Add Bias before it. # self.fuse_add_bias_skip_layer_norm() - # if progress_bar: - # progress_bar.update(1) self.postprocess() if progress_bar: From 89950d13838868cefe249b4fe154f0a1eb4329b9 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Dec 2024 04:34:20 +0000 Subject: [PATCH 13/26] t5 optimization and mixed precision conversion --- .../stable_diffusion/optimize_pipeline.py | 154 +++++++++++- .../tools/transformers/onnx_model_t5.py | 225 ++++++++++++++---- .../python/tools/transformers/optimizer.py | 2 +- 3 files changed, 333 insertions(+), 48 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index feb3b27611e6c..38ed031658838 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -29,6 +29,7 @@ from fusion_options import FusionOptions from onnx_model_clip import ClipOnnxModel from onnx_model_mmdit import MmditOnnxModel +from onnx_model_t5 import T5OnnxModel from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel from optimizer import optimize_by_onnxruntime, optimize_model @@ -122,21 +123,23 @@ def _optimize_sd_pipeline( RuntimeError: input onnx model does not exist RuntimeError: output onnx model path existed """ + is_flux_pipeline = pipeline_type == "flux" model_type_mapping = { "transformer": "mmdit", "unet": "unet", "vae_encoder": "vae", "vae_decoder": "vae", "text_encoder": "clip", - "text_encoder_2": "clip" if pipeline_type != "flux" else "bert", + "text_encoder_2": "t5" if is_flux_pipeline else "clip", + "text_encoder_3": "t5", # t5-v1_1-xxl is used in SD 3.x text_encoder_3 and Flux text_encoder_2. "safety_checker": "unet", - "text_encoder_3": "clip", # t5? } model_type_class_mapping = { "unet": UnetOnnxModel, "vae": VaeOnnxModel, "clip": ClipOnnxModel, + "t5": T5OnnxModel, "mmdit": MmditOnnxModel, } @@ -151,6 +154,134 @@ def _optimize_sd_pipeline( "transformer": [], } + # The node block list is generated by running the fp32 model and duming node inputs and outputs statistics. + # Nodes with any input or output of float or double data type, but value ouf of range of float16 are candidates. + # python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp32_opt + # export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1 + # export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1 + # export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1 + # python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >dump.txt 2>err.txt + # Warning: The node name might change in different export settings. We used python 3.10 and the following packages: + # diffusers==0.31.0 transformers==4.46.3 optimum==1.24.0.dev0 torch==2.5.1 onnx==1.17.0 protobuf==5.29.2 + flux_node_block_list = { + "text_encoder_2": [ + "/encoder/block.10/layer.1/DenseReluDense/wo/MatMul", + "SkipLayerNorm_20", + "SkipLayerNorm_21", + "SkipLayerNorm_22", + "SkipLayerNorm_23", + "SkipLayerNorm_24", + "SkipLayerNorm_25", + "SkipLayerNorm_26", + "SkipLayerNorm_27", + "SkipLayerNorm_28", + "SkipLayerNorm_29", + "SkipLayerNorm_30", + "SkipLayerNorm_31", + "SkipLayerNorm_32", + "SkipLayerNorm_33", + "SkipLayerNorm_34", + "SkipLayerNorm_35", + "SkipLayerNorm_36", + "SkipLayerNorm_37", + "SkipLayerNorm_38", + "SkipLayerNorm_39", + "SkipLayerNorm_40", + "SkipLayerNorm_41", + "SkipLayerNorm_42", + "SkipLayerNorm_43", + "SkipLayerNorm_44", + "SkipLayerNorm_45", + "SkipLayerNorm_46", + ], + "vae_decoder": [ + "/decoder/mid_block/attentions.0/MatMul", + "/decoder/mid_block/attentions.0/Softmax", + ], + "transformer": [ + "/transformer_blocks.18/Add_7", + "/Concat_1", + "LayerNorm_76", + "/single_transformer_blocks.0/Add", + "LayerNorm_77", + "/single_transformer_blocks.1/Add", + "LayerNorm_78", + "/single_transformer_blocks.2/Add", + "LayerNorm_79", + "/single_transformer_blocks.3/Add", + "LayerNorm_80", + "/single_transformer_blocks.4/Add", + "LayerNorm_81", + "/single_transformer_blocks.5/Add", + "LayerNorm_82", + "/single_transformer_blocks.6/Add", + "LayerNorm_83", + "/single_transformer_blocks.7/Add", + "LayerNorm_84", + "/single_transformer_blocks.8/Add", + "LayerNorm_85", + "/single_transformer_blocks.9/Add", + "LayerNorm_86", + "/single_transformer_blocks.10/Add", + "LayerNorm_87", + "/single_transformer_blocks.11/Add", + "LayerNorm_88", + "/single_transformer_blocks.12/Add", + "LayerNorm_89", + "/single_transformer_blocks.13/Add", + "LayerNorm_90", + "/single_transformer_blocks.14/Add", + "LayerNorm_91", + "/single_transformer_blocks.15/Add", + "LayerNorm_92", + "/single_transformer_blocks.16/Add", + "LayerNorm_93", + "/single_transformer_blocks.17/Add", + "LayerNorm_94", + "/single_transformer_blocks.18/Add", + "LayerNorm_95", + "/single_transformer_blocks.19/Add", + "LayerNorm_96", + "/single_transformer_blocks.20/Add", + "LayerNorm_97", + "/single_transformer_blocks.21/Add", + "LayerNorm_98", + "/single_transformer_blocks.22/Add", + "LayerNorm_99", + "/single_transformer_blocks.23/Add", + "LayerNorm_100", + "/single_transformer_blocks.24/Add", + "LayerNorm_101", + "/single_transformer_blocks.25/Add", + "LayerNorm_102", + "/single_transformer_blocks.26/Add", + "LayerNorm_103", + "/single_transformer_blocks.27/Add", + "LayerNorm_104", + "/single_transformer_blocks.28/Add", + "LayerNorm_105", + "/single_transformer_blocks.29/Add", + "LayerNorm_106", + "/single_transformer_blocks.30/Add", + "LayerNorm_107", + "/single_transformer_blocks.31/Add", + "LayerNorm_108", + "/single_transformer_blocks.32/Add", + "LayerNorm_109", + "/single_transformer_blocks.33/Add", + "LayerNorm_110", + "/single_transformer_blocks.34/Add", + "LayerNorm_111", + "/single_transformer_blocks.35/Add", + "LayerNorm_112", + "/single_transformer_blocks.36/Add", + "LayerNorm_113", + "/single_transformer_blocks.37/Add", + "/Shape", + "/Slice", + ], + } + if force_fp32_ops: for fp32_operator in force_fp32_ops: parts = fp32_operator.split(":") @@ -171,6 +302,10 @@ def _optimize_sd_pipeline( # Prepare output directory optimized_model_path = target_dir / name / "model.onnx" + if os.path.exists(optimized_model_path): + if not args.overwrite: + logger.warning("Skipped optimization since the target file existed: %s", optimized_model_path) + continue output_dir = optimized_model_path.parent output_dir.mkdir(parents=True, exist_ok=True) @@ -202,8 +337,17 @@ def _optimize_sd_pipeline( ) if float16: + if is_flux_pipeline and name in flux_node_block_list: + m.convert_float_to_float16( + keep_io_types=False, + node_block_list=flux_node_block_list[name], + ) + elif model_type == "t5": + assert isinstance(m, T5OnnxModel) + # TODO: follow t5 model in Flux model to use node block list instead. + m.convert_mixed_precision(force_dense_output_fp32=True, keep_io_types=True) # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32. - if is_sdxl(source_dir) and name == "vae_decoder": + elif pipeline_type in ["sdxl"] and name in ["vae_decoder"]: logger.info("Skip converting %s to float16 to avoid NaN", name) else: logger.info("Convert %s to float16 ...", name) @@ -256,6 +400,8 @@ def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: List[s continue target_path = target_dir / name + if target_path.exists(): + shutil.rmtree(target_path) shutil.copytree(source_path, target_path) logger.info("%s => %s", source_path, target_path) @@ -291,8 +437,6 @@ def optimize_stable_diffusion_pipeline( if os.path.exists(output_dir): if overwrite: shutil.rmtree(output_dir, ignore_errors=True) - else: - raise RuntimeError(f"output directory existed:{output_dir}. Add --overwrite to empty the directory.") source_dir = Path(input_dir) target_dir = Path(output_dir) diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 9cc4878e8022d..5619be870107d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -75,9 +75,10 @@ def create_attention_node( k_weight = self.model.get_initializer(k_matmul.input[1]) v_weight = self.model.get_initializer(v_matmul.input[1]) - if q_weight is None: + if q_weight is None or k_weight is None or v_weight is None: + matmul = q_matmul if q_weight is None else k_matmul if k_weight is None else v_matmul print( - f"{q_matmul.input[1]} is not an initializer. " + f"{matmul.input[1]} is not an initializer. " "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion" ) return None @@ -222,9 +223,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no return qkv_nodes = self.model.match_parent_path( - normalize_node, - ["MatMul", "Reshape", "Transpose", "MatMul"], - [1, 0, 0, 0], + normalize_node, ["MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0], output_name_to_node ) if qkv_nodes is None: return @@ -235,6 +234,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no reshape_qkv, ["Concat", "Unsqueeze", "Gather", "Shape"], [1, 0, 0, 0], + output_name_to_node, ) if qkv_shape_nodes is None: return @@ -244,6 +244,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0], + output_name_to_node, ) if v_nodes is None: return @@ -254,28 +255,64 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0], + output_name_to_node, ) if qk_nodes is None: return _, add_qk, matmul_qk = qk_nodes - mask_index = None mask_nodes = self.model.match_parent_path( add_qk, ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [1, 1, 0, 1, 0, 0], + output_name_to_node, ) + + is_pattern_for_one_graph_input = mask_nodes is None if mask_nodes is None: - return - mul_node = mask_nodes[1] - if mask_nodes[1].op_type != "Mul": - return + # Pattern for SD3 and Flux. + mask_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Slice", "Mul", "Sub", "Unsqueeze", "Unsqueeze"], + [1, 1, 0, 0, 1, 0], + output_name_to_node, + ) + if mask_nodes is None: + return + mul_node = mask_nodes[2] + else: + mul_node = mask_nodes[1] _, mul_val = self.model.get_constant_input(mul_node) - if mul_val != -10000: - self.mask_filter_value = mul_val + if mul_val is None: + return - mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) + if mul_val != -10000: + self.mask_filter_value = float(mul_val) + + # If the mask is derived from shape of input_ids, it means there is no padding mask. + mask_nodes_2 = self.model.match_parent_path( + mask_nodes[-1], + ["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"], + [0, 0, 0, 0, 0], + output_name_to_node, + ) + mask_nodes_3 = self.model.match_parent_path( + mask_nodes[-1], + ["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"], + [0, 0, 1, 0, 0], + output_name_to_node, + ) + if ( + mask_nodes_2 is not None + and any(input.name == mask_nodes_2[-1].input[0] for input in self.model.graph().input) + and mask_nodes_3 is not None + and mask_nodes_2[-1].input[0] == mask_nodes_3[-1].input[0] + and len(mask_nodes_2[1].input) == 2 + ): + mask_index = "" + else: + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) res_pos_bias = None rpb_nodes = self.model.match_parent_path( @@ -283,10 +320,17 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no ["Add", "RelativePositionBias"], [1, 0], ) + if rpb_nodes is None and is_pattern_for_one_graph_input: + # Pattern for SD3 and Flux. + rpb_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Slice", "RelativePositionBias"], + [1, 0, 0], + ) if rpb_nodes is None: return - rpb_add_node = rpb_nodes[0] - res_pos_bias = rpb_add_node.input[0] + + res_pos_bias = rpb_nodes[-1].output[0] k_nodes = self.model.match_parent_path( matmul_qk, @@ -332,13 +376,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name - self.nodes_to_remove.extend(qkv_nodes[1:]) - self.nodes_to_remove.extend(qk_nodes) - self.nodes_to_remove.extend(k_nodes[:-1]) - if v_nodes is not None: - self.nodes_to_remove.extend(v_nodes[:-1]) - self.nodes_to_remove.extend(q_nodes[:-1]) - + self.nodes_to_remove.append(reshape_qkv) self.prune_graph = True def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_node): @@ -591,12 +629,7 @@ def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_no self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name - self.nodes_to_remove.extend(qkv_nodes[1:]) - self.nodes_to_remove.extend(qk_nodes) - self.nodes_to_remove.extend(k_nodes[:-1]) - if v_nodes is not None: - self.nodes_to_remove.extend(v_nodes[:-1]) - self.nodes_to_remove.extend(q_nodes[:-1]) + self.nodes_to_remove.append(reshape_qkv) self.prune_graph = True @@ -605,7 +638,6 @@ class FusionRelativePositionBiasBlock(Fusion): def __init__(self, model: OnnxModel, max_distance: int): super().__init__(model, "RelativePositionBias", ["Add", "Slice"]) self.max_distance = max_distance - # bidirectional=(not self.is_decoder) self.is_bidirectional = False def fuse(self, node, input_name_to_nodes, output_name_to_node): @@ -615,11 +647,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return compute_bias_nodes = self.model.match_parent_path( - node, ["Unsqueeze", "Transpose", "Gather", "Where"], [0, 0, 0, 1] + node, ["Unsqueeze", "Transpose", "Gather", "Where"], [0, 0, 0, 1], output_name_to_node ) if compute_bias_nodes is None: compute_bias_nodes = self.model.match_parent_path( - node, ["Unsqueeze", "Transpose", "Gather", "Add", "Where"], [0, 0, 0, 1, 1] + node, ["Unsqueeze", "Transpose", "Gather", "Add", "Where"], [0, 0, 0, 1, 1], output_name_to_node ) if compute_bias_nodes is None: return @@ -632,20 +664,29 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): where, ["Min", "ConstantOfShape", "Shape", "Add", "Cast", "Mul", "Div", "Log", "Div"], [2, 1, 0, 0, 0, 0, 0, 0, 0], + output_name_to_node, ) if compute_buckets_nodes is None: return + # It is possible to deduce max_distance from a Div node: + # The value of self.model.get_constant_value(compute_buckets_nodes[-3].input[1]) is close to + # math.log(max_distance / (relative_attention_num_buckets // (4 if is_bidirectional else 2))) + # See https://github.com/huggingface/transformers/blob/608e163b527eaee41e650ffb9eb4c422d2679902/src/transformers/models/t5/modeling_t5.py#L397. + # Most t5 models use max_distance=128, so we hardcode it unitl we see a model with different value. + # TODO: maybe add a sanity check here. + div = compute_buckets_nodes[-1] range_nodes = self.model.match_parent_path( div, ["Cast", "Neg", "Min", "ConstantOfShape", "Shape", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 1, 0, 0, 0, 0], + output_name_to_node, ) if range_nodes is None: range_nodes = self.model.match_parent_path( - div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0] + div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0], output_name_to_node ) self.is_bidirectional = True if range_nodes is None: @@ -653,17 +694,20 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): range_node = range_nodes[-1] - self.nodes_to_remove.extend(compute_bias_nodes) - self.nodes_to_remove.extend(compute_buckets_nodes) - self.nodes_to_remove.extend(range_nodes) + self.nodes_to_remove.append(unsqueeze) + self.prune_graph = True - node_name_prefix = "encoder" if self.is_bidirectional else "decoder" + node_name = self.model.create_node_name( + "RelativePositionBias", name_prefix="RelPosBias_" + ("encoder" if self.is_bidirectional else "decoder") + ) table_weight_i = self.model.get_initializer(gather.input[0]) + if table_weight_i is None: + return table_weight = NumpyHelper.to_array(table_weight_i) table_weight_t = np.transpose(table_weight) bias_table = helper.make_tensor( - name=self.model.create_node_name("bias_table_weight", name_prefix=node_name_prefix), + name=node_name + "_bias_table_weight", data_type=TensorProto.FLOAT, dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]], vals=table_weight_t.tobytes(), @@ -677,7 +721,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): "RelativePositionBias", inputs=inputs, outputs=outputs, - name=self.model.create_node_name("RelativePositionBias", name_prefix=node_name_prefix), + name=node_name, ) rpb_node.domain = "com.microsoft" rpb_node.attribute.extend([helper.make_attribute("max_distance", self.max_distance)]) @@ -688,14 +732,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): class T5OnnxModel(BertOnnxModel): - def __init__(self, model, num_heads, hidden_size): + def __init__(self, model, num_heads: int = 0, hidden_size: int = 0): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) + + # When the model has only one input (input_ids), there is no padding mask. + if len(self.model.graph.input) == 1: + from fusion_options import AttentionMaskFormat + + self.attention_mask.mask_format = AttentionMaskFormat.NoMask + self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask) self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self) self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self) - # TODO: consider retrieve max_distance from model. - # math.log(max_distance / (num_buckets // 2)) self.rpb_fusion = FusionRelativePositionBiasBlock(self, 128) def fuse_attention(self): @@ -704,9 +753,65 @@ def fuse_attention(self): def fuse_layer_norm(self): self.layer_norm_fusion.apply() - def fuse_skip_layer_norm(self): + def fuse_skip_layer_norm(self, shape_infer=True): self.skip_layer_norm_fusion.apply() + def adjust_rel_pos_bis_length_input(self): + # For T5 encoder, it uses complex logic to compute the query and key length when there is only one graph input (input_ids) + # We can directly get the length from shape (the 2nd dimension) of input_ids. + for node in self.nodes(): + if node.op_type == "RelativePositionBias": + nodes = self.match_parent_path( + node, + [ + "Gather", + "Shape", + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "SimplifiedLayerNormalization", + "Gather", + ], + [1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + ) + # TODO: more validation on node attributes + if nodes is not None: + graph_input_names = [input.name for input in self.model.graph.input] + if nodes[-1].input[1] in graph_input_names: + node_name = self.create_node_name("Shape", name_prefix="Added_Shape_") + shape_node = helper.make_node( + "Shape", + inputs=[nodes[-1].input[1]], + outputs=[node_name + "_Output"], + name=node_name, + ) + + indices_1 = helper.make_tensor( + name="Constant_Index_1", + data_type=TensorProto.INT64, + dims=[1], # Shape of the tensor + vals=[1], # Tensor values + ) + self.add_initializer(indices_1) + + gather = helper.make_node( + "Gather", + inputs=[node_name + "_Output", "Constant_Index_1"], + outputs=[node_name + "_Output_Gather_1"], + name=node_name, + axis=0, + ) + + self.add_node(shape_node) + self.add_node(gather) + node.input[1] = node_name + "_Output_Gather_1" + node.input[2] = node_name + "_Output_Gather_1" + + break + # Remove get_extended_attention_mask() since it generates all zeros. def remove_extended_mask_decoder_init(self): nodes_to_remove = [] @@ -787,5 +892,41 @@ def postprocess(self): # remove get_extended_attention_mask() since it generates all zeros. self.remove_extended_mask_decoder_init() self.remove_extended_mask_decoder() + self.adjust_rel_pos_bis_length_input() self.prune_graph() + + def convert_mixed_precision(self, force_dense_output_fp32: bool = True, keep_io_types=True): + """ + Convert model to mixed precision (float32 and float16). + This shall be done after the model is fully optimized and pruned. + + Args: + force_dense_output_fp32: force the output MatMul in MatMul-Rel-MatMul to float32. + keep_io_types: keep graph input/output data type. + """ + + node_block_list = [] + + output_name_to_node_map = self.output_name_to_node() + # See https://github.com/huggingface/transformers/issues/20287#issuecomment-1342219429 + if force_dense_output_fp32: + for node in self.nodes(): + if node.op_type == "SkipSimplifiedLayerNormalization": + parent, i = self.match_first_parent(node, "MatMul", output_name_to_node=output_name_to_node_map) + if parent: + if parent.name: + node_block_list.append(parent.name) + else: + logger.warning(f"Node has no name. Its first output is {parent.output[0]}") + + parameters = { + "keep_io_types": keep_io_types, + "op_block_list": ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization"], + "node_block_list": node_block_list, + } + + logger.info(f"mixed precision parameters: {parameters}") + self.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters) + + return parameters diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 33737a7d34998..a83c54e345d7d 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -239,7 +239,7 @@ def optimize_by_fusion( Returns: object of an optimizer class. """ - if model_type not in ["bert", "swin", "unet", "vae", "clip", "sam2", "mmdit"] and ( + if model_type not in ["bert", "t5", "swin", "unet", "vae", "clip", "sam2", "mmdit"] and ( num_heads == 0 or hidden_size == 0 ): logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") From c8691511b0066dc17b4796ddf5568a12c0c52046 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Dec 2024 07:09:54 +0000 Subject: [PATCH 14/26] fix node name --- onnxruntime/python/tools/transformers/onnx_model_t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 5619be870107d..ae62ee10c56a0 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -801,7 +801,7 @@ def adjust_rel_pos_bis_length_input(self): "Gather", inputs=[node_name + "_Output", "Constant_Index_1"], outputs=[node_name + "_Output_Gather_1"], - name=node_name, + name=self.create_node_name("Gather", name_prefix="Added_Gather_"), axis=0, ) From 84b1a51515011c8584a81ffe3e24b5955e77193a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Dec 2024 19:10:04 +0000 Subject: [PATCH 15/26] Add option to use bfloat16 --- .../python/tools/transformers/float16.py | 66 +++++++++++++++---- .../models/stable_diffusion/README.md | 22 ++++--- .../models/stable_diffusion/benchmark.py | 4 +- .../stable_diffusion/optimize_pipeline.py | 39 ++++++++--- .../tools/transformers/onnx_model_t5.py | 35 ---------- 5 files changed, 100 insertions(+), 66 deletions(-) diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 2398bb9d6031b..874020640ef62 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -143,6 +143,8 @@ def make_value_info_from_tensor(tensor): "Upsample", ] +# Some operators do not support float16 in CUDA. This is not a full list, just some common operators in transformers. +BF16_OP_BLACK_LIST = ["SkipSimplifiedLayerNormalization", "Attention", "MultiHeadAttention"] # Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices # Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this. @@ -154,14 +156,19 @@ class InitializerTracker: def __init__(self, initializer: TensorProto): self.initializer = initializer + self.bf16_nodes = [] self.fp32_nodes = [] self.fp16_nodes = [] - def add_node(self, node: NodeProto, is_node_blocked): - if is_node_blocked: + def add_node(self, node: NodeProto, dtype: int): + if dtype == TensorProto.FLOAT: self.fp32_nodes.append(node) - else: + elif dtype == TensorProto.BFLOAT16: + self.bf16_nodes.append(node) + elif dtype == TensorProto.FLOAT16: self.fp16_nodes.append(node) + else: + raise ValueError("Invalid dtype") def convert_float_to_float16( @@ -333,11 +340,19 @@ def convert_float_to_float16( for i, input_name in enumerate(n.input): if input_name in fp32_initializers: # For Resize/GroupNorm, only the first input can be float16 - use_fp32_weight = is_node_blocked or ( - i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) - and i not in force_fp16_inputs_dict.get(n.op_type, []) - ) - fp32_initializers[input_name].add_node(n, use_fp32_weight) + if i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) and i not in force_fp16_inputs_dict.get( + n.op_type, [] + ): + dtype = TensorProto.FLOAT + elif is_node_blocked: + dtype = ( + TensorProto.BFLOAT16 + if (use_bfloat16_as_blocked_nodes_dtype and n.op_type not in BF16_OP_BLACK_LIST) + else TensorProto.FLOAT + ) + else: + dtype = TensorProto.FLOAT16 + fp32_initializers[input_name].add_node(n, dtype) if is_node_blocked: node_list.append(n) @@ -404,15 +419,21 @@ def convert_float_to_float16( queue = next_level + initializers_to_be_casted_to_bf16: Dict[str, TensorProto] = {} for value in fp32_initializers.values(): # By default, to avoid precision loss, do not convert an initializer to fp16 when it is used only by fp32 nodes. if force_fp16_initializers or value.fp16_nodes: value.initializer = convert_tensor_float_to_float16(value.initializer, min_positive_val, max_finite_val) value_info_list.append(make_value_info_from_tensor(value.initializer)) - if value.fp32_nodes and not force_fp16_initializers: + if (value.fp32_nodes or value.bf16_nodes) and not force_fp16_initializers: logger.info( - f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}" + f"initializer is used by both fp32/bf16 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}" ) + elif value.bf16_nodes: + # If float initializer is only used by bfloat16 nodes, need to convert it to bfloat16. + # However, numpy does not support bfloat16, so we will add a Cast node to conver it later. + initializers_to_be_casted_to_bf16[value.initializer.name] = value.initializer + continue # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. for node in mixed_float_type_node_list: @@ -435,14 +456,16 @@ def convert_float_to_float16( node.input[i] = output_name break - accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT # process the nodes in block list that doesn't support tensor(float16) for node in node_list: # if input's name is in the value_info_list meaning input is tensor(float16) type, - # insert a float16 to float Cast node before the node, + # insert a float16 to target type (float or bfloat16) Cast node before the node, # change current node's input name and create new value_info for the new name + use_bf16 = use_bfloat16_as_blocked_nodes_dtype and node.op_type not in BF16_OP_BLACK_LIST + accuracy_type = TensorProto.BFLOAT16 if use_bf16 else TensorProto.FLOAT for i in range(len(node.input)): input_name = node.input[i] + is_input_converted = False for value_info in value_info_list: if input_name == value_info.name: # create new value_info for current node's new input name @@ -457,9 +480,24 @@ def convert_float_to_float16( model.graph.node.extend(new_node) # change current node's input name node.input[i] = output_name + is_input_converted = True break - # if output's name is in the value_info_list meaning output is tensor(float16) type, insert a float to - # float16 Cast node after the node, change current node's output name and create new value_info for the new name + + # For bfloat16 nodes, we need to convert float initializers to bfloat16. + if (not is_input_converted) and use_bf16 and (input_name in initializers_to_be_casted_to_bf16): + output_name = node.name + "_input_cast_" + str(i) + value_info = helper.make_tensor_value_info( + name=output_name, elem_type=accuracy_type, shape=initializers_to_be_casted_to_bf16[input_name].dims + ) + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(value_info) + node_name = node.name + "_input_cast" + str(i) + new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)] + model.graph.node.extend(new_node) + node.input[i] = output_name + + # if output's name is in the value_info_list meaning output is tensor(float16) type, insert a Cast (to float16) + # node after it, change current node's output name and create new value_info for the new name. for i in range(len(node.output)): output = node.output[i] for value_info in value_info_list: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 7cd5f273c556e..20451d768133f 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -214,17 +214,19 @@ optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task #### Stable Diffusion 3.x and Flux 1.0 -Stable Diffusion 3.x and Flux 1.0 requires transformers >= 4.45, and optimum > 1.23.3: +Stable Diffusion 3.x and Flux 1.0 requires transformers >= 4.45, and optimum > 1.23.3. +The default opset version is 12 for T5. To support bfloat16, please set `--opset` verison explicitly like below example. + ``` git clone https://github.com/huggingface/optimum cd optimum pip install -e . -optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers ./sd3_onnx/fp32 -optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-medium ./sd3.5_medium_onnx/fp32 -optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-large ./sd3.5_large_onnx/fp32 -optimum-cli export onnx --model black-forest-labs/FLUX.1-schnell ./flux1_schnell_onnx/fp32 -optimum-cli export onnx --model black-forest-labs/FLUX.1-dev ./flux1_dev_onnx/fp32 +optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers ./sd3_onnx/fp32 --opset 15 +optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-medium ./sd3.5_medium_onnx/fp32 --opset 15 +optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-large ./sd3.5_large_onnx/fp32 --opset 15 +optimum-cli export onnx --model black-forest-labs/FLUX.1-schnell ./flux1_schnell_onnx/fp32 --opset 15 +optimum-cli export onnx --model black-forest-labs/FLUX.1-dev ./flux1_dev_onnx/fp32 --opset 15 ``` ### Optimize ONNX Pipeline @@ -242,9 +244,12 @@ cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion python optimize_pipeline.py -i ./sdxl_onnx/fp32 -o ./sdxl_onnx/fp16 --float16 python optimize_pipeline.py -i ./sd3_onnx/fp32 -o ./sd3_onnx/fp16 --float16 python optimize_pipeline.py -i ./sd3.5_medium_onnx/fp32 -o ./sd3.5_medium_onnx/fp16 --float16 -python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp16 --float16 -python optimize_pipeline.py -i ./flux1_dev_onnx/fp32 -o ./flux1_dev_onnx/fp16 --float16 +python optimize_pipeline.py -i ./sd3.5_large_onnx/fp32 -o ./sd3.5_large_onnx/fp16 --float16 +python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp16 --float16 --bfloat16 +python optimize_pipeline.py -i ./flux1_dev_onnx/fp32 -o ./flux1_dev_onnx/fp16 --float16 --bfloat16 ``` +When converting model to float16, some nodes has overflow risk and we can force those nodes to run in either float32 or bfloat16. +Option `--bfloat16` enables the later. If an operator does not support bfloat16, it will fallback to float32. For SDXL model, it is recommended to use a machine with 48 GB or more memory to optimize. @@ -288,6 +293,7 @@ Example of benchmark with optimum using CUDA provider on stable diffusion 3.5 me ``` python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp32 python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5L -p sd3.5_large_onnx/fp16 python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp16 python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v Flux.1D -p flux1_dev_onnx/fp16 ``` diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index f6e49f484db38..3a3fd9d7ce8d8 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -9,6 +9,7 @@ import statistics import sys import time +from pathlib import Path import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs @@ -473,7 +474,8 @@ def run_optimum_ort( load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") - image_filename_prefix = get_image_filename_prefix("optimum", model_name, batch_size, disable_safety_checker) + full_model_name = model_name + "_" + Path(directory).name if directory else model_name + image_filename_prefix = get_image_filename_prefix("optimum", full_model_name, batch_size, disable_safety_checker) result = run_optimum_ort_pipeline( pipe, batch_size, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 38ed031658838..6ee23c64fbec8 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -104,6 +104,7 @@ def _optimize_sd_pipeline( model_list: List[str], use_external_data_format: Optional[bool], float16: bool, + bfloat16: bool, force_fp32_ops: List[str], enable_runtime_optimization: bool, args, @@ -116,6 +117,7 @@ def _optimize_sd_pipeline( model_list (List[str]): list of directory names with onnx model. use_external_data_format (Optional[bool]): use external data format. float16 (bool): use half precision + bfloat16 (bool): use bfloat16 as fallback if float16 is also provided. force_fp32_ops(List[str]): operators that are forced to run in float32. enable_runtime_optimization(bool): run graph optimization using Onnx Runtime. @@ -160,7 +162,7 @@ def _optimize_sd_pipeline( # export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1 # export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1 # export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1 - # python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >dump.txt 2>err.txt + # python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >stdout.txt 2>stderr.txt # Warning: The node name might change in different export settings. We used python 3.10 and the following packages: # diffusers==0.31.0 transformers==4.46.3 optimum==1.24.0.dev0 torch==2.5.1 onnx==1.17.0 protobuf==5.29.2 flux_node_block_list = { @@ -199,6 +201,7 @@ def _optimize_sd_pipeline( "/decoder/mid_block/attentions.0/Softmax", ], "transformer": [ + "/transformer_blocks.18/Mul_5", "/transformer_blocks.18/Add_7", "/Concat_1", "LayerNorm_76", @@ -282,6 +285,8 @@ def _optimize_sd_pipeline( ], } + sd3_node_block_list = {"text_encoder_3": flux_node_block_list["text_encoder_2"]} + if force_fp32_ops: for fp32_operator in force_fp32_ops: parts = fp32_operator.split(":") @@ -337,15 +342,23 @@ def _optimize_sd_pipeline( ) if float16: - if is_flux_pipeline and name in flux_node_block_list: + model_node_block_list = ( + flux_node_block_list if is_flux_pipeline else sd3_node_block_list if pipeline_type == "sd3" else {} + ) + if name in model_node_block_list: + # Opset 12 does not support bfloat16. + # By default, optimum exports T5 model with opset 12. So we need to check the opset version. + use_bfloat16 = bfloat16 + for opset in m.model.opset_import: + if opset.domain in ["", "ai.onnx"] and opset.version < 13: + logger.warning("onnx model requires opset 13 or higher to use bfloat16. Fall back to float32.") + use_bfloat16 = False + m.convert_float_to_float16( keep_io_types=False, - node_block_list=flux_node_block_list[name], + node_block_list=model_node_block_list[name], + use_bfloat16_as_blocked_nodes_dtype=use_bfloat16, ) - elif model_type == "t5": - assert isinstance(m, T5OnnxModel) - # TODO: follow t5 model in Flux model to use node block list instead. - m.convert_mixed_precision(force_dense_output_fp32=True, keep_io_types=True) # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32. elif pipeline_type in ["sdxl"] and name in ["vae_decoder"]: logger.info("Skip converting %s to float16 to avoid NaN", name) @@ -454,6 +467,7 @@ def optimize_stable_diffusion_pipeline( model_list, use_external_data_format, float16, + args.bfloat16, args.force_fp32_ops, enable_runtime_optimization, args, @@ -488,10 +502,18 @@ def parse_arguments(argv: Optional[List[str]] = None): "--float16", required=False, action="store_true", - help="Output models of half or mixed precision.", + help="Output models of float16, except some nodes falls back to float32 or bfloat16 to avoid overflow.", ) parser.set_defaults(float16=False) + parser.add_argument( + "--bfloat16", + required=False, + action="store_true", + help="Allow bfloat16 as fallback if --float16 is also provided.", + ) + parser.set_defaults(bfloat16=False) + parser.add_argument( "--force_fp32_ops", required=False, @@ -544,6 +566,7 @@ def parse_arguments(argv: Optional[List[str]] = None): def main(argv: Optional[List[str]] = None): args = parse_arguments(argv) + logger.info("Arguments: %s", str(args)) optimize_stable_diffusion_pipeline( args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index ae62ee10c56a0..70742bb5f52e3 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -895,38 +895,3 @@ def postprocess(self): self.adjust_rel_pos_bis_length_input() self.prune_graph() - - def convert_mixed_precision(self, force_dense_output_fp32: bool = True, keep_io_types=True): - """ - Convert model to mixed precision (float32 and float16). - This shall be done after the model is fully optimized and pruned. - - Args: - force_dense_output_fp32: force the output MatMul in MatMul-Rel-MatMul to float32. - keep_io_types: keep graph input/output data type. - """ - - node_block_list = [] - - output_name_to_node_map = self.output_name_to_node() - # See https://github.com/huggingface/transformers/issues/20287#issuecomment-1342219429 - if force_dense_output_fp32: - for node in self.nodes(): - if node.op_type == "SkipSimplifiedLayerNormalization": - parent, i = self.match_first_parent(node, "MatMul", output_name_to_node=output_name_to_node_map) - if parent: - if parent.name: - node_block_list.append(parent.name) - else: - logger.warning(f"Node has no name. Its first output is {parent.output[0]}") - - parameters = { - "keep_io_types": keep_io_types, - "op_block_list": ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization"], - "node_block_list": node_block_list, - } - - logger.info(f"mixed precision parameters: {parameters}") - self.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters) - - return parameters From b7041d1e344ae802d60bd8488b464e4eb1a0b18d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Dec 2024 08:58:34 +0000 Subject: [PATCH 16/26] fix attention --- .../cuda/bert/attention_prepare_qkv.cu | 61 ++++++++----------- .../transformers/compare_bert_results.py | 12 +++- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index c8c0191967d40..282ba2403b135 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -125,42 +125,31 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - if (data.bias == nullptr) { - assert(nullptr == fused_runner); - // For quantized attention, bias has been added so only need transpose here. - // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH - assert(qk_head_size == v_head_size); - int matrix_to_trans = (past_present_share_buffer ? 1 : 3); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } else { - // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) - // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) - // For unfused kernel, transpose to 3xBxNxSxH (format 1) - // For fused causal kernel, use format 1 since we need have K and V to update present state, - // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. - const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - data.qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); - - // For fused causal, we will update gemm_buffer with bias directly. - T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; - - int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); - // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v - // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) - LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.rotary_embedding, - parameters.past_sequence_length); - } + // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) + // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) + // For unfused kernel, transpose to 3xBxNxSxH (format 1) + // For fused causal kernel, use format 1 since we need have K and V to update present state, + // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. + const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); + data.qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); + + // For fused causal, we will update gemm_buffer with bias directly. + T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; + + int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); + // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v + // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) + LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, + 3, parameters.do_rotary, parameters.rotary_embedding, + parameters.past_sequence_length); return Status::OK(); } diff --git a/onnxruntime/python/tools/transformers/compare_bert_results.py b/onnxruntime/python/tools/transformers/compare_bert_results.py index 0c5125e74c8a4..ed60af6383273 100644 --- a/onnxruntime/python/tools/transformers/compare_bert_results.py +++ b/onnxruntime/python/tools/transformers/compare_bert_results.py @@ -37,16 +37,23 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3): # Validate the output of baseline and treatment, to make sure the results are similar. diff_count = 0 max_abs_diff = 0 + max_diff_percentage = 0 + case_passed = True for test_case_id, results in enumerate(baseline_results): - case_passed = True for i in range(len(results)): treatment_output = treatment_results[test_case_id][i] - abs_diff = np.amax(np.abs(treatment_output - results[i])) + abs_diff_tensor = np.abs(treatment_output - results[i]) + abs_diff = np.amax(abs_diff_tensor) if verbose and abs_diff > atol: print("abs_diff", abs_diff) print("treatment", treatment_output) print("baseline", results[i]) + count_exceeding = np.sum(abs_diff_tensor > atol) + total_elements = abs_diff_tensor.size + percentage_exceeding = (count_exceeding / total_elements) * 100 + max_diff_percentage = max(max_diff_percentage, percentage_exceeding) + max_abs_diff = max(max_abs_diff, abs_diff) if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol): if case_passed: @@ -66,6 +73,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3): ) print(f"maximum absolute difference={max_abs_diff}") + print(f"maximum percentage of elements that exceeds atol={atol} is {max_diff_percentage:.3f}%") return max_abs_diff, case_passed From 455a3ea9470f30cc0bb9b9ee67d43c058cd80e28 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Dec 2024 08:59:27 +0000 Subject: [PATCH 17/26] update node block list of t5 encoder --- onnxruntime/core/framework/print_tensor_statistics_utils.h | 2 +- .../transformers/models/stable_diffusion/optimize_pipeline.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/print_tensor_statistics_utils.h b/onnxruntime/core/framework/print_tensor_statistics_utils.h index 65360674e88d0..e4bb8b917d2b4 100644 --- a/onnxruntime/core/framework/print_tensor_statistics_utils.h +++ b/onnxruntime/core/framework/print_tensor_statistics_utils.h @@ -30,7 +30,7 @@ void PrintFloatStats(const T* data, size_t count) { size_t zero = 0; size_t subnormal = 0; for (size_t i = 0; i < count; i++) { - switch (my_fpclassify(*data)) { + switch (my_fpclassify(data[i])) { case FP_INFINITE: inf++; break; diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 6ee23c64fbec8..9e97bae6d2a5c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -194,6 +194,7 @@ def _optimize_sd_pipeline( "SkipLayerNorm_43", "SkipLayerNorm_44", "SkipLayerNorm_45", + "/encoder/block.23/layer.1/DenseReluDense/wo/MatMul", "SkipLayerNorm_46", ], "vae_decoder": [ From dad0ac408220732364641a1898bd0a78dc95ec4f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Dec 2024 09:13:15 +0000 Subject: [PATCH 18/26] benchmark torch eager mode --- .../models/stable_diffusion/README.md | 18 +++-- .../models/stable_diffusion/benchmark.py | 67 +++++++++++-------- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 20451d768133f..dc83f4dc220f0 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -215,7 +215,7 @@ optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task #### Stable Diffusion 3.x and Flux 1.0 Stable Diffusion 3.x and Flux 1.0 requires transformers >= 4.45, and optimum > 1.23.3. -The default opset version is 12 for T5. To support bfloat16, please set `--opset` verison explicitly like below example. +The default opset version for T5 is 12, which does not support bfloat16. To support bfloat16, please set opset version explicitly like below example. ``` git clone https://github.com/huggingface/optimum @@ -291,11 +291,17 @@ The default parameters are stable diffusion version=1.5, height=512, width=512, #### Stable Diffusion 3.x and Flux 1.0 Example of benchmark with optimum using CUDA provider on stable diffusion 3.5 medium and Flux 1.0: ``` -python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp32 -python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp16 -python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v 3.5L -p sd3.5_large_onnx/fp16 -python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp16 -python benchmark.py -e optimum --height 1024 --width 1024 --steps 20 -b 1 -v Flux.1D -p flux1_dev_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 30 -b 1 -v 3.0M -p sd3_onnx/fp32 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 30 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 30 -b 1 -v 3.5L -p sd3.5_large_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 30 -b 1 -v Flux.1D -p flux1_dev_onnx/fp16 +``` + +Benchmark PyTorch eager mode performance: +``` +python benchmark.py -e torch --height 1024 --width 1024 --steps 30 -b 1 -v 3.5L +python benchmark.py -e torch --height 1024 --width 1024 --steps 30 -b 1 -v Flux.1D ``` ### Run Benchmark with xFormers diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 3a3fd9d7ce8d8..c645b231455ae 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -96,6 +96,16 @@ def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_c def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_torch_compile: bool, use_xformers: bool): + if "FLUX" in model_name: + from diffusers import FluxPipeline + + return FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") + + if "stable-diffusion-3" in model_name: + from diffusers import StableDiffusion3Pipeline + + return StableDiffusion3Pipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") + from diffusers import DDIMScheduler, StableDiffusionPipeline from torch import channels_last, float16 @@ -199,6 +209,25 @@ def warmup(): } +def get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux, batch_size) -> dict: + # Flux does not support negative prompt + kwargs = ( + ( + {"negative_prompt": negative_prompt} + if use_num_images_per_prompt + else {"negative_prompt": [negative_prompt] * batch_size} + ) + if not is_flux + else {} + ) + + # Fix the random seed so that we can inspect the output quality easily. + if torch.cuda.is_available(): + kwargs["generator"] = torch.Generator(device="cuda").manual_seed(123) + + return kwargs + + def run_torch_pipeline( pipe, batch_size: int, @@ -213,16 +242,14 @@ def run_torch_pipeline( ): prompts, negative_prompt = example_prompts() - # total 2 runs of warm up, and measure GPU memory for CUDA EP + import diffusers + + is_flux = isinstance(pipe, diffusers.FluxPipeline) + def warmup(): prompt, negative = warmup_prompts() - pipe( - prompt=[prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, - negative_prompt=[negative] * batch_size, - ) + extra_kwargs = get_negative_prompt_kwargs(negative, False, is_flux, batch_size) + pipe(prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs) # Run warm up, and measure GPU memory of two runs (The first run has cuDNN algo search so it might need more memory) first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) @@ -238,13 +265,14 @@ def warmup(): break torch.cuda.synchronize() inference_start = time.time() + extra_kwargs = get_negative_prompt_kwargs(negative_prompt, False, is_flux, batch_size) images = pipe( prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, - negative_prompt=[negative_prompt] * batch_size, generator=None, # torch.Generator + **extra_kwargs, ).images torch.cuda.synchronize() @@ -368,23 +396,9 @@ def run_optimum_ort_pipeline( prompts, negative_prompt = example_prompts() - def get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux) -> dict: - # Flux does not support negative prompt - negative_prompt_kwargs = ( - ( - {"negative_prompt": negative} - if use_num_images_per_prompt - else {"negative_prompt": [negative] * batch_size} - ) - if not is_flux - else {} - ) - - return negative_prompt_kwargs - def warmup(): prompt, negative = warmup_prompts() - extra_kwargs = get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux) + extra_kwargs = get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux, batch_size) if use_num_images_per_prompt: pipe( prompt=prompt, @@ -404,10 +418,7 @@ def warmup(): warmup() - extra_kwargs = get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux) - # Fix the random seed so that we can inspect the output quality easily. - if torch.cuda.is_available(): - extra_kwargs["generator"] = torch.Generator(device="cuda").manual_seed(123) + extra_kwargs = get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux, batch_size) latency_list = [] for i, prompt in enumerate(prompts): From 840055806390574322cf73e178a56e5082086366 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Dec 2024 09:13:34 +0000 Subject: [PATCH 19/26] update comment --- onnxruntime/python/tools/transformers/float16.py | 5 ++++- .../models/stable_diffusion/optimize_pipeline.py | 13 ++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 874020640ef62..306d89b430277 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -143,7 +143,7 @@ def make_value_info_from_tensor(tensor): "Upsample", ] -# Some operators do not support float16 in CUDA. This is not a full list, just some common operators in transformers. +# Some operators do not support bfloat16 in CUDA. This is not a full list, just some common operators in transformers. BF16_OP_BLACK_LIST = ["SkipSimplifiedLayerNormalization", "Attention", "MultiHeadAttention"] # Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices @@ -201,6 +201,9 @@ def convert_float_to_float16( Default to false, which will convert only the one needed to avoid precision loss. force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if this script's preference it to keep them in float32. + use_bfloat16_as_blocked_nodes_dtype(bool): use bfloat16 as the data type for blocked nodes. Default to False. + If the node does not support bfloat16, it will remain in float. + Raises: ValueError: input type is not ModelProto. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 9e97bae6d2a5c..caac8d4b60035 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -156,7 +156,7 @@ def _optimize_sd_pipeline( "transformer": [], } - # The node block list is generated by running the fp32 model and duming node inputs and outputs statistics. + # The node block list is generated by running the fp32 model and get statistics of node inputs and outputs. # Nodes with any input or output of float or double data type, but value ouf of range of float16 are candidates. # python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp32_opt # export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1 @@ -350,10 +350,13 @@ def _optimize_sd_pipeline( # Opset 12 does not support bfloat16. # By default, optimum exports T5 model with opset 12. So we need to check the opset version. use_bfloat16 = bfloat16 - for opset in m.model.opset_import: - if opset.domain in ["", "ai.onnx"] and opset.version < 13: - logger.warning("onnx model requires opset 13 or higher to use bfloat16. Fall back to float32.") - use_bfloat16 = False + if use_bfloat16: + for opset in m.model.opset_import: + if opset.domain in ["", "ai.onnx"] and opset.version < 13: + logger.warning( + "onnx model requires opset 13 or higher to use bfloat16. Fall back to float32." + ) + use_bfloat16 = False m.convert_float_to_float16( keep_io_types=False, From 9e43e2068210a5308fcc461ed2fd02e619084f23 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Dec 2024 08:40:26 +0000 Subject: [PATCH 20/26] benchmark torch compile --- .../models/stable_diffusion/benchmark.py | 33 +++-- .../models/stable_diffusion/benchmark_flux.sh | 116 ++++++++++++++++++ 2 files changed, 137 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index c645b231455ae..0452cff235c11 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -99,12 +99,20 @@ def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_tor if "FLUX" in model_name: from diffusers import FluxPipeline - return FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") + pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") + if enable_torch_compile: + pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) + return pipe if "stable-diffusion-3" in model_name: from diffusers import StableDiffusion3Pipeline - return StableDiffusion3Pipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") + pipe = StableDiffusion3Pipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") + if enable_torch_compile: + pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) + return pipe from diffusers import DDIMScheduler, StableDiffusionPipeline from torch import channels_last, float16 @@ -132,9 +140,9 @@ def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_tor return pipe -def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, disable_safety_checker: bool): +def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, steps: int, disable_safety_checker: bool): short_model_name = model_name.split("/")[-1].replace("stable-diffusion-", "sd") - return f"{engine}_{short_model_name}_b{batch_size}" + ("" if disable_safety_checker else "_safe") + return f"{engine}_{short_model_name}_b{batch_size}_s{steps}" + ("" if disable_safety_checker else "_safe") def run_ort_pipeline( @@ -271,7 +279,6 @@ def warmup(): height=height, width=width, num_inference_steps=steps, - generator=None, # torch.Generator **extra_kwargs, ).images @@ -323,7 +330,7 @@ def run_ort( load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") - image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, steps, disable_safety_checker) result = run_ort_pipeline( pipe, batch_size, @@ -486,7 +493,9 @@ def run_optimum_ort( print(f"Model loading took {load_end - load_start} seconds") full_model_name = model_name + "_" + Path(directory).name if directory else model_name - image_filename_prefix = get_image_filename_prefix("optimum", full_model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix( + "optimum", full_model_name, batch_size, steps, disable_safety_checker + ) result = run_optimum_ort_pipeline( pipe, batch_size, @@ -591,7 +600,7 @@ def warmup(): warmup() - image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, steps, disable_safety_checker) latency_list = [] prompts, negative_prompt = example_prompts() @@ -730,7 +739,7 @@ def warmup(): warmup() - image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, steps, disable_safety_checker) latency_list = [] prompts, negative_prompt = example_prompts() @@ -885,7 +894,7 @@ def warmup(): warmup() model_name = pipeline_info.name() - image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, steps, disable_safety_checker) latency_list = [] prompts, negative_prompt = example_prompts() @@ -980,7 +989,7 @@ def warmup(): warmup() model_name = pipeline.pipeline_info.name() - image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, steps, disable_safety_checker) latency_list = [] prompts, negative_prompt = example_prompts() @@ -1048,7 +1057,7 @@ def run_torch( load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") - image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, steps, disable_safety_checker) if not enable_torch_compile: with torch.inference_mode(): diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh new file mode 100644 index 0000000000000..deffd21a9d8ca --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh @@ -0,0 +1,116 @@ +#!/bin/bash +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- + +# Please run this script under conda or virtual environment with Python 3.10, 3.11 or 3.12. +# bash benchmark_flux.sh + +# Installation directory (default: $HOME) +install_dir="${1:-$HOME}" + +# Root directory for the onnx models +onnx_dir="${2:-onnx_models}" + +# Which GPU to use +export CUDA_VISIBLE_DEVICES=0 + +# Function to install CUDA 12.6 +install_cuda_12() +{ + pushd $install_dir + wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run + sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath=$install_dir/cuda12.6 --silent --override --no-man-page + + export PATH="$install_dir/cuda12.6/bin:$PATH" + export LD_LIBRARY_PATH="$install_dir/cuda12.6/lib64:$LD_LIBRARY_PATH" + popd +} + +# Function to install cuDNN 9.6 +install_cudnn_9() { + pushd "$install_dir" + wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz + mkdir -p "$install_dir/cudnn9.6" + tar -Jxvf cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz -C "$install_dir/cudnn9.6"--strip=1 + export LD_LIBRARY_PATH="$install_dir/cudnn9.5/lib:$LD_LIBRARY_PATH" + popd +} + +# Install optimum from source before 1.24 is released +install_optimum() { + pushd "$install_dir" + optimum_dir="$install_dir/optimum" + if [ ! -d "$optimum_dir" ]; then + git clone https://github.com/huggingface/optimum + fi + cd "$sam2_dir" + pip show optimum > /dev/null 2>&1 || pip install -e . + popd +} + +# Install onnxruntime-gpu from source before 1.21 is released +install_onnxruntime() { + pushd "$install_dir" + if ! [ -d onnxruntime ]; then + git clone https://github.com/microsoft/onnxruntime + fi + cd onnxruntime + CUDA_ARCH=$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}{cc[1]}')") + if [ -n "$CUDA_ARCH" ]; then + pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==2.2 + sh build.sh --config Release --build_dir build/cuda12 --build_shared_lib --parallel \ + --use_cuda --cuda_version 12.6 --cuda_home $install_dir/cuda12.6 \ + --cudnn_home $install_dir/cudnn9.6 \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --compile_no_warning_as_error \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH + pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl + else + echo "No CUDA device found." + exit 1 + fi + popd +} + +# Install GPU dependencies +install_gpu() { + [ ! -d "$install_dir/cuda12.6" ] && install_cuda_12 + [ ! -d "$install_dir/cudnn9.6" ] && install_cudnn_9 + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 + + pip install diffusers==0.31.0 transformers==4.46.3 onnx==1.17.0 protobuf==5.29.2 + + install_onnxruntime + + install_optimum +} + +run_benchmark() { + local model=$1 + local dir=$2 + local version=$3 + local steps=$4 + local batch=$5 + + mkdir -p $dir + [ ! -d "$dir/fp32" ] && optimum-cli export onnx --model $model $dir/fp32 --opset 15 --task text-to-image + [ ! -d "$dir/fp16_fp32" ] && python optimize_pipeline.py -i $dir/fp32 -o $dir/fp16_fp32 --float16 + [ ! -d "$dir/fp16_bf16" ] && python optimize_pipeline.py -i $dir/fp32 -o $dir/fp16_bf16 --float16 --bfloat16 + python benchmark.py -e optimum --height 1024 --width 1024 --steps $steps -b $batch -v $version -p $dir/fp16_fp32 + python benchmark.py -e optimum --height 1024 --width 1024 --steps $steps -b $batch -v $version -p $dir/fp16_bf16 + python benchmark.py -e torch --height 1024 --width 1024 --steps $steps -b $batch -v $version + python benchmark.py -e torch --height 1024 --width 1024 --steps $steps -b $batch -v $version --enable_torch_compile +} + +install_gpu + +mkdir -p $root_dir + +run_benchmark black-forest-labs/FLUX.1-schnell ${root_dir}/flux1_schnell Flux.1S 4 1 > $root_dir/flux1_schnell_s4_b1.log +run_benchmark black-forest-labs/FLUX.1-dev ${root_dir}/flux1_dev Flux.1D 50 1 > $root_dir/flux1_dev_s50_b1.log + +run_benchmark stabilityai/stable-diffusion-3.5-large ${root_dir}/sd3.5_large 3.5L 50 1 > $root_dir/sd3.5_large_s50_b1.log +run_benchmark stabilityai/stable-diffusion-3.5-medium ${root_dir}/sd3.5_medium 3.5M 50 1 > $root_dir/sd3.5_medium_s50_b1.log From 4bf9f2524ebc234ee668189cec2b07fce8b3efa8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Dec 2024 09:47:23 +0000 Subject: [PATCH 21/26] refine benchmark_flux.sh --- .../models/stable_diffusion/benchmark_flux.sh | 119 ++++++++++-------- .../stable_diffusion/optimize_pipeline.py | 3 +- 2 files changed, 68 insertions(+), 54 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh index deffd21a9d8ca..82a0b4f0746a5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh @@ -4,25 +4,33 @@ # Licensed under the MIT License. # ------------------------------------------------------------------------- -# Please run this script under conda or virtual environment with Python 3.10, 3.11 or 3.12. -# bash benchmark_flux.sh +set -euo pipefail -# Installation directory (default: $HOME) -install_dir="${1:-$HOME}" +# Script to benchmark Flux models with ONNX and PyTorch +# Usage: bash benchmark_flux.sh + +# Validate inputs and environment +command -v python3 &>/dev/null || { echo "Python3 is required but not installed."; exit 1; } +command -v wget &>/dev/null || { echo "wget is required but not installed."; exit 1; } -# Root directory for the onnx models +# Input arguments with defaults +install_dir="${1:-$HOME}" onnx_dir="${2:-onnx_models}" -# Which GPU to use +# GPU settings export CUDA_VISIBLE_DEVICES=0 -# Function to install CUDA 12.6 -install_cuda_12() -{ - pushd $install_dir - wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run - sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath=$install_dir/cuda12.6 --silent --override --no-man-page +# Function to log messages +log() { + echo -e "\033[1;32m[INFO]\033[0m $1" +} +# Function to install CUDA 12.6 +install_cuda_12() { + log "Installing CUDA 12.6" + pushd "$install_dir" + wget -q https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run + sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath="$install_dir/cuda12.6" --silent --override --no-man-page export PATH="$install_dir/cuda12.6/bin:$PATH" export LD_LIBRARY_PATH="$install_dir/cuda12.6/lib64:$LD_LIBRARY_PATH" popd @@ -30,64 +38,67 @@ install_cuda_12() # Function to install cuDNN 9.6 install_cudnn_9() { + log "Installing cuDNN 9.6" pushd "$install_dir" - wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz + wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz mkdir -p "$install_dir/cudnn9.6" - tar -Jxvf cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz -C "$install_dir/cudnn9.6"--strip=1 - export LD_LIBRARY_PATH="$install_dir/cudnn9.5/lib:$LD_LIBRARY_PATH" + tar -Jxvf cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz -C "$install_dir/cudnn9.6" --strip=1 + export LD_LIBRARY_PATH="$install_dir/cudnn9.6/lib:$LD_LIBRARY_PATH" popd } -# Install optimum from source before 1.24 is released +# Function to install optimum install_optimum() { - pushd "$install_dir" + log "Installing Optimum" optimum_dir="$install_dir/optimum" if [ ! -d "$optimum_dir" ]; then - git clone https://github.com/huggingface/optimum + git clone https://github.com/huggingface/optimum "$optimum_dir" fi - cd "$sam2_dir" - pip show optimum > /dev/null 2>&1 || pip install -e . + pushd "$optimum_dir" + pip show optimum &>/dev/null || pip install -e . popd } -# Install onnxruntime-gpu from source before 1.21 is released +# Function to build and install ONNX Runtime install_onnxruntime() { + log "Building ONNX Runtime" pushd "$install_dir" - if ! [ -d onnxruntime ]; then + if [ ! -d onnxruntime ]; then git clone https://github.com/microsoft/onnxruntime fi - cd onnxruntime + pushd onnxruntime CUDA_ARCH=$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}{cc[1]}')") - if [ -n "$CUDA_ARCH" ]; then - pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==2.2 - sh build.sh --config Release --build_dir build/cuda12 --build_shared_lib --parallel \ - --use_cuda --cuda_version 12.6 --cuda_home $install_dir/cuda12.6 \ - --cudnn_home $install_dir/cudnn9.6 \ - --build_wheel --skip_tests \ - --cmake_generator Ninja \ - --compile_no_warning_as_error \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH - pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl - else + if [ -z "$CUDA_ARCH" ]; then echo "No CUDA device found." exit 1 fi + pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==2.2 + sh build.sh --config Release --build_dir build/cuda12 --parallel \ + --use_cuda --cuda_version 12.6 --cuda_home "$install_dir/cuda12.6" \ + --cudnn_home "$install_dir/cudnn9.6" \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --compile_no_warning_as_error \ + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF CMAKE_CUDA_ARCHITECTURES="$CUDA_ARCH" + + log "Installing ONNX Runtime" + pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl + popd popd } -# Install GPU dependencies +# Function to install GPU dependencies install_gpu() { + log "Installing GPU dependencies" [ ! -d "$install_dir/cuda12.6" ] && install_cuda_12 [ ! -d "$install_dir/cudnn9.6" ] && install_cudnn_9 pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 - - pip install diffusers==0.31.0 transformers==4.46.3 onnx==1.17.0 protobuf==5.29.2 - + pip install diffusers==0.32.0 transformers==4.46.3 onnx==1.17.0 protobuf==5.29.2 py3nvml install_onnxruntime - install_optimum } +# Function to run benchmarks run_benchmark() { local model=$1 local dir=$2 @@ -95,22 +106,26 @@ run_benchmark() { local steps=$4 local batch=$5 - mkdir -p $dir - [ ! -d "$dir/fp32" ] && optimum-cli export onnx --model $model $dir/fp32 --opset 15 --task text-to-image - [ ! -d "$dir/fp16_fp32" ] && python optimize_pipeline.py -i $dir/fp32 -o $dir/fp16_fp32 --float16 - [ ! -d "$dir/fp16_bf16" ] && python optimize_pipeline.py -i $dir/fp32 -o $dir/fp16_bf16 --float16 --bfloat16 - python benchmark.py -e optimum --height 1024 --width 1024 --steps $steps -b $batch -v $version -p $dir/fp16_fp32 - python benchmark.py -e optimum --height 1024 --width 1024 --steps $steps -b $batch -v $version -p $dir/fp16_bf16 - python benchmark.py -e torch --height 1024 --width 1024 --steps $steps -b $batch -v $version - python benchmark.py -e torch --height 1024 --width 1024 --steps $steps -b $batch -v $version --enable_torch_compile + log "Running benchmark for model: $model" + mkdir -p "$dir" + [ ! -d "$dir/fp32" ] && optimum-cli export onnx --model "$model" "$dir/fp32" --opset 15 --task text-to-image + [ ! -d "$dir/fp16_fp32" ] && python optimize_pipeline.py -i "$dir/fp32" -o "$dir/fp16_fp32" --float16 + [ ! -d "$dir/fp16_bf16" ] && python optimize_pipeline.py -i "$dir/fp32" -o "$dir/fp16_bf16" --float16 --bfloat16 + python benchmark.py -e optimum --height 1024 --width 1024 --steps "$steps" -b "$batch" -v "$version" -p "$dir/fp16_fp32" + python benchmark.py -e optimum --height 1024 --width 1024 --steps "$steps" -b "$batch" -v "$version" -p "$dir/fp16_bf16" + python benchmark.py -e torch --height 1024 --width 1024 --steps "$steps" -b "$batch" -v "$version" + python benchmark.py -e torch --height 1024 --width 1024 --steps "$steps" -b "$batch" -v "$version" --enable_torch_compile } +# Main script execution install_gpu -mkdir -p $root_dir +log "Creating ONNX model directory: $onnx_dir" +mkdir -p "$onnx_dir" -run_benchmark black-forest-labs/FLUX.1-schnell ${root_dir}/flux1_schnell Flux.1S 4 1 > $root_dir/flux1_schnell_s4_b1.log -run_benchmark black-forest-labs/FLUX.1-dev ${root_dir}/flux1_dev Flux.1D 50 1 > $root_dir/flux1_dev_s50_b1.log +run_benchmark black-forest-labs/FLUX.1-schnell "$onnx_dir/flux1_schnell" Flux.1S 4 1 > "$onnx_dir/flux1_schnell_s4_b1.log" +run_benchmark black-forest-labs/FLUX.1-dev "$onnx_dir/flux1_dev" Flux.1D 50 1 > "$onnx_dir/flux1_dev_s50_b1.log" +run_benchmark stabilityai/stable-diffusion-3.5-large "$onnx_dir/sd3.5_large" 3.5L 50 1 > "$onnx_dir/sd3.5_large_s50_b1.log" +run_benchmark stabilityai/stable-diffusion-3.5-medium "$onnx_dir/sd3.5_medium" 3.5M 50 1 > "$onnx_dir/sd3.5_medium_s50_b1.log" -run_benchmark stabilityai/stable-diffusion-3.5-large ${root_dir}/sd3.5_large 3.5L 50 1 > $root_dir/sd3.5_large_s50_b1.log -run_benchmark stabilityai/stable-diffusion-3.5-medium ${root_dir}/sd3.5_medium 3.5M 50 1 > $root_dir/sd3.5_medium_s50_b1.log +log "Benchmark completed." diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index caac8d4b60035..cdb6518be000f 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -163,8 +163,7 @@ def _optimize_sd_pipeline( # export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1 # export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1 # python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >stdout.txt 2>stderr.txt - # Warning: The node name might change in different export settings. We used python 3.10 and the following packages: - # diffusers==0.31.0 transformers==4.46.3 optimum==1.24.0.dev0 torch==2.5.1 onnx==1.17.0 protobuf==5.29.2 + # Warning: The node name might change in different export settings. See benchmark_flux.sh for the settings. flux_node_block_list = { "text_encoder_2": [ "/encoder/block.10/layer.1/DenseReluDense/wo/MatMul", From a47b6af567d91c33ec77b96f6f82613762a85119 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 10 Jan 2025 23:52:06 +0000 Subject: [PATCH 22/26] undo layer norm kernel --- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 1 - .../core/providers/cuda/nn/layer_norm.cc | 33 +++++-------------- .../core/providers/cuda/nn/layer_norm_impl.cu | 17 +++------- .../core/providers/cuda/nn/layer_norm_impl.h | 1 - .../tools/transformers/onnx_model_mmdit.py | 9 ----- 5 files changed, 13 insertions(+), 48 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 91e8577df487b..3299bc2cb11de 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -101,7 +101,6 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (double)epsilon_, // epsilon reinterpret_cast(gamma->Data()), // gamma (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta - 0, // broadcast stride for gamma/beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm.cc b/onnxruntime/core/providers/cuda/nn/layer_norm.cc index c430ffe5aa97d..7dd10f9c2960c 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/layer_norm.cc @@ -44,36 +44,19 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast(bias->Data()); const TensorShape& x_shape = X->Shape(); - auto x_num_dims = x_shape.NumDimensions(); - const int64_t axis = HandleNegativeAxis(axis_, x_num_dims); + const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); int n1 = gsl::narrow(x_shape.SizeToDimension(axis)); int n2 = gsl::narrow(x_shape.SizeFromDimension(axis)); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias_data) ? bias->Shape().Size() : 0; - - int broadcast = 0; if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) { - // Handle a special case for MMDit where scale and bias need broadcast. - // X shape is (B, S, D), scale and bias shape is (B, 1, D), and we store S as broadcast stride. - if (x_num_dims == 3 && axis == 2 && n2 > 1 && - scale->Shape().NumDimensions() == x_num_dims && - scale->Shape().GetDims()[0] == x_shape.GetDims()[0] && - scale->Shape().GetDims()[1] == 1 && - scale->Shape().GetDims()[2] == x_shape.GetDims()[2] && - bias->Shape().NumDimensions() == x_num_dims && - bias->Shape().GetDims()[0] == x_shape.GetDims()[0] && - bias->Shape().GetDims()[1] == 1 && - bias->Shape().GetDims()[2] == x_shape.GetDims()[2]) { - broadcast = static_cast(x_shape.GetDims()[1]); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Size of X.shape()[axis:] == ", n2, - ". Size of scale and bias (if provided) must match this " - "and the size must not be 1. Got scale size of ", - scale_size, " and bias size of ", bias_size); - } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", n2, + ". Size of scale and bias (if provided) must match this " + "and the size must not be 1. Got scale size of ", + scale_size, " and bias size of ", bias_size); } // Outputs @@ -82,7 +65,7 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con // Mean and variance std::vector mean_inv_std_var_dim; - for (int i = 0; i < static_cast(x_num_dims); ++i) { + for (int i = 0; i < static_cast(x_shape.NumDimensions()); ++i) { if (i < axis) { mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]); } else { @@ -111,7 +94,7 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con } HostApplyLayerNorm(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, - X_data, n1, n2, epsilon_, scale_data, bias_data, broadcast); + X_data, n1, n2, epsilon_, scale_data, bias_data); CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index c21943649775b..b9e8b45307079 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -334,7 +334,6 @@ __global__ void cuApplyLayerNorm( const U epsilon, const V* __restrict__ gamma, const V* __restrict__ beta, - int broadcast, const T* __restrict__ skip, const T* __restrict__ bias, T* __restrict__ skip_input_bias_add_output) { @@ -367,13 +366,8 @@ __global__ void cuApplyLayerNorm( curr += static_cast(skip_vals[i]); } - // onnx operator LayerNormalization support broadcast. - // gamma and beta should be unidirectional broadcastable to tensor x. - // Here we support a special case for transformer models that x is (B, S, D) and gamma/beta is (B, 1, D) - int index = (broadcast > 0) ? ((i1 / broadcast) * n2 + i) : i; - U gamma_i = (gamma != nullptr) ? (U)gamma[index] : (U)1; - U beta_i = (beta != nullptr) ? (U)beta[index] : (U)0; - + U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1; + U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0; if (simplified) { ovals[i] = static_cast(gamma_i * c_inv_std_dev * curr); } else { @@ -415,7 +409,6 @@ void HostApplyLayerNorm( double epsilon, const V* gamma, const V* beta, - int broadcast, const T* skip, const T* bias, T* skip_input_bias_add_output) { @@ -449,15 +442,15 @@ void HostApplyLayerNorm( input, n1, n2, U(epsilon), - gamma, beta, broadcast, + gamma, beta, skip, bias, skip_input_bias_add_output); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ - double epsilon, const V* gamma, const V* beta, int broadcast, \ - const T* skip, const T* bias, T* skip_input_bias_add_output); + double epsilon, const V* gamma, const V* beta, const T* skip, \ + const T* bias, T* skip_input_bias_add_output); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index 3ba895e8829b6..e3952eefae35d 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -41,7 +41,6 @@ void HostApplyLayerNorm( double epsilon, const V* gamma, const V* beta, - int broadcast = 0, // broadcast stride for gamma/beta const T* skip = nullptr, const T* bias = nullptr, T* skip_input_bias_add_output = nullptr); diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py index 80d408e671979..4c9b19c0c97ca 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -87,14 +87,6 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if progress_bar: progress_bar.update(1) - # TODO: SkipLayerNormalization does not support broadcast yet. - # if (options is None) or options.enable_skip_layer_norm: - # self.fuse_skip_simplified_layer_norm() - # self.fuse_skip_layer_norm() - # if (options is None) or options.enable_bias_skip_layer_norm: - # # Fuse SkipLayerNormalization and Add Bias before it. - # self.fuse_add_bias_skip_layer_norm() - self.postprocess() if progress_bar: progress_bar.update(1) @@ -110,7 +102,6 @@ def get_fused_operator_statistics(self): "FastGelu", "MultiHeadAttention", "LayerNormalization", - # "SkipLayerNormalization", "SimplifiedLayerNormalization", ] From 55178d67d0bf7e7c3dcc588e7b8cc4ff501b5f91 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 11 Jan 2025 00:01:01 +0000 Subject: [PATCH 23/26] CMAKE_CUDA_ARCHITECTURES=native --- .../transformers/models/stable_diffusion/benchmark_flux.sh | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh index 82a0b4f0746a5..2c7785eb8f62f 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh @@ -67,11 +67,6 @@ install_onnxruntime() { git clone https://github.com/microsoft/onnxruntime fi pushd onnxruntime - CUDA_ARCH=$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}{cc[1]}')") - if [ -z "$CUDA_ARCH" ]; then - echo "No CUDA device found." - exit 1 - fi pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==2.2 sh build.sh --config Release --build_dir build/cuda12 --parallel \ --use_cuda --cuda_version 12.6 --cuda_home "$install_dir/cuda12.6" \ @@ -79,7 +74,7 @@ install_onnxruntime() { --build_wheel --skip_tests \ --cmake_generator Ninja \ --compile_no_warning_as_error \ - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF CMAKE_CUDA_ARCHITECTURES="$CUDA_ARCH" + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF CMAKE_CUDA_ARCHITECTURES=native log "Installing ONNX Runtime" pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl From ebade480aec0b05acf6f0797624da6613e9540be Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 12 Jan 2025 04:11:17 +0000 Subject: [PATCH 24/26] add tests --- .../stable_diffusion/optimize_pipeline.py | 11 +- .../test_optimizer_stable_diffusion.py | 148 ++++++++++++++++++ 2 files changed, 156 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index cdb6518be000f..52d332848357f 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -297,6 +297,7 @@ def _optimize_sd_pipeline( f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}" ) + op_counters = {} for name, model_type in model_type_mapping.items(): onnx_model_path = source_dir / name / "model.onnx" if not os.path.exists(onnx_model_path): @@ -391,11 +392,13 @@ def _optimize_sd_pipeline( m = model_type_class_mapping[model_type](model) m.get_operator_statistics() - m.get_fused_operator_statistics() + op_counters[name] = m.get_fused_operator_statistics() m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format) logger.info("%s is optimized", name) logger.info("*" * 20) + return op_counters + def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: List[str]): """Copy extra directory that does not have onnx model @@ -463,7 +466,7 @@ def optimize_stable_diffusion_pipeline( _copy_extra_directory(source_dir, target_dir, model_list) - _optimize_sd_pipeline( + return _optimize_sd_pipeline( source_dir, target_dir, pipeline_type, @@ -571,7 +574,9 @@ def main(argv: Optional[List[str]] = None): args = parse_arguments(argv) logger.info("Arguments: %s", str(args)) - optimize_stable_diffusion_pipeline( + + # Return op counters for testing purpose. + return optimize_stable_diffusion_pipeline( args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args ) diff --git a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py index dca250f39fae2..bc4b2f9e63027 100644 --- a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py +++ b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py @@ -29,6 +29,8 @@ TINY_MODELS = { "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", + "stable-diffusion-3": "optimum-internal-testing/tiny-random-stable-diffusion-3", + "flux": "optimum-internal-testing/tiny-random-flux", } @@ -267,5 +269,151 @@ def test_optimize_sdxl_fp16(self): self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) +class TestSD3FluxOptimization(unittest.TestCase): + def optimize_sd3_or_flux( + self, model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16, atol + ): + from optimum.onnxruntime import ORTPipelineForText2Image + + if os.path.exists(export_onnx_dir): + shutil.rmtree(export_onnx_dir, ignore_errors=True) + + baseline = ORTPipelineForText2Image.from_pretrained(model_name, export=True, provider="CUDAExecutionProvider") + if not os.path.exists(export_onnx_dir): + baseline.save_pretrained(export_onnx_dir) + + argv = [ + "--input", + export_onnx_dir, + "--output", + optimized_onnx_dir, + "--overwrite", + "--disable_group_norm", + "--disable_bias_splitgelu", + ] + + if is_float16: + argv.append("--float16") + + op_counters = optimize_stable_diffusion(argv) + + for name in expected_op_counters: + self.assertTrue(name in op_counters) + for op, count in expected_op_counters[name].items(): + self.assertTrue(op in op_counters[name]) + self.assertEqual(op_counters[name][op], count) + + treatment = ORTPipelineForText2Image.from_pretrained(optimized_onnx_dir, provider="CUDAExecutionProvider") + batch_size, num_images_per_prompt, height, width = 1, 1, 64, 64 + inputs = { + "prompt": ["starry night by van gogh"] * batch_size, + "num_inference_steps": 3, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "output_type": "np", + } + + seed = 123 + np.random.seed(seed) + import torch + + baseline_outputs = baseline(**inputs, generator=torch.Generator(device="cuda").manual_seed(seed)) + + np.random.seed(seed) + treatment_outputs = treatment(**inputs, generator=torch.Generator(device="cuda").manual_seed(seed)) + + self.assertTrue(np.allclose(baseline_outputs.images[0], treatment_outputs.images[0], atol=atol)) + + @pytest.mark.slow + def test_sd3(self): + """This tests optimization of stable diffusion 3 pipeline""" + model_name = TINY_MODELS["stable-diffusion-3"] + + expected_op_counters = { + "transformer": { + "FastGelu": 3, + "MultiHeadAttention": 2, + "LayerNormalization": 8, + "SimplifiedLayerNormalization": 0, + }, + "vae_encoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 17}, + "vae_decoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 25}, + "text_encoder": { + "Attention": 2, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 2, + "SkipLayerNormalization": 4, + }, + "text_encoder_2": { + "Attention": 2, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 0, + "SkipLayerNormalization": 4, + }, + "text_encoder_3": { + "Attention": 2, + "MultiHeadAttention": 0, + "Gelu": 0, + "FastGelu": 2, + "BiasGelu": 0, + "GemmFastGelu": 0, + "LayerNormalization": 0, + "SimplifiedLayerNormalization": 2, + "SkipLayerNormalization": 0, + "SkipSimplifiedLayerNormalization": 3, + }, + } + + export_onnx_dir = "tiny-random-stable-diffusion-3" + optimized_onnx_dir = "tiny-random-stable-diffusion-3-optimized-fp32" + self.optimize_sd3_or_flux( + model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=False, atol=5e-3 + ) + + optimized_onnx_dir = "tiny-random-stable-diffusion-3-optimized-fp16" + self.optimize_sd3_or_flux( + model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=True, atol=5e-2 + ) + + @pytest.mark.slow + def test_flux(self): + """This tests optimization of flux pipeline""" + model_name = TINY_MODELS["flux"] + + expected_op_counters = { + "transformer": { + "FastGelu": 3, + "MultiHeadAttention": 2, + "LayerNormalization": 6, + "SimplifiedLayerNormalization": 6, + }, + "vae_encoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 8}, + "vae_decoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 10}, + "text_encoder": { + "Attention": 5, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 0, + "SkipLayerNormalization": 10, + }, + # The tiny flux uses clip, but FLUX.1-dev uses t5, so we skip op count verification for text_encoder_2. + "text_encoder_2": {}, + } + + export_onnx_dir = "tiny-random-flux" + optimized_onnx_dir = "tiny-random-flux-optimized-fp32" + self.optimize_sd3_or_flux( + model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=False, atol=1e-3 + ) + + optimized_onnx_dir = "tiny-random-flux-optimized-fp16" + self.optimize_sd3_or_flux( + model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=True, atol=5e-2 + ) + + if __name__ == "__main__": unittest.main() From fd227bb3367924c859f783b35c25c0f9fbe80fc6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 12 Jan 2025 08:21:01 +0000 Subject: [PATCH 25/26] update tests --- .../tools/transformers/fusion_group_norm.py | 4 +- .../tools/transformers/onnx_model_clip.py | 1 + .../test_optimizer_stable_diffusion.py | 340 +++++++++--------- 3 files changed, 166 insertions(+), 179 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index c718d2c27e015..c9bf52234d696 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -84,6 +84,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): instance_norm_scale = self.model.get_constant_value(instance_norm.input[1]) if instance_norm_scale is None or len(instance_norm_scale.shape) != 1: return + num_groups = int(instance_norm_scale.shape[0]) instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape: @@ -156,7 +157,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): ) new_node.attribute.extend(instance_norm.attribute) - new_node.attribute.extend([helper.make_attribute("groups", 32)]) + + new_node.attribute.extend([helper.make_attribute("groups", num_groups)]) new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)]) if not self.channels_last: diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py index 388d058c7856c..725be3c762e5a 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_clip.py +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -27,6 +27,7 @@ def get_fused_operator_statistics(self): "Gelu", "LayerNormalization", "QuickGelu", + "BiasGelu", "SkipLayerNormalization", ] for op in ops: diff --git a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py index bc4b2f9e63027..692382a12da9f 100644 --- a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py +++ b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py @@ -30,7 +30,7 @@ "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "stable-diffusion-3": "optimum-internal-testing/tiny-random-stable-diffusion-3", - "flux": "optimum-internal-testing/tiny-random-flux", + "flux": "tlwu/tiny-random-flux", } @@ -116,162 +116,17 @@ def test_clip_sd(self): float16=True, ) - @pytest.mark.slow - def test_clip_sdxl(self): - save_directory = "tiny-random-stable-diffusion-xl" - if os.path.exists(save_directory): - shutil.rmtree(save_directory, ignore_errors=True) - - model_type = "stable-diffusion-xl" - model_name = TINY_MODELS[model_type] - - from optimum.onnxruntime import ORTStableDiffusionXLPipeline - - base = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) - base.save_pretrained(save_directory) - - clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx") - optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx") - self.verify_clip_optimizer( - clip_onnx_path, - optimized_clip_onnx_path, - expected_counters={ - "EmbedLayerNormalization": 0, - "Attention": 5, - "SkipLayerNormalization": 10, - "LayerNormalization": 1, - "Gelu": 0, - "BiasGelu": 5, - }, - ) - - clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "model.onnx") - optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "opt.onnx") - self.verify_clip_optimizer( - clip_onnx_path, - optimized_clip_onnx_path, - expected_counters={ - "EmbedLayerNormalization": 0, - "Attention": 5, - "SkipLayerNormalization": 10, - "LayerNormalization": 1, - "Gelu": 0, - "BiasGelu": 5, - }, - ) - - @pytest.mark.slow - def test_optimize_sdxl_fp32(self): - save_directory = "tiny-random-stable-diffusion-xl" - if os.path.exists(save_directory): - shutil.rmtree(save_directory, ignore_errors=True) - - model_type = "stable-diffusion-xl" - model_name = TINY_MODELS[model_type] - - from optimum.onnxruntime import ORTStableDiffusionXLPipeline - - baseline = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) - if not os.path.exists(save_directory): - baseline.save_pretrained(save_directory) - - batch_size, num_images_per_prompt, height, width = 2, 2, 64, 64 - latents = baseline.prepare_latents( - batch_size * num_images_per_prompt, - baseline.unet.config["in_channels"], - height, - width, - dtype=np.float32, - generator=np.random.RandomState(0), - ) - - optimized_directory = "tiny-random-stable-diffusion-xl-optimized" - argv = [ - "--input", - save_directory, - "--output", - optimized_directory, - "--disable_group_norm", - "--disable_bias_splitgelu", - "--overwrite", - ] - optimize_stable_diffusion(argv) - - treatment = ORTStableDiffusionXLPipeline.from_pretrained(optimized_directory, provider="CUDAExecutionProvider") - inputs = { - "prompt": ["starry night by van gogh"] * batch_size, - "num_inference_steps": 3, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_rescale": 0.1, - "output_type": "np", - } - - ort_outputs_1 = baseline(latents=latents, **inputs) - ort_outputs_2 = treatment(latents=latents, **inputs) - self.assertTrue(np.allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-3)) - - @pytest.mark.slow - def test_optimize_sdxl_fp16(self): - """This tests optimized fp16 pipeline, and result is deterministic for a given seed""" - save_directory = "tiny-random-stable-diffusion-xl" - if os.path.exists(save_directory): - shutil.rmtree(save_directory, ignore_errors=True) - - model_type = "stable-diffusion-xl" - model_name = TINY_MODELS[model_type] - - from optimum.onnxruntime import ORTStableDiffusionXLPipeline - - baseline = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) - if not os.path.exists(save_directory): - baseline.save_pretrained(save_directory) - - optimized_directory = "tiny-random-stable-diffusion-xl-optimized-fp16" - argv = [ - "--input", - save_directory, - "--output", - optimized_directory, - "--disable_group_norm", - "--disable_bias_splitgelu", - "--float16", - "--overwrite", - ] - optimize_stable_diffusion(argv) - - fp16_pipeline = ORTStableDiffusionXLPipeline.from_pretrained( - optimized_directory, provider="CUDAExecutionProvider" - ) - batch_size, num_images_per_prompt, height, width = 1, 1, 64, 64 - inputs = { - "prompt": ["starry night by van gogh"] * batch_size, - "num_inference_steps": 3, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_rescale": 0.1, - "output_type": "latent", - } - - seed = 123 - np.random.seed(seed) - ort_outputs_1 = fp16_pipeline(**inputs) - - np.random.seed(seed) - ort_outputs_2 = fp16_pipeline(**inputs) - np.random.seed(seed) - ort_outputs_3 = fp16_pipeline(**inputs) - - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) - - -class TestSD3FluxOptimization(unittest.TestCase): - def optimize_sd3_or_flux( - self, model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16, atol +class TestStableDiffusionOrFluxPipelineOptimization(unittest.TestCase): + def verify_pipeline_optimization( + self, + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16, + atol, + disable_group_norm=False, ): from optimum.onnxruntime import ORTPipelineForText2Image @@ -288,26 +143,29 @@ def optimize_sd3_or_flux( "--output", optimized_onnx_dir, "--overwrite", - "--disable_group_norm", "--disable_bias_splitgelu", ] + if disable_group_norm: + argv.append("--disable_group_norm") + if is_float16: argv.append("--float16") op_counters = optimize_stable_diffusion(argv) + print(op_counters) for name in expected_op_counters: - self.assertTrue(name in op_counters) + self.assertIn(name, op_counters) for op, count in expected_op_counters[name].items(): - self.assertTrue(op in op_counters[name]) - self.assertEqual(op_counters[name][op], count) + self.assertIn(op, op_counters[name]) + self.assertEqual(op_counters[name][op], count, f"Expected {count} {op} in {name}") treatment = ORTPipelineForText2Image.from_pretrained(optimized_onnx_dir, provider="CUDAExecutionProvider") batch_size, num_images_per_prompt, height, width = 1, 1, 64, 64 inputs = { "prompt": ["starry night by van gogh"] * batch_size, - "num_inference_steps": 3, + "num_inference_steps": 20, "num_images_per_prompt": num_images_per_prompt, "height": height, "width": width, @@ -325,6 +183,122 @@ def optimize_sd3_or_flux( self.assertTrue(np.allclose(baseline_outputs.images[0], treatment_outputs.images[0], atol=atol)) + @pytest.mark.slow + def test_sd(self): + """This tests optimization of stable diffusion 1.x pipeline""" + model_name = TINY_MODELS["stable-diffusion"] + + expected_op_counters = { + "unet": { + "Attention": 6, + "MultiHeadAttention": 6, + "LayerNormalization": 6, + "SkipLayerNormalization": 12, + "BiasSplitGelu": 0, + "GroupNorm": 0, + "SkipGroupNorm": 0, + "NhwcConv": 47, + "BiasAdd": 0, + }, + "vae_encoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 13}, + "vae_decoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 17}, + "text_encoder": { + "Attention": 5, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 5, + "BiasGelu": 0, + "SkipLayerNormalization": 10, + }, + } + + export_onnx_dir = "tiny-random-sd" + optimized_onnx_dir = "tiny-random-sd-optimized-fp32" + # Disable GroupNorm due to limitation of current cuda kernel implementation. + self.verify_pipeline_optimization( + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16=False, + atol=5e-3, + disable_group_norm=True, + ) + + expected_op_counters["unet"].update({"Attention": 0, "MultiHeadAttention": 12}) + optimized_onnx_dir = "tiny-random-sd-optimized-fp16" + self.verify_pipeline_optimization( + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16=True, + atol=5e-2, + disable_group_norm=True, + ) + + @pytest.mark.slow + def test_sdxl(self): + """This tests optimization of SDXL pipeline""" + model_name = TINY_MODELS["stable-diffusion-xl"] + + expected_op_counters = { + "unet": { + "Attention": 12, + "MultiHeadAttention": 12, + "LayerNormalization": 6, + "SkipLayerNormalization": 30, + "BiasSplitGelu": 0, + "GroupNorm": 0, + "SkipGroupNorm": 0, + "NhwcConv": 35, + "BiasAdd": 0, + }, + "vae_encoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 13}, + "vae_decoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 17}, + "text_encoder": { + "Attention": 5, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 0, + "BiasGelu": 5, + "SkipLayerNormalization": 10, + }, + "text_encoder_2": { + "Attention": 5, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 0, + "BiasGelu": 5, + "SkipLayerNormalization": 10, + }, + } + + export_onnx_dir = "tiny-random-sdxl" + optimized_onnx_dir = "tiny-random-sdxl-optimized-fp32" + # Disable GroupNorm due to limitation of current cuda kernel implementation. + self.verify_pipeline_optimization( + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16=False, + atol=5e-3, + disable_group_norm=True, + ) + + expected_op_counters["unet"].update({"Attention": 0, "MultiHeadAttention": 24}) + optimized_onnx_dir = "tiny-random-sdxl-optimized-fp16" + self.verify_pipeline_optimization( + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16=True, + atol=5e-2, + disable_group_norm=True, + ) + @pytest.mark.slow def test_sd3(self): """This tests optimization of stable diffusion 3 pipeline""" @@ -337,8 +311,8 @@ def test_sd3(self): "LayerNormalization": 8, "SimplifiedLayerNormalization": 0, }, - "vae_encoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 17}, - "vae_decoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 25}, + "vae_encoder": {"Attention": 0, "GroupNorm": 10, "SkipGroupNorm": 3, "NhwcConv": 17}, + "vae_decoder": {"Attention": 0, "GroupNorm": 14, "SkipGroupNorm": 7, "NhwcConv": 25}, "text_encoder": { "Attention": 2, "Gelu": 0, @@ -369,12 +343,12 @@ def test_sd3(self): export_onnx_dir = "tiny-random-stable-diffusion-3" optimized_onnx_dir = "tiny-random-stable-diffusion-3-optimized-fp32" - self.optimize_sd3_or_flux( + self.verify_pipeline_optimization( model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=False, atol=5e-3 ) optimized_onnx_dir = "tiny-random-stable-diffusion-3-optimized-fp16" - self.optimize_sd3_or_flux( + self.verify_pipeline_optimization( model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=True, atol=5e-2 ) @@ -385,32 +359,42 @@ def test_flux(self): expected_op_counters = { "transformer": { - "FastGelu": 3, - "MultiHeadAttention": 2, - "LayerNormalization": 6, - "SimplifiedLayerNormalization": 6, + "FastGelu": 8, + "MultiHeadAttention": 6, + "LayerNormalization": 13, + "SimplifiedLayerNormalization": 16, }, - "vae_encoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 8}, - "vae_decoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 10}, + "vae_encoder": {"Attention": 0, "GroupNorm": 10, "SkipGroupNorm": 3, "NhwcConv": 17}, + "vae_decoder": {"Attention": 0, "GroupNorm": 14, "SkipGroupNorm": 7, "NhwcConv": 25}, "text_encoder": { - "Attention": 5, + "Attention": 2, "Gelu": 0, "LayerNormalization": 1, - "QuickGelu": 0, - "SkipLayerNormalization": 10, + "QuickGelu": 2, + "SkipLayerNormalization": 4, + }, + "text_encoder_2": { + "Attention": 2, + "MultiHeadAttention": 0, + "Gelu": 0, + "FastGelu": 2, + "BiasGelu": 0, + "GemmFastGelu": 0, + "LayerNormalization": 0, + "SimplifiedLayerNormalization": 2, + "SkipLayerNormalization": 0, + "SkipSimplifiedLayerNormalization": 3, }, - # The tiny flux uses clip, but FLUX.1-dev uses t5, so we skip op count verification for text_encoder_2. - "text_encoder_2": {}, } export_onnx_dir = "tiny-random-flux" optimized_onnx_dir = "tiny-random-flux-optimized-fp32" - self.optimize_sd3_or_flux( + self.verify_pipeline_optimization( model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=False, atol=1e-3 ) optimized_onnx_dir = "tiny-random-flux-optimized-fp16" - self.optimize_sd3_or_flux( + self.verify_pipeline_optimization( model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=True, atol=5e-2 ) From 87bd3ecc79405bbcdd321c7f52505828482826e9 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 14 Jan 2025 06:09:39 +0000 Subject: [PATCH 26/26] undo some change (move to another PR) --- .../framework/print_tensor_statistics_utils.h | 2 +- .../python/tools/transformers/float16.py | 69 ++++--------------- 2 files changed, 15 insertions(+), 56 deletions(-) diff --git a/onnxruntime/core/framework/print_tensor_statistics_utils.h b/onnxruntime/core/framework/print_tensor_statistics_utils.h index e4bb8b917d2b4..65360674e88d0 100644 --- a/onnxruntime/core/framework/print_tensor_statistics_utils.h +++ b/onnxruntime/core/framework/print_tensor_statistics_utils.h @@ -30,7 +30,7 @@ void PrintFloatStats(const T* data, size_t count) { size_t zero = 0; size_t subnormal = 0; for (size_t i = 0; i < count; i++) { - switch (my_fpclassify(data[i])) { + switch (my_fpclassify(*data)) { case FP_INFINITE: inf++; break; diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index c87f08b8c07fa..74adc951c4aa3 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -144,8 +144,6 @@ def make_value_info_from_tensor(tensor): "Upsample", ] -# Some operators do not support bfloat16 in CUDA. This is not a full list, just some common operators in transformers. -BF16_OP_BLACK_LIST = ["SkipSimplifiedLayerNormalization", "Attention", "MultiHeadAttention"] # Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices # Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this. @@ -157,19 +155,14 @@ class InitializerTracker: def __init__(self, initializer: TensorProto): self.initializer = initializer - self.bf16_nodes = [] self.fp32_nodes = [] self.fp16_nodes = [] - def add_node(self, node: NodeProto, dtype: int): - if dtype == TensorProto.FLOAT: + def add_node(self, node: NodeProto, is_node_blocked): + if is_node_blocked: self.fp32_nodes.append(node) - elif dtype == TensorProto.BFLOAT16: - self.bf16_nodes.append(node) - elif dtype == TensorProto.FLOAT16: - self.fp16_nodes.append(node) else: - raise ValueError("Invalid dtype") + self.fp16_nodes.append(node) def convert_float_to_float16( @@ -202,9 +195,6 @@ def convert_float_to_float16( Default to false, which will convert only the one needed to avoid precision loss. force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if this script's preference it to keep them in float32. - use_bfloat16_as_blocked_nodes_dtype(bool): use bfloat16 as the data type for blocked nodes. Default to False. - If the node does not support bfloat16, it will remain in float. - Raises: ValueError: input type is not ModelProto. @@ -344,19 +334,11 @@ def convert_float_to_float16( for i, input_name in enumerate(n.input): if input_name in fp32_initializers: # For Resize/GroupNorm, only the first input can be float16 - if i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) and i not in force_fp16_inputs_dict.get( - n.op_type, [] - ): - dtype = TensorProto.FLOAT - elif is_node_blocked: - dtype = ( - TensorProto.BFLOAT16 - if (use_bfloat16_as_blocked_nodes_dtype and n.op_type not in BF16_OP_BLACK_LIST) - else TensorProto.FLOAT - ) - else: - dtype = TensorProto.FLOAT16 - fp32_initializers[input_name].add_node(n, dtype) + use_fp32_weight = is_node_blocked or ( + i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) + and i not in force_fp16_inputs_dict.get(n.op_type, []) + ) + fp32_initializers[input_name].add_node(n, use_fp32_weight) if is_node_blocked: node_list.append(n) @@ -423,21 +405,15 @@ def convert_float_to_float16( queue = next_level - initializers_to_be_casted_to_bf16: Dict[str, TensorProto] = {} for value in fp32_initializers.values(): # By default, to avoid precision loss, do not convert an initializer to fp16 when it is used only by fp32 nodes. if force_fp16_initializers or value.fp16_nodes: value.initializer = convert_tensor_float_to_float16(value.initializer, min_positive_val, max_finite_val) value_info_list.append(make_value_info_from_tensor(value.initializer)) - if (value.fp32_nodes or value.bf16_nodes) and not force_fp16_initializers: + if value.fp32_nodes and not force_fp16_initializers: logger.info( - f"initializer is used by both fp32/bf16 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}" + f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}" ) - elif value.bf16_nodes: - # If float initializer is only used by bfloat16 nodes, need to convert it to bfloat16. - # However, numpy does not support bfloat16, so we will add a Cast node to conver it later. - initializers_to_be_casted_to_bf16[value.initializer.name] = value.initializer - continue # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. for node in mixed_float_type_node_list: @@ -460,16 +436,14 @@ def convert_float_to_float16( node.input[i] = output_name break + accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT # process the nodes in block list that doesn't support tensor(float16) for node in node_list: # if input's name is in the value_info_list meaning input is tensor(float16) type, - # insert a float16 to target type (float or bfloat16) Cast node before the node, + # insert a float16 to float Cast node before the node, # change current node's input name and create new value_info for the new name - use_bf16 = use_bfloat16_as_blocked_nodes_dtype and node.op_type not in BF16_OP_BLACK_LIST - accuracy_type = TensorProto.BFLOAT16 if use_bf16 else TensorProto.FLOAT for i in range(len(node.input)): input_name = node.input[i] - is_input_converted = False for value_info in value_info_list: if input_name == value_info.name: # create new value_info for current node's new input name @@ -484,24 +458,9 @@ def convert_float_to_float16( model.graph.node.extend(new_node) # change current node's input name node.input[i] = output_name - is_input_converted = True break - - # For bfloat16 nodes, we need to convert float initializers to bfloat16. - if (not is_input_converted) and use_bf16 and (input_name in initializers_to_be_casted_to_bf16): - output_name = node.name + "_input_cast_" + str(i) - value_info = helper.make_tensor_value_info( - name=output_name, elem_type=accuracy_type, shape=initializers_to_be_casted_to_bf16[input_name].dims - ) - new_value_info = model.graph.value_info.add() - new_value_info.CopyFrom(value_info) - node_name = node.name + "_input_cast" + str(i) - new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)] - model.graph.node.extend(new_node) - node.input[i] = output_name - - # if output's name is in the value_info_list meaning output is tensor(float16) type, insert a Cast (to float16) - # node after it, change current node's output name and create new value_info for the new name. + # if output's name is in the value_info_list meaning output is tensor(float16) type, insert a float to + # float16 Cast node after the node, change current node's output name and create new value_info for the new name for i in range(len(node.output)): output = node.output[i] for value_info in value_info_list: