diff --git a/README.md b/README.md index 779884f..42b5cc5 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,9 @@ is to update the original UNet model's parameters inplace. The following code assumes you have already load a LoRA and compiled the model, and you want to switch to another LoRA. +If you don't enable CUDA graph and keep `preserve_parameters = True`, things could be much easier. +The following code might not even be needed. + ```python # load_state_dict with assign=True requires torch >= 2.1.0 @@ -235,14 +238,27 @@ switch_lora(compiled_model.unet, lora_b_path) ### Model Quantization -`stable-fast` extends PyTorch's `quantize_dynamic` functionality and provides a fast quantized linear operator. +`stable-fast` extends PyTorch's `quantize_dynamic` functionality and provides a dynamically quantized linear operator on CUDA backend. By enabling it, you could get a slight VRAM reduction for `diffusers` and significant VRAM reduction for `transformers`, -and cound get a potential speedup. +and cound get a potential speedup (not always). -However, since `diffusers` implements its own `Linear` layer as `LoRACompatibleLinear`, -you need to do some hacks to make it work and it is a little complex and tricky. +For `SD XL`, it is expected to see VRAM reduction of `2GB` with an image size of `1024x1024`. -Refer to [tests/compilers/test_stable_diffusion_pipeline_compiler.py](tests/compilers/test_stable_diffusion_pipeline_compiler.py) to see how to do it. +```python +def quantize_unet(m): + from diffusers.utils import USE_PEFT_BACKEND + assert USE_PEFT_BACKEND + m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear}, + dtype=torch.qint8, + inplace=True) + return m + +model.unet = quantize_unet(model.unet) +if hasattr(model, 'controlnet'): + model.controlnet = quantize_unet(model.controlnet) +``` + +Refer to [examples/optimize_stable_diffusion_pipeline.py](examples/optimize_stable_diffusion_pipeline.py) for more details. ### Some Common Methods To Speed Up PyTorch diff --git a/examples/optimize_lcm_lora.py b/examples/optimize_lcm_lora.py index ee1c519..d97cfa2 100644 --- a/examples/optimize_lcm_lora.py +++ b/examples/optimize_lcm_lora.py @@ -50,6 +50,7 @@ def parse_args(): type=str, default='sfast', choices=['none', 'sfast', 'compile', 'compile-max-autotune']) + parser.add_argument('--quantize', action='store_true') return parser.parse_args() @@ -159,6 +160,21 @@ def main(): lora=args.lora, controlnet=args.controlnet, ) + + if args.quantize: + + def quantize_unet(m): + from diffusers.utils import USE_PEFT_BACKEND + assert USE_PEFT_BACKEND + m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear}, + dtype=torch.qint8, + inplace=True) + return m + + model.unet = quantize_unet(model.unet) + if hasattr(model, 'controlnet'): + model.controlnet = quantize_unet(model.controlnet) + if args.compiler == 'none': pass elif args.compiler == 'sfast': @@ -243,6 +259,8 @@ def get_kwarg_inputs(): iter_per_sec = iter_profiler.get_iter_per_sec() if iter_per_sec is not None: print(f'Iterations per second: {iter_per_sec:.3f}') + peak_mem = torch.cuda.max_memory_allocated() + print(f'Peak memory: {peak_mem / 1024**3:.3f}GiB') if args.output_image is not None: output_images[0].save(args.output_image) diff --git a/examples/optimize_lcm_pipeline.py b/examples/optimize_lcm_pipeline.py index 1d2e816..ea208cb 100644 --- a/examples/optimize_lcm_pipeline.py +++ b/examples/optimize_lcm_pipeline.py @@ -50,6 +50,7 @@ def parse_args(): type=str, default='sfast', choices=['none', 'sfast', 'compile', 'compile-max-autotune']) + parser.add_argument('--quantize', action='store_true') return parser.parse_args() @@ -159,6 +160,21 @@ def main(): lora=args.lora, controlnet=args.controlnet, ) + + if args.quantize: + + def quantize_unet(m): + from diffusers.utils import USE_PEFT_BACKEND + assert USE_PEFT_BACKEND + m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear}, + dtype=torch.qint8, + inplace=True) + return m + + model.unet = quantize_unet(model.unet) + if hasattr(model, 'controlnet'): + model.controlnet = quantize_unet(model.controlnet) + if args.compiler == 'none': pass elif args.compiler == 'sfast': @@ -243,6 +259,8 @@ def get_kwarg_inputs(): iter_per_sec = iter_profiler.get_iter_per_sec() if iter_per_sec is not None: print(f'Iterations per second: {iter_per_sec:.3f}') + peak_mem = torch.cuda.max_memory_allocated() + print(f'Peak memory: {peak_mem / 1024**3:.3f}GiB') if args.output_image is not None: output_images[0].save(args.output_image) diff --git a/examples/optimize_stable_diffusion_pipeline.py b/examples/optimize_stable_diffusion_pipeline.py index 5e54376..1f1af53 100644 --- a/examples/optimize_stable_diffusion_pipeline.py +++ b/examples/optimize_stable_diffusion_pipeline.py @@ -50,6 +50,7 @@ def parse_args(): type=str, default='sfast', choices=['none', 'sfast', 'compile', 'compile-max-autotune']) + parser.add_argument('--quantize', action='store_true') return parser.parse_args() @@ -159,6 +160,21 @@ def main(): lora=args.lora, controlnet=args.controlnet, ) + + if args.quantize: + + def quantize_unet(m): + from diffusers.utils import USE_PEFT_BACKEND + assert USE_PEFT_BACKEND + m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear}, + dtype=torch.qint8, + inplace=True) + return m + + model.unet = quantize_unet(model.unet) + if hasattr(model, 'controlnet'): + model.controlnet = quantize_unet(model.controlnet) + if args.compiler == 'none': pass elif args.compiler == 'sfast': @@ -243,6 +259,8 @@ def get_kwarg_inputs(): iter_per_sec = iter_profiler.get_iter_per_sec() if iter_per_sec is not None: print(f'Iterations per second: {iter_per_sec:.3f}') + peak_mem = torch.cuda.max_memory_allocated() + print(f'Peak memory: {peak_mem / 1024**3:.3f}GiB') if args.output_image is not None: output_images[0].save(args.output_image) diff --git a/tests/compilers/test_stable_diffusion_pipeline_compiler.py b/tests/compilers/test_stable_diffusion_pipeline_compiler.py index 08dcdf5..e4304c2 100644 --- a/tests/compilers/test_stable_diffusion_pipeline_compiler.py +++ b/tests/compilers/test_stable_diffusion_pipeline_compiler.py @@ -305,39 +305,13 @@ def load_model(): ) if quantize: - def replace_linear(m): - # Replace LoraCompatibleLinear with torch.nn.Linear - new_m = torch.nn.Linear(m.in_features, - m.out_features, - bias=m.bias is not None).eval() - new_m = new_m.to(device=m.weight.device, - dtype=m.weight.dtype) - new_m.weight.copy_(m.weight) - if m.bias is not None: - new_m.bias.copy_(m.bias) - return new_m - - def make_linear_compatible(m): - forward = m.forward - - def new_forward(x, *args, **kwargs): - return forward(x) - - m.forward = new_forward - return m - def quantize_unet(m): - m = patch_module( - m, lambda stack: isinstance(stack[-1][1], torch.nn. - Linear), replace_linear) + from diffusers.utils import USE_PEFT_BACKEND + assert USE_PEFT_BACKEND m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear}, dtype=torch.qint8, inplace=True) - m = patch_module( - m, lambda stack: isinstance( - stack[-1][1], torch.ao.nn.quantized.Linear), - make_linear_compatible) return m model.unet = quantize_unet(model.unet)