From 977bb295deea9b7a5ba08c4bbdc1c898fb370394 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 13 Nov 2023 20:12:20 +0800 Subject: [PATCH 1/5] simplify stable_diffusion_pipeline_compiler --- README.md | 2 + .../stable_diffusion_pipeline_compiler.py | 222 +++++++----------- sfast/jit/trace_helper.py | 13 +- ...test_stable_diffusion_pipeline_compiler.py | 40 ++-- 4 files changed, 115 insertions(+), 162 deletions(-) diff --git a/README.md b/README.md index f3f56f6..8a1204d 100644 --- a/README.md +++ b/README.md @@ -261,6 +261,8 @@ The following code assumes you have already load a LoRA and compiled the model, and you want to switch to another LoRA. ```python +# load_state_dict with assign=True requires torch >= 2.1.0 + def update_state_dict(dst, src): for key, value in src.items(): # Do inplace copy. diff --git a/sfast/compilers/stable_diffusion_pipeline_compiler.py b/sfast/compilers/stable_diffusion_pipeline_compiler.py index 6b95466..452666c 100644 --- a/sfast/compilers/stable_diffusion_pipeline_compiler.py +++ b/sfast/compilers/stable_diffusion_pipeline_compiler.py @@ -6,7 +6,7 @@ import torch import sfast from sfast.jit import passes -from sfast.jit.trace_helper import lazy_trace, to_module +from sfast.jit.trace_helper import lazy_trace from sfast.jit import utils as jit_utils from sfast.cuda.graphs import make_dynamic_graphed_callable from sfast.utils import gpu_device @@ -24,159 +24,117 @@ class Default: ) enable_jit: bool = True enable_jit_freeze: bool = True + preserve_parameters: bool = True enable_cnn_optimization: bool = True prefer_lowp_gemm: bool = True enable_xformers: bool = False enable_cuda_graph: bool = False enable_triton: bool = False - enable_quantization: bool = False trace_scheduler: bool = False def compile(m, config): - with torch.no_grad(): - enable_cuda_graph = config.enable_cuda_graph and m.device.type == "cuda" + enable_cuda_graph = config.enable_cuda_graph and m.device.type == 'cuda' - scheduler = m.scheduler - scheduler._set_timesteps = scheduler.set_timesteps + scheduler = m.scheduler + scheduler._set_timesteps = scheduler.set_timesteps - def set_timesteps(self, num_timesteps: int, device: Union[str, torch.device]): - return self._set_timesteps(num_timesteps, device=torch.device("cpu")) + def set_timesteps(self, num_timesteps: int, device: Union[str, torch.device]): + return self._set_timesteps(num_timesteps, device=torch.device('cpu')) - scheduler.set_timesteps = set_timesteps.__get__(scheduler) - - if config.enable_xformers: - if config.enable_jit: - from sfast.utils.xformers_attention import ( - xformers_memory_efficient_attention, - ) - from xformers import ops - - ops.memory_efficient_attention = xformers_memory_efficient_attention - - m.enable_xformers_memory_efficient_attention() - - if config.memory_format == torch.channels_last: - m.unet.to(memory_format=torch.channels_last) - m.vae.to(memory_format=torch.channels_last) - if hasattr(m, "controlnet"): - m.controlnet.to(memory_format=torch.channels_last) - - if config.enable_quantization: - m.unet = torch.quantization.quantize_dynamic( - m.unet, {torch.nn.Linear}, dtype=torch.qint8, inplace=True - ) + scheduler.set_timesteps = set_timesteps.__get__(scheduler) + if config.enable_xformers: if config.enable_jit: - modify_model = functools.partial( - _modify_model, - enable_cnn_optimization=config.enable_cnn_optimization, - prefer_lowp_gemm=config.prefer_lowp_gemm, - enable_triton=config.enable_triton, - memory_format=config.memory_format, + from sfast.utils.xformers_attention import ( + xformers_memory_efficient_attention, ) + from xformers import ops + + ops.memory_efficient_attention = xformers_memory_efficient_attention + m.enable_xformers_memory_efficient_attention() + + if config.memory_format == torch.channels_last: + m.unet.to(memory_format=torch.channels_last) + m.vae.to(memory_format=torch.channels_last) + if hasattr(m, 'controlnet'): + m.controlnet.to(memory_format=torch.channels_last) + + if config.enable_jit: + modify_model = functools.partial( + _modify_model, + enable_cnn_optimization=config.enable_cnn_optimization, + prefer_lowp_gemm=config.prefer_lowp_gemm, + enable_triton=config.enable_triton, + memory_format=config.memory_format, + ) - def ts_compiler( - m, - call_helper, - inputs, - kwarg_inputs, - freeze=False, - enable_cuda_graph=False, - ): - with torch.jit.optimized_execution(True): - if freeze: - # raw freeze causes Tensor reference leak - # because the constant Tensors in the GraphFunction of - # the compilation unit are never freed. - m = jit_utils.better_freeze( + def ts_compiler( + m, + call_helper, + inputs, + kwarg_inputs, + freeze=False, + ): + with torch.jit.optimized_execution(True): + if freeze: + # raw freeze causes Tensor reference leak + # because the constant Tensors in the GraphFunction of + # the compilation unit are never freed. + m = jit_utils.better_freeze( m, - # preserve_parameters=True is probably not needed - # for load_state_dict() to work, because the - # traced graph and CUDA graph shares the same underlying - # data (pointer) for the parameters. - # preserve_parameters=False, + preserve_parameters=config.preserve_parameters, ) - modify_model(m) - - if enable_cuda_graph: - m = make_dynamic_graphed_callable(m) - return m - - lazy_trace_ = functools.partial( - lazy_trace, - ts_compiler=functools.partial( - ts_compiler, - freeze=config.enable_jit_freeze, - ), - check_trace=False, - strict=False, - ) - - m.text_encoder.forward = lazy_trace_(to_module(m.text_encoder.forward)) - unet_forward = lazy_trace( - to_module(m.unet.forward), - ts_compiler=functools.partial( - ts_compiler, - freeze=config.enable_jit_freeze, - enable_cuda_graph=enable_cuda_graph, - ), - check_trace=False, - strict=False, - ) + modify_model(m) + + return m + + lazy_trace_ = functools.partial( + lazy_trace, + ts_compiler=functools.partial( + ts_compiler, + freeze=config.enable_jit_freeze, + ), + check_trace=False, + strict=False, + ) - @functools.wraps(m.unet.forward) - def unet_forward_wrapper(sample, t, *args, **kwargs): - t = t.to(device=sample.device) - return unet_forward(sample, t, *args, **kwargs) - - m.unet.forward = unet_forward_wrapper - - if ( - not packaging.version.parse("2.0.0") - <= packaging.version.parse(torch.__version__) - < packaging.version.parse("2.1.0") - ): - """ - Weird bug in PyTorch 2.0.x - - RuntimeError: shape '[512, 512, 64, 64]' is invalid for input of size 2097152 - - When executing AttnProcessor in TorchScript - """ - m.vae.decode = lazy_trace_(to_module(m.vae.decode)) - # For img2img - m.vae.encoder.forward = lazy_trace_(to_module(m.vae.encoder.forward)) - m.vae.quant_conv.forward = lazy_trace_( - to_module(m.vae.quant_conv.forward) - ) - - if config.trace_scheduler: - m.scheduler.scale_model_input = lazy_trace_( - to_module(m.scheduler.scale_model_input) - ) - m.scheduler.step = lazy_trace_(to_module(m.scheduler.step)) - - if hasattr(m, "controlnet"): - controlnet_forward = lazy_trace( - to_module(m.controlnet.forward), - ts_compiler=functools.partial( - ts_compiler, - freeze=False, - enable_cuda_graph=enable_cuda_graph, - ), - check_trace=False, - strict=False, - ) - - @functools.wraps(m.controlnet.forward) - def controlnet_forward_wrapper(sample, t, *args, **kwargs): + m.text_encoder.forward = lazy_trace_(m.text_encoder.forward) + m.unet.forward = lazy_trace_(m.unet.forward) + if hasattr(m, 'controlnet'): + controlnet.forward = lazy_trace_(controlnet.forward) + if ( + not packaging.version.parse('2.0.0') + <= packaging.version.parse(torch.__version__) + < packaging.version.parse('2.1.0') + ): + """ + Weird bug in PyTorch 2.0.x + + RuntimeError: shape '[512, 512, 64, 64]' is invalid for input of size 2097152 + + When executing AttnProcessor in TorchScript + """ + m.vae.decode = lazy_trace_(m.vae.decode) + # For img2img + m.vae.encoder.forward = lazy_trace_(m.vae.encoder.forward) + m.vae.quant_conv.forward = lazy_trace_(m.vae.quant_conv.forward) + if config.trace_scheduler: + m.scheduler.scale_model_input = lazy_trace_(m.scheduler.scale_model_input) + m.scheduler.step = lazy_trace_(m.scheduler.step) + + if enable_cuda_graph: + for sub_m in [m.unet] + ([m.controlnet] if hasattr(m, 'controlnet') else []): + cuda_graphed_forward = make_dynamic_graphed_callable(sub_m.forward) + + @functools.wraps(sub_m.forward) + def forward_with_timestamp(sample, t, *args, **kwargs): t = t.to(device=sample.device) - return controlnet_forward(sample, t, *args, **kwargs) + return cuda_graphed_forward(sample, t, *args, **kwargs) - m.controlnet.forward = controlnet_forward_wrapper + sub_m.forward = forward_with_timestamp - return m + return m def _modify_model( diff --git a/sfast/jit/trace_helper.py b/sfast/jit/trace_helper.py index 391c4ed..d7725e6 100644 --- a/sfast/jit/trace_helper.py +++ b/sfast/jit/trace_helper.py @@ -32,8 +32,11 @@ def lazy_trace(func, *, ts_compiler=None, **kwargs_): lock = threading.Lock() traced_modules = {} - @functools.wraps( - func.forward if isinstance(func, torch.nn.Module) else func) + name = getattr(func, '__name__', func.__class__.__name__) + wraped = func.forward if isinstance(func, torch.nn.Module) else func + module_to_be_traced = to_module(wraped) + + @functools.wraps(wraped) def wrapper(*args, **kwargs): nonlocal lock, traced_modules key = (hash_arg(args), hash_arg(kwargs)) @@ -42,11 +45,9 @@ def wrapper(*args, **kwargs): with lock: traced_module = traced_modules.get(key) if traced_module is None: - logger.info( - f'Tracing {getattr(func, "__name__", func.__class__.__name__)}' - ) + logger.info(f'Tracing {name}') traced_m, call_helper = trace_with_kwargs( - func, args, kwargs, **kwargs_) + module_to_be_traced, args, kwargs, **kwargs_) if ts_compiler is not None: traced_m = ts_compiler(traced_m, call_helper, args, kwargs) diff --git a/tests/compilers/test_stable_diffusion_pipeline_compiler.py b/tests/compilers/test_stable_diffusion_pipeline_compiler.py index fd0c167..1285b79 100644 --- a/tests/compilers/test_stable_diffusion_pipeline_compiler.py +++ b/tests/compilers/test_stable_diffusion_pipeline_compiler.py @@ -2,6 +2,7 @@ import logging import functools +import packaging.version import os import glob import cv2 @@ -19,7 +20,7 @@ basic_kwarg_inputs = dict( - prompt="(masterpiece:1,2), best quality, masterpiece, best detail face, realistic, unreal engine, a beautiful girl", + prompt='(masterpiece:1,2), best quality, masterpiece, best detail face, realistic, unreal engine, a beautiful girl', height=512, width=512, num_inference_steps=30, @@ -31,7 +32,7 @@ def display_image(image): def get_images_from_path(path): - image_paths = sorted(glob.glob(os.path.join(path, "*.*"))) + image_paths = sorted(glob.glob(os.path.join(path, '*.*'))) images = [cv2.imread(image_path, cv2.IMREAD_COLOR) for image_path in image_paths] images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images] return images @@ -91,15 +92,6 @@ def test_benchmark_sd15_model_with_controlnet( ) -# def test_compile_sd15_model_with_quantization(sd15_model_path, skip_comparsion=True): -# benchmark_sd_model( -# sd15_model_path, -# kwarg_inputs=basic_kwarg_inputs, -# enable_quantization=True, -# skip_comparsion=skip_comparsion, -# ) - - def test_benchmark_sd21_model(sd21_model_path, skip_comparsion=False): benchmark_sd_model( sd21_model_path, @@ -154,7 +146,6 @@ def benchmark_sd_model( scheduler_class=None, controlnet_model_path=None, enable_cuda_graph=True, - enable_quantization=False, skip_comparsion=False, lora_a_path=None, lora_b_path=None, @@ -178,7 +169,7 @@ def load_model(): controlnet_model = ControlNetModel.from_pretrained( controlnet_model_path, torch_dtype=torch.float16 ) - model_init_kwargs["controlnet"] = controlnet_model + model_init_kwargs['controlnet'] = controlnet_model model = model_class.from_pretrained( model_path, torch_dtype=torch.float16, **model_init_kwargs @@ -186,7 +177,7 @@ def load_model(): if scheduler_class is not None: model.scheduler = scheduler_class.from_config(model.scheduler.config) model.safety_checker = None - model.to(torch.device("cuda")) + model.to(torch.device('cuda')) if lora_a_path is not None: model.unet.load_attn_procs(lora_a_path) @@ -196,7 +187,7 @@ def load_model(): with AutoProfiler(0.02) as profiler, low_compute_precision(): if not skip_comparsion: - logger.info("Benchmarking StableDiffusionPipeline") + logger.info('Benchmarking StableDiffusionPipeline') model = load_model() def call_original_model(): @@ -210,12 +201,12 @@ def call_original_model(): del model - if hasattr(torch, "compile"): + if hasattr(torch, 'compile'): model = load_model() - logger.info("Benchmarking StableDiffusionPipeline with torch.compile") + logger.info('Benchmarking StableDiffusionPipeline with torch.compile') model.unet.to(memory_format=torch.channels_last) model.unet = torch.compile(model.unet) - if hasattr(model, "controlnet"): + if hasattr(model, 'controlnet'): model.controlnet.to(memory_format=torch.channels_last) model.controlnet = torch.compile(model.controlnet) @@ -230,7 +221,7 @@ def call_torch_compiled_model(): del model - # logger.info("Benchmarking compiled StableDiffusionPipeline") + # logger.info('Benchmarking compiled StableDiffusionPipeline') # config = CompilationConfig.Default() # compiled_model = compile(load_model(), config) @@ -246,22 +237,21 @@ def call_torch_compiled_model(): # del compiled_model logger.info( - "Benchmarking compiled StableDiffusionPipeline with xformers, Triton and CUDA Graph" + 'Benchmarking compiled StableDiffusionPipeline with xformers, Triton and CUDA Graph' ) config = CompilationConfig.Default() try: import xformers config.enable_xformers = True except ImportError: - logger.warning("xformers not installed, skip") + logger.warning('xformers not installed, skip') try: import triton config.enable_triton = True except ImportError: - logger.warning("triton not installed, skip") + logger.warning('triton not installed, skip') # config.trace_scheduler = True config.enable_cuda_graph = enable_cuda_graph - config.enable_quantization = enable_quantization compiled_model = compile(load_model(), config) def call_faster_compiled_model(): @@ -273,7 +263,9 @@ def call_faster_compiled_model(): output_image = profiler.with_cProfile(call_faster_compiled_model)() display_image(output_image) - if lora_a_path is not None and lora_b_path is not None: + if lora_a_path is not None and lora_b_path is not None and packaging.version.parse( + torch.__version__) >= packaging.version.parse('2.1.0'): + # load_state_dict with assign=True requires torch >= 2.1.0 def update_state_dict(dst, src): for key, value in src.items(): From 9eef0df1f3e69405974edc673011dbf1b7f3ba7a Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 13 Nov 2023 20:52:11 +0800 Subject: [PATCH 2/5] simplify stable_diffusion_pipeline_compiler --- .../stable_diffusion_pipeline_compiler.py | 81 +++++++++---------- 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/sfast/compilers/stable_diffusion_pipeline_compiler.py b/sfast/compilers/stable_diffusion_pipeline_compiler.py index 452666c..82a1fd4 100644 --- a/sfast/compilers/stable_diffusion_pipeline_compiler.py +++ b/sfast/compilers/stable_diffusion_pipeline_compiler.py @@ -36,14 +36,6 @@ class Default: def compile(m, config): enable_cuda_graph = config.enable_cuda_graph and m.device.type == 'cuda' - scheduler = m.scheduler - scheduler._set_timesteps = scheduler.set_timesteps - - def set_timesteps(self, num_timesteps: int, device: Union[str, torch.device]): - return self._set_timesteps(num_timesteps, device=torch.device('cpu')) - - scheduler.set_timesteps = set_timesteps.__get__(scheduler) - if config.enable_xformers: if config.enable_jit: from sfast.utils.xformers_attention import ( @@ -54,11 +46,11 @@ def set_timesteps(self, num_timesteps: int, device: Union[str, torch.device]): ops.memory_efficient_attention = xformers_memory_efficient_attention m.enable_xformers_memory_efficient_attention() - if config.memory_format == torch.channels_last: - m.unet.to(memory_format=torch.channels_last) - m.vae.to(memory_format=torch.channels_last) + if config.memory_format is not None: + m.unet.to(memory_format=config.memory_format) + m.vae.to(memory_format=config.memory_format) if hasattr(m, 'controlnet'): - m.controlnet.to(memory_format=torch.channels_last) + m.controlnet.to(memory_format=config.memory_format) if config.enable_jit: modify_model = functools.partial( @@ -69,32 +61,16 @@ def set_timesteps(self, num_timesteps: int, device: Union[str, torch.device]): memory_format=config.memory_format, ) - def ts_compiler( - m, - call_helper, - inputs, - kwarg_inputs, - freeze=False, - ): - with torch.jit.optimized_execution(True): - if freeze: - # raw freeze causes Tensor reference leak - # because the constant Tensors in the GraphFunction of - # the compilation unit are never freed. - m = jit_utils.better_freeze( - m, - preserve_parameters=config.preserve_parameters, - ) - modify_model(m) - - return m + ts_compiler = functools.partial( + _ts_compiler, + freeze=config.enable_jit_freeze, + preserve_parameters=config.preserve_parameters, + modify_model_fn=modify_model, + ) lazy_trace_ = functools.partial( lazy_trace, - ts_compiler=functools.partial( - ts_compiler, - freeze=config.enable_jit_freeze, - ), + ts_compiler=ts_compiler, check_trace=False, strict=False, ) @@ -125,14 +101,7 @@ def ts_compiler( if enable_cuda_graph: for sub_m in [m.unet] + ([m.controlnet] if hasattr(m, 'controlnet') else []): - cuda_graphed_forward = make_dynamic_graphed_callable(sub_m.forward) - - @functools.wraps(sub_m.forward) - def forward_with_timestamp(sample, t, *args, **kwargs): - t = t.to(device=sample.device) - return cuda_graphed_forward(sample, t, *args, **kwargs) - - sub_m.forward = forward_with_timestamp + sub_m.forward = make_dynamic_graphed_callable(sub_m.forward) return m @@ -165,7 +134,7 @@ def _modify_model( if memory_format is not None: sfast._C._jit_pass_convert_op_input_tensors( - m.graph, "aten::_convolution", indices=[0], memory_format=memory_format + m.graph, 'aten::_convolution', indices=[0], memory_format=memory_format ) if enable_cnn_optimization: @@ -174,3 +143,27 @@ def _modify_model( if prefer_lowp_gemm: passes.jit_pass_prefer_lowp_gemm(m.graph) passes.jit_pass_fuse_lowp_linear_add(m.graph) + + +def _ts_compiler( + m, + call_helper, + inputs, + kwarg_inputs, + modify_model_fn=None, + freeze=False, + preserve_parameters=False, +): + with torch.jit.optimized_execution(True): + if freeze and not getattr(m, 'training', False): + # raw freeze causes Tensor reference leak + # because the constant Tensors in the GraphFunction of + # the compilation unit are never freed. + m = jit_utils.better_freeze( + m, + preserve_parameters=preserve_parameters, + ) + if modify_model_fn is not None: + modify_model_fn(m) + + return m \ No newline at end of file From d1169b5c8f197f4acbcb90817a1f4629567834b9 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 13 Nov 2023 21:09:12 +0800 Subject: [PATCH 3/5] add compile_unet --- .../stable_diffusion_pipeline_compiler.py | 100 +++++++++++------- 1 file changed, 61 insertions(+), 39 deletions(-) diff --git a/sfast/compilers/stable_diffusion_pipeline_compiler.py b/sfast/compilers/stable_diffusion_pipeline_compiler.py index 82a1fd4..9a759b6 100644 --- a/sfast/compilers/stable_diffusion_pipeline_compiler.py +++ b/sfast/compilers/stable_diffusion_pipeline_compiler.py @@ -34,51 +34,22 @@ class Default: def compile(m, config): + m.unet = compile_unet(m.unet, config) + if hasattr(m, 'controlnet'): + m.controlnet = compile_unet(m.controlnet, config) + enable_cuda_graph = config.enable_cuda_graph and m.device.type == 'cuda' if config.enable_xformers: - if config.enable_jit: - from sfast.utils.xformers_attention import ( - xformers_memory_efficient_attention, - ) - from xformers import ops - - ops.memory_efficient_attention = xformers_memory_efficient_attention - m.enable_xformers_memory_efficient_attention() + _enable_xformers(m) if config.memory_format is not None: - m.unet.to(memory_format=config.memory_format) m.vae.to(memory_format=config.memory_format) - if hasattr(m, 'controlnet'): - m.controlnet.to(memory_format=config.memory_format) if config.enable_jit: - modify_model = functools.partial( - _modify_model, - enable_cnn_optimization=config.enable_cnn_optimization, - prefer_lowp_gemm=config.prefer_lowp_gemm, - enable_triton=config.enable_triton, - memory_format=config.memory_format, - ) - - ts_compiler = functools.partial( - _ts_compiler, - freeze=config.enable_jit_freeze, - preserve_parameters=config.preserve_parameters, - modify_model_fn=modify_model, - ) - - lazy_trace_ = functools.partial( - lazy_trace, - ts_compiler=ts_compiler, - check_trace=False, - strict=False, - ) + lazy_trace_ = _build_lazy_trace(config) m.text_encoder.forward = lazy_trace_(m.text_encoder.forward) - m.unet.forward = lazy_trace_(m.unet.forward) - if hasattr(m, 'controlnet'): - controlnet.forward = lazy_trace_(controlnet.forward) if ( not packaging.version.parse('2.0.0') <= packaging.version.parse(torch.__version__) @@ -99,9 +70,24 @@ def compile(m, config): m.scheduler.scale_model_input = lazy_trace_(m.scheduler.scale_model_input) m.scheduler.step = lazy_trace_(m.scheduler.step) - if enable_cuda_graph: - for sub_m in [m.unet] + ([m.controlnet] if hasattr(m, 'controlnet') else []): - sub_m.forward = make_dynamic_graphed_callable(sub_m.forward) + return m + + +def compile_unet(m, config): + enable_cuda_graph = config.enable_cuda_graph and m.device.type == 'cuda' + + if config.enable_xformers: + _enable_xformers(m) + + if config.memory_format is not None: + m.to(memory_format=config.memory_format) + + if config.enable_jit: + lazy_trace_ = _build_lazy_trace(config) + m.forward = lazy_trace_(m.forward) + + if enable_cuda_graph: + m.forward = make_dynamic_graphed_callable(m.forward) return m @@ -166,4 +152,40 @@ def _ts_compiler( if modify_model_fn is not None: modify_model_fn(m) - return m \ No newline at end of file + return m + + +def _build_lazy_trace(config): + modify_model = functools.partial( + _modify_model, + enable_cnn_optimization=config.enable_cnn_optimization, + prefer_lowp_gemm=config.prefer_lowp_gemm, + enable_triton=config.enable_triton, + memory_format=config.memory_format, + ) + + ts_compiler = functools.partial( + _ts_compiler, + freeze=config.enable_jit_freeze, + preserve_parameters=config.preserve_parameters, + modify_model_fn=modify_model, + ) + + lazy_trace_ = functools.partial( + lazy_trace, + ts_compiler=ts_compiler, + check_trace=False, + strict=False, + ) + + return lazy_trace_ + + +def _enable_xformers(m): + from xformers import ops + from sfast.utils.xformers_attention import ( + xformers_memory_efficient_attention, + ) + + ops.memory_efficient_attention = xformers_memory_efficient_attention + m.enable_xformers_memory_efficient_attention() \ No newline at end of file From 5bc9fb8b1c856c29443d110e9903bfc1d3575896 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 13 Nov 2023 21:10:17 +0800 Subject: [PATCH 4/5] remove unused enable_cuda_graph --- sfast/compilers/stable_diffusion_pipeline_compiler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sfast/compilers/stable_diffusion_pipeline_compiler.py b/sfast/compilers/stable_diffusion_pipeline_compiler.py index 9a759b6..b448af4 100644 --- a/sfast/compilers/stable_diffusion_pipeline_compiler.py +++ b/sfast/compilers/stable_diffusion_pipeline_compiler.py @@ -38,8 +38,6 @@ def compile(m, config): if hasattr(m, 'controlnet'): m.controlnet = compile_unet(m.controlnet, config) - enable_cuda_graph = config.enable_cuda_graph and m.device.type == 'cuda' - if config.enable_xformers: _enable_xformers(m) From be7b5fe9da38c8ec71a5ed9baf455070226604f5 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Tue, 14 Nov 2023 09:01:52 +0800 Subject: [PATCH 5/5] refactor stable_diffusion_pipeline_compiler.py and trace_helper.py --- README.md | 5 +- .../stable_diffusion_pipeline_compiler.py | 84 ++++++++++++++----- sfast/jit/trace_helper.py | 2 + ...test_stable_diffusion_pipeline_compiler.py | 84 ++++++++++++------- 4 files changed, 124 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 8a1204d..388bbdc 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ __NOTE__: `stable-fast` is currently only in beta stage and is prone to be buggy Performance varies very greatly across different hardware/software/platform/driver configurations. It is very hard to benchmark accurately. And preparing the environment for benchmarking is also a hard job. I have tested on some platforms before but the results may still be inaccurate. +Note that when benchmarking, the progress bar showed by `tqdm` may be inaccurate because of the asynchronous nature of CUDA. `stable-fast` is expected to work better on newer GPUs and newer CUDA versions. __On older GPUs, the performance increase might be limited.__ @@ -253,7 +254,7 @@ output_image = compiled_model(**kwarg_inputs).images[0] ### Dynamically Switch LoRA Switching LoRA dynamically is supported but you need to do some extra work. -It is possible because the compiled graph and `CUDA Graph` shares the same +It is possible because the compiled graph and `CUDA Graph` share the same underlaying data (pointers) with the original UNet model. So all you need to do is to update the original UNet model's parameters inplace. @@ -266,7 +267,7 @@ and you want to switch to another LoRA. def update_state_dict(dst, src): for key, value in src.items(): # Do inplace copy. - # As the traced forward function shares the same reference of the tensors, + # As the traced forward function shares the same underlaying data (pointers), # this modification will be reflected in the traced forward function. dst[key].copy_(value) diff --git a/sfast/compilers/stable_diffusion_pipeline_compiler.py b/sfast/compilers/stable_diffusion_pipeline_compiler.py index b448af4..f515719 100644 --- a/sfast/compilers/stable_diffusion_pipeline_compiler.py +++ b/sfast/compilers/stable_diffusion_pipeline_compiler.py @@ -15,13 +15,49 @@ class CompilationConfig: + @dataclass class Default: + ''' + Default compilation config + + memory_format: + channels_last if tensor core is available, otherwise contiguous_format. + On GPUs with tensor core, channels_last is faster + enable_jit: + Whether to enable JIT, most optimizations are done with JIT + enable_jit_freeze: + Whether to freeze the model after JIT tracing. + Freezing the model will enable us to optimize the model further. + preserve_parameters: + Whether to preserve parameters when freezing the model. + If True, parameters will be preserved, but the model will be a bit slower. + If False, parameters will be marked as constants, and the model will be faster. + However, if parameters are not preserved, LoRA cannot be switched dynamically. + enable_cnn_optimization: + Whether to enable CNN optimization by fusion. + prefer_lowp_gemm: + Whether to prefer low-precision GEMM and a series of fusion optimizations. + This will make the model faster, but may cause numerical issues. + enable_xformers: + Whether to enable xformers and hijack it to make it compatible with JIT tracing. + enable_cuda_graph: + Whether to enable CUDA graph. CUDA Graph will significantly speed up the model, + by reducing the overhead of CUDA kernel launch, memory allocation, etc. + However, it will also increase the memory usage. + Our implementation of CUDA graph supports dynamic shape by caching graphs of + different shapes. + enable_triton: + Whether to enable Triton generated CUDA kernels. + Triton generated CUDA kernels are faster than PyTorch's CUDA kernels. + However, Triton has a lot of bugs, and can increase the CPU overhead, + though the overhead can be reduced by enabling CUDA graph. + trace_scheduler: + Whether to trace the scheduler. + ''' memory_format: torch.memory_format = ( - torch.channels_last - if gpu_device.device_has_tensor_core() - else torch.contiguous_format - ) + torch.channels_last if gpu_device.device_has_tensor_core() else + torch.contiguous_format) enable_jit: bool = True enable_jit_freeze: bool = True preserve_parameters: bool = True @@ -48,11 +84,8 @@ def compile(m, config): lazy_trace_ = _build_lazy_trace(config) m.text_encoder.forward = lazy_trace_(m.text_encoder.forward) - if ( - not packaging.version.parse('2.0.0') - <= packaging.version.parse(torch.__version__) - < packaging.version.parse('2.1.0') - ): + if (not packaging.version.parse('2.0.0') <= packaging.version.parse( + torch.__version__) < packaging.version.parse('2.1.0')): """ Weird bug in PyTorch 2.0.x @@ -65,14 +98,19 @@ def compile(m, config): m.vae.encoder.forward = lazy_trace_(m.vae.encoder.forward) m.vae.quant_conv.forward = lazy_trace_(m.vae.quant_conv.forward) if config.trace_scheduler: - m.scheduler.scale_model_input = lazy_trace_(m.scheduler.scale_model_input) + m.scheduler.scale_model_input = lazy_trace_( + m.scheduler.scale_model_input) m.scheduler.step = lazy_trace_(m.scheduler.step) return m def compile_unet(m, config): - enable_cuda_graph = config.enable_cuda_graph and m.device.type == 'cuda' + # attribute `device` is not generally available + device = m.device if hasattr(m, 'device') else torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + + enable_cuda_graph = config.enable_cuda_graph and device.type == 'cuda' if config.enable_xformers: _enable_xformers(m) @@ -118,8 +156,10 @@ def _modify_model( if memory_format is not None: sfast._C._jit_pass_convert_op_input_tensors( - m.graph, 'aten::_convolution', indices=[0], memory_format=memory_format - ) + m.graph, + 'aten::_convolution', + indices=[0], + memory_format=memory_format) if enable_cnn_optimization: passes.jit_pass_optimize_cnn(m.graph) @@ -144,9 +184,9 @@ def _ts_compiler( # because the constant Tensors in the GraphFunction of # the compilation unit are never freed. m = jit_utils.better_freeze( - m, - preserve_parameters=preserve_parameters, - ) + m, + preserve_parameters=preserve_parameters, + ) if modify_model_fn is not None: modify_model_fn(m) @@ -182,8 +222,14 @@ def _build_lazy_trace(config): def _enable_xformers(m): from xformers import ops from sfast.utils.xformers_attention import ( - xformers_memory_efficient_attention, - ) + xformers_memory_efficient_attention, ) ops.memory_efficient_attention = xformers_memory_efficient_attention - m.enable_xformers_memory_efficient_attention() \ No newline at end of file + + if hasattr(m, 'enable_xformers_memory_efficient_attention'): + m.enable_xformers_memory_efficient_attention() + else: + logger.warning( + 'enable_xformers_memory_efficient_attention() is not available.' + ' If you have enabled xformers by other means, ignore this warning.' + ) diff --git a/sfast/jit/trace_helper.py b/sfast/jit/trace_helper.py index d7725e6..4b0adb4 100644 --- a/sfast/jit/trace_helper.py +++ b/sfast/jit/trace_helper.py @@ -55,6 +55,8 @@ def wrapper(*args, **kwargs): traced_modules[key] = traced_module return traced_module(*args, **kwargs) + wrapper._traced_modules = traced_modules + return wrapper diff --git a/tests/compilers/test_stable_diffusion_pipeline_compiler.py b/tests/compilers/test_stable_diffusion_pipeline_compiler.py index 1285b79..c940929 100644 --- a/tests/compilers/test_stable_diffusion_pipeline_compiler.py +++ b/tests/compilers/test_stable_diffusion_pipeline_compiler.py @@ -18,9 +18,9 @@ logger = logging.getLogger() - basic_kwarg_inputs = dict( - prompt='(masterpiece:1,2), best quality, masterpiece, best detail face, realistic, unreal engine, a beautiful girl', + prompt= + '(masterpiece:1,2), best quality, masterpiece, best detail face, realistic, unreal engine, a beautiful girl', height=512, width=512, num_inference_steps=30, @@ -33,7 +33,9 @@ def display_image(image): def get_images_from_path(path): image_paths = sorted(glob.glob(os.path.join(path, '*.*'))) - images = [cv2.imread(image_path, cv2.IMREAD_COLOR) for image_path in image_paths] + images = [ + cv2.imread(image_path, cv2.IMREAD_COLOR) for image_path in image_paths + ] images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images] return images @@ -50,7 +52,10 @@ def test_compile_sd15_model(sd15_model_path, skip_comparsion=True): test_benchmark_sd15_model(sd15_model_path, skip_comparsion=skip_comparsion) -def test_benchmark_sd15_model_with_lora(sd15_model_path, sd15_lora_t4_path, sd15_lora_dog_path, skip_comparsion=False): +def test_benchmark_sd15_model_with_lora(sd15_model_path, + sd15_lora_t4_path, + sd15_lora_dog_path, + skip_comparsion=False): benchmark_sd_model( sd15_model_path, kwarg_inputs=basic_kwarg_inputs, @@ -59,7 +64,11 @@ def test_benchmark_sd15_model_with_lora(sd15_model_path, sd15_lora_t4_path, sd15 skip_comparsion=skip_comparsion, ) -def test_compile_sd15_model_with_lora(sd15_model_path, sd15_lora_t4_path, sd15_lora_dog_path, skip_comparsion=True): + +def test_compile_sd15_model_with_lora(sd15_model_path, + sd15_lora_t4_path, + sd15_lora_dog_path, + skip_comparsion=True): benchmark_sd_model( sd15_model_path, kwarg_inputs=basic_kwarg_inputs, @@ -69,10 +78,10 @@ def test_compile_sd15_model_with_lora(sd15_model_path, sd15_lora_t4_path, sd15_l ) -def test_benchmark_sd15_model_with_controlnet( - sd15_model_path, sd_controlnet_canny_model_path, diffusers_dog_example_path, - skip_comparsion=False -): +def test_benchmark_sd15_model_with_controlnet(sd15_model_path, + sd_controlnet_canny_model_path, + diffusers_dog_example_path, + skip_comparsion=False): from diffusers import StableDiffusionControlNetPipeline dog_image = get_images_from_path(diffusers_dog_example_path)[0] @@ -119,21 +128,21 @@ def test_compile_sdxl_model(sdxl_model_path, skip_comparsion=True): test_benchmark_sdxl_model(sdxl_model_path, skip_comparsion=skip_comparsion) -def test_compile_sd15_model_with_controlnet( - sd15_model_path, sd_controlnet_canny_model_path, diffusers_dog_example_path, - skip_comparsion=True -): - test_benchmark_sd15_model_with_controlnet( - sd15_model_path, sd_controlnet_canny_model_path, diffusers_dog_example_path, - skip_comparsion=skip_comparsion - ) +def test_compile_sd15_model_with_controlnet(sd15_model_path, + sd_controlnet_canny_model_path, + diffusers_dog_example_path, + skip_comparsion=True): + test_benchmark_sd15_model_with_controlnet(sd15_model_path, + sd_controlnet_canny_model_path, + diffusers_dog_example_path, + skip_comparsion=skip_comparsion) def call_model(model, inputs=None, kwarg_inputs=None): - inputs = tuple() if inputs is None else inputs() if callable(inputs) else inputs - kwarg_inputs = dict() if kwarg_inputs is None else kwarg_inputs() if callable( - kwarg_inputs - ) else kwarg_inputs + inputs = tuple() if inputs is None else inputs() if callable( + inputs) else inputs + kwarg_inputs = dict() if kwarg_inputs is None else kwarg_inputs( + ) if callable(kwarg_inputs) else kwarg_inputs torch.manual_seed(0) output_image = model(*inputs, **kwarg_inputs).images[0] return output_image @@ -167,20 +176,33 @@ def load_model(): from diffusers import ControlNetModel controlnet_model = ControlNetModel.from_pretrained( - controlnet_model_path, torch_dtype=torch.float16 - ) + controlnet_model_path, torch_dtype=torch.float16) model_init_kwargs['controlnet'] = controlnet_model - model = model_class.from_pretrained( - model_path, torch_dtype=torch.float16, **model_init_kwargs - ) + model = model_class.from_pretrained(model_path, + torch_dtype=torch.float16, + **model_init_kwargs) if scheduler_class is not None: - model.scheduler = scheduler_class.from_config(model.scheduler.config) + model.scheduler = scheduler_class.from_config( + model.scheduler.config) + model.safety_checker = None model.to(torch.device('cuda')) if lora_a_path is not None: model.unet.load_attn_procs(lora_a_path) + + # This is only for benchmarking purpose. + # Patch the scheduler to force a synchronize to make the progress bar work properly. + scheduler_step = model.scheduler.step + + def scheduler_step_(*args, **kwargs): + ret = scheduler_step(*args, **kwargs) + torch.cuda.synchronize() + return ret + + model.scheduler.step = scheduler_step_ + return model call_model_ = functools.partial(call_model, kwarg_inputs=kwarg_inputs) @@ -203,7 +225,8 @@ def call_original_model(): if hasattr(torch, 'compile'): model = load_model() - logger.info('Benchmarking StableDiffusionPipeline with torch.compile') + logger.info( + 'Benchmarking StableDiffusionPipeline with torch.compile') model.unet.to(memory_format=torch.channels_last) model.unet = torch.compile(model.unet) if hasattr(model, 'controlnet'): @@ -216,7 +239,8 @@ def call_torch_compiled_model(): for _ in range(3): call_torch_compiled_model() - output_image = profiler.with_cProfile(call_torch_compiled_model)() + output_image = profiler.with_cProfile( + call_torch_compiled_model)() display_image(output_image) del model @@ -264,7 +288,7 @@ def call_faster_compiled_model(): display_image(output_image) if lora_a_path is not None and lora_b_path is not None and packaging.version.parse( - torch.__version__) >= packaging.version.parse('2.1.0'): + torch.__version__) >= packaging.version.parse('2.1.0'): # load_state_dict with assign=True requires torch >= 2.1.0 def update_state_dict(dst, src):