diff --git a/README.md b/README.md index f3f56f6..388bbdc 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ __NOTE__: `stable-fast` is currently only in beta stage and is prone to be buggy Performance varies very greatly across different hardware/software/platform/driver configurations. It is very hard to benchmark accurately. And preparing the environment for benchmarking is also a hard job. I have tested on some platforms before but the results may still be inaccurate. +Note that when benchmarking, the progress bar showed by `tqdm` may be inaccurate because of the asynchronous nature of CUDA. `stable-fast` is expected to work better on newer GPUs and newer CUDA versions. __On older GPUs, the performance increase might be limited.__ @@ -253,7 +254,7 @@ output_image = compiled_model(**kwarg_inputs).images[0] ### Dynamically Switch LoRA Switching LoRA dynamically is supported but you need to do some extra work. -It is possible because the compiled graph and `CUDA Graph` shares the same +It is possible because the compiled graph and `CUDA Graph` share the same underlaying data (pointers) with the original UNet model. So all you need to do is to update the original UNet model's parameters inplace. @@ -261,10 +262,12 @@ The following code assumes you have already load a LoRA and compiled the model, and you want to switch to another LoRA. ```python +# load_state_dict with assign=True requires torch >= 2.1.0 + def update_state_dict(dst, src): for key, value in src.items(): # Do inplace copy. - # As the traced forward function shares the same reference of the tensors, + # As the traced forward function shares the same underlaying data (pointers), # this modification will be reflected in the traced forward function. dst[key].copy_(value) diff --git a/sfast/compilers/stable_diffusion_pipeline_compiler.py b/sfast/compilers/stable_diffusion_pipeline_compiler.py index 6b95466..f515719 100644 --- a/sfast/compilers/stable_diffusion_pipeline_compiler.py +++ b/sfast/compilers/stable_diffusion_pipeline_compiler.py @@ -6,7 +6,7 @@ import torch import sfast from sfast.jit import passes -from sfast.jit.trace_helper import lazy_trace, to_module +from sfast.jit.trace_helper import lazy_trace from sfast.jit import utils as jit_utils from sfast.cuda.graphs import make_dynamic_graphed_callable from sfast.utils import gpu_device @@ -15,168 +15,117 @@ class CompilationConfig: + @dataclass class Default: + ''' + Default compilation config + + memory_format: + channels_last if tensor core is available, otherwise contiguous_format. + On GPUs with tensor core, channels_last is faster + enable_jit: + Whether to enable JIT, most optimizations are done with JIT + enable_jit_freeze: + Whether to freeze the model after JIT tracing. + Freezing the model will enable us to optimize the model further. + preserve_parameters: + Whether to preserve parameters when freezing the model. + If True, parameters will be preserved, but the model will be a bit slower. + If False, parameters will be marked as constants, and the model will be faster. + However, if parameters are not preserved, LoRA cannot be switched dynamically. + enable_cnn_optimization: + Whether to enable CNN optimization by fusion. + prefer_lowp_gemm: + Whether to prefer low-precision GEMM and a series of fusion optimizations. + This will make the model faster, but may cause numerical issues. + enable_xformers: + Whether to enable xformers and hijack it to make it compatible with JIT tracing. + enable_cuda_graph: + Whether to enable CUDA graph. CUDA Graph will significantly speed up the model, + by reducing the overhead of CUDA kernel launch, memory allocation, etc. + However, it will also increase the memory usage. + Our implementation of CUDA graph supports dynamic shape by caching graphs of + different shapes. + enable_triton: + Whether to enable Triton generated CUDA kernels. + Triton generated CUDA kernels are faster than PyTorch's CUDA kernels. + However, Triton has a lot of bugs, and can increase the CPU overhead, + though the overhead can be reduced by enabling CUDA graph. + trace_scheduler: + Whether to trace the scheduler. + ''' memory_format: torch.memory_format = ( - torch.channels_last - if gpu_device.device_has_tensor_core() - else torch.contiguous_format - ) + torch.channels_last if gpu_device.device_has_tensor_core() else + torch.contiguous_format) enable_jit: bool = True enable_jit_freeze: bool = True + preserve_parameters: bool = True enable_cnn_optimization: bool = True prefer_lowp_gemm: bool = True enable_xformers: bool = False enable_cuda_graph: bool = False enable_triton: bool = False - enable_quantization: bool = False trace_scheduler: bool = False def compile(m, config): - with torch.no_grad(): - enable_cuda_graph = config.enable_cuda_graph and m.device.type == "cuda" + m.unet = compile_unet(m.unet, config) + if hasattr(m, 'controlnet'): + m.controlnet = compile_unet(m.controlnet, config) - scheduler = m.scheduler - scheduler._set_timesteps = scheduler.set_timesteps + if config.enable_xformers: + _enable_xformers(m) - def set_timesteps(self, num_timesteps: int, device: Union[str, torch.device]): - return self._set_timesteps(num_timesteps, device=torch.device("cpu")) + if config.memory_format is not None: + m.vae.to(memory_format=config.memory_format) - scheduler.set_timesteps = set_timesteps.__get__(scheduler) + if config.enable_jit: + lazy_trace_ = _build_lazy_trace(config) - if config.enable_xformers: - if config.enable_jit: - from sfast.utils.xformers_attention import ( - xformers_memory_efficient_attention, - ) - from xformers import ops + m.text_encoder.forward = lazy_trace_(m.text_encoder.forward) + if (not packaging.version.parse('2.0.0') <= packaging.version.parse( + torch.__version__) < packaging.version.parse('2.1.0')): + """ + Weird bug in PyTorch 2.0.x - ops.memory_efficient_attention = xformers_memory_efficient_attention + RuntimeError: shape '[512, 512, 64, 64]' is invalid for input of size 2097152 - m.enable_xformers_memory_efficient_attention() + When executing AttnProcessor in TorchScript + """ + m.vae.decode = lazy_trace_(m.vae.decode) + # For img2img + m.vae.encoder.forward = lazy_trace_(m.vae.encoder.forward) + m.vae.quant_conv.forward = lazy_trace_(m.vae.quant_conv.forward) + if config.trace_scheduler: + m.scheduler.scale_model_input = lazy_trace_( + m.scheduler.scale_model_input) + m.scheduler.step = lazy_trace_(m.scheduler.step) - if config.memory_format == torch.channels_last: - m.unet.to(memory_format=torch.channels_last) - m.vae.to(memory_format=torch.channels_last) - if hasattr(m, "controlnet"): - m.controlnet.to(memory_format=torch.channels_last) + return m - if config.enable_quantization: - m.unet = torch.quantization.quantize_dynamic( - m.unet, {torch.nn.Linear}, dtype=torch.qint8, inplace=True - ) - if config.enable_jit: - modify_model = functools.partial( - _modify_model, - enable_cnn_optimization=config.enable_cnn_optimization, - prefer_lowp_gemm=config.prefer_lowp_gemm, - enable_triton=config.enable_triton, - memory_format=config.memory_format, - ) +def compile_unet(m, config): + # attribute `device` is not generally available + device = m.device if hasattr(m, 'device') else torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') - def ts_compiler( - m, - call_helper, - inputs, - kwarg_inputs, - freeze=False, - enable_cuda_graph=False, - ): - with torch.jit.optimized_execution(True): - if freeze: - # raw freeze causes Tensor reference leak - # because the constant Tensors in the GraphFunction of - # the compilation unit are never freed. - m = jit_utils.better_freeze( - m, - # preserve_parameters=True is probably not needed - # for load_state_dict() to work, because the - # traced graph and CUDA graph shares the same underlying - # data (pointer) for the parameters. - # preserve_parameters=False, - ) - modify_model(m) - - if enable_cuda_graph: - m = make_dynamic_graphed_callable(m) - return m - - lazy_trace_ = functools.partial( - lazy_trace, - ts_compiler=functools.partial( - ts_compiler, - freeze=config.enable_jit_freeze, - ), - check_trace=False, - strict=False, - ) + enable_cuda_graph = config.enable_cuda_graph and device.type == 'cuda' - m.text_encoder.forward = lazy_trace_(to_module(m.text_encoder.forward)) - unet_forward = lazy_trace( - to_module(m.unet.forward), - ts_compiler=functools.partial( - ts_compiler, - freeze=config.enable_jit_freeze, - enable_cuda_graph=enable_cuda_graph, - ), - check_trace=False, - strict=False, - ) + if config.enable_xformers: + _enable_xformers(m) + + if config.memory_format is not None: + m.to(memory_format=config.memory_format) + + if config.enable_jit: + lazy_trace_ = _build_lazy_trace(config) + m.forward = lazy_trace_(m.forward) - @functools.wraps(m.unet.forward) - def unet_forward_wrapper(sample, t, *args, **kwargs): - t = t.to(device=sample.device) - return unet_forward(sample, t, *args, **kwargs) - - m.unet.forward = unet_forward_wrapper - - if ( - not packaging.version.parse("2.0.0") - <= packaging.version.parse(torch.__version__) - < packaging.version.parse("2.1.0") - ): - """ - Weird bug in PyTorch 2.0.x - - RuntimeError: shape '[512, 512, 64, 64]' is invalid for input of size 2097152 - - When executing AttnProcessor in TorchScript - """ - m.vae.decode = lazy_trace_(to_module(m.vae.decode)) - # For img2img - m.vae.encoder.forward = lazy_trace_(to_module(m.vae.encoder.forward)) - m.vae.quant_conv.forward = lazy_trace_( - to_module(m.vae.quant_conv.forward) - ) - - if config.trace_scheduler: - m.scheduler.scale_model_input = lazy_trace_( - to_module(m.scheduler.scale_model_input) - ) - m.scheduler.step = lazy_trace_(to_module(m.scheduler.step)) - - if hasattr(m, "controlnet"): - controlnet_forward = lazy_trace( - to_module(m.controlnet.forward), - ts_compiler=functools.partial( - ts_compiler, - freeze=False, - enable_cuda_graph=enable_cuda_graph, - ), - check_trace=False, - strict=False, - ) - - @functools.wraps(m.controlnet.forward) - def controlnet_forward_wrapper(sample, t, *args, **kwargs): - t = t.to(device=sample.device) - return controlnet_forward(sample, t, *args, **kwargs) - - m.controlnet.forward = controlnet_forward_wrapper - - return m + if enable_cuda_graph: + m.forward = make_dynamic_graphed_callable(m.forward) + + return m def _modify_model( @@ -207,8 +156,10 @@ def _modify_model( if memory_format is not None: sfast._C._jit_pass_convert_op_input_tensors( - m.graph, "aten::_convolution", indices=[0], memory_format=memory_format - ) + m.graph, + 'aten::_convolution', + indices=[0], + memory_format=memory_format) if enable_cnn_optimization: passes.jit_pass_optimize_cnn(m.graph) @@ -216,3 +167,69 @@ def _modify_model( if prefer_lowp_gemm: passes.jit_pass_prefer_lowp_gemm(m.graph) passes.jit_pass_fuse_lowp_linear_add(m.graph) + + +def _ts_compiler( + m, + call_helper, + inputs, + kwarg_inputs, + modify_model_fn=None, + freeze=False, + preserve_parameters=False, +): + with torch.jit.optimized_execution(True): + if freeze and not getattr(m, 'training', False): + # raw freeze causes Tensor reference leak + # because the constant Tensors in the GraphFunction of + # the compilation unit are never freed. + m = jit_utils.better_freeze( + m, + preserve_parameters=preserve_parameters, + ) + if modify_model_fn is not None: + modify_model_fn(m) + + return m + + +def _build_lazy_trace(config): + modify_model = functools.partial( + _modify_model, + enable_cnn_optimization=config.enable_cnn_optimization, + prefer_lowp_gemm=config.prefer_lowp_gemm, + enable_triton=config.enable_triton, + memory_format=config.memory_format, + ) + + ts_compiler = functools.partial( + _ts_compiler, + freeze=config.enable_jit_freeze, + preserve_parameters=config.preserve_parameters, + modify_model_fn=modify_model, + ) + + lazy_trace_ = functools.partial( + lazy_trace, + ts_compiler=ts_compiler, + check_trace=False, + strict=False, + ) + + return lazy_trace_ + + +def _enable_xformers(m): + from xformers import ops + from sfast.utils.xformers_attention import ( + xformers_memory_efficient_attention, ) + + ops.memory_efficient_attention = xformers_memory_efficient_attention + + if hasattr(m, 'enable_xformers_memory_efficient_attention'): + m.enable_xformers_memory_efficient_attention() + else: + logger.warning( + 'enable_xformers_memory_efficient_attention() is not available.' + ' If you have enabled xformers by other means, ignore this warning.' + ) diff --git a/sfast/jit/trace_helper.py b/sfast/jit/trace_helper.py index 391c4ed..4b0adb4 100644 --- a/sfast/jit/trace_helper.py +++ b/sfast/jit/trace_helper.py @@ -32,8 +32,11 @@ def lazy_trace(func, *, ts_compiler=None, **kwargs_): lock = threading.Lock() traced_modules = {} - @functools.wraps( - func.forward if isinstance(func, torch.nn.Module) else func) + name = getattr(func, '__name__', func.__class__.__name__) + wraped = func.forward if isinstance(func, torch.nn.Module) else func + module_to_be_traced = to_module(wraped) + + @functools.wraps(wraped) def wrapper(*args, **kwargs): nonlocal lock, traced_modules key = (hash_arg(args), hash_arg(kwargs)) @@ -42,11 +45,9 @@ def wrapper(*args, **kwargs): with lock: traced_module = traced_modules.get(key) if traced_module is None: - logger.info( - f'Tracing {getattr(func, "__name__", func.__class__.__name__)}' - ) + logger.info(f'Tracing {name}') traced_m, call_helper = trace_with_kwargs( - func, args, kwargs, **kwargs_) + module_to_be_traced, args, kwargs, **kwargs_) if ts_compiler is not None: traced_m = ts_compiler(traced_m, call_helper, args, kwargs) @@ -54,6 +55,8 @@ def wrapper(*args, **kwargs): traced_modules[key] = traced_module return traced_module(*args, **kwargs) + wrapper._traced_modules = traced_modules + return wrapper diff --git a/tests/compilers/test_stable_diffusion_pipeline_compiler.py b/tests/compilers/test_stable_diffusion_pipeline_compiler.py index fd0c167..c940929 100644 --- a/tests/compilers/test_stable_diffusion_pipeline_compiler.py +++ b/tests/compilers/test_stable_diffusion_pipeline_compiler.py @@ -2,6 +2,7 @@ import logging import functools +import packaging.version import os import glob import cv2 @@ -17,9 +18,9 @@ logger = logging.getLogger() - basic_kwarg_inputs = dict( - prompt="(masterpiece:1,2), best quality, masterpiece, best detail face, realistic, unreal engine, a beautiful girl", + prompt= + '(masterpiece:1,2), best quality, masterpiece, best detail face, realistic, unreal engine, a beautiful girl', height=512, width=512, num_inference_steps=30, @@ -31,8 +32,10 @@ def display_image(image): def get_images_from_path(path): - image_paths = sorted(glob.glob(os.path.join(path, "*.*"))) - images = [cv2.imread(image_path, cv2.IMREAD_COLOR) for image_path in image_paths] + image_paths = sorted(glob.glob(os.path.join(path, '*.*'))) + images = [ + cv2.imread(image_path, cv2.IMREAD_COLOR) for image_path in image_paths + ] images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images] return images @@ -49,7 +52,10 @@ def test_compile_sd15_model(sd15_model_path, skip_comparsion=True): test_benchmark_sd15_model(sd15_model_path, skip_comparsion=skip_comparsion) -def test_benchmark_sd15_model_with_lora(sd15_model_path, sd15_lora_t4_path, sd15_lora_dog_path, skip_comparsion=False): +def test_benchmark_sd15_model_with_lora(sd15_model_path, + sd15_lora_t4_path, + sd15_lora_dog_path, + skip_comparsion=False): benchmark_sd_model( sd15_model_path, kwarg_inputs=basic_kwarg_inputs, @@ -58,7 +64,11 @@ def test_benchmark_sd15_model_with_lora(sd15_model_path, sd15_lora_t4_path, sd15 skip_comparsion=skip_comparsion, ) -def test_compile_sd15_model_with_lora(sd15_model_path, sd15_lora_t4_path, sd15_lora_dog_path, skip_comparsion=True): + +def test_compile_sd15_model_with_lora(sd15_model_path, + sd15_lora_t4_path, + sd15_lora_dog_path, + skip_comparsion=True): benchmark_sd_model( sd15_model_path, kwarg_inputs=basic_kwarg_inputs, @@ -68,10 +78,10 @@ def test_compile_sd15_model_with_lora(sd15_model_path, sd15_lora_t4_path, sd15_l ) -def test_benchmark_sd15_model_with_controlnet( - sd15_model_path, sd_controlnet_canny_model_path, diffusers_dog_example_path, - skip_comparsion=False -): +def test_benchmark_sd15_model_with_controlnet(sd15_model_path, + sd_controlnet_canny_model_path, + diffusers_dog_example_path, + skip_comparsion=False): from diffusers import StableDiffusionControlNetPipeline dog_image = get_images_from_path(diffusers_dog_example_path)[0] @@ -91,15 +101,6 @@ def test_benchmark_sd15_model_with_controlnet( ) -# def test_compile_sd15_model_with_quantization(sd15_model_path, skip_comparsion=True): -# benchmark_sd_model( -# sd15_model_path, -# kwarg_inputs=basic_kwarg_inputs, -# enable_quantization=True, -# skip_comparsion=skip_comparsion, -# ) - - def test_benchmark_sd21_model(sd21_model_path, skip_comparsion=False): benchmark_sd_model( sd21_model_path, @@ -127,21 +128,21 @@ def test_compile_sdxl_model(sdxl_model_path, skip_comparsion=True): test_benchmark_sdxl_model(sdxl_model_path, skip_comparsion=skip_comparsion) -def test_compile_sd15_model_with_controlnet( - sd15_model_path, sd_controlnet_canny_model_path, diffusers_dog_example_path, - skip_comparsion=True -): - test_benchmark_sd15_model_with_controlnet( - sd15_model_path, sd_controlnet_canny_model_path, diffusers_dog_example_path, - skip_comparsion=skip_comparsion - ) +def test_compile_sd15_model_with_controlnet(sd15_model_path, + sd_controlnet_canny_model_path, + diffusers_dog_example_path, + skip_comparsion=True): + test_benchmark_sd15_model_with_controlnet(sd15_model_path, + sd_controlnet_canny_model_path, + diffusers_dog_example_path, + skip_comparsion=skip_comparsion) def call_model(model, inputs=None, kwarg_inputs=None): - inputs = tuple() if inputs is None else inputs() if callable(inputs) else inputs - kwarg_inputs = dict() if kwarg_inputs is None else kwarg_inputs() if callable( - kwarg_inputs - ) else kwarg_inputs + inputs = tuple() if inputs is None else inputs() if callable( + inputs) else inputs + kwarg_inputs = dict() if kwarg_inputs is None else kwarg_inputs( + ) if callable(kwarg_inputs) else kwarg_inputs torch.manual_seed(0) output_image = model(*inputs, **kwarg_inputs).images[0] return output_image @@ -154,7 +155,6 @@ def benchmark_sd_model( scheduler_class=None, controlnet_model_path=None, enable_cuda_graph=True, - enable_quantization=False, skip_comparsion=False, lora_a_path=None, lora_b_path=None, @@ -176,27 +176,40 @@ def load_model(): from diffusers import ControlNetModel controlnet_model = ControlNetModel.from_pretrained( - controlnet_model_path, torch_dtype=torch.float16 - ) - model_init_kwargs["controlnet"] = controlnet_model + controlnet_model_path, torch_dtype=torch.float16) + model_init_kwargs['controlnet'] = controlnet_model - model = model_class.from_pretrained( - model_path, torch_dtype=torch.float16, **model_init_kwargs - ) + model = model_class.from_pretrained(model_path, + torch_dtype=torch.float16, + **model_init_kwargs) if scheduler_class is not None: - model.scheduler = scheduler_class.from_config(model.scheduler.config) + model.scheduler = scheduler_class.from_config( + model.scheduler.config) + model.safety_checker = None - model.to(torch.device("cuda")) + model.to(torch.device('cuda')) if lora_a_path is not None: model.unet.load_attn_procs(lora_a_path) + + # This is only for benchmarking purpose. + # Patch the scheduler to force a synchronize to make the progress bar work properly. + scheduler_step = model.scheduler.step + + def scheduler_step_(*args, **kwargs): + ret = scheduler_step(*args, **kwargs) + torch.cuda.synchronize() + return ret + + model.scheduler.step = scheduler_step_ + return model call_model_ = functools.partial(call_model, kwarg_inputs=kwarg_inputs) with AutoProfiler(0.02) as profiler, low_compute_precision(): if not skip_comparsion: - logger.info("Benchmarking StableDiffusionPipeline") + logger.info('Benchmarking StableDiffusionPipeline') model = load_model() def call_original_model(): @@ -210,12 +223,13 @@ def call_original_model(): del model - if hasattr(torch, "compile"): + if hasattr(torch, 'compile'): model = load_model() - logger.info("Benchmarking StableDiffusionPipeline with torch.compile") + logger.info( + 'Benchmarking StableDiffusionPipeline with torch.compile') model.unet.to(memory_format=torch.channels_last) model.unet = torch.compile(model.unet) - if hasattr(model, "controlnet"): + if hasattr(model, 'controlnet'): model.controlnet.to(memory_format=torch.channels_last) model.controlnet = torch.compile(model.controlnet) @@ -225,12 +239,13 @@ def call_torch_compiled_model(): for _ in range(3): call_torch_compiled_model() - output_image = profiler.with_cProfile(call_torch_compiled_model)() + output_image = profiler.with_cProfile( + call_torch_compiled_model)() display_image(output_image) del model - # logger.info("Benchmarking compiled StableDiffusionPipeline") + # logger.info('Benchmarking compiled StableDiffusionPipeline') # config = CompilationConfig.Default() # compiled_model = compile(load_model(), config) @@ -246,22 +261,21 @@ def call_torch_compiled_model(): # del compiled_model logger.info( - "Benchmarking compiled StableDiffusionPipeline with xformers, Triton and CUDA Graph" + 'Benchmarking compiled StableDiffusionPipeline with xformers, Triton and CUDA Graph' ) config = CompilationConfig.Default() try: import xformers config.enable_xformers = True except ImportError: - logger.warning("xformers not installed, skip") + logger.warning('xformers not installed, skip') try: import triton config.enable_triton = True except ImportError: - logger.warning("triton not installed, skip") + logger.warning('triton not installed, skip') # config.trace_scheduler = True config.enable_cuda_graph = enable_cuda_graph - config.enable_quantization = enable_quantization compiled_model = compile(load_model(), config) def call_faster_compiled_model(): @@ -273,7 +287,9 @@ def call_faster_compiled_model(): output_image = profiler.with_cProfile(call_faster_compiled_model)() display_image(output_image) - if lora_a_path is not None and lora_b_path is not None: + if lora_a_path is not None and lora_b_path is not None and packaging.version.parse( + torch.__version__) >= packaging.version.parse('2.1.0'): + # load_state_dict with assign=True requires torch >= 2.1.0 def update_state_dict(dst, src): for key, value in src.items():