Skip to content

Commit

Permalink
demonstrate the ability to quantize models
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Dec 17, 2023
1 parent defe16e commit 1f46389
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 33 deletions.
26 changes: 21 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ is to update the original UNet model's parameters inplace.
The following code assumes you have already load a LoRA and compiled the model,
and you want to switch to another LoRA.

If you don't enable CUDA graph and keep `preserve_parameters = True`, things could be much easier.
The following code might not even be needed.

```python
# load_state_dict with assign=True requires torch >= 2.1.0

Expand Down Expand Up @@ -235,14 +238,27 @@ switch_lora(compiled_model.unet, lora_b_path)

### Model Quantization

`stable-fast` extends PyTorch's `quantize_dynamic` functionality and provides a fast quantized linear operator.
`stable-fast` extends PyTorch's `quantize_dynamic` functionality and provides a dynamically quantized linear operator on CUDA backend.
By enabling it, you could get a slight VRAM reduction for `diffusers` and significant VRAM reduction for `transformers`,
and cound get a potential speedup.
and cound get a potential speedup (not always).

However, since `diffusers` implements its own `Linear` layer as `LoRACompatibleLinear`,
you need to do some hacks to make it work and it is a little complex and tricky.
For `SD XL`, it is expected to see VRAM reduction of `2GB` with an image size of `1024x1024`.

Refer to [tests/compilers/test_stable_diffusion_pipeline_compiler.py](tests/compilers/test_stable_diffusion_pipeline_compiler.py) to see how to do it.
```python
def quantize_unet(m):
from diffusers.utils import USE_PEFT_BACKEND
assert USE_PEFT_BACKEND
m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear},
dtype=torch.qint8,
inplace=True)
return m

model.unet = quantize_unet(model.unet)
if hasattr(model, 'controlnet'):
model.controlnet = quantize_unet(model.controlnet)
```

Refer to [examples/optimize_stable_diffusion_pipeline.py](examples/optimize_stable_diffusion_pipeline.py) for more details.

### Some Common Methods To Speed Up PyTorch

Expand Down
18 changes: 18 additions & 0 deletions examples/optimize_lcm_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def parse_args():
type=str,
default='sfast',
choices=['none', 'sfast', 'compile', 'compile-max-autotune'])
parser.add_argument('--quantize', action='store_true')
return parser.parse_args()


Expand Down Expand Up @@ -159,6 +160,21 @@ def main():
lora=args.lora,
controlnet=args.controlnet,
)

if args.quantize:

def quantize_unet(m):
from diffusers.utils import USE_PEFT_BACKEND
assert USE_PEFT_BACKEND
m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear},
dtype=torch.qint8,
inplace=True)
return m

model.unet = quantize_unet(model.unet)
if hasattr(model, 'controlnet'):
model.controlnet = quantize_unet(model.controlnet)

if args.compiler == 'none':
pass
elif args.compiler == 'sfast':
Expand Down Expand Up @@ -243,6 +259,8 @@ def get_kwarg_inputs():
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f'Iterations per second: {iter_per_sec:.3f}')
peak_mem = torch.cuda.max_memory_allocated()
print(f'Peak memory: {peak_mem / 1024**3:.3f}GiB')

if args.output_image is not None:
output_images[0].save(args.output_image)
Expand Down
18 changes: 18 additions & 0 deletions examples/optimize_lcm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def parse_args():
type=str,
default='sfast',
choices=['none', 'sfast', 'compile', 'compile-max-autotune'])
parser.add_argument('--quantize', action='store_true')
return parser.parse_args()


Expand Down Expand Up @@ -159,6 +160,21 @@ def main():
lora=args.lora,
controlnet=args.controlnet,
)

if args.quantize:

def quantize_unet(m):
from diffusers.utils import USE_PEFT_BACKEND
assert USE_PEFT_BACKEND
m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear},
dtype=torch.qint8,
inplace=True)
return m

model.unet = quantize_unet(model.unet)
if hasattr(model, 'controlnet'):
model.controlnet = quantize_unet(model.controlnet)

if args.compiler == 'none':
pass
elif args.compiler == 'sfast':
Expand Down Expand Up @@ -243,6 +259,8 @@ def get_kwarg_inputs():
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f'Iterations per second: {iter_per_sec:.3f}')
peak_mem = torch.cuda.max_memory_allocated()
print(f'Peak memory: {peak_mem / 1024**3:.3f}GiB')

if args.output_image is not None:
output_images[0].save(args.output_image)
Expand Down
18 changes: 18 additions & 0 deletions examples/optimize_stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def parse_args():
type=str,
default='sfast',
choices=['none', 'sfast', 'compile', 'compile-max-autotune'])
parser.add_argument('--quantize', action='store_true')
return parser.parse_args()


Expand Down Expand Up @@ -159,6 +160,21 @@ def main():
lora=args.lora,
controlnet=args.controlnet,
)

if args.quantize:

def quantize_unet(m):
from diffusers.utils import USE_PEFT_BACKEND
assert USE_PEFT_BACKEND
m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear},
dtype=torch.qint8,
inplace=True)
return m

model.unet = quantize_unet(model.unet)
if hasattr(model, 'controlnet'):
model.controlnet = quantize_unet(model.controlnet)

if args.compiler == 'none':
pass
elif args.compiler == 'sfast':
Expand Down Expand Up @@ -243,6 +259,8 @@ def get_kwarg_inputs():
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f'Iterations per second: {iter_per_sec:.3f}')
peak_mem = torch.cuda.max_memory_allocated()
print(f'Peak memory: {peak_mem / 1024**3:.3f}GiB')

if args.output_image is not None:
output_images[0].save(args.output_image)
Expand Down
30 changes: 2 additions & 28 deletions tests/compilers/test_stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,39 +305,13 @@ def load_model():
)
if quantize:

def replace_linear(m):
# Replace LoraCompatibleLinear with torch.nn.Linear
new_m = torch.nn.Linear(m.in_features,
m.out_features,
bias=m.bias is not None).eval()
new_m = new_m.to(device=m.weight.device,
dtype=m.weight.dtype)
new_m.weight.copy_(m.weight)
if m.bias is not None:
new_m.bias.copy_(m.bias)
return new_m

def make_linear_compatible(m):
forward = m.forward

def new_forward(x, *args, **kwargs):
return forward(x)

m.forward = new_forward
return m

def quantize_unet(m):
m = patch_module(
m, lambda stack: isinstance(stack[-1][1], torch.nn.
Linear), replace_linear)
from diffusers.utils import USE_PEFT_BACKEND
assert USE_PEFT_BACKEND
m = torch.quantization.quantize_dynamic(m,
{torch.nn.Linear},
dtype=torch.qint8,
inplace=True)
m = patch_module(
m, lambda stack: isinstance(
stack[-1][1], torch.ao.nn.quantized.Linear),
make_linear_compatible)
return m

model.unet = quantize_unet(model.unet)
Expand Down

0 comments on commit 1f46389

Please sign in to comment.