Skip to content

Commit

Permalink
[Quantization] Add quantization support for bitsandbytes (#9213)
Browse files Browse the repository at this point in the history
* quantization config.

* fix-copies

* fix

* modules_to_not_convert

* add bitsandbytes utilities.

* make progress.

* fixes

* quality

* up

* up

rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312)

fix notes and dtype

up

up

* minor

* up

* up

* fix

* provide credits where due.

* make configurations work.

* fixes

* fix

* update_missing_keys

* fix

* fix

* make it work.

* fix

* provide credits to transformers.

* empty commit

* handle to() better.

* tests

* change to bnb from bitsandbytes

* fix tests

fix slow quality tests

SD3 remark

fix

complete int4 tests

add a readme to the test files.

add model cpu offload tests

warning test

* better safeguard.

* change merging status

* courtesy to transformers.

* move  upper.

* better

* make the unused kwargs warning friendlier.

* harmonize changes with huggingface/transformers#33122

* style

* trainin tests

* feedback part i.

* Add Flux inpainting and Flux Img2Img (#9135)

---------

Co-authored-by: yiyixuxu <[email protected]>

Update `UNet2DConditionModel`'s error messages (#9230)

* refactor

[CI] Update Single file Nightly Tests (#9357)

* update

* update

feedback.

improve README for flux dreambooth lora (#9290)

* improve readme

* improve readme

* improve readme

* improve readme

fix one uncaught deprecation warning for accessing vae_latent_channels in VaeImagePreprocessor (#9372)

deprecation warning vae_latent_channels

add mixed int8 tests and more tests to nf4.

[core] Freenoise memory improvements (#9262)

* update

* implement prompt interpolation

* make style

* resnet memory optimizations

* more memory optimizations; todo: refactor

* update

* update animatediff controlnet with latest changes

* refactor chunked inference changes

* remove print statements

* update

* chunk -> split

* remove changes from incorrect conflict resolution

* remove changes from incorrect conflict resolution

* add explanation of SplitInferenceModule

* update docs

* Revert "update docs"

This reverts commit c55a50a.

* update docstring for freenoise split inference

* apply suggestions from review

* add tests

* apply suggestions from review

quantization docs.

docs.

* Revert "Add Flux inpainting and Flux Img2Img (#9135)"

This reverts commit 5799954.

* tests

* don

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

* contribution guide.

* changes

* empty

* fix tests

* harmonize with huggingface/transformers#33546.

* numpy_cosine_distance

* config_dict modification.

* remove if config comment.

* note for load_state_dict changes.

* float8 check.

* quantizer.

* raise an error for non-True low_cpu_mem_usage values when using quant.

* low_cpu_mem_usage shenanigans when using fp32 modules.

* don't re-assign _pre_quantization_type.

* make comments clear.

* remove comments.

* handle mixed types better when moving to cpu.

* add tests to check if we're throwing warning rightly.

* better check.

* fix 8bit test_quality.

* handle dtype more robustly.

* better message when keep_in_fp32_modules.

* handle dtype casting.

* fix dtype checks in pipeline.

* fix warning message.

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: YiYi Xu <[email protected]>

* mitigate the confusing cpu warning

---------

Co-authored-by: Vishnu V Jaddipal <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
4 people authored Oct 21, 2024
1 parent 24281f8 commit b821f00
Show file tree
Hide file tree
Showing 25 changed files with 3,606 additions and 30 deletions.
8 changes: 8 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@
title: Reinforcement learning training with DDPO
title: Methods
title: Training
- sections:
- local: quantization/overview
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
title: Quantization Methods
- sections:
- local: optimization/fp16
title: Speed up inference
Expand Down Expand Up @@ -209,6 +215,8 @@
title: Logging
- local: api/outputs
title: Outputs
- local: api/quantization
title: Quantization
title: Main Classes
- isExpanded: false
sections:
Expand Down
33 changes: 33 additions & 0 deletions docs/source/en/api/quantization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Quantization

Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).

Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.

<Tip>

Learn how to quantize models in the [Quantization](../quantization/overview) guide.

</Tip>


## BitsAndBytesConfig

[[autodoc]] BitsAndBytesConfig

## DiffusersQuantizer

[[autodoc]] quantizers.base.DiffusersQuantizer
267 changes: 267 additions & 0 deletions docs/source/en/quantization/bitsandbytes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# bitsandbytes

[bitsandbytes](https://huggingface.co/docs/bitsandbytes/index) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance.

4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.


To use bitsandbytes, make sure you have the following libraries installed:

```bash
pip install diffusers transformers accelerate bitsandbytes -U
```

Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.

<hfoptions id="bnb">
<hfoption id="8-bit">

Quantizing a model in 8-bit halves the memory-usage:

```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
```

By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:

```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.float32
)
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
```

Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights.

```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
```

</hfoption>
<hfoption id="4-bit">

Quantizing a model in 4-bit reduces your memory-usage by 4x:

```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
```

By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:

```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.float32
)
model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
```

Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].

</hfoption>
</hfoptions>

<Tip warning={true}>

Training with 8-bit and 4-bit weights are only supported for training *extra* parameters.

</Tip>

Check your memory footprint with the `get_memory_footprint` method:

```py
print(model.get_memory_footprint())
```

Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:

```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
)
```

## 8-bit (LLM.int8() algorithm)

<Tip>

Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)!

</Tip>

This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion.

### Outlier threshold

An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).

To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]:

```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_threshold=10,
)

model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
)
```

### Skip module conversion

For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]:

```py
from diffusers import SD3Transformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_skip_modules=["proj_out"],
)

model_8bit = SD3Transformer2DModel.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="transformer",
quantization_config=quantization_config,
)
```


## 4-bit (QLoRA algorithm)

<Tip>

Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).

</Tip>

This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization.


### Compute data type

To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]:

```py
import torch
from diffusers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
```

### Normal Float 4 (NF4)

NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]:

```py
from diffusers import BitsAndBytesConfig

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)

model_nf4 = SD3Transformer2DModel.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="transformer",
quantization_config=nf4_config,
)
```

For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values.

### Nested quantization

Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.

```py
from diffusers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)

double_quant_model = SD3Transformer2DModel.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="transformer",
quantization_config=double_quant_config,
)
```

## Dequantizing `bitsandbytes` models

Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.

```python
from diffusers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)

double_quant_model = SD3Transformer2DModel.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="transformer",
quantization_config=double_quant_config,
)
model.dequantize()
```
35 changes: 35 additions & 0 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Quantization

Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.

<Tip>

Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.

</Tip>

<Tip>

If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI:

* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)

</Tip>

## When to use what?

This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
Loading

0 comments on commit b821f00

Please sign in to comment.