Skip to content

Commit

Permalink
[CUDA] StableDiffusion XL demo with CUDA EP (#17997)
Browse files Browse the repository at this point in the history
Add CUDA EP to the StableDiffusion XL Demo including:
(1) Add fp16 VAE support for CUDA EP.
(2) Configuration for each model separately (For example, some models
can run with CUDA graph but some models cannot).

Some remaining works will boost performance further later:
(1) Enable CUDA Graph for Clip2 and UNet. Currently, some part of graph
is partitioned to CPU, which blocks CUDA graph.
(2) Update GroupNorm CUDA kernel for refiner. Currently, the cuda kernel
only supports limited number of channels in refiner so we shall see some
gain there if we remove the limitation.

Some extra works that are nice to have (thus lower priority):
(3) Support denoising_end to ensemble base and refiner.
(4) Support classifier free guidance (The idea is from
https://www.baseten.co/blog/sdxl-inference-in-under-2-seconds-the-ultimate-guide-to-stable-diffusion-optimiza/).


#### Performance on A100-SXM4-80GB

Example commands to test an engine built with static shape or dynamic
shape:
```
engine_name=ORT_CUDA
python demo_txt2img_xl.py --engine $engine_name "some prompt"
python demo_txt2img_xl.py --engine $engine_name --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape "some prompt"
```
Engine built with dynamic shape could support different batch size (1 to
4 for TRT; 1 to 16 for CUDA) and image size (256x256 to 1024x1024).
Engine built with static shape could only support fixed batch size (1)
and image size (1024x1024).

The latency (ms) of generating an image of size 1024x1024 (sorted by
total latency):

 Engine | Base (30 Steps)* | Refiner (9 Steps) | Total Latency (ms)
-- | -- | -- | --
ORT_TRT (static shape) | 2467 | 1033 | 3501
TRT (static shape) | 2507 | 1048 | 3555
ORT_CUDA (static shape) | 2630 | 1015 | 3645
ORT_CUDA (dynamic shape) | 2639 | 1016 | 3654
TRT (dynamic shape) | 2777 | 1099 | 3876
ORT_TRT (dynamic shape) | 2890 | 1166 | 4057

\* VAE decoder is not used in Base since the output from base is latent,
which is consumed by refiner to output image.

We can see that ORT_CUDA is faster on dynamic shape, while slower in
static shape (The cause is Clip2 and UNet cannot run with CUDA Graph
right now, and we will address the issue later).

### Motivation and Context
Follow up of #17536
  • Loading branch information
tianleiwu authored Oct 18, 2023
1 parent 61f1a16 commit 59ae3fd
Show file tree
Hide file tree
Showing 14 changed files with 379 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def init_pipeline(pipeline_class, pipeline_info):
base_pipeline_info = PipelineInfo(version)
demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info)

refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True)
refiner_pipeline_info = PipelineInfo(version, is_refiner=True)
demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info)

max_device_memory = max(demo_base.backend.max_device_memory(), demo_refiner.backend.max_device_memory())
Expand All @@ -887,7 +887,7 @@ def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False):
guidance=5.0,
warmup=warmup,
seed=seed,
return_type="latents",
return_type="latent",
)

images, time_refiner = demo_refiner.run(
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def init_pipeline(pipeline_class, pipeline_info):
base_pipeline_info = PipelineInfo(version)
demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info)

refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True)
refiner_pipeline_info = PipelineInfo(version, is_refiner=True)
demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info)

demo_base.load_resources(image_height, image_width, batch_size)
Expand All @@ -1053,7 +1053,7 @@ def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False):
guidance=5.0,
warmup=warmup,
seed=seed,
return_type="latents",
return_type="latent",
)
images, time_refiner = demo_refiner.run(
prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run_inference(warmup=False):
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
return_type="images",
return_type="image",
)

if not args.disable_cuda_graph:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
from pipeline_img2img_xl import Img2ImgXLPipeline
from pipeline_txt2img_xl import Txt2ImgXLPipeline

if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")

def run_demo():
"""Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""

args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo")

prompt, negative_prompt = repeat_prompt(args)

# Recommend image size as one of those used in training (see Appendix I in https://arxiv.org/pdf/2307.01952.pdf).
image_height = args.height
image_width = args.width

Expand All @@ -44,39 +48,32 @@
init_trt_plugins()

max_batch_size = 16
if args.build_dynamic_shape or image_height > 512 or image_width > 512:
if (engine_type in [EngineType.ORT_TRT, EngineType.TRT]) and (
args.build_dynamic_shape or image_height > 512 or image_width > 512
):
max_batch_size = 4

batch_size = len(prompt)
if batch_size > max_batch_size:
raise ValueError(
f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
)
raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.")

base_info = PipelineInfo(args.version, use_vae_in_xl_base=not args.enable_refiner)
# No VAE decoder in base when it outputs latent instead of image.
base_info = PipelineInfo(args.version, use_vae=False)
base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size)

if args.enable_refiner:
refiner_info = PipelineInfo(args.version, is_sd_xl_refiner=True)
refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size)
refiner_info = PipelineInfo(args.version, is_refiner=True)
refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size)

if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory())
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
base.backend.activate_engines(shared_device_memory)
refiner.backend.activate_engines(shared_device_memory)

base.load_resources(image_height, image_width, batch_size)
refiner.load_resources(image_height, image_width, batch_size)
else:
if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), base.backend.max_device_memory())
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
base.backend.activate_engines(shared_device_memory)
if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory())
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
base.backend.activate_engines(shared_device_memory)
refiner.backend.activate_engines(shared_device_memory)

base.load_resources(image_height, image_width, batch_size)
base.load_resources(image_height, image_width, batch_size)
refiner.load_resources(image_height, image_width, batch_size)

def run_sd_xl_inference(enable_refiner: bool, warmup=False):
def run_base_and_refiner(warmup=False):
images, time_base = base.run(
prompt,
negative_prompt,
Expand All @@ -86,44 +83,46 @@ def run_sd_xl_inference(enable_refiner: bool, warmup=False):
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
return_type="latents" if enable_refiner else "images",
return_type="latent",
)

images, time_refiner = refiner.run(
prompt,
negative_prompt,
images,
image_height,
image_width,
warmup=warmup,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
)

if enable_refiner:
images, time_refiner = refiner.run(
prompt,
negative_prompt,
images,
image_height,
image_width,
warmup=warmup,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
)
return images, time_base + time_refiner
else:
return images, time_base
return images, time_base + time_refiner

if not args.disable_cuda_graph:
# inference once to get cuda graph
images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True)
_, _ = run_base_and_refiner(warmup=True)

print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True)
_, _ = run_base_and_refiner(warmup=True)

print("[I] Running StableDiffusion XL pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
images, pipeline_time = run_sd_xl_inference(args.enable_refiner, warmup=False)
_, latency = run_base_and_refiner(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()

base.teardown()

if args.enable_refiner:
print("|------------|--------------|")
print("| {:^10} | {:>9.2f} ms |".format("e2e", pipeline_time))
print("|------------|--------------|")
refiner.teardown()
print("|------------|--------------|")
print("| {:^10} | {:>9.2f} ms |".format("e2e", latency))
print("|------------|--------------|")
refiner.teardown()


if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
run_demo()
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatte
def parse_arguments(is_xl: bool, description: str):
parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter)

engines = ["ORT_TRT", "TRT"] if is_xl else ["ORT_CUDA", "ORT_TRT", "TRT"]
engines = ["ORT_CUDA", "ORT_TRT", "TRT"]

parser.add_argument(
"--engine",
Expand Down Expand Up @@ -95,7 +95,7 @@ def parse_arguments(is_xl: bool, description: str):
"--denoising-steps",
type=int,
default=30 if is_xl else 50,
help="Number of denoising steps" + (" in each of base and refiner." if is_xl else "."),
help="Number of denoising steps" + (" in base." if is_xl else "."),
)

parser.add_argument(
Expand Down Expand Up @@ -158,12 +158,6 @@ def parse_arguments(is_xl: bool, description: str):
"--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources."
)

# Pipeline options
if is_xl:
parser.add_argument(
"--enable-refiner", action="store_true", help="Enable refiner and run both base and refiner pipelines."
)

args = parser.parse_args()

if (
Expand Down Expand Up @@ -203,6 +197,7 @@ def repeat_prompt(args):
raise ValueError(
f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}"
)

if len(args.negative_prompt) == 1:
negative_prompt = args.negative_prompt * len(prompt)
else:
Expand Down Expand Up @@ -236,16 +231,11 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
onnx_opset=args.onnx_opset,
opt_image_height=args.height,
opt_image_width=args.height,
opt_batch_size=batch_size,
force_engine_rebuild=args.force_engine_build,
device_id=torch.cuda.current_device(),
disable_cuda_graph_models=[
"clip2", # TODO: Add ArgMax cuda kernel to enable cuda graph for clip2.
"unetxl",
],
)
elif engine_type == EngineType.ORT_TRT:
# Build TensorRT EP engines and load pytorch modules
Expand Down
Loading

0 comments on commit 59ae3fd

Please sign in to comment.