diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index e218e9878599..58218c0272bd 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -150,6 +150,12 @@
title: Reinforcement learning training with DDPO
title: Methods
title: Training
+- sections:
+ - local: quantization/overview
+ title: Getting Started
+ - local: quantization/bitsandbytes
+ title: bitsandbytes
+ title: Quantization Methods
- sections:
- local: optimization/fp16
title: Speed up inference
@@ -209,6 +215,8 @@
title: Logging
- local: api/outputs
title: Outputs
+ - local: api/quantization
+ title: Quantization
title: Main Classes
- isExpanded: false
sections:
diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md
new file mode 100644
index 000000000000..2fbde9e707ea
--- /dev/null
+++ b/docs/source/en/api/quantization.md
@@ -0,0 +1,33 @@
+
+
+# Quantization
+
+Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).
+
+Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
+
+
+
+Learn how to quantize models in the [Quantization](../quantization/overview) guide.
+
+
+
+
+## BitsAndBytesConfig
+
+[[autodoc]] BitsAndBytesConfig
+
+## DiffusersQuantizer
+
+[[autodoc]] quantizers.base.DiffusersQuantizer
diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md
new file mode 100644
index 000000000000..f272346aa2e2
--- /dev/null
+++ b/docs/source/en/quantization/bitsandbytes.md
@@ -0,0 +1,267 @@
+
+
+# bitsandbytes
+
+[bitsandbytes](https://huggingface.co/docs/bitsandbytes/index) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance.
+
+4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
+
+
+To use bitsandbytes, make sure you have the following libraries installed:
+
+```bash
+pip install diffusers transformers accelerate bitsandbytes -U
+```
+
+Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
+
+
+
+
+Quantizing a model in 8-bit halves the memory-usage:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
+model_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config
+)
+```
+
+By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
+model_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.float32
+)
+model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
+```
+
+Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights.
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
+model_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config
+)
+```
+
+
+
+
+Quantizing a model in 4-bit reduces your memory-usage by 4x:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+
+model_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config
+)
+```
+
+By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+
+model_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.float32
+)
+model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
+```
+
+Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
+
+
+
+
+
+
+Training with 8-bit and 4-bit weights are only supported for training *extra* parameters.
+
+
+
+Check your memory footprint with the `get_memory_footprint` method:
+
+```py
+print(model.get_memory_footprint())
+```
+
+Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+
+model_4bit = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
+)
+```
+
+## 8-bit (LLM.int8() algorithm)
+
+
+
+Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)!
+
+
+
+This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion.
+
+### Outlier threshold
+
+An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).
+
+To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(
+ load_in_8bit=True, llm_int8_threshold=10,
+)
+
+model_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+)
+```
+
+### Skip module conversion
+
+For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]:
+
+```py
+from diffusers import SD3Transformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(
+ load_in_8bit=True, llm_int8_skip_modules=["proj_out"],
+)
+
+model_8bit = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+)
+```
+
+
+## 4-bit (QLoRA algorithm)
+
+
+
+Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
+
+
+
+This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization.
+
+
+### Compute data type
+
+To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]:
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
+```
+
+### Normal Float 4 (NF4)
+
+NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]:
+
+```py
+from diffusers import BitsAndBytesConfig
+
+nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+)
+
+model_nf4 = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ subfolder="transformer",
+ quantization_config=nf4_config,
+)
+```
+
+For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values.
+
+### Nested quantization
+
+Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.
+
+```py
+from diffusers import BitsAndBytesConfig
+
+double_quant_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+)
+
+double_quant_model = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ subfolder="transformer",
+ quantization_config=double_quant_config,
+)
+```
+
+## Dequantizing `bitsandbytes` models
+
+Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.
+
+```python
+from diffusers import BitsAndBytesConfig
+
+double_quant_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+)
+
+double_quant_model = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ subfolder="transformer",
+ quantization_config=double_quant_config,
+)
+model.dequantize()
+```
\ No newline at end of file
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
new file mode 100644
index 000000000000..d8adbc85a259
--- /dev/null
+++ b/docs/source/en/quantization/overview.md
@@ -0,0 +1,35 @@
+
+
+# Quantization
+
+Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
+
+
+
+Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
+
+
+
+
+
+If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI:
+
+* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
+* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)
+
+
+
+## When to use what?
+
+This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
\ No newline at end of file
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 69fa08d03839..a1d126f3823b 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -31,6 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
+ "quantizers.quantization_config": ["BitsAndBytesConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
@@ -124,7 +125,6 @@
"VQModel",
]
)
-
_import_structure["optimization"] = [
"get_constant_schedule",
"get_constant_schedule_with_warmup",
@@ -156,6 +156,7 @@
"StableDiffusionMixin",
]
)
+ _import_structure["quantizers"] = ["DiffusersQuantizer"]
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
@@ -538,6 +539,7 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
+ from .quantizers.quantization_config import BitsAndBytesConfig
try:
if not is_onnx_available():
@@ -632,6 +634,7 @@
ScoreSdeVePipeline,
StableDiffusionMixin,
)
+ from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler,
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 3dccd785cae4..11d45dc64d97 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -510,6 +510,9 @@ def extract_init_dict(cls, config_dict, **kwargs):
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
+ # remove quantization_config
+ config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"}
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
init_dict = {}
for key in expected_keys:
@@ -586,10 +589,19 @@ def to_json_saveable(value):
value = value.as_posix()
return value
+ if "quantization_config" in config_dict:
+ config_dict["quantization_config"] = (
+ config_dict.quantization_config.to_dict()
+ if not isinstance(config_dict.quantization_config, dict)
+ else config_dict.quantization_config
+ )
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
+ _ = config_dict.pop("_pre_quantization_dtype", None)
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index c9eb664443b5..5277ad2f9389 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -25,6 +25,7 @@
import torch
from huggingface_hub.utils import EntryNotFoundError
+from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
@@ -54,11 +55,36 @@
# Adapted from `transformers` (see modeling_utils.py)
-def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
+def _determine_device_map(
+ model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
+):
if isinstance(device_map, str):
+ special_dtypes = {}
+ if hf_quantizer is not None:
+ special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
+ special_dtypes.update(
+ {
+ name: torch.float32
+ for name, _ in model.named_parameters()
+ if any(m in name for m in keep_in_fp32_modules)
+ }
+ )
+
+ target_dtype = torch_dtype
+ if hf_quantizer is not None:
+ target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
+
no_split_modules = model._get_no_split_modules(device_map)
device_map_kwargs = {"no_split_module_classes": no_split_modules}
+ if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
+ device_map_kwargs["special_dtypes"] = special_dtypes
+ elif len(special_dtypes) > 0:
+ logger.warning(
+ "This model has some weights that should be kept in higher precision, you need to upgrade "
+ "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
+ )
+
if device_map != "sequential":
max_memory = get_balanced_memory(
model,
@@ -70,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
else:
max_memory = get_max_memory(max_memory)
+ if hf_quantizer is not None:
+ max_memory = hf_quantizer.adjust_max_memory(max_memory)
+
device_map_kwargs["max_memory"] = max_memory
- device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
+
+ if hf_quantizer is not None:
+ hf_quantizer.validate_environment(device_map=device_map)
return device_map
@@ -100,6 +132,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
+ # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
+ # when refactoring the _merge_sharded_checkpoints() method later.
+ if isinstance(checkpoint_file, dict):
+ return checkpoint_file
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
@@ -137,29 +173,60 @@ def load_model_dict_into_meta(
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
+ hf_quantizer=None,
+ keep_in_fp32_modules=None,
) -> List[str]:
- device = device or torch.device("cpu")
+ if hf_quantizer is None:
+ device = device or torch.device("cpu")
dtype = dtype or torch.float32
+ is_quantized = hf_quantizer is not None
+ is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
-
- unexpected_keys = []
empty_state_dict = model.state_dict()
+ unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
+
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
- unexpected_keys.append(param_name)
continue
- if empty_state_dict[param_name].shape != param.shape:
+ set_module_kwargs = {}
+ # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
+ # in int/uint/bool and not cast them.
+ # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
+ if torch.is_floating_point(param):
+ if (
+ keep_in_fp32_modules is not None
+ and any(
+ module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
+ )
+ and dtype == torch.float16
+ ):
+ param = param.to(torch.float32)
+ if accepts_dtype:
+ set_module_kwargs["dtype"] = torch.float32
+ else:
+ param = param.to(dtype)
+ if accepts_dtype:
+ set_module_kwargs["dtype"] = dtype
+
+ # bnb params are flattened.
+ if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
- if accepts_dtype:
- set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
+ if not is_quantized or (
+ not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)
+ ):
+ if accepts_dtype:
+ set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
+ else:
+ set_module_tensor_to_device(model, param_name, device, value=param)
else:
- set_module_tensor_to_device(model, param_name, device, value=param)
+ hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
+
return unexpected_keys
@@ -231,6 +298,35 @@ def _fetch_index_file(
return index_file
+# Adapted from
+# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
+def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
+ weight_map = sharded_metadata.get("weight_map", None)
+ if weight_map is None:
+ raise KeyError("'weight_map' key not found in the shard index file.")
+
+ # Collect all unique safetensors files from weight_map
+ files_to_load = set(weight_map.values())
+ is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
+ merged_state_dict = {}
+
+ # Load tensors from each unique file
+ for file_name in files_to_load:
+ part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
+ if not os.path.exists(part_file_path):
+ raise FileNotFoundError(f"Part file {file_name} not found.")
+
+ if is_safetensors:
+ with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
+ for tensor_key in f.keys():
+ if tensor_key in weight_map:
+ merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
+ else:
+ merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
+
+ return merged_state_dict
+
+
def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index ad3433889fca..4a486fd4ce40 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -14,13 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import inspect
import itertools
import json
import os
import re
from collections import OrderedDict
-from functools import partial
+from functools import partial, wraps
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
@@ -31,6 +32,8 @@
from torch import Tensor, nn
from .. import __version__
+from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
+from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
@@ -43,6 +46,8 @@
_get_model_file,
deprecate,
is_accelerate_available,
+ is_bitsandbytes_available,
+ is_bitsandbytes_version,
is_torch_version,
logging,
)
@@ -56,6 +61,7 @@
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
+ _merge_sharded_checkpoints,
load_model_dict_into_meta,
load_state_dict,
)
@@ -125,6 +131,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
_no_split_modules = None
+ _keep_in_fp32_modules = None
def __init__(self):
super().__init__()
@@ -308,6 +315,19 @@ def save_pretrained(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
+ hf_quantizer = getattr(self, "hf_quantizer", None)
+ if hf_quantizer is not None:
+ quantization_serializable = (
+ hf_quantizer is not None
+ and isinstance(hf_quantizer, DiffusersQuantizer)
+ and hf_quantizer.is_serializable
+ )
+ if not quantization_serializable:
+ raise ValueError(
+ f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
+ " the logger on the traceback to understand the reason why the quantized model is not serializable."
+ )
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
@@ -402,6 +422,18 @@ def save_pretrained(
create_pr=create_pr,
)
+ def dequantize(self):
+ """
+ Potentially dequantize the model in case it has been quantized by a quantization method that support
+ dequantization.
+ """
+ hf_quantizer = getattr(self, "hf_quantizer", None)
+
+ if hf_quantizer is None:
+ raise ValueError("You need to first quantize your model in order to dequantize it")
+
+ return hf_quantizer.dequantize(self)
+
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -524,6 +556,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ quantization_config = kwargs.pop("quantization_config", None)
allow_pickle = False
if use_safetensors is None:
@@ -618,6 +651,60 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
**kwargs,
)
+ # no in-place modification of the original config.
+ config = copy.deepcopy(config)
+
+ # determine initial quantization config.
+ #######################################
+ pre_quantized = "quantization_config" in config and config["quantization_config"] is not None
+ if pre_quantized or quantization_config is not None:
+ if pre_quantized:
+ config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs(
+ config["quantization_config"], quantization_config
+ )
+ else:
+ config["quantization_config"] = quantization_config
+ hf_quantizer = DiffusersAutoQuantizer.from_config(
+ config["quantization_config"], pre_quantized=pre_quantized
+ )
+ else:
+ hf_quantizer = None
+
+ if hf_quantizer is not None:
+ if device_map is not None:
+ raise NotImplementedError(
+ "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
+ )
+ hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
+
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
+ user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
+
+ # Force-set to `True` for more mem efficiency
+ if low_cpu_mem_usage is None:
+ low_cpu_mem_usage = True
+ logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
+ elif not low_cpu_mem_usage:
+ raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
+
+ # Check if `_keep_in_fp32_modules` is not None
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
+ )
+ if use_keep_in_fp32_modules:
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
+ if not isinstance(keep_in_fp32_modules, list):
+ keep_in_fp32_modules = [keep_in_fp32_modules]
+
+ if low_cpu_mem_usage is None:
+ low_cpu_mem_usage = True
+ logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
+ elif not low_cpu_mem_usage:
+ raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
+ else:
+ keep_in_fp32_modules = []
+ #######################################
# Determine if we're loading from a directory of sharded checkpoints.
is_sharded = False
@@ -684,6 +771,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder or "",
)
+ if hf_quantizer is not None:
+ model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
+ logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
+ is_sharded = False
elif use_safetensors and not is_sharded:
try:
@@ -729,13 +820,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
+ if hf_quantizer is not None:
+ hf_quantizer.preprocess_model(
+ model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
+ )
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None and not is_sharded:
- param_device = "cpu"
+ # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
+ # It would error out during the `validate_environment()` call above in the absence of cuda.
+ is_quant_method_bnb = (
+ getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+ )
+ if hf_quantizer is None:
+ param_device = "cpu"
+ # TODO (sayakpaul, SunMarc): remove this after model loading refactor
+ elif is_quant_method_bnb:
+ param_device = torch.cuda.current_device()
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
+
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
+ if hf_quantizer is not None:
+ missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
@@ -750,6 +858,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
@@ -765,7 +875,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
force_hook = True
- device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
+ device_map = _determine_device_map(
+ model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
+ )
if device_map is None and is_sharded:
# we load the parameters on the cpu
device_map = {"": "cpu"}
@@ -843,14 +955,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"error_msgs": error_msgs,
}
+ if hf_quantizer is not None:
+ hf_quantizer.postprocess_model(model)
+ model.hf_quantizer = hf_quantizer
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
- elif torch_dtype is not None:
+ # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
+ # completely lose the effectivity of `use_keep_in_fp32_modules`.
+ elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
model = model.to(torch_dtype)
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ if hf_quantizer is not None:
+ # We also make sure to purge `_pre_quantization_dtype` when we serialize
+ # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype)
+ else:
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
@@ -859,6 +982,76 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
return model
+ # Adapted from `transformers`.
+ @wraps(torch.nn.Module.cuda)
+ def cuda(self, *args, **kwargs):
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+ if getattr(self, "is_loaded_in_8bit", False):
+ raise ValueError(
+ "Calling `cuda()` is not supported for `8-bit` quantized models. "
+ " Please use the model as it is, since the model has already been set to the correct devices."
+ )
+ elif is_bitsandbytes_version("<", "0.43.2"):
+ raise ValueError(
+ "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
+ )
+ return super().cuda(*args, **kwargs)
+
+ # Adapted from `transformers`.
+ @wraps(torch.nn.Module.to)
+ def to(self, *args, **kwargs):
+ dtype_present_in_args = "dtype" in kwargs
+
+ if not dtype_present_in_args:
+ for arg in args:
+ if isinstance(arg, torch.dtype):
+ dtype_present_in_args = True
+ break
+
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+ if dtype_present_in_args:
+ raise ValueError(
+ "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
+ " desired `dtype` by passing the correct `torch_dtype` argument."
+ )
+
+ if getattr(self, "is_loaded_in_8bit", False):
+ raise ValueError(
+ "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
+ " model has already been set to the correct devices and casted to the correct `dtype`."
+ )
+ elif is_bitsandbytes_version("<", "0.43.2"):
+ raise ValueError(
+ "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
+ )
+ return super().to(*args, **kwargs)
+
+ # Taken from `transformers`.
+ def half(self, *args):
+ # Checks if the model is quantized
+ if getattr(self, "is_quantized", False):
+ raise ValueError(
+ "`.half()` is not supported for quantized model. Please use the model as it is, since the"
+ " model has already been cast to the correct `dtype`."
+ )
+ else:
+ return super().half(*args)
+
+ # Taken from `transformers`.
+ def float(self, *args):
+ # Checks if the model is quantized
+ if getattr(self, "is_quantized", False):
+ raise ValueError(
+ "`.float()` is not supported for quantized model. Please use the model as it is, since the"
+ " model has already been cast to the correct `dtype`."
+ )
+ else:
+ return super().float(*args)
+
@classmethod
def _load_pretrained_model(
cls,
@@ -1041,19 +1234,63 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
859520964
```
"""
+ is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
+
+ if is_loaded_in_4bit:
+ if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+ else:
+ raise ValueError(
+ "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
+ " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
+ )
if exclude_embeddings:
embedding_param_names = [
- f"{name}.weight"
- for name, module_type in self.named_modules()
- if isinstance(module_type, torch.nn.Embedding)
+ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
]
- non_embedding_parameters = [
+ total_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
]
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+ total_parameters = list(self.parameters())
+
+ total_numel = []
+
+ for param in total_parameters:
+ if param.requires_grad or not only_trainable:
+ # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
+ # used for the 4bit quantization (uint8 tensors are stored)
+ if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
+ if hasattr(param, "element_size"):
+ num_bytes = param.element_size()
+ elif hasattr(param, "quant_storage"):
+ num_bytes = param.quant_storage.itemsize
+ else:
+ num_bytes = 1
+ total_numel.append(param.numel() * 2 * num_bytes)
+ else:
+ total_numel.append(param.numel())
+
+ return sum(total_numel)
+
+ def get_memory_footprint(self, return_buffers=True):
+ r"""
+ Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
+ Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
+ PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
+
+ Arguments:
+ return_buffers (`bool`, *optional*, defaults to `True`):
+ Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
+ are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
+ norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
+ """
+ mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
+ if return_buffers:
+ mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
+ mem = mem + mem_bufs
+ return mem
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
deprecated_attention_block_paths = []
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 2be0c5e7310c..2e1858b16148 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -44,6 +44,7 @@
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
+from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
CONFIG_NAME,
@@ -54,6 +55,7 @@
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
+ is_transformers_version,
logging,
numpy_to_pil,
)
@@ -432,18 +434,23 @@ def module_is_offloaded(module):
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for module in modules:
- is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
+ _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
- if is_loaded_in_8bit and dtype is not None:
+ if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
logger.warning(
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
)
- if is_loaded_in_8bit and device is not None:
+ if is_loaded_in_8bit_bnb and device is not None:
logger.warning(
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)
- else:
+
+ # This can happen for `transformer` models. CPU placement was added in
+ # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
+ if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
+ module.to(device=device)
+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
module.to(device, dtype)
if (
@@ -1040,9 +1047,18 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
hook = None
for model_str in self.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str, None)
+
if not isinstance(model, torch.nn.Module):
continue
+ # This is because the model would already be placed on a CUDA device.
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
+ if is_loaded_in_8bit_bnb:
+ logger.info(
+ f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
+ )
+ continue
+
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
self._all_hooks.append(hook)
diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py
new file mode 100644
index 000000000000..93852d29ef59
--- /dev/null
+++ b/src/diffusers/quantizers/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer
+from .base import DiffusersQuantizer
diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py
new file mode 100644
index 000000000000..f231f279e13a
--- /dev/null
+++ b/src/diffusers/quantizers/auto.py
@@ -0,0 +1,137 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
+"""
+import warnings
+from typing import Dict, Optional, Union
+
+from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
+from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
+
+
+AUTO_QUANTIZER_MAPPING = {
+ "bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
+ "bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
+}
+
+AUTO_QUANTIZATION_CONFIG_MAPPING = {
+ "bitsandbytes_4bit": BitsAndBytesConfig,
+ "bitsandbytes_8bit": BitsAndBytesConfig,
+}
+
+
+class DiffusersAutoQuantizationConfig:
+ """
+ The auto diffusers quantization config class that takes care of automatically dispatching to the correct
+ quantization config given a quantization config stored in a dictionary.
+ """
+
+ @classmethod
+ def from_dict(cls, quantization_config_dict: Dict):
+ quant_method = quantization_config_dict.get("quant_method", None)
+ # We need a special care for bnb models to make sure everything is BC ..
+ if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
+ suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
+ quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
+ elif quant_method is None:
+ raise ValueError(
+ "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
+ )
+
+ if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
+ raise ValueError(
+ f"Unknown quantization type, got {quant_method} - supported types are:"
+ f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
+ )
+
+ target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
+ return target_cls.from_dict(quantization_config_dict)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
+ if getattr(model_config, "quantization_config", None) is None:
+ raise ValueError(
+ f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
+ )
+ quantization_config_dict = model_config.quantization_config
+ quantization_config = cls.from_dict(quantization_config_dict)
+ # Update with potential kwargs that are passed through from_pretrained.
+ quantization_config.update(kwargs)
+ return quantization_config
+
+
+class DiffusersAutoQuantizer:
+ """
+ The auto diffusers quantizer class that takes care of automatically instantiating to the correct
+ `DiffusersQuantizer` given the `QuantizationConfig`.
+ """
+
+ @classmethod
+ def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
+ # Convert it to a QuantizationConfig if the q_config is a dict
+ if isinstance(quantization_config, dict):
+ quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
+
+ quant_method = quantization_config.quant_method
+
+ # Again, we need a special care for bnb as we have a single quantization config
+ # class for both 4-bit and 8-bit quantization
+ if quant_method == QuantizationMethod.BITS_AND_BYTES:
+ if quantization_config.load_in_8bit:
+ quant_method += "_8bit"
+ else:
+ quant_method += "_4bit"
+
+ if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
+ raise ValueError(
+ f"Unknown quantization type, got {quant_method} - supported types are:"
+ f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
+ )
+
+ target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
+ return target_cls(quantization_config, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ return cls.from_config(quantization_config)
+
+ @classmethod
+ def merge_quantization_configs(
+ cls,
+ quantization_config: Union[dict, QuantizationConfigMixin],
+ quantization_config_from_args: Optional[QuantizationConfigMixin],
+ ):
+ """
+ handles situations where both quantization_config from args and quantization_config from model config are
+ present.
+ """
+ if quantization_config_from_args is not None:
+ warning_msg = (
+ "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
+ " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
+ )
+ else:
+ warning_msg = ""
+
+ if isinstance(quantization_config, dict):
+ quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
+
+ if warning_msg != "":
+ warnings.warn(warning_msg)
+
+ return quantization_config
diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py
new file mode 100644
index 000000000000..017136a98854
--- /dev/null
+++ b/src/diffusers/quantizers/base.py
@@ -0,0 +1,230 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/quantizers/base.py
+"""
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from ..utils import is_torch_available
+from .quantization_config import QuantizationConfigMixin
+
+
+if TYPE_CHECKING:
+ from ..models.modeling_utils import ModelMixin
+
+if is_torch_available():
+ import torch
+
+
+class DiffusersQuantizer(ABC):
+ """
+ Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or
+ quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be
+ easily used outside the scope of that method yet.
+
+ Attributes
+ quantization_config (`diffusers.quantizers.quantization_config.QuantizationConfigMixin`):
+ The quantization config that defines the quantization parameters of your model that you want to quantize.
+ modules_to_not_convert (`List[str]`, *optional*):
+ The list of module names to not convert when quantizing the model.
+ required_packages (`List[str]`, *optional*):
+ The list of required pip packages to install prior to using the quantizer
+ requires_calibration (`bool`):
+ Whether the quantization method requires to calibrate the model before using it.
+ """
+
+ requires_calibration = False
+ required_packages = None
+
+ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
+ self.quantization_config = quantization_config
+
+ # -- Handle extra kwargs below --
+ self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
+ self.pre_quantized = kwargs.pop("pre_quantized", True)
+
+ if not self.pre_quantized and self.requires_calibration:
+ raise ValueError(
+ f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized."
+ f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to "
+ f"pass `pre_quantized=True` while knowing what you are doing."
+ )
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ """
+ Some quantization methods require to explicitly set the dtype of the model to a target dtype. You need to
+ override this method in case you want to make sure that behavior is preserved
+
+ Args:
+ torch_dtype (`torch.dtype`):
+ The input dtype that is passed in `from_pretrained`
+ """
+ return torch_dtype
+
+ def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+ """
+ Override this method if you want to pass a override the existing device map with a new one. E.g. for
+ bitsandbytes, since `accelerate` is a hard requirement, if no device_map is passed, the device_map is set to
+ `"auto"``
+
+ Args:
+ device_map (`Union[dict, str]`, *optional*):
+ The device_map that is passed through the `from_pretrained` method.
+ """
+ return device_map
+
+ def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ """
+ Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained` to compute the
+ device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype` to `torch.int8`
+ and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`.
+
+ Args:
+ torch_dtype (`torch.dtype`, *optional*):
+ The torch_dtype that is used to compute the device_map.
+ """
+ return torch_dtype
+
+ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
+ """
+ Override this method if you want to adjust the `missing_keys`.
+
+ Args:
+ missing_keys (`List[str]`, *optional*):
+ The list of missing keys in the checkpoint compared to the state dict of the model
+ """
+ return missing_keys
+
+ def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
+ """
+ returns dtypes for modules that are not quantized - used for the computation of the device_map in case one
+ passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in
+ `_process_model_before_weight_loading`. `diffusers` models don't have any `modules_to_not_convert` attributes
+ yet but this can change soon in the future.
+
+ Args:
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
+ The model to quantize
+ torch_dtype (`torch.dtype`):
+ The dtype passed in `from_pretrained` method.
+ """
+
+ return {
+ name: torch_dtype
+ for name, _ in model.named_parameters()
+ if any(m in name for m in self.modules_to_not_convert)
+ }
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
+ return max_memory
+
+ def check_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ) -> bool:
+ """
+ checks if a loaded state_dict component is part of quantized param + some validation; only defined for
+ quantization methods that require to create a new parameters for quantization.
+ """
+ return False
+
+ def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
+ """
+ takes needed components from state_dict and creates quantized param.
+ """
+ if not hasattr(self, "check_quantized_param"):
+ raise AttributeError(
+ f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
+ )
+
+ def validate_environment(self, *args, **kwargs):
+ """
+ This method is used to potentially check for potential conflicts with arguments that are passed in
+ `from_pretrained`. You need to define it for all future quantizers that are integrated with diffusers. If no
+ explicit check are needed, simply return nothing.
+ """
+ return
+
+ def preprocess_model(self, model: "ModelMixin", **kwargs):
+ """
+ Setting model attributes and/or converting model before weights loading. At this point the model should be
+ initialized on the meta device so you can freely manipulate the skeleton of the model in order to replace
+ modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.
+
+ Args:
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
+ The model to quantize
+ kwargs (`dict`, *optional*):
+ The keyword arguments that are passed along `_process_model_before_weight_loading`.
+ """
+ model.is_quantized = True
+ model.quantization_method = self.quantization_config.quant_method
+ return self._process_model_before_weight_loading(model, **kwargs)
+
+ def postprocess_model(self, model: "ModelMixin", **kwargs):
+ """
+ Post-process the model post weights loading. Make sure to override the abstract method
+ `_process_model_after_weight_loading`.
+
+ Args:
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
+ The model to quantize
+ kwargs (`dict`, *optional*):
+ The keyword arguments that are passed along `_process_model_after_weight_loading`.
+ """
+ return self._process_model_after_weight_loading(model, **kwargs)
+
+ def dequantize(self, model):
+ """
+ Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note
+ not all quantization schemes support this.
+ """
+ model = self._dequantize(model)
+
+ # Delete quantizer and quantization config
+ del model.hf_quantizer
+
+ return model
+
+ def _dequantize(self, model):
+ raise NotImplementedError(
+ f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
+ )
+
+ @abstractmethod
+ def _process_model_before_weight_loading(self, model, **kwargs):
+ ...
+
+ @abstractmethod
+ def _process_model_after_weight_loading(self, model, **kwargs):
+ ...
+
+ @property
+ @abstractmethod
+ def is_serializable(self):
+ ...
+
+ @property
+ @abstractmethod
+ def is_trainable(self):
+ ...
diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py
new file mode 100644
index 000000000000..9e745bc810fa
--- /dev/null
+++ b/src/diffusers/quantizers/bitsandbytes/__init__.py
@@ -0,0 +1,2 @@
+from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
+from .utils import dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear
diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
new file mode 100644
index 000000000000..e3041aba60ae
--- /dev/null
+++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
@@ -0,0 +1,549 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py
+"""
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from ...utils import get_module_from_name
+from ..base import DiffusersQuantizer
+
+
+if TYPE_CHECKING:
+ from ...models.modeling_utils import ModelMixin
+
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bitsandbytes_available,
+ is_bitsandbytes_version,
+ is_torch_available,
+ logging,
+)
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
+ """
+ 4-bit quantization from bitsandbytes.py quantization method:
+ before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the
+ layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call saving:
+ from state dict, as usual; saves weights and `quant_state` components
+ loading:
+ need to locate `quant_state` components and pass to Param4bit constructor
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_calibration = False
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ if self.quantization_config.llm_int8_skip_modules is not None:
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ def validate_environment(self, *args, **kwargs):
+ if not torch.cuda.is_available():
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
+ raise ImportError(
+ "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
+ )
+ if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"):
+ raise ImportError(
+ "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
+ )
+
+ if kwargs.get("from_flax", False):
+ raise ValueError(
+ "Converting into 4-bit weights from flax weights is currently not supported, please make"
+ " sure the weights are in PyTorch format."
+ )
+
+ device_map = kwargs.get("device_map", None)
+ if (
+ device_map is not None
+ and isinstance(device_map, dict)
+ and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
+ ):
+ device_map_without_no_convert = {
+ key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
+ }
+ if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
+ raise ValueError(
+ "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
+ "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
+ "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to "
+ "`from_pretrained`. Check "
+ "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
+ "for more details. "
+ )
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if target_dtype != torch.int8:
+ from accelerate.utils import CustomDtype
+
+ logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
+ return CustomDtype.INT4
+ else:
+ raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
+
+ def check_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ) -> bool:
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
+ # Add here check for loaded components' dtypes once serialization is implemented
+ return True
+ elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
+ # bias could be loaded by regular set_module_tensor_to_device() from accelerate,
+ # but it would wrongly use uninitialized weight there.
+ return True
+ else:
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Dict[str, Any],
+ unexpected_keys: Optional[List[str]] = None,
+ ):
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+
+ if tensor_name not in module._parameters:
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
+
+ old_value = getattr(module, tensor_name)
+
+ if tensor_name == "bias":
+ if param_value is None:
+ new_value = old_value.to(target_device)
+ else:
+ new_value = param_value.to(target_device)
+
+ new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
+ module._parameters[tensor_name] = new_value
+ return
+
+ if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
+ raise ValueError("this function only loads `Linear4bit components`")
+ if (
+ old_value.device == torch.device("meta")
+ and target_device not in ["meta", torch.device("meta")]
+ and param_value is None
+ ):
+ raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
+
+ # construct `new_value` for the module._parameters[tensor_name]:
+ if self.pre_quantized:
+ # 4bit loading. Collecting components for restoring quantized weight
+ # This can be expanded to make a universal call for any quantized weight loading
+
+ if not self.is_serializable:
+ raise ValueError(
+ "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. "
+ "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
+ )
+
+ if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
+ param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
+ ):
+ raise ValueError(
+ f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
+ )
+
+ quantized_stats = {}
+ for k, v in state_dict.items():
+ # `startswith` to counter for edge cases where `param_name`
+ # substring can be present in multiple places in the `state_dict`
+ if param_name + "." in k and k.startswith(param_name):
+ quantized_stats[k] = v
+ if unexpected_keys is not None and k in unexpected_keys:
+ unexpected_keys.remove(k)
+
+ new_value = bnb.nn.Params4bit.from_prequantized(
+ data=param_value,
+ quantized_stats=quantized_stats,
+ requires_grad=False,
+ device=target_device,
+ )
+ else:
+ new_value = param_value.to("cpu")
+ kwargs = old_value.__dict__
+ new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
+
+ module._parameters[tensor_name] = new_value
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ # need more space for buffers that are created during quantization
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
+ logger.info(
+ "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
+ "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
+ "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
+ " torch_dtype=torch.float16 to remove this warning.",
+ torch_dtype,
+ )
+ torch_dtype = torch.float16
+ return torch_dtype
+
+ # (sayakpaul): I think it could be better to disable custom `device_map`s
+ # for the first phase of the integration in the interest of simplicity.
+ # Commenting this for discussions on the PR.
+ # def update_device_map(self, device_map):
+ # if device_map is None:
+ # device_map = {"": torch.cuda.current_device()}
+ # logger.info(
+ # "The device_map was not initialized. "
+ # "Setting device_map to {'':torch.cuda.current_device()}. "
+ # "If you want to use the model for inference, please set device_map ='auto' "
+ # )
+ # return device_map
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ from .utils import replace_with_bnb_linear
+
+ load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
+
+ # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+
+ if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
+ raise ValueError(
+ "If you want to offload some keys to `cpu` or `disk`, you need to set "
+ "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
+ " converted to 8-bit but kept in 32-bit."
+ )
+ self.modules_to_not_convert.extend(keys_on_cpu)
+
+ # Purge `None`.
+ # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
+ # in case of diffusion transformer models. For language models and others alike, `lm_head`
+ # and tied modules are usually kept in FP32.
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
+
+ model = replace_with_bnb_linear(
+ model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ model.config.quantization_config = self.quantization_config
+
+ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
+ model.is_loaded_in_4bit = True
+ model.is_4bit_serializable = self.is_serializable
+ return model
+
+ @property
+ def is_serializable(self):
+ # Because we're mandating `bitsandbytes` 0.43.3.
+ return True
+
+ @property
+ def is_trainable(self) -> bool:
+ # Because we're mandating `bitsandbytes` 0.43.3.
+ return True
+
+ def _dequantize(self, model):
+ from .utils import dequantize_and_replace
+
+ is_model_on_cpu = model.device.type == "cpu"
+ if is_model_on_cpu:
+ logger.info(
+ "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
+ )
+ model.to(torch.cuda.current_device())
+
+ model = dequantize_and_replace(
+ model, self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ if is_model_on_cpu:
+ model.to("cpu")
+ return model
+
+
+class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
+ """
+ 8-bit quantization from bitsandbytes quantization method:
+ before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the
+ layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call
+ saving:
+ from state dict, as usual; saves weights and 'SCB' component
+ loading:
+ need to locate SCB component and pass to the Linear8bitLt object
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_calibration = False
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ if self.quantization_config.llm_int8_skip_modules is not None:
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit
+ def validate_environment(self, *args, **kwargs):
+ if not torch.cuda.is_available():
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
+ raise ImportError(
+ "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
+ )
+ if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"):
+ raise ImportError(
+ "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
+ )
+
+ if kwargs.get("from_flax", False):
+ raise ValueError(
+ "Converting into 8-bit weights from flax weights is currently not supported, please make"
+ " sure the weights are in PyTorch format."
+ )
+
+ device_map = kwargs.get("device_map", None)
+ if (
+ device_map is not None
+ and isinstance(device_map, dict)
+ and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
+ ):
+ device_map_without_no_convert = {
+ key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
+ }
+ if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
+ raise ValueError(
+ "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
+ "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
+ "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to "
+ "`from_pretrained`. Check "
+ "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
+ "for more details. "
+ )
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ # need more space for buffers that are created during quantization
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_torch_dtype
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
+ logger.info(
+ "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
+ "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
+ "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
+ " torch_dtype=torch.float16 to remove this warning.",
+ torch_dtype,
+ )
+ torch_dtype = torch.float16
+ return torch_dtype
+
+ # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
+ # def update_device_map(self, device_map):
+ # if device_map is None:
+ # device_map = {"": torch.cuda.current_device()}
+ # logger.info(
+ # "The device_map was not initialized. "
+ # "Setting device_map to {'':torch.cuda.current_device()}. "
+ # "If you want to use the model for inference, please set device_map ='auto' "
+ # )
+ # return device_map
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if target_dtype != torch.int8:
+ logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
+ return torch.int8
+
+ def check_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ):
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params):
+ if self.pre_quantized:
+ if param_name.replace("weight", "SCB") not in state_dict.keys():
+ raise ValueError("Missing quantization component `SCB`")
+ if param_value.dtype != torch.int8:
+ raise ValueError(
+ f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`."
+ )
+ return True
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Dict[str, Any],
+ unexpected_keys: Optional[List[str]] = None,
+ ):
+ import bitsandbytes as bnb
+
+ fp16_statistics_key = param_name.replace("weight", "SCB")
+ fp16_weights_format_key = param_name.replace("weight", "weight_format")
+
+ fp16_statistics = state_dict.get(fp16_statistics_key, None)
+ fp16_weights_format = state_dict.get(fp16_weights_format_key, None)
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if tensor_name not in module._parameters:
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
+
+ old_value = getattr(module, tensor_name)
+
+ if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
+ raise ValueError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.")
+ if (
+ old_value.device == torch.device("meta")
+ and target_device not in ["meta", torch.device("meta")]
+ and param_value is None
+ ):
+ raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
+
+ new_value = param_value.to("cpu")
+ if self.pre_quantized and not self.is_serializable:
+ raise ValueError(
+ "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
+ "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
+ )
+
+ kwargs = old_value.__dict__
+ new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device)
+
+ module._parameters[tensor_name] = new_value
+ if fp16_statistics is not None:
+ setattr(module.weight, "SCB", fp16_statistics.to(target_device))
+ if unexpected_keys is not None:
+ unexpected_keys.remove(fp16_statistics_key)
+
+ # We just need to pop the `weight_format` keys from the state dict to remove unneeded
+ # messages. The correct format is correctly retrieved during the first forward pass.
+ if fp16_weights_format is not None and unexpected_keys is not None:
+ unexpected_keys.remove(fp16_weights_format_key)
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
+ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
+ model.is_loaded_in_8bit = True
+ model.is_8bit_serializable = self.is_serializable
+ return model
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ from .utils import replace_with_bnb_linear
+
+ load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
+
+ # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+
+ if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
+ raise ValueError(
+ "If you want to offload some keys to `cpu` or `disk`, you need to set "
+ "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
+ " converted to 8-bit but kept in 32-bit."
+ )
+ self.modules_to_not_convert.extend(keys_on_cpu)
+
+ # Purge `None`.
+ # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
+ # in case of diffusion transformer models. For language models and others alike, `lm_head`
+ # and tied modules are usually kept in FP32.
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
+
+ model = replace_with_bnb_linear(
+ model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ model.config.quantization_config = self.quantization_config
+
+ @property
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
+ def is_serializable(self):
+ # Because we're mandating `bitsandbytes` 0.43.3.
+ return True
+
+ @property
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
+ def is_trainable(self) -> bool:
+ # Because we're mandating `bitsandbytes` 0.43.3.
+ return True
+
+ def _dequantize(self, model):
+ from .utils import dequantize_and_replace
+
+ model = dequantize_and_replace(
+ model, self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ return model
diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py
new file mode 100644
index 000000000000..03755db3d1ec
--- /dev/null
+++ b/src/diffusers/quantizers/bitsandbytes/utils.py
@@ -0,0 +1,306 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py
+"""
+
+import inspect
+from inspect import signature
+from typing import Union
+
+from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
+from ..quantization_config import QuantizationMethod
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+
+if is_accelerate_available():
+ import accelerate
+ from accelerate import init_empty_weights
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
+
+logger = logging.get_logger(__name__)
+
+
+def _replace_with_bnb_linear(
+ model,
+ modules_to_not_convert=None,
+ current_key_name=None,
+ quantization_config=None,
+ has_been_replaced=False,
+):
+ """
+ Private method that wraps the recursion for module replacement.
+
+ Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
+ """
+ for name, module in model.named_children():
+ if current_key_name is None:
+ current_key_name = []
+ current_key_name.append(name)
+
+ if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
+ # Check if the current key is not in the `modules_to_not_convert`
+ current_key_name_str = ".".join(current_key_name)
+ if not any(
+ (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
+ ):
+ with init_empty_weights():
+ in_features = module.in_features
+ out_features = module.out_features
+
+ if quantization_config.quantization_method() == "llm_int8":
+ model._modules[name] = bnb.nn.Linear8bitLt(
+ in_features,
+ out_features,
+ module.bias is not None,
+ has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
+ threshold=quantization_config.llm_int8_threshold,
+ )
+ has_been_replaced = True
+ else:
+ if (
+ quantization_config.llm_int8_skip_modules is not None
+ and name in quantization_config.llm_int8_skip_modules
+ ):
+ pass
+ else:
+ extra_kwargs = (
+ {"quant_storage": quantization_config.bnb_4bit_quant_storage}
+ if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
+ else {}
+ )
+ model._modules[name] = bnb.nn.Linear4bit(
+ in_features,
+ out_features,
+ module.bias is not None,
+ quantization_config.bnb_4bit_compute_dtype,
+ compress_statistics=quantization_config.bnb_4bit_use_double_quant,
+ quant_type=quantization_config.bnb_4bit_quant_type,
+ **extra_kwargs,
+ )
+ has_been_replaced = True
+ # Store the module class in case we need to transpose the weight later
+ model._modules[name].source_cls = type(module)
+ # Force requires grad to False to avoid unexpected errors
+ model._modules[name].requires_grad_(False)
+ if len(list(module.children())) > 0:
+ _, has_been_replaced = _replace_with_bnb_linear(
+ module,
+ modules_to_not_convert,
+ current_key_name,
+ quantization_config,
+ has_been_replaced=has_been_replaced,
+ )
+ # Remove the last key for recursion
+ current_key_name.pop(-1)
+ return model, has_been_replaced
+
+
+def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
+ """
+ Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or
+ `bnb.nn.Linear4bit` using the `bitsandbytes` library.
+
+ References:
+ * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at
+ Scale](https://arxiv.org/abs/2208.07339)
+ * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
+
+ Parameters:
+ model (`torch.nn.Module`):
+ Input model or `torch.nn.Module` as the function is run recursively.
+ modules_to_not_convert (`List[`str`]`, *optional*, defaults to `[]`):
+ Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in
+ full precision for numerical stability reasons.
+ current_key_name (`List[`str`]`, *optional*):
+ An array to track the current key of the recursion. This is used to check whether the current key (part of
+ it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
+ `disk`).
+ quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'):
+ To configure and manage settings related to quantization, a technique used to compress neural network
+ models by reducing the precision of the weights and activations, thus making models more efficient in terms
+ of both storage and computation.
+ """
+ model, has_been_replaced = _replace_with_bnb_linear(
+ model, modules_to_not_convert, current_key_name, quantization_config
+ )
+
+ if not has_been_replaced:
+ logger.warning(
+ "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
+ " Please double check your model architecture, or submit an issue on github if you think this is"
+ " a bug."
+ )
+
+ return model
+
+
+# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
+def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
+ """
+ Helper function to dequantize 4bit or 8bit bnb weights.
+
+ If the weight is not a bnb quantized weight, it will be returned as is.
+ """
+ if not isinstance(weight, torch.nn.Parameter):
+ raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
+
+ cls_name = weight.__class__.__name__
+ if cls_name not in ("Params4bit", "Int8Params"):
+ return weight
+
+ if cls_name == "Params4bit":
+ output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
+ logger.warning_once(
+ f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
+ )
+ return output_tensor
+
+ if state.SCB is None:
+ state.SCB = weight.SCB
+
+ im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
+ im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
+ im, Sim = bnb.functional.transform(im, "col32")
+ if state.CxB is None:
+ state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
+ out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
+ return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
+
+
+def _create_accelerate_new_hook(old_hook):
+ r"""
+ Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
+ https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
+ some changes
+ """
+ old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
+ old_hook_attr = old_hook.__dict__
+ filtered_old_hook_attr = {}
+ old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
+ for k in old_hook_attr.keys():
+ if k in old_hook_init_signature.parameters:
+ filtered_old_hook_attr[k] = old_hook_attr[k]
+ new_hook = old_hook_cls(**filtered_old_hook_attr)
+ return new_hook
+
+
+def _dequantize_and_replace(
+ model,
+ modules_to_not_convert=None,
+ current_key_name=None,
+ quantization_config=None,
+ has_been_replaced=False,
+):
+ """
+ Converts a quantized model into its dequantized original version. The newly converted model will have some
+ performance drop compared to the original model before quantization - use it only for specific usecases such as
+ QLoRA adapters merging.
+
+ Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
+ """
+ quant_method = quantization_config.quantization_method()
+
+ target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
+
+ for name, module in model.named_children():
+ if current_key_name is None:
+ current_key_name = []
+ current_key_name.append(name)
+
+ if isinstance(module, target_cls) and name not in modules_to_not_convert:
+ # Check if the current key is not in the `modules_to_not_convert`
+ current_key_name_str = ".".join(current_key_name)
+
+ if not any(
+ (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
+ ):
+ bias = getattr(module, "bias", None)
+
+ device = module.weight.device
+ with init_empty_weights():
+ new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
+
+ if quant_method == "llm_int8":
+ state = module.state
+ else:
+ state = None
+
+ new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
+
+ if bias is not None:
+ new_module.bias = bias
+
+ # Create a new hook and attach it in case we use accelerate
+ if hasattr(module, "_hf_hook"):
+ old_hook = module._hf_hook
+ new_hook = _create_accelerate_new_hook(old_hook)
+
+ remove_hook_from_module(module)
+ add_hook_to_module(new_module, new_hook)
+
+ new_module.to(device)
+ model._modules[name] = new_module
+ has_been_replaced = True
+ if len(list(module.children())) > 0:
+ _, has_been_replaced = _dequantize_and_replace(
+ module,
+ modules_to_not_convert,
+ current_key_name,
+ quantization_config,
+ has_been_replaced=has_been_replaced,
+ )
+ # Remove the last key for recursion
+ current_key_name.pop(-1)
+ return model, has_been_replaced
+
+
+def dequantize_and_replace(
+ model,
+ modules_to_not_convert=None,
+ quantization_config=None,
+):
+ model, has_been_replaced = _dequantize_and_replace(
+ model,
+ modules_to_not_convert=modules_to_not_convert,
+ quantization_config=quantization_config,
+ )
+
+ if not has_been_replaced:
+ logger.warning(
+ "For some reason the model has not been properly dequantized. You might see unexpected behavior."
+ )
+
+ return model
+
+
+def _check_bnb_status(module) -> Union[bool, bool]:
+ is_loaded_in_4bit_bnb = (
+ hasattr(module, "is_loaded_in_4bit")
+ and module.is_loaded_in_4bit
+ and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+ )
+ is_loaded_in_8bit_bnb = (
+ hasattr(module, "is_loaded_in_8bit")
+ and module.is_loaded_in_8bit
+ and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+ )
+ return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb
diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py
new file mode 100644
index 000000000000..f521c5d717d6
--- /dev/null
+++ b/src/diffusers/quantizers/quantization_config.py
@@ -0,0 +1,391 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/utils/quantization_config.py
+"""
+
+import copy
+import importlib.metadata
+import json
+import os
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, Union
+
+from packaging import version
+
+from ..utils import is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class QuantizationMethod(str, Enum):
+ BITS_AND_BYTES = "bitsandbytes"
+
+
+@dataclass
+class QuantizationConfigMixin:
+ """
+ Mixin class for quantization config
+ """
+
+ quant_method: QuantizationMethod
+ _exclude_attributes_at_init = []
+
+ @classmethod
+ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
+ """
+ Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters.
+
+ Args:
+ config_dict (`Dict[str, Any]`):
+ Dictionary that will be used to instantiate the configuration object.
+ return_unused_kwargs (`bool`,*optional*, defaults to `False`):
+ Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
+ `PreTrainedModel`.
+ kwargs (`Dict[str, Any]`):
+ Additional parameters from which to initialize the configuration object.
+
+ Returns:
+ [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
+ """
+
+ config = cls(**config_dict)
+
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ if return_unused_kwargs:
+ return config, kwargs
+ else:
+ return config
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default
+ `QuantizationConfig()` is serialized to JSON file.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ config_dict = self.to_dict()
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ writer.write(json_string)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ return copy.deepcopy(self.__dict__)
+
+ def __iter__(self):
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
+ for attr, value in copy.deepcopy(self.__dict__).items():
+ yield attr, value
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ def to_json_string(self, use_diff: bool = True) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Args:
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
+ is serialized to JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ if use_diff is True:
+ config_dict = self.to_diff_dict()
+ else:
+ config_dict = self.to_dict()
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
+ returning all the unused kwargs.
+
+ Args:
+ kwargs (`Dict[str, Any]`):
+ Dictionary of attributes to tentatively update this class.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+
+ # Remove all the attributes that were updated, without modifying the input dict
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
+
+
+@dataclass
+class BitsAndBytesConfig(QuantizationConfigMixin):
+ """
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
+ loaded using `bitsandbytes`.
+
+ This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
+
+ Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
+ then more arguments will be added to this class.
+
+ Args:
+ load_in_8bit (`bool`, *optional*, defaults to `False`):
+ This flag is used to enable 8-bit quantization with LLM.int8().
+ load_in_4bit (`bool`, *optional*, defaults to `False`):
+ This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
+ `bitsandbytes`.
+ llm_int8_threshold (`float`, *optional*, defaults to 6.0):
+ This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
+ Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
+ that is above this threshold will be considered an outlier and the operation on those values will be done
+ in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
+ there are some exceptional systematic outliers that are very differently distributed for large models.
+ These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
+ magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
+ but a lower threshold might be needed for more unstable models (small models, fine-tuning).
+ llm_int8_skip_modules (`List[str]`, *optional*):
+ An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
+ Jukebox that has several heads in different places and not necessarily at the last position. For example
+ for `CausalLM` models, the last `lm_head` is typically kept in its original `dtype`.
+ llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
+ This flag is used for advanced use cases and users that are aware of this feature. If you want to split
+ your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
+ this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
+ operations will not be run on CPU.
+ llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
+ This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
+ have to be converted back and forth for the backward pass.
+ bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
+ This sets the computational type which might be different than the input type. For example, inputs might be
+ fp32, but computation can be set to bf16 for speedups.
+ bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`):
+ This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
+ which are specified by `fp4` or `nf4`.
+ bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
+ This flag is used for nested quantization where the quantization constants from the first quantization are
+ quantized again.
+ bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
+ This sets the storage type to pack the quanitzed 4-bit prarams.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional parameters from which to initialize the configuration object.
+ """
+
+ _exclude_attributes_at_init = ["_load_in_4bit", "_load_in_8bit", "quant_method"]
+
+ def __init__(
+ self,
+ load_in_8bit=False,
+ load_in_4bit=False,
+ llm_int8_threshold=6.0,
+ llm_int8_skip_modules=None,
+ llm_int8_enable_fp32_cpu_offload=False,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=None,
+ bnb_4bit_quant_type="fp4",
+ bnb_4bit_use_double_quant=False,
+ bnb_4bit_quant_storage=None,
+ **kwargs,
+ ):
+ self.quant_method = QuantizationMethod.BITS_AND_BYTES
+
+ if load_in_4bit and load_in_8bit:
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
+
+ self._load_in_8bit = load_in_8bit
+ self._load_in_4bit = load_in_4bit
+ self.llm_int8_threshold = llm_int8_threshold
+ self.llm_int8_skip_modules = llm_int8_skip_modules
+ self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
+ self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
+ self.bnb_4bit_quant_type = bnb_4bit_quant_type
+ self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
+
+ if bnb_4bit_compute_dtype is None:
+ self.bnb_4bit_compute_dtype = torch.float32
+ elif isinstance(bnb_4bit_compute_dtype, str):
+ self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
+ elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
+ self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
+ else:
+ raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
+
+ if bnb_4bit_quant_storage is None:
+ self.bnb_4bit_quant_storage = torch.uint8
+ elif isinstance(bnb_4bit_quant_storage, str):
+ if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
+ raise ValueError(
+ "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
+ )
+ self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
+ elif isinstance(bnb_4bit_quant_storage, torch.dtype):
+ self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
+ else:
+ raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")
+
+ if kwargs and not all(k in self._exclude_attributes_at_init for k in kwargs):
+ logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.")
+
+ self.post_init()
+
+ @property
+ def load_in_4bit(self):
+ return self._load_in_4bit
+
+ @load_in_4bit.setter
+ def load_in_4bit(self, value: bool):
+ if not isinstance(value, bool):
+ raise TypeError("load_in_4bit must be a boolean")
+
+ if self.load_in_8bit and value:
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
+ self._load_in_4bit = value
+
+ @property
+ def load_in_8bit(self):
+ return self._load_in_8bit
+
+ @load_in_8bit.setter
+ def load_in_8bit(self, value: bool):
+ if not isinstance(value, bool):
+ raise TypeError("load_in_8bit must be a boolean")
+
+ if self.load_in_4bit and value:
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
+ self._load_in_8bit = value
+
+ def post_init(self):
+ r"""
+ Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
+ """
+ if not isinstance(self.load_in_4bit, bool):
+ raise TypeError("load_in_4bit must be a boolean")
+
+ if not isinstance(self.load_in_8bit, bool):
+ raise TypeError("load_in_8bit must be a boolean")
+
+ if not isinstance(self.llm_int8_threshold, float):
+ raise TypeError("llm_int8_threshold must be a float")
+
+ if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
+ raise TypeError("llm_int8_skip_modules must be a list of strings")
+ if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
+ raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean")
+
+ if not isinstance(self.llm_int8_has_fp16_weight, bool):
+ raise TypeError("llm_int8_has_fp16_weight must be a boolean")
+
+ if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
+ raise TypeError("bnb_4bit_compute_dtype must be torch.dtype")
+
+ if not isinstance(self.bnb_4bit_quant_type, str):
+ raise TypeError("bnb_4bit_quant_type must be a string")
+
+ if not isinstance(self.bnb_4bit_use_double_quant, bool):
+ raise TypeError("bnb_4bit_use_double_quant must be a boolean")
+
+ if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
+ "0.39.0"
+ ):
+ raise ValueError(
+ "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
+ )
+
+ def is_quantizable(self):
+ r"""
+ Returns `True` if the model is quantizable, `False` otherwise.
+ """
+ return self.load_in_8bit or self.load_in_4bit
+
+ def quantization_method(self):
+ r"""
+ This method returns the quantization method used for the model. If the model is not quantizable, it returns
+ `None`.
+ """
+ if self.load_in_8bit:
+ return "llm_int8"
+ elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
+ return "fp4"
+ elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
+ return "nf4"
+ else:
+ return None
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
+ output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1]
+ output["load_in_4bit"] = self.load_in_4bit
+ output["load_in_8bit"] = self.load_in_8bit
+
+ return output
+
+ def __repr__(self):
+ config_dict = self.to_dict()
+ return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
+
+ def to_diff_dict(self) -> Dict[str, Any]:
+ """
+ Removes all attributes from config which correspond to the default config attributes for better readability and
+ serializes to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ config_dict = self.to_dict()
+
+ # get the default config dict
+ default_config_dict = BitsAndBytesConfig().to_dict()
+
+ serializable_config_dict = {}
+
+ # only serialize values that differ from the default config
+ for key, value in config_dict.items():
+ if value != default_config_dict[key]:
+ serializable_config_dict[key] = value
+
+ return serializable_config_dict
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index c7ea2bcc5b7f..c8f64adf3e8a 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -62,6 +62,7 @@
is_accelerate_available,
is_accelerate_version,
is_bitsandbytes_available,
+ is_bitsandbytes_version,
is_bs4_available,
is_flax_available,
is_ftfy_available,
@@ -94,7 +95,7 @@
is_xformers_available,
requires_backends,
)
-from .loading_utils import load_image, load_video
+from .loading_utils import get_module_from_name, load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index eaab67c93b18..10d0399a6761 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1020,6 +1020,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class DiffusersQuantizer(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AmusedScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index daecec4aa258..f1323bf00ea4 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -745,6 +745,20 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)
+def is_bitsandbytes_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current bitsandbytes version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _bitsandbytes_version:
+ return False
+ return compare_versions(parse(_bitsandbytes_version), operation, version)
+
+
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py
index b36664cb81ff..bac24fa23e63 100644
--- a/src/diffusers/utils/loading_utils.py
+++ b/src/diffusers/utils/loading_utils.py
@@ -1,6 +1,6 @@
import os
import tempfile
-from typing import Callable, List, Optional, Union
+from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import unquote, urlparse
import PIL.Image
@@ -135,3 +135,16 @@ def load_video(
pil_images = convert_method(pil_images)
return pil_images
+
+
+# Taken from `transformers`.
+def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
+ if "." in tensor_name:
+ splits = tensor_name.split(".")
+ for split in splits[:-1]:
+ new_module = getattr(module, split)
+ if new_module is None:
+ raise ValueError(f"{module} has no attribute {split}.")
+ module = new_module
+ tensor_name = splits[-1]
+ return module, tensor_name
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index a2f283d0c4f5..1179b113d636 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -1,5 +1,6 @@
import functools
import importlib
+import importlib.metadata
import inspect
import io
import logging
@@ -27,6 +28,8 @@
from .import_utils import (
BACKENDS_MAPPING,
+ is_accelerate_available,
+ is_bitsandbytes_available,
is_compel_available,
is_flax_available,
is_note_seq_available,
@@ -371,6 +374,20 @@ def require_timm(test_case):
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
+def require_bitsandbytes(test_case):
+ """
+ Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
+ """
+ return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
+
+
+def require_accelerate(test_case):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
+
+
def require_peft_version_greater(peft_version):
"""
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
@@ -418,6 +435,18 @@ def decorator(test_case):
return decorator
+def require_bitsandbytes_version_greater(bnb_version):
+ def decorator(test_case):
+ correct_bnb_version = is_bitsandbytes_available() and version.parse(
+ version.parse(importlib.metadata.version("bitsandbytes")).base_version
+ ) > version.parse(bnb_version)
+ return unittest.skipUnless(
+ correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
+ )(test_case)
+
+ return decorator
+
+
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
diff --git a/tests/quantization/bnb/README.md b/tests/quantization/bnb/README.md
new file mode 100644
index 000000000000..f1585581597d
--- /dev/null
+++ b/tests/quantization/bnb/README.md
@@ -0,0 +1,44 @@
+The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/tree/409fcfdfccde77a14b7cc36972b774cabc371ae1/tests/quantization/bnb).
+
+They were conducted on the `audace` machine, using a single RTX 4090. Below is `nvidia-smi`:
+
+```bash
++-----------------------------------------------------------------------------------------+
+| NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 |
+|-----------------------------------------+------------------------+----------------------+
+| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
+| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
+| | | MIG M. |
+|=========================================+========================+======================|
+| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
+| 30% 55C P0 61W / 450W | 1MiB / 24564MiB | 2% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+| 1 NVIDIA GeForce RTX 4090 Off | 00000000:13:00.0 Off | Off |
+| 30% 51C P0 60W / 450W | 1MiB / 24564MiB | 0% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+```
+
+`diffusers-cli`:
+
+```bash
+- 🤗 Diffusers version: 0.31.0.dev0
+- Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
+- Running on Google Colab?: No
+- Python version: 3.10.12
+- PyTorch version (GPU?): 2.5.0.dev20240818+cu124 (True)
+- Flax version (CPU?/GPU?/TPU?): not installed (NA)
+- Jax version: not installed
+- JaxLib version: not installed
+- Huggingface_hub version: 0.24.5
+- Transformers version: 4.44.2
+- Accelerate version: 0.34.0.dev0
+- PEFT version: 0.12.0
+- Bitsandbytes version: 0.43.3
+- Safetensors version: 0.4.4
+- xFormers version: not installed
+- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
+NVIDIA GeForce RTX 4090, 24564 MiB
+- Using GPU in script?: Yes
+```
\ No newline at end of file
diff --git a/tests/quantization/bnb/__init__.py b/tests/quantization/bnb/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
new file mode 100644
index 000000000000..6c1b24e31e2a
--- /dev/null
+++ b/tests/quantization/bnb/test_4bit.py
@@ -0,0 +1,584 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a clone of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+
+from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
+from diffusers.utils import logging
+from diffusers.utils.testing_utils import (
+ CaptureLogger,
+ is_bitsandbytes_available,
+ is_torch_available,
+ is_transformers_available,
+ load_pt,
+ numpy_cosine_similarity_distance,
+ require_accelerate,
+ require_bitsandbytes_version_greater,
+ require_torch,
+ require_torch_gpu,
+ require_transformers_version_greater,
+ slow,
+ torch_device,
+)
+
+
+def get_some_linear_layer(model):
+ if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
+ return model.transformer_blocks[0].attn.to_q
+ else:
+ return NotImplementedError("Don't know what layer to retrieve here.")
+
+
+if is_transformers_available():
+ from transformers import T5EncoderModel
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+ class LoRALayer(nn.Module):
+ """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
+
+ Taken from
+ https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
+ """
+
+ def __init__(self, module: nn.Module, rank: int):
+ super().__init__()
+ self.module = module
+ self.adapter = nn.Sequential(
+ nn.Linear(module.in_features, rank, bias=False),
+ nn.Linear(rank, module.out_features, bias=False),
+ )
+ small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
+ nn.init.normal_(self.adapter[0].weight, std=small_std)
+ nn.init.zeros_(self.adapter[1].weight)
+ self.adapter.to(module.weight.device)
+
+ def forward(self, input, *args, **kwargs):
+ return self.module(input, *args, **kwargs) + self.adapter(input)
+
+
+if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+
+
+@require_bitsandbytes_version_greater("0.43.2")
+@require_accelerate
+@require_torch
+@require_torch_gpu
+@slow
+class Base4bitTests(unittest.TestCase):
+ # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
+ # Therefore here we use only SD3 to test our module
+ model_name = "stabilityai/stable-diffusion-3-medium-diffusers"
+
+ # This was obtained on audace so the number might slightly change
+ expected_rel_difference = 3.69
+
+ prompt = "a beautiful sunset amidst the mountains."
+ num_inference_steps = 10
+ seed = 0
+
+ def get_dummy_inputs(self):
+ prompt_embeds = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
+ )
+ pooled_prompt_embeds = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
+ )
+ latent_model_input = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
+ )
+
+ input_dict_for_transformer = {
+ "hidden_states": latent_model_input,
+ "encoder_hidden_states": prompt_embeds,
+ "pooled_projections": pooled_prompt_embeds,
+ "timestep": torch.Tensor([1.0]),
+ "return_dict": False,
+ }
+ return input_dict_for_transformer
+
+
+class BnB4BitBasicTests(Base4bitTests):
+ def setUp(self):
+ # Models
+ self.model_fp16 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", torch_dtype=torch.float16
+ )
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ self.model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+
+ def tearDown(self):
+ del self.model_fp16
+ del self.model_4bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quantization_num_parameters(self):
+ r"""
+ Test if the number of returned parameters is correct
+ """
+ num_params_4bit = self.model_4bit.num_parameters()
+ num_params_fp16 = self.model_fp16.num_parameters()
+
+ self.assertEqual(num_params_4bit, num_params_fp16)
+
+ def test_quantization_config_json_serialization(self):
+ r"""
+ A simple test to check if the quantization config is correctly serialized and deserialized
+ """
+ config = self.model_4bit.config
+
+ self.assertTrue("quantization_config" in config)
+
+ _ = config["quantization_config"].to_dict()
+ _ = config["quantization_config"].to_diff_dict()
+
+ _ = config["quantization_config"].to_json_string()
+
+ def test_memory_footprint(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ mem_fp16 = self.model_fp16.get_memory_footprint()
+ mem_4bit = self.model_4bit.get_memory_footprint()
+
+ self.assertAlmostEqual(mem_fp16 / mem_4bit, self.expected_rel_difference, delta=1e-2)
+ linear = get_some_linear_layer(self.model_4bit)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
+
+ def test_original_dtype(self):
+ r"""
+ A simple test to check if the model succesfully stores the original dtype
+ """
+ self.assertTrue("_pre_quantization_dtype" in self.model_4bit.config)
+ self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config)
+ self.assertTrue(self.model_4bit.config["_pre_quantization_dtype"] == torch.float16)
+
+ def test_keep_modules_in_fp32(self):
+ r"""
+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
+ Also ensures if inference works.
+ """
+ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
+ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
+
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ model = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ self.assertTrue(module.weight.dtype == torch.float32)
+ else:
+ # 4-bit parameters are packed in uint8 variables
+ self.assertTrue(module.weight.dtype == torch.uint8)
+
+ # test if inference works.
+ with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ _ = model(**model_inputs)
+
+ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
+
+ def test_linear_are_4bit(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ self.model_fp16.get_memory_footprint()
+ self.model_4bit.get_memory_footprint()
+
+ for name, module in self.model_4bit.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name not in ["proj_out"]:
+ # 4-bit parameters are packed in uint8 variables
+ self.assertTrue(module.weight.dtype == torch.uint8)
+
+ def test_config_from_pretrained(self):
+ transformer_4bit = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
+ )
+ linear = get_some_linear_layer(transformer_4bit)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
+ self.assertTrue(hasattr(linear.weight, "quant_state"))
+ self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
+
+ def test_device_assignment(self):
+ mem_before = self.model_4bit.get_memory_footprint()
+
+ # Move to CPU
+ self.model_4bit.to("cpu")
+ self.assertEqual(self.model_4bit.device.type, "cpu")
+ self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
+
+ # Move back to CUDA device
+ for device in [0, "cuda", "cuda:0", "call()"]:
+ if device == "call()":
+ self.model_4bit.cuda(0)
+ else:
+ self.model_4bit.to(device)
+ self.assertEqual(self.model_4bit.device, torch.device(0))
+ self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
+ self.model_4bit.to("cpu")
+
+ def test_device_and_dtype_assignment(self):
+ r"""
+ Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
+ Checks also if other models are casted correctly. Device placement, however, is supported.
+ """
+ with self.assertRaises(ValueError):
+ # Tries with a `dtype`
+ self.model_4bit.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device` and `dtype`
+ self.model_4bit.to(device="cuda:0", dtype=torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a cast
+ self.model_4bit.float()
+
+ with self.assertRaises(ValueError):
+ # Tries with a cast
+ self.model_4bit.half()
+
+ # This should work
+ self.model_4bit.to("cuda")
+
+ # Test if we did not break anything
+ self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(dtype=torch.float32, device=torch_device)
+ for k, v in input_dict_for_transformer.items()
+ if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ with torch.no_grad():
+ _ = self.model_fp16(**model_inputs)
+
+ # Check this does not throw an error
+ _ = self.model_fp16.to("cpu")
+
+ # Check this does not throw an error
+ _ = self.model_fp16.half()
+
+ # Check this does not throw an error
+ _ = self.model_fp16.float()
+
+ # Check that this does not throw an error
+ _ = self.model_fp16.cuda()
+
+ def test_bnb_4bit_wrong_config(self):
+ r"""
+ Test whether creating a bnb config with unsupported values leads to errors.
+ """
+ with self.assertRaises(ValueError):
+ _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")
+
+
+class BnB4BitTrainingTests(Base4bitTests):
+ def setUp(self):
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ self.model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+
+ def test_training(self):
+ # Step 1: freeze all parameters
+ for param in self.model_4bit.parameters():
+ param.requires_grad = False # freeze the model - train adapters later
+ if param.ndim == 1:
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
+ param.data = param.data.to(torch.float32)
+
+ # Step 2: add adapters
+ for _, module in self.model_4bit.named_modules():
+ if "Attention" in repr(type(module)):
+ module.to_k = LoRALayer(module.to_k, rank=4)
+ module.to_q = LoRALayer(module.to_q, rank=4)
+ module.to_v = LoRALayer(module.to_v, rank=4)
+
+ # Step 3: dummy batch
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+
+ # Step 4: Check if the gradient is not None
+ with torch.amp.autocast("cuda", dtype=torch.float16):
+ out = self.model_4bit(**model_inputs)[0]
+ out.norm().backward()
+
+ for module in self.model_4bit.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+ self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
+
+
+@require_transformers_version_greater("4.44.0")
+class SlowBnb4BitTests(Base4bitTests):
+ def setUp(self) -> None:
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+ self.pipeline_4bit = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_4bit, torch_dtype=torch.float16
+ )
+ self.pipeline_4bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_4bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quality(self):
+ output = self.pipeline_4bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.1123, 0.1296, 0.1609, 0.1042, 0.1230, 0.1274, 0.0928, 0.1165, 0.1216])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ print(f"{max_diff=}")
+ self.assertTrue(max_diff < 1e-2)
+
+ def test_generate_quality_dequantize(self):
+ r"""
+ Test that loading the model and unquantize it produce correct results.
+ """
+ self.pipeline_4bit.transformer.dequantize()
+ output = self.pipeline_4bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.1216, 0.1387, 0.1584, 0.1152, 0.1318, 0.1282, 0.1062, 0.1226, 0.1228])
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
+ # Since we offloaded the `pipeline_4bit.transformer` to CPU (result of `enable_model_cpu_offload()), check
+ # the following.
+ self.assertTrue(self.pipeline_4bit.transformer.device.type == "cpu")
+ # calling it again shouldn't be a problem
+ _ = self.pipeline_4bit(
+ prompt=self.prompt,
+ num_inference_steps=2,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+ def test_moving_to_cpu_throws_warning(self):
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(30)
+ with CaptureLogger(logger) as cap_logger:
+ # Because `model.dtype` will return torch.float16 as SD3 transformer has
+ # a conv layer as the first layer.
+ _ = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_4bit, torch_dtype=torch.float16
+ ).to("cpu")
+
+ assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
+
+
+@require_transformers_version_greater("4.44.0")
+class SlowBnb4BitFluxTests(Base4bitTests):
+ def setUp(self) -> None:
+ # TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo.
+ model_id = "sayakpaul/flux.1-dev-nf4-pkg"
+ t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
+ transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
+ self.pipeline_4bit = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ text_encoder_2=t5_4bit,
+ transformer=transformer_4bit,
+ torch_dtype=torch.float16,
+ )
+ self.pipeline_4bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_4bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quality(self):
+ # keep the resolution and max tokens to a lower number for faster execution.
+ output = self.pipeline_4bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ height=256,
+ width=256,
+ max_sequence_length=64,
+ output_type="np",
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.0583, 0.0586, 0.0632, 0.0815, 0.0813, 0.0947, 0.1040, 0.1145, 0.1265])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
+
+@slow
+class BaseBnb4BitSerializationTests(Base4bitTests):
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
+ r"""
+ Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default.
+ See ExtendedSerializationTest class for more params combinations.
+ """
+
+ self.quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type=quant_type,
+ bnb_4bit_use_double_quant=double_quant,
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ )
+ model_0 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=self.quantization_config
+ )
+ self.assertTrue("_pre_quantization_dtype" in model_0.config)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
+
+ config = SD3Transformer2DModel.load_config(tmpdirname)
+ self.assertTrue("quantization_config" in config)
+ self.assertTrue("_pre_quantization_dtype" not in config)
+
+ model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)
+
+ # checking quantized linear module weight
+ linear = get_some_linear_layer(model_1)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
+ self.assertTrue(hasattr(linear.weight, "quant_state"))
+ self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
+
+ # checking memory footpring
+ self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
+
+ # Matching all parameters and their quant_state items:
+ d0 = dict(model_0.named_parameters())
+ d1 = dict(model_1.named_parameters())
+ self.assertTrue(d0.keys() == d1.keys())
+
+ for k in d0.keys():
+ self.assertTrue(d0[k].shape == d1[k].shape)
+ self.assertTrue(d0[k].device.type == d1[k].device.type)
+ self.assertTrue(d0[k].device == d1[k].device)
+ self.assertTrue(d0[k].dtype == d1[k].dtype)
+ self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
+
+ if isinstance(d0[k], bnb.nn.modules.Params4bit):
+ for v0, v1 in zip(
+ d0[k].quant_state.as_dict().values(),
+ d1[k].quant_state.as_dict().values(),
+ ):
+ if isinstance(v0, torch.Tensor):
+ self.assertTrue(torch.equal(v0, v1.to(v0.device)))
+ else:
+ self.assertTrue(v0 == v1)
+
+ # comparing forward() outputs
+ dummy_inputs = self.get_dummy_inputs()
+ inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
+ inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
+ out_0 = model_0(**inputs)[0]
+ out_1 = model_1(**inputs)[0]
+ self.assertTrue(torch.equal(out_0, out_1))
+
+
+class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
+ """
+ tests more combinations of parameters
+ """
+
+ def test_nf4_single_unsafe(self):
+ self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False)
+
+ def test_nf4_single_safe(self):
+ self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True)
+
+ def test_nf4_double_unsafe(self):
+ self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False)
+
+ # nf4 double safetensors quantization is tested in test_serialization() method from the parent class
+
+ def test_fp4_single_unsafe(self):
+ self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False)
+
+ def test_fp4_single_safe(self):
+ self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True)
+
+ def test_fp4_double_unsafe(self):
+ self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False)
+
+ def test_fp4_double_safe(self):
+ self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
new file mode 100644
index 000000000000..2e4aec39b427
--- /dev/null
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -0,0 +1,538 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a clone of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+
+from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
+from diffusers.utils.testing_utils import (
+ CaptureLogger,
+ is_bitsandbytes_available,
+ is_torch_available,
+ is_transformers_available,
+ load_pt,
+ numpy_cosine_similarity_distance,
+ require_accelerate,
+ require_bitsandbytes_version_greater,
+ require_torch,
+ require_torch_gpu,
+ require_transformers_version_greater,
+ slow,
+ torch_device,
+)
+
+
+def get_some_linear_layer(model):
+ if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
+ return model.transformer_blocks[0].attn.to_q
+ else:
+ return NotImplementedError("Don't know what layer to retrieve here.")
+
+
+if is_transformers_available():
+ from transformers import T5EncoderModel
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+ class LoRALayer(nn.Module):
+ """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
+
+ Taken from
+ https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
+ """
+
+ def __init__(self, module: nn.Module, rank: int):
+ super().__init__()
+ self.module = module
+ self.adapter = nn.Sequential(
+ nn.Linear(module.in_features, rank, bias=False),
+ nn.Linear(rank, module.out_features, bias=False),
+ )
+ small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
+ nn.init.normal_(self.adapter[0].weight, std=small_std)
+ nn.init.zeros_(self.adapter[1].weight)
+ self.adapter.to(module.weight.device)
+
+ def forward(self, input, *args, **kwargs):
+ return self.module(input, *args, **kwargs) + self.adapter(input)
+
+
+if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+
+
+@require_bitsandbytes_version_greater("0.43.2")
+@require_accelerate
+@require_torch
+@require_torch_gpu
+@slow
+class Base8bitTests(unittest.TestCase):
+ # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
+ # Therefore here we use only SD3 to test our module
+ model_name = "stabilityai/stable-diffusion-3-medium-diffusers"
+
+ # This was obtained on audace so the number might slightly change
+ expected_rel_difference = 1.94
+
+ prompt = "a beautiful sunset amidst the mountains."
+ num_inference_steps = 10
+ seed = 0
+
+ def get_dummy_inputs(self):
+ prompt_embeds = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
+ )
+ pooled_prompt_embeds = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
+ )
+ latent_model_input = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
+ )
+
+ input_dict_for_transformer = {
+ "hidden_states": latent_model_input,
+ "encoder_hidden_states": prompt_embeds,
+ "pooled_projections": pooled_prompt_embeds,
+ "timestep": torch.Tensor([1.0]),
+ "return_dict": False,
+ }
+ return input_dict_for_transformer
+
+
+class BnB8bitBasicTests(Base8bitTests):
+ def setUp(self):
+ # Models
+ self.model_fp16 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", torch_dtype=torch.float16
+ )
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ self.model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ )
+
+ def tearDown(self):
+ del self.model_fp16
+ del self.model_8bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quantization_num_parameters(self):
+ r"""
+ Test if the number of returned parameters is correct
+ """
+ num_params_8bit = self.model_8bit.num_parameters()
+ num_params_fp16 = self.model_fp16.num_parameters()
+
+ self.assertEqual(num_params_8bit, num_params_fp16)
+
+ def test_quantization_config_json_serialization(self):
+ r"""
+ A simple test to check if the quantization config is correctly serialized and deserialized
+ """
+ config = self.model_8bit.config
+
+ self.assertTrue("quantization_config" in config)
+
+ _ = config["quantization_config"].to_dict()
+ _ = config["quantization_config"].to_diff_dict()
+
+ _ = config["quantization_config"].to_json_string()
+
+ def test_memory_footprint(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ mem_fp16 = self.model_fp16.get_memory_footprint()
+ mem_8bit = self.model_8bit.get_memory_footprint()
+
+ self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2)
+ linear = get_some_linear_layer(self.model_8bit)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+
+ def test_original_dtype(self):
+ r"""
+ A simple test to check if the model succesfully stores the original dtype
+ """
+ self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config)
+ self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config)
+ self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16)
+
+ def test_keep_modules_in_fp32(self):
+ r"""
+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
+ Also ensures if inference works.
+ """
+ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
+ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
+
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ model = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ )
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ self.assertTrue(module.weight.dtype == torch.float32)
+ else:
+ # 8-bit parameters are packed in int8 variables
+ self.assertTrue(module.weight.dtype == torch.int8)
+
+ # test if inference works.
+ with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ _ = model(**model_inputs)
+
+ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
+
+ def test_linear_are_8bit(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ self.model_fp16.get_memory_footprint()
+ self.model_8bit.get_memory_footprint()
+
+ for name, module in self.model_8bit.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name not in ["proj_out"]:
+ # 8-bit parameters are packed in int8 variables
+ self.assertTrue(module.weight.dtype == torch.int8)
+
+ def test_llm_skip(self):
+ r"""
+ A simple test to check if `llm_int8_skip_modules` works as expected
+ """
+ config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=config
+ )
+ linear = get_some_linear_layer(model_8bit)
+ self.assertTrue(linear.weight.dtype == torch.int8)
+ self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))
+
+ self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
+ self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)
+
+ def test_config_from_pretrained(self):
+ transformer_8bit = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer"
+ )
+ linear = get_some_linear_layer(transformer_8bit)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+ self.assertTrue(hasattr(linear.weight, "SCB"))
+
+ def test_device_and_dtype_assignment(self):
+ r"""
+ Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
+ Checks also if other models are casted correctly.
+ """
+ with self.assertRaises(ValueError):
+ # Tries with `str`
+ self.model_8bit.to("cpu")
+
+ with self.assertRaises(ValueError):
+ # Tries with a `dtype``
+ self.model_8bit.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_8bit.to(torch.device("cuda:0"))
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_8bit.float()
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_8bit.half()
+
+ # Test if we did not break anything
+ self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(dtype=torch.float32, device=torch_device)
+ for k, v in input_dict_for_transformer.items()
+ if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ with torch.no_grad():
+ _ = self.model_fp16(**model_inputs)
+
+ # Check this does not throw an error
+ _ = self.model_fp16.to("cpu")
+
+ # Check this does not throw an error
+ _ = self.model_fp16.half()
+
+ # Check this does not throw an error
+ _ = self.model_fp16.float()
+
+ # Check that this does not throw an error
+ _ = self.model_fp16.cuda()
+
+
+class BnB8bitTrainingTests(Base8bitTests):
+ def setUp(self):
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ self.model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ )
+
+ def test_training(self):
+ # Step 1: freeze all parameters
+ for param in self.model_8bit.parameters():
+ param.requires_grad = False # freeze the model - train adapters later
+ if param.ndim == 1:
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
+ param.data = param.data.to(torch.float32)
+
+ # Step 2: add adapters
+ for _, module in self.model_8bit.named_modules():
+ if "Attention" in repr(type(module)):
+ module.to_k = LoRALayer(module.to_k, rank=4)
+ module.to_q = LoRALayer(module.to_q, rank=4)
+ module.to_v = LoRALayer(module.to_v, rank=4)
+
+ # Step 3: dummy batch
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+
+ # Step 4: Check if the gradient is not None
+ with torch.amp.autocast("cuda", dtype=torch.float16):
+ out = self.model_8bit(**model_inputs)[0]
+ out.norm().backward()
+
+ for module in self.model_8bit.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+ self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
+
+
+@require_transformers_version_greater("4.44.0")
+class SlowBnb8bitTests(Base8bitTests):
+ def setUp(self) -> None:
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ )
+ self.pipeline_8bit = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_8bit, torch_dtype=torch.float16
+ )
+ self.pipeline_8bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_8bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quality(self):
+ output = self.pipeline_8bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.0149, 0.0322, 0.0073, 0.0134, 0.0332, 0.011, 0.002, 0.0232, 0.0193])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-2)
+
+ def test_model_cpu_offload_raises_warning(self):
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
+ )
+ pipeline_8bit = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_8bit, torch_dtype=torch.float16
+ )
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(30)
+
+ with CaptureLogger(logger) as cap_logger:
+ pipeline_8bit.enable_model_cpu_offload()
+
+ assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out
+
+ def test_moving_to_cpu_throws_warning(self):
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
+ )
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(30)
+
+ with CaptureLogger(logger) as cap_logger:
+ # Because `model.dtype` will return torch.float16 as SD3 transformer has
+ # a conv layer as the first layer.
+ _ = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_8bit, torch_dtype=torch.float16
+ ).to("cpu")
+
+ assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
+
+ def test_generate_quality_dequantize(self):
+ r"""
+ Test that loading the model and unquantize it produce correct results.
+ """
+ self.pipeline_8bit.transformer.dequantize()
+ output = self.pipeline_8bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208])
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-2)
+
+ # 8bit models cannot be offloaded to CPU.
+ self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda")
+ # calling it again shouldn't be a problem
+ _ = self.pipeline_8bit(
+ prompt=self.prompt,
+ num_inference_steps=2,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+
+@require_transformers_version_greater("4.44.0")
+class SlowBnb8bitFluxTests(Base8bitTests):
+ def setUp(self) -> None:
+ # TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo.
+ model_id = "sayakpaul/flux.1-dev-int8-pkg"
+ t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
+ transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
+ self.pipeline_8bit = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ text_encoder_2=t5_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ )
+ self.pipeline_8bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_8bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quality(self):
+ # keep the resolution and max tokens to a lower number for faster execution.
+ output = self.pipeline_8bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ height=256,
+ width=256,
+ max_sequence_length=64,
+ output_type="np",
+ ).images
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
+
+@slow
+class BaseBnb8bitSerializationTests(Base8bitTests):
+ def setUp(self):
+ quantization_config = BitsAndBytesConfig(
+ load_in_8bit=True,
+ )
+ self.model_0 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=quantization_config
+ )
+
+ def tearDown(self):
+ del self.model_0
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_serialization(self):
+ r"""
+ Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default.
+ """
+ self.assertTrue("_pre_quantization_dtype" in self.model_0.config)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ self.model_0.save_pretrained(tmpdirname)
+
+ config = SD3Transformer2DModel.load_config(tmpdirname)
+ self.assertTrue("quantization_config" in config)
+ self.assertTrue("_pre_quantization_dtype" not in config)
+
+ model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)
+
+ # checking quantized linear module weight
+ linear = get_some_linear_layer(model_1)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+ self.assertTrue(hasattr(linear.weight, "SCB"))
+
+ # checking memory footpring
+ self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
+
+ # Matching all parameters and their quant_state items:
+ d0 = dict(self.model_0.named_parameters())
+ d1 = dict(model_1.named_parameters())
+ self.assertTrue(d0.keys() == d1.keys())
+
+ # comparing forward() outputs
+ dummy_inputs = self.get_dummy_inputs()
+ inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
+ inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
+ out_0 = self.model_0(**inputs)[0]
+ out_1 = model_1(**inputs)[0]
+ self.assertTrue(torch.equal(out_0, out_1))
+
+ def test_serialization_sharded(self):
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB")
+
+ config = SD3Transformer2DModel.load_config(tmpdirname)
+ self.assertTrue("quantization_config" in config)
+ self.assertTrue("_pre_quantization_dtype" not in config)
+
+ model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)
+
+ # checking quantized linear module weight
+ linear = get_some_linear_layer(model_1)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+ self.assertTrue(hasattr(linear.weight, "SCB"))
+
+ # comparing forward() outputs
+ dummy_inputs = self.get_dummy_inputs()
+ inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
+ inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
+ out_0 = self.model_0(**inputs)[0]
+ out_1 = model_1(**inputs)[0]
+ self.assertTrue(torch.equal(out_0, out_1))