forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
HfQuantizer
class for quantization-related stuff in `modeling_utils…
….py` (huggingface#26610) * squashed earlier commits for easier rebase * rm rebase leftovers * 4bit save enabled @quantizers * TMP gptq test use exllama * fix AwqConfigTest::test_wrong_backend for A100 * quantizers AWQ fixes * _load_pretrained_model low_cpu_mem_usage branch * quantizers style * remove require_low_cpu_mem_usage attr * rm dtype arg from process_model_before_weight_loading * rm config_origin from Q-config * rm inspect from q_config * fixed docstrings in QuantizationConfigParser * logger.warning fix * mv is_loaded_in_4(8)bit to BnbHFQuantizer * is_accelerate_available error msg fix in quantizer * split is_model_trainable in bnb quantizer class * rm llm_int8_skip_modules as separate var in Q * Q rm todo * fwd ref to HFQuantizer in type hint * rm note re optimum.gptq.GPTQQuantizer * quantization_config in __init__ simplified * replaced NonImplemented with create_quantized_param * rm load_in_4/8_bit deprecation warning * QuantizationConfigParser refactoring * awq-related minor changes * awq-related changes * awq config.modules_to_not_convert * raise error if no q-method in q-config in args * minor cleanup * awq quantizer docstring * combine common parts in bnb process_model_before_weight_loading * revert test_gptq * .process_model_ cleanup * restore dict config warning * removed typevars in quantizers.py * cleanup post-rebase 16 jan * QuantizationConfigParser classmethod refactor * rework of handling of unexpected aux elements of bnb weights * moved q-related stuff from save_pretrained to quantizers * refactor v1 * more changes * fix some tests * remove it from main init * ooops * Apply suggestions from code review Co-authored-by: Marc Sun <[email protected]> * fix awq issues * fix * fix * fix * fix * fix * fix * add docs * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Arthur <[email protected]> * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * Update docs/source/en/hf_quantizer.md * address comments * fix * fixup * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <[email protected]> * address final comment * update * Update src/transformers/quantizers/base.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/quantizers/auto.py Co-authored-by: Arthur <[email protected]> * fix * add kwargs update * fixup * add `optimum_quantizer` attribute * oops * rm unneeded file * fix doctests --------- Co-authored-by: younesbelkada <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Arthur <[email protected]>
- Loading branch information
1 parent
854be0c
commit a33ede6
Showing
18 changed files
with
1,443 additions
and
487 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# Contribute new quantization method | ||
|
||
Transformers supports and integrates many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are other quantization approaches that are not yet integrated. To make adding and using these quantization methods with Transformers models easier, you should use the [`HfQuantizer`] class. The [`HfQuantizer`] is designed as an internal helper class for adding a quantization method instead of something you apply to every PyTorch module. | ||
|
||
This guide will show you how to integrate a new quantization method with the [`HfQuantizer`] class. | ||
|
||
|
||
## Requirements | ||
|
||
Before integrating a new quantization method into Transformers, ensure the method you are trying to add meets the following prerequisites. Only quantization methods that can be run with PyTorch modules are currently supported. | ||
|
||
- The quantization method is available through a Python package that is pip-installable by anyone (it is also fine if you can only install the package from source). Ideally, pre-compiled kernels are included in the pip package. | ||
- The method can run on commonly-used hardware (CPU, GPU, ...). | ||
- The method is wrapped in a `nn.Module` (e.g., `Linear8bitLt`, `Linear4bit`), and the quantized linear layer should have the following definition: | ||
|
||
```py | ||
class Linear4bit(nn.Module): | ||
def __init__(self, ...): | ||
... | ||
|
||
def forward(self, x): | ||
return my_4bit_kernel(x, self.weight, self.bias) | ||
``` | ||
This way, Transformers models can be easily quantized by replacing some instances of `nn.Linear` with a target class. | ||
- The quantization method should be serializable. You can save the quantized weights locally or push them to the Hub. | ||
- Make sure the package that contains the quantization kernels/primitive is stable (no frequent breaking changes). | ||
|
||
For some quantization methods, they may require "pre-quantizing" the models through data calibration (e.g., AWQ). In this case, we prefer to only support inference in Transformers and let the third-party library maintained by the ML community deal with the model quantization itself. | ||
|
||
## Build a new HFQuantizer class | ||
|
||
1. 📕 Create a new quantization config class inside `src/transformers/utils/quantization_config.py` and make sure to expose the new quantization config inside Transformers main `init` by adding it to the `_import_structure` object of `src/transformers/__init__.py`. | ||
|
||
2- 🗃 Create a new file inside `src/transformers/quantizers/` named `quantizer_your_method.py`, and make it inherit from `src/transformers/quantizers/base.py::HfQuantizer`. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in `src/transformers/quantizers/auto.py` | ||
|
||
3- 🔩 Define the following class attributes/property methods for your quantization method: | ||
|
||
* `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization. | ||
* `required_packages`: A list of strings of the required packages to use the quantized weights. You might need to define some new utility methods such as `is_auto_awq_available` in `transformers/src/utils/import_utils.py`. | ||
* `requires_parameters_quantization`: Only required if your quantization method requires extra attention to the underlying `nn.Parameter` object. For example, bitsandbytes uses `Params4bit` and `Int8Param`, which requires some extra attention when quantizing the model. Most of the recent quantization method packs int2/int4 weights inside `torch.uint8` weights, so this flag should not be really required (set to `False` by default). | ||
* `is_serializable`: A property method to determine whether the method is serializable or not. | ||
* `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches). | ||
|
||
|
||
4- 🪛 Write the `validate_environment` and `update_torch_dtype` methods. These methods are called before creating the quantized model to ensure users use the right configuration. You can have a look at how this is done on other quantizers. | ||
|
||
5- 🖋 Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules (e.g., `nn.Linear`) with the target modules (quantization modules). You can define a module replacement logic or any other utility method by creating a new file in `transformers/src/integrations/` and exposing the relevant methods in that folder's `__init__.py` file. The best starting point would be to have a look at another quantization methods such as `quantizer_awq.py` | ||
|
||
6- 🖊 Write the `_process_model_after_weight_loading` method. This method enables implementing additional features that require manipulating the model after loading the weights. | ||
|
||
7- 📖 Document everything! Make sure your quantization method is documented in the `docs/source/en/quantization.md` file. | ||
|
||
8- 🟢 Add tests! You should add tests by first adding the package in our nightly Dockerfile inside `docker/transformers-all-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out how it is implemented for other quantization methods. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.