-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization] Add quantization support for bitsandbytes
#9213
Merged
Merged
Changes from 10 commits
Commits
Show all changes
119 commits
Select commit
Hold shift + click to select a range
e634ff2
quantization config.
sayakpaul 02a6dff
fix-copies
sayakpaul c385a2b
Merge branch 'main' into quantization-config
sayakpaul 0355875
Merge branch 'main' into quantization-config
sayakpaul e41b494
Merge branch 'main' into quantization-config
sayakpaul dfb33eb
Merge branch 'main' into quantization-config
sayakpaul e492655
Merge branch 'main' into quantization-config
sayakpaul 6e86cc0
fix
sayakpaul 58a3d15
modules_to_not_convert
sayakpaul 1d477f9
Merge branch 'main' into quantization-config
sayakpaul bd7f46d
Merge branch 'main' into quantization-config
sayakpaul d5d7bb6
Merge branch 'main' into quantization-config
sayakpaul 44c8a75
Merge branch 'main' into quantization-config
sayakpaul 6a0fcdc
add bitsandbytes utilities.
sayakpaul e4590fa
make progress.
sayakpaul 77a1438
Merge branch 'main' into quantization-config
sayakpaul 335ab6b
fixes
sayakpaul d44ef85
quality
sayakpaul 210fa1e
up
sayakpaul f4feee1
up
sayakpaul e8c1722
Merge branch 'main' into quantization-config
sayakpaul 7f86a71
Merge branch 'main' into quantization-config
sayakpaul ba671b6
minor
sayakpaul c1a9f13
up
sayakpaul 4489c54
Merge branch 'main' into quantization-config
sayakpaul f2ca5e2
up
sayakpaul d6b8954
fix
sayakpaul 45029e2
provide credits where due.
sayakpaul 4eb468a
make configurations work.
sayakpaul 939965d
fixes
sayakpaul 8557166
Merge branch 'main' into quantization-config
sayakpaul d098d07
fix
sayakpaul c4a0074
update_missing_keys
sayakpaul ee45612
fix
sayakpaul b24c0a7
fix
sayakpaul 473505c
make it work.
sayakpaul c795c82
fix
sayakpaul c1d5b96
Merge branch 'main' into quantization-config
sayakpaul af7caca
provide credits to transformers.
sayakpaul 80967f5
empty commit
sayakpaul 3bdf25a
handle to() better.
sayakpaul 27415cc
tests
sayakpaul 51cac09
change to bnb from bitsandbytes
sayakpaul 15f3032
fix tests
sayakpaul 77c9fdb
better safeguard.
sayakpaul ddc9f29
change merging status
sayakpaul 44c4109
courtesy to transformers.
sayakpaul 27666a8
move upper.
sayakpaul 3464d83
better
sayakpaul b106124
Merge branch 'main' into quantization-config
sayakpaul 330fa0a
Merge branch 'main' into quantization-config
sayakpaul abc8607
make the unused kwargs warning friendlier.
sayakpaul 31725aa
harmonize changes with https://github.com/huggingface/transformers/pu…
sayakpaul e5938a6
style
sayakpaul 444588f
trainin tests
sayakpaul d3360ce
Merge branch 'main' into quantization-config
sayakpaul d8b35f4
Merge branch 'main' into quantization-config
sayakpaul 859f2d7
Merge branch 'main' into quantization-config
sayakpaul 3b2d6e1
feedback part i.
sayakpaul 5799954
Add Flux inpainting and Flux Img2Img (#9135)
Gothos 8e4bd08
Revert "Add Flux inpainting and Flux Img2Img (#9135)"
sayakpaul 835d4ad
tests
sayakpaul 27075fe
don
sayakpaul 5c00c1c
Merge branch 'main' into quantization-config
sayakpaul 5d633a0
Merge branch 'main' into quantization-config
sayakpaul c381fe0
Apply suggestions from code review
sayakpaul 3c92878
Merge branch 'main' into quantization-config
sayakpaul acdeb25
contribution guide.
sayakpaul aa295b7
Merge branch 'main' into quantization-config
sayakpaul 7f7c9ce
Merge branch 'main' into quantization-config
sayakpaul 55f96d8
Merge branch 'main' into quantization-config
sayakpaul b28cc65
changes
sayakpaul 8328e86
Merge branch 'main' into quantization-config
sayakpaul 9758942
empty
sayakpaul b1a9878
fix tests
sayakpaul 971305b
harmonize with https://github.com/huggingface/transformers/pull/33546.
sayakpaul f41adf1
numpy_cosine_distance
sayakpaul 0bcb88b
Merge branch 'main' into quantization-config
sayakpaul 55b3696
Merge branch 'main' into quantization-config
sayakpaul 4cb3a6d
Merge branch 'main' into quantization-config
sayakpaul 8a03eae
Merge branch 'main' into quantization-config
sayakpaul 53f0a92
Merge branch 'main' into quantization-config
sayakpaul 6aab47c
Merge branch 'main' into quantization-config
sayakpaul 9b9a610
resolved conflicts,
sayakpaul 510d57a
Merge branch 'main' into quantization-config
sayakpaul 555a5ae
config_dict modification.
sayakpaul da10365
remove if config comment.
sayakpaul 71316a6
note for load_state_dict changes.
sayakpaul 12f5c59
float8 check.
sayakpaul 5e722cd
quantizer.
sayakpaul c78dd0c
raise an error for non-True low_cpu_mem_usage values when using quant.
sayakpaul af3ecea
low_cpu_mem_usage shenanigans when using fp32 modules.
sayakpaul a473d28
don't re-assign _pre_quantization_type.
sayakpaul 870d74f
make comments clear.
sayakpaul 3e6cfeb
remove comments.
sayakpaul 673993c
handle mixed types better when moving to cpu.
sayakpaul 0d5f2f7
add tests to check if we're throwing warning rightly.
sayakpaul 3cb20fe
better check.
sayakpaul 10940a9
fix 8bit test_quality.
sayakpaul c0a88ae
Merge branch 'main' into quantization-config
sayakpaul dcc5bc5
Merge branch 'main' into quantization-config
sayakpaul 5e0b4eb
Merge branch 'main' into quantization-config
sayakpaul 569dd96
Merge branch 'main' into quantization-config
sayakpaul 8bdc846
Merge branch 'main' into quantization-config
sayakpaul ff8ddef
handle dtype more robustly.
sayakpaul de6394a
better message when keep_in_fp32_modules.
sayakpaul 81bb48a
handle dtype casting.
sayakpaul c5e62ae
Merge branch 'main' into quantization-config
sayakpaul d023b40
Merge branch 'main' into quantization-config
sayakpaul a3d2655
Merge branch 'main' into quantization-config
sayakpaul 700b0f3
Merge branch 'main' into quantization-config
sayakpaul 0ae70fe
fix dtype checks in pipeline.
sayakpaul ecdf1d0
fix warning message.
sayakpaul aea3398
Update src/diffusers/models/modeling_utils.py
sayakpaul 3a91974
Merge branch 'main' into quantization-config
sayakpaul 5d8e844
Merge branch 'main' into quantization-config
sayakpaul 501a6ba
mitigate the confusing cpu warning
sayakpaul 1a931cb
Merge branch 'main' into quantization-config
sayakpaul 2fa8fb9
Merge branch 'main' into quantization-config
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .base import HfQuantizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Adapted from | ||
https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/quantizers/base.py | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||
|
||
from ..utils import is_torch_available | ||
from .quantization_config import QuantizationConfigMixin | ||
|
||
|
||
if TYPE_CHECKING: | ||
from ..models.modeling_utils import ModelMixin | ||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
|
||
class HfQuantizer(ABC): | ||
""" | ||
Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or | ||
quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be | ||
easily used outside the scope of that method yet. | ||
|
||
Attributes | ||
quantization_config (`diffusers.quantizers.quantization_config.QuantizationConfigMixin`): | ||
The quantization config that defines the quantization parameters of your model that you want to quantize. | ||
modules_to_not_convert (`List[str]`, *optional*): | ||
The list of module names to not convert when quantizing the model. | ||
required_packages (`List[str]`, *optional*): | ||
The list of required pip packages to install prior to using the quantizer | ||
requires_calibration (`bool`): | ||
Whether the quantization method requires to calibrate the model before using it. | ||
""" | ||
|
||
requires_calibration = False | ||
required_packages = None | ||
|
||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): | ||
self.quantization_config = quantization_config | ||
|
||
# -- Handle extra kwargs below -- | ||
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) | ||
self.pre_quantized = kwargs.pop("pre_quantized", True) | ||
|
||
if not self.pre_quantized and self.requires_calibration: | ||
raise ValueError( | ||
f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized." | ||
f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to " | ||
f"pass `pre_quantized=True` while knowing what you are doing." | ||
) | ||
|
||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
""" | ||
Some quantization methods require to explicitly set the dtype of the model to a target dtype. You need to | ||
override this method in case you want to make sure that behavior is preserved | ||
|
||
Args: | ||
torch_dtype (`torch.dtype`): | ||
The input dtype that is passed in `from_pretrained` | ||
""" | ||
return torch_dtype | ||
|
||
def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: | ||
""" | ||
Override this method if you want to pass a override the existing device map with a new one. E.g. for | ||
bitsandbytes, since `accelerate` is a hard requirement, if no device_map is passed, the device_map is set to | ||
`"auto"`` | ||
|
||
Args: | ||
device_map (`Union[dict, str]`, *optional*): | ||
The device_map that is passed through the `from_pretrained` method. | ||
""" | ||
return device_map | ||
|
||
def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
""" | ||
Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained` to compute the | ||
device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype` to `torch.int8` | ||
and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`. | ||
|
||
Args: | ||
torch_dtype (`torch.dtype`, *optional*): | ||
The torch_dtype that is used to compute the device_map. | ||
""" | ||
return torch_dtype | ||
|
||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: | ||
""" | ||
Override this method if you want to adjust the `missing_keys`. | ||
|
||
Args: | ||
missing_keys (`List[str]`, *optional*): | ||
The list of missing keys in the checkpoint compared to the state dict of the model | ||
""" | ||
return missing_keys | ||
|
||
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]: | ||
""" | ||
returns dtypes for modules that are not quantized - used for the computation of the device_map in case one | ||
passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in | ||
`_process_model_before_weight_loading`. `diffusers` models don't have any `modules_to_not_convert` attributes | ||
yet but this can change soon in the future. | ||
|
||
Args: | ||
model (`~diffusers.models.modeling_utils.ModelMixin`): | ||
The model to quantize | ||
torch_dtype (`torch.dtype`): | ||
The dtype passed in `from_pretrained` method. | ||
""" | ||
|
||
return { | ||
name: torch_dtype | ||
for name, _ in model.named_parameters() | ||
if any(m in name for m in self.modules_to_not_convert) | ||
} | ||
|
||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: | ||
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization""" | ||
return max_memory | ||
|
||
def check_quantized_param( | ||
self, | ||
model: "ModelMixin", | ||
param_value: "torch.Tensor", | ||
param_name: str, | ||
state_dict: Dict[str, Any], | ||
**kwargs, | ||
) -> bool: | ||
""" | ||
checks if a loaded state_dict component is part of quantized param + some validation; only defined for | ||
quantization methods that require to create a new parameters for quantization. | ||
""" | ||
return False | ||
|
||
def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter": | ||
""" | ||
takes needed components from state_dict and creates quantized param. | ||
""" | ||
if not hasattr(self, "check_quantized_param"): | ||
raise AttributeError( | ||
f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}." | ||
) | ||
|
||
def validate_environment(self, *args, **kwargs): | ||
""" | ||
This method is used to potentially check for potential conflicts with arguments that are passed in | ||
`from_pretrained`. You need to define it for all future quantizers that are integrated with diffusers. If no | ||
explicit check are needed, simply return nothing. | ||
""" | ||
return | ||
|
||
def preprocess_model(self, model: "ModelMixin", **kwargs): | ||
""" | ||
Setting model attributes and/or converting model before weights loading. At this point the model should be | ||
initialized on the meta device so you can freely manipulate the skeleton of the model in order to replace | ||
modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`. | ||
|
||
Args: | ||
model (`~diffusers.models.modeling_utils.ModelMixin`): | ||
The model to quantize | ||
kwargs (`dict`, *optional*): | ||
The keyword arguments that are passed along `_process_model_before_weight_loading`. | ||
""" | ||
model.is_quantized = True | ||
model.quantization_method = self.quantization_config.quant_method | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._process_model_before_weight_loading(model, **kwargs) | ||
|
||
def postprocess_model(self, model: "ModelMixin", **kwargs): | ||
""" | ||
Post-process the model post weights loading. Make sure to override the abstract method | ||
`_process_model_after_weight_loading`. | ||
|
||
Args: | ||
model (`~diffusers.models.modeling_utils.ModelMixin`): | ||
The model to quantize | ||
kwargs (`dict`, *optional*): | ||
The keyword arguments that are passed along `_process_model_after_weight_loading`. | ||
""" | ||
return self._process_model_after_weight_loading(model, **kwargs) | ||
|
||
def dequantize(self, model): | ||
""" | ||
Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note | ||
not all quantization schemes support this. | ||
""" | ||
model = self._dequantize(model) | ||
|
||
# Delete quantizer and quantization config | ||
del model.hf_quantizer | ||
|
||
return model | ||
|
||
def _dequantize(self, model): | ||
raise NotImplementedError( | ||
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." | ||
) | ||
|
||
@abstractmethod | ||
def _process_model_before_weight_loading(self, model, **kwargs): | ||
... | ||
|
||
@abstractmethod | ||
def _process_model_after_weight_loading(self, model, **kwargs): | ||
... | ||
|
||
@property | ||
@abstractmethod | ||
def is_serializable(self): | ||
... | ||
|
||
@property | ||
@abstractmethod | ||
def is_trainable(self): | ||
... |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO
check_is_quantized_param
orcheck_if_quantized_param
more explicitly conveys what this method does.