diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index eda1fca6a61b..168aba9dbbc7 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -47,8 +47,11 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/acc # Add bitsandbytes for mixed int8 testing RUN python3 -m pip install --no-cache-dir bitsandbytes -# For bettertransformer -RUN python3 -m pip install --no-cache-dir optimum +# Add auto-gptq for gtpq quantization testing +RUN python3 -m pip install --no-cache-dir auto-gptq + +# For bettertransformer + gptq +RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum # For video model testing RUN python3 -m pip install --no-cache-dir decord av==9.2.0 diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index e37f8cd9987a..eae71c9cc9e4 100644 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -16,6 +16,137 @@ rendered properly in your Markdown viewer. # Quantize 🤗 Transformers models +## `AutoGPTQ` Integration + +🤗 Transformers has integrated `optimum` API to perform GPTQ quantization on language models. You can load and quantize your model in 8,6,4 or even 2 bits without a big drop of performance and faster inference speed! This is supported by most GPU hardwares. + +To learn more about the the quantization model, check out: +- the [GPTQ](https://arxiv.org/pdf/2210.17323.pdf) paper + +- the [`AutoGPTQ`](https://github.com/PanQiWei/AutoGPTQ) library used as the backend + +### Requirements + +You need to have the following requirements installed to run the code below: + +- Install latest `AutoGPTQ` library +`pip install auto-gptq` + +- Install latest `optimum` from source +`pip install git+https://github.com/huggingface/optimum.git` + +- Install latest `transformers` from source +`pip install git+https://github.com/huggingface/transformers.git` + +- Install latest `accelerate` library +`pip install --upgrade accelerate` +GPTQ integration supports for now only text models and you may encounter unexpected behaviour for vision, speech or multi-modal models. + +### Load and quantize a model + +GPTQ is a quantization method that requires weights calibration before using the quantized models. If you want to quantize transformers model from scratch, it might take some time before producing the quantized model (~10 min on a Google colab for `facebook/opt-350m` model. + +Hence, there are two different scenarios where you want to use GPTQ-quantized models. The first use case would be to load models that has been already quantized by other users that are available on the Hub, the second use case would be to quantize your model from scratch and save it or push it on the Hub so that other users can also use it. +#### GPTQ Configuration + +In order to load and quantize a model, you need to create a [`GPTQConfig`]. You need to pass the number of `bits`, a `dataset` in order to calibrate the quantization and the `tokenizer` of the model in order prepare the dataset. + +```python +model_id = "facebook/opt-125m" +tokenizer = AutoTokenizer.from_pretrained(model_id) +gptq_config = GPTQConfig(bits=4, dataset = "c4", tokenizer=tokenizer) +``` + +Note that you can pass your own dataset as a list of string. However, it is highly recommended to use the dataset from the GPTQ paper. +```python +dataset = ["auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."] +quantization = GPTQConfig(bits=4, dataset = dataset, tokenizer=tokenizer) +``` + +#### Quantization + +You can quantize a model by using `from_pretrained` and setting the `quantization_config`. + +```python +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=gptq_config) +``` +Note that you will need a GPU to quantize a model. We will put the model in the cpu and move the modules back and forth to the gpu in order to quantize them. + +If you want to maximize your gpus usage while using cpu offload, you can set `device_map = "auto"`. +```python +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config) +``` +Note that disk offload is not supported. Furthermore, if you are out of memory because of the dataset, you may have to pass `max_memory` in `from_pretained`. Checkout this [guide](https://huggingface.co/docs/accelerate/usage_guides/big_modeling#designing-a-device-map) to learn more about `device_map` and `max_memory`. + + +GPTQ quantization only works for text model for now. Futhermore, the quantization process can a lot of time depending on one's hardware (175B model = 4 gpu hours using NVIDIA A100). Please check on the hub if there is not a GPTQ quantized version of the model. If not, you can submit a demand on github. + + +### Push quantized model to 🤗 Hub + +You can push the quantized model like any 🤗 model to Hub with `push_to_hub`. The quantization config will be saved and pushed along the model. + +```python +quantized_model.push_to_hub("opt-125m-gptq") +tokenizer.push_to_hub("opt-125m-gptq") +``` + +If you want to save your quantized model on your local machine, you can also do it with `save_pretrained`: +```python +quantized_model.save_pretrained("opt-125m-gptq") +tokenizer.save_pretrained("opt-125m-gptq") +``` + +Note that if you have quantized your model with a `device_map`, make sure to move the entire model to one of your gpus or the `cpu` before saving it. +```python +quantized_model.to("cpu") +quantized_model.save_pretrained("opt-125m-gptq") +``` + +### Load a quantized model from the 🤗 Hub + +You can load a quantized model from the Hub by using `from_pretrained`. +Make sure that the pushed weights are quantized, by checking that the attribute `quantization_config` is present in the model configuration object. + +```python +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq") +``` + +If you want to load a model faster and without allocating more memory than needed, the `device_map` argument also works with quantized model. Make sure that you have `accelerate` library installed. +```python +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto") +``` + +### Exllama kernels for faster inference + +For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `disable_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels. + +```py +import torch +gptq_config = GPTQConfig(bits=4, disable_exllama=False) +model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config) +``` + +Note that only 4-bit models are supported for now. Furthermore, it is recommended to deactivate the exllama kernels if you are finetuning a quantized model with peft. + +#### Fine-tune a quantized model + +With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ. +Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details. + +### Example demo + +Check out the Google Colab [notebook](https://colab.research.google.com/drive/1_TIrmuKOFhuRRiTWN94iLKUFu6ZX4ceb?usp=sharing) to learn how to quantize your model with GPTQ and how finetune the quantized model with peft. + +### GPTQConfig + +[[autodoc]] GPTQConfig + + ## `bitsandbytes` Integration 🤗 Transformers is closely integrated with most used modules on `bitsandbytes`. You can load your model in 8-bit precision with few lines of code. @@ -215,7 +346,7 @@ This section is intended to advanced users, that want to explore what it is poss One of the advanced use case of this is being able to load a model and dispatch the weights between `CPU` and `GPU`. Note that the weights that will be dispatched on CPU **will not** be converted in 8-bit, thus kept in `float32`. This feature is intended for users that want to fit a very large model and dispatch the model between GPU and CPU. -First, load a `BitsAndBytesConfig` from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`: +First, load a [`BitsAndBytesConfig`] from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`: ```python from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2253bda3908a..695e51fbe293 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -731,7 +731,7 @@ "logging", ], "utils.bitsandbytes": [], - "utils.quantization_config": ["BitsAndBytesConfig"], + "utils.quantization_config": ["BitsAndBytesConfig", "GPTQConfig"], } # sentencepiece-backed objects @@ -4703,7 +4703,7 @@ ) # bitsandbytes config - from .utils.quantization_config import BitsAndBytesConfig + from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig try: if not is_sentencepiece_available(): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9d6dc9d82b33..b1edfb72503f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -64,6 +64,7 @@ download_url, has_file, is_accelerate_available, + is_auto_gptq_available, is_bitsandbytes_available, is_offline_mode, is_optimum_available, @@ -75,7 +76,7 @@ ) from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy -from .utils.quantization_config import BitsAndBytesConfig +from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod from .utils.versions import require_version_core @@ -1915,7 +1916,7 @@ def get_memory_footprint(self, return_buffers=True): @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): # Checks if the model has been loaded in 8-bit - if getattr(self, "is_quantized", False): + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" " model has already been set to the correct devices and casted to the correct `dtype`." @@ -1926,29 +1927,29 @@ def cuda(self, *args, **kwargs): @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): # Checks if the model has been loaded in 8-bit - if getattr(self, "is_quantized", False): + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( - "`.to` is not supported for `4-bit` or `8-bit` models. Please use the model as it is, since the" + "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" " model has already been set to the correct devices and casted to the correct `dtype`." ) else: return super().to(*args, **kwargs) def half(self, *args): - # Checks if the model has been loaded in 8-bit + # Checks if the model is quantized if getattr(self, "is_quantized", False): raise ValueError( - "`.half()` is not supported for `4-bit` or `8-bit` models. Please use the model as it is, since the" + "`.half()` is not supported for quantized model. Please use the model as it is, since the" " model has already been casted to the correct `dtype`." ) else: return super().half(*args) def float(self, *args): - # Checks if the model has been loaded in 8-bit + # Checks if the model is quantized if getattr(self, "is_quantized", False): raise ValueError( - "`.float()` is not supported for `4-bit` or `8-bit` models. Please use the model as it is, since the" + "`.float()` is not supported for quantized model. Please use the model as it is, since the" " model has already been casted to the correct `dtype`." ) else: @@ -2130,9 +2131,9 @@ def from_pretrained( load_in_4bit (`bool`, *optional*, defaults to `False`): If `True`, will convert the loaded model into 4bit precision quantized model. To use this feature install the latest version of `bitsandbytes` (`pip install -U bitsandbytes`). - quantization_config (`Dict`, *optional*): - A dictionary of configuration parameters for the `bitsandbytes` library and loading the model using - advanced features such as offloading in fp32 on CPU or on disk. + quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*): + A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g + bitsandbytes, gptq) subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. @@ -2287,13 +2288,20 @@ def from_pretrained( "Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`" ) - if quantization_config is None: + quantization_method_from_args = None + if quantization_config is not None: + quantization_method_from_args = getattr( + quantization_config, "quant_method", QuantizationMethod.BITS_AND_BYTES + ) + + if quantization_config is None and (load_in_8bit or load_in_4bit): + quantization_method_from_args = QuantizationMethod.BITS_AND_BYTES quantization_config, kwargs = BitsAndBytesConfig.from_dict( config_dict={"load_in_8bit": load_in_8bit, "load_in_4bit": load_in_4bit}, return_unused_kwargs=True, **kwargs, ) - elif quantization_config is not None: + elif quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES: load_in_8bit = quantization_config.load_in_8bit load_in_4bit = quantization_config.load_in_4bit @@ -2375,15 +2383,63 @@ def from_pretrained( else: model_kwargs = kwargs - if is_8bit_serializable and quantization_config is not None and load_in_8bit: - if hasattr(config, "quantization_config"): + quantizer = None + quantization_method_from_config = None + if hasattr(config, "quantization_config"): + quantization_method_from_config = config.quantization_config.get( + "quant_method", QuantizationMethod.BITS_AND_BYTES + ) + + if quantization_method_from_config == QuantizationMethod.GPTQ and quantization_method_from_args is not None: + loading_attr_dict = quantization_config.get_loading_attributes() + for attr, val in loading_attr_dict.items(): + config.quantization_config[attr] = val + quantization_method_from_args = None + logger.warning( + "You passed `quantization_config` to `from_pretrained` but the model you're loading already has a " + "`quantization_config` attribute and has already quantized weights. However, loading attributes" + " (e.g. disable_exllama, use_cuda_fp16) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored." + ) + if ( + quantization_method_from_args == QuantizationMethod.GPTQ + or quantization_method_from_config == QuantizationMethod.GPTQ + ): + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to quantize or run quantize model.") + elif not (is_optimum_available() and is_auto_gptq_available()): + raise ImportError( + "Loading GPTQ quantized model requires optimum library : `pip install optimum` and auto-gptq library 'pip install auto-gptq'" + ) + else: + # Need to protect the import + from optimum.gptq import GPTQQuantizer + if quantization_method_from_config == QuantizationMethod.GPTQ: + quantization_config = GPTQConfig.from_dict(config.quantization_config) + config.quantization_config = quantization_config + logger.info( + f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to " + "requirements of `auto-gptq` to enable model quantization " + ) + torch_dtype = torch.float16 + quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict()) + + if ( + is_8bit_serializable + and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES + and load_in_8bit + ): + if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES: logger.warning( "You passed `quantization_config` to `from_pretrained` but the model you're loading already has a" " `quantization_config` attribute. The `quantization_config` attribute will be overwritten with the" " one you passed to `from_pretrained`." ) config.quantization_config = quantization_config - elif is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"): + elif ( + is_8bit_serializable + and not load_in_8bit + and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES + ): quantization_config = config.quantization_config if isinstance(quantization_config, dict): quantization_config = BitsAndBytesConfig.from_dict(quantization_config, return_unused_kwargs=False) @@ -2413,7 +2469,11 @@ def from_pretrained( if low_cpu_mem_usage is None: low_cpu_mem_usage = True - elif not is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"): + elif ( + not is_8bit_serializable + and not load_in_8bit + and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES + ): logger.warning( "Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct" " `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with " @@ -2800,6 +2860,16 @@ def from_pretrained( "All non-linear modules will be loaded in full precision." " If you want to load the other modules in other precision, please specify a `torch_dtype` attribute." ) + if quantization_method_from_config == QuantizationMethod.GPTQ: + model = quantizer.convert_model(model) + model._is_quantized_training_enabled = True + + if quantization_method_from_config is not None: + model.quantization_method = quantization_method_from_config + elif quantization_method_from_args is not None: + model.quantization_method = quantization_method_from_args + if hasattr(model, "quantization_method"): + model.is_quantized = True if isinstance(device_map, str): special_dtypes = {} @@ -2951,13 +3021,12 @@ def from_pretrained( offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, - is_quantized=(load_in_8bit or load_in_4bit), + is_quantized=(getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES), keep_in_fp32_modules=keep_in_fp32_modules, ) model.is_loaded_in_4bit = load_in_4bit model.is_loaded_in_8bit = load_in_8bit - model.is_quantized = load_in_8bit or load_in_4bit # make sure token embedding weights are still tied if needed model.tie_weights() @@ -2995,6 +3064,17 @@ def from_pretrained( kwargs["skip_keys"] = model._skip_keys_device_placement dispatch_model(model, **kwargs) + if quantization_method_from_args == QuantizationMethod.GPTQ: + if quantization_config.tokenizer is None: + quantization_config.tokenizer = pretrained_model_name_or_path + if cls.main_input_name != "input_ids": + raise RuntimeError("We can only quantize pure text model.") + quantizer.quantize_model(model, quantization_config.tokenizer) + model.config.quantization_config = GPTQConfig.from_dict(quantizer.to_dict()) + model._is_quantized_training_enabled = True + if quantization_method_from_config == QuantizationMethod.GPTQ: + model = quantizer.post_init_model(model) + if output_loading_info: if loading_info is None: loading_info = { diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 36d3435eb7b5..a5e4fcbb90ee 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -475,6 +475,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # meaningless in the context of the config object - torch.dtype values are acceptable if kwargs.get("torch_dtype", None) == "auto": _ = kwargs.pop("torch_dtype") + # to not overwrite the quantization_config if config has a quantization_config + if kwargs.get("quantization_config", None) is not None: + _ = kwargs.pop("quantization_config") config, kwargs = AutoConfig.from_pretrained( pretrained_model_name_or_path, @@ -487,6 +490,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # if torch_dtype=auto was passed here, ensure to pass it on if kwargs_orig.get("torch_dtype", None) == "auto": kwargs["torch_dtype"] = "auto" + if kwargs_orig.get("quantization_config", None) is not None: + kwargs["quantization_config"] = kwargs_orig["quantization_config"] has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map has_local_code = type(config) in cls._model_mapping.keys() diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index b93a40daa2ac..94811a7bc8a2 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -51,6 +51,7 @@ from .utils import ( is_accelerate_available, is_apex_available, + is_auto_gptq_available, is_bitsandbytes_available, is_bs4_available, is_cython_available, @@ -776,6 +777,13 @@ def require_optimum(test_case): return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case) +def require_auto_gptq(test_case): + """ + Decorator for auto_gptq dependency + """ + return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) + + def require_phonemizer(test_case): """ Decorator marking a test that requires phonemizer diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cb4c04321027..31d3b95b2454 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -146,6 +146,7 @@ logging, strtobool, ) +from .utils.quantization_config import QuantizationMethod DEFAULT_CALLBACKS = [DefaultFlowCallback] @@ -396,10 +397,9 @@ def __init__( if getattr(model, "is_quantized", False): if getattr(model, "_is_quantized_training_enabled", False): logger.info( - "The model is loaded in 8-bit precision. To train this model you need to add additional modules" + "The model is quantized. To train this model you need to add additional modules" " inside the model such as adapters using `peft` library and freeze the model weights. Please" - " check " - " the examples in https://github.com/huggingface/peft for more details." + " check the examples in https://github.com/huggingface/peft for more details." ) else: raise ValueError( @@ -498,8 +498,11 @@ def __init__( self.eval_dataset = eval_dataset self.tokenizer = tokenizer - # Quantized models doesn't support `.to` operation. - if self.place_model_on_device and not getattr(model, "is_quantized", False): + # Bnb Quantized models doesn't support `.to` operation. + if ( + self.place_model_on_device + and not getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ): self._move_model_to_device(model, args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 728fb5c911d1..83b0128fbc58 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -104,6 +104,7 @@ get_torch_version, is_accelerate_available, is_apex_available, + is_auto_gptq_available, is_bitsandbytes_available, is_bs4_available, is_coloredlogs_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c0a8c80f0b09..54ed4030a2b6 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -98,6 +98,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _onnx_available = _is_package_available("onnx") _openai_available = _is_package_available("openai") _optimum_available = _is_package_available("optimum") +_auto_gptq_available = _is_package_available("auto_gptq") _pandas_available = _is_package_available("pandas") _peft_available = _is_package_available("peft") _phonemizer_available = _is_package_available("phonemizer") @@ -554,6 +555,10 @@ def is_optimum_available(): return _optimum_available +def is_auto_gptq_available(): + return _auto_gptq_available + + def is_optimum_neuron_available(): return _optimum_available and _is_package_available("optimum.neuron") diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index cc7d195e1752..f0c82602f02e 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -19,7 +19,8 @@ import json import os from dataclasses import dataclass -from typing import Any, Dict, Union +from enum import Enum +from typing import Any, Dict, List, Optional, Union from packaging import version @@ -33,8 +34,100 @@ logger = logging.get_logger(__name__) +class QuantizationMethod(str, Enum): + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + + @dataclass -class BitsAndBytesConfig: +class QuantizationConfigMixin: + """ + Mixin class for quantization config + """ + + quant_method: QuantizationMethod + + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """ + Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters. + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. + return_unused_kwargs (`bool`,*optional*, defaults to `False`): + Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in + `PreTrainedModel`. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. + """ + + config = cls(**config_dict) + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + if return_unused_kwargs: + return config, kwargs + else: + return config + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + + Args: + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + +@dataclass +class BitsAndBytesConfig(QuantizationConfigMixin): """ This is a wrapper class about all possible attributes and features that you can play with a model that has been loaded using `bitsandbytes`. @@ -97,6 +190,7 @@ def __init__( bnb_4bit_use_double_quant=False, **kwargs, ): + self.quant_method = QuantizationMethod.BITS_AND_BYTES self.load_in_8bit = load_in_8bit self.load_in_4bit = load_in_4bit self.llm_int8_threshold = llm_int8_threshold @@ -168,88 +262,16 @@ def quantization_method(self): else: return None - @classmethod - def from_dict(cls, config_dict, return_unused_kwargs, **kwargs): - """ - Instantiates a [`BitsAndBytesConfig`] from a Python dictionary of parameters. - - Args: - config_dict (`Dict[str, Any]`): - Dictionary that will be used to instantiate the configuration object. - return_unused_kwargs (`bool`): - Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in - `PreTrainedModel`. - kwargs (`Dict[str, Any]`): - Additional parameters from which to initialize the configuration object. - - Returns: - [`BitsAndBytesConfig`]: The configuration object instantiated from those parameters. - """ - - config = cls(**config_dict) - - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - - if return_unused_kwargs: - return config, kwargs - else: - return config - - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `BitsAndBytesConfig()` is serialized to JSON file. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - writer.write(json_string) - def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ - output = copy.deepcopy(self.__dict__) output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1] return output - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - def to_json_string(self, use_diff: bool = True) -> str: - """ - Serializes this instance to a JSON string. - - Args: - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` - is serialized to JSON string. - - Returns: - `str`: String containing all the attributes that make up this configuration instance in JSON format. - """ - if use_diff is True: - config_dict = self.to_diff_dict() - else: - config_dict = self.to_dict() - return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - def to_diff_dict(self) -> Dict[str, Any]: """ Removes all attributes from config which correspond to the default config attributes for better readability and @@ -271,3 +293,119 @@ def to_diff_dict(self) -> Dict[str, Any]: serializable_config_dict[key] = value return serializable_config_dict + + +@dataclass +class GPTQConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `optimum` api for gptq quantization relying on auto_gptq backend. + + Args: + bits (`int`): + The number of bits to quantize to, supported numbers are (2, 3, 4, 8). + tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*): + The tokenizer used to process the dataset. You can pass either: + - A custom tokenizer object. + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + dataset (`Union[List[str]]`, *optional*): + The dataset used for quantization. You can provide your own dataset in a list of string or just use the + original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new'] + group_size (`int`, *optional*, defaults to 128): + The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. + damp_percent (`float`, *optional*, defaults to 0.01): + The percent of the average Hessian diagonal to use for dampening. Recommended value is 0.01. + desc_act (`bool`, *optional*, defaults to `True`): + Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly + speed up inference but the perplexity may become slightly worse. Also known as act-order. + sym (`bool`, *optional*, defaults to `True`): + Whether to use symetric quantization. + true_sequential (`bool`, *optional*, defaults to `True`): + Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing + the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes + quantization using inputs that have passed through the previously quantized layers. + use_cuda_fp16 (`bool`, *optional*, defaults to `False`): + Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. + model_seqlen (`int`, *optional*): + The maximum sequence length that the model can take. + block_name_to_quantize (`str`, *optional*): + The transformers block name to quantize. + module_name_preceding_first_block (`List[str]`, *optional*): + The layers that are preceding the first Transformer block. + batch_size (`int`, *optional*, defaults to 1): + The batch size used when processing the dataset + pad_token_id (`int`, *optional*): + The pad token id. Needed to prepare the dataset when `batch_size` > 1. + disable_exllama (`bool`, *optional*, defaults to `False`): + Whether to use exllama backend. Only works with `bits` = 4. + """ + + def __init__( + self, + bits: int, + tokenizer: Any = None, + dataset: Optional[Union[List[str], str]] = None, + group_size: int = 128, + damp_percent: float = 0.01, + desc_act: bool = True, + sym: bool = True, + true_sequential: bool = True, + use_cuda_fp16: bool = False, + model_seqlen: Optional[int] = None, + block_name_to_quantize: Optional[str] = None, + module_name_preceding_first_block: Optional[List[str]] = None, + batch_size: int = 1, + pad_token_id: Optional[int] = None, + disable_exllama: bool = False, + **kwargs, + ): + self.quant_method = QuantizationMethod.GPTQ + self.bits = bits + self.tokenizer = tokenizer + self.dataset = dataset + self.group_size = group_size + self.damp_percent = damp_percent + self.desc_act = desc_act + self.sym = sym + self.true_sequential = true_sequential + self.use_cuda_fp16 = use_cuda_fp16 + self.model_seqlen = model_seqlen + self.block_name_to_quantize = block_name_to_quantize + self.module_name_preceding_first_block = module_name_preceding_first_block + self.batch_size = batch_size + self.pad_token_id = pad_token_id + self.disable_exllama = disable_exllama + self.post_init() + + def get_loading_attributes(self): + attibutes_dict = copy.deepcopy(self.__dict__) + loading_attibutes = ["disable_exllama", "use_cuda_fp16"] + loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} + return loading_attibutes_dict + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + if self.bits not in [2, 4, 6, 8]: + raise ValueError(f"Only support quantization to [2,4,6,8] bits but found {self.bits}") + if self.group_size != -1 and self.group_size <= 0: + raise ValueError("group_size must be greater than 0 or equal to -1") + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must between 0 and 1.") + if self.dataset is not None: + if isinstance(self.dataset, str): + if self.dataset not in ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]: + raise ValueError( + f"""You have entered a string value for dataset. You can only choose between + ['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}""" + ) + elif not isinstance(self.dataset, list): + raise ValueError( + f"""dataset needs to be either a list of string or a value in + ['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}""" + ) diff --git a/tests/bnb/README.md b/tests/quantization/bnb/README.md similarity index 100% rename from tests/bnb/README.md rename to tests/quantization/bnb/README.md diff --git a/tests/bnb/__init__.py b/tests/quantization/bnb/__init__.py similarity index 100% rename from tests/bnb/__init__.py rename to tests/quantization/bnb/__init__.py diff --git a/tests/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py similarity index 100% rename from tests/bnb/test_4bit.py rename to tests/quantization/bnb/test_4bit.py diff --git a/tests/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py similarity index 99% rename from tests/bnb/test_mixed_int8.py rename to tests/quantization/bnb/test_mixed_int8.py index 3e88a366d82b..55e75bb52bd0 100644 --- a/tests/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -753,6 +753,7 @@ def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self): model_8bit = AutoModelForCausalLM.from_pretrained( self.model_name, device_map=device_map, + load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True, offload_folder=tmpdirname, ) diff --git a/tests/quantization/gptq/__init__.py b/tests/quantization/gptq/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/gptq/test_gptq.py b/tests/quantization/gptq/test_gptq.py new file mode 100644 index 000000000000..257c6f020dd3 --- /dev/null +++ b/tests/quantization/gptq/test_gptq.py @@ -0,0 +1,272 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import pytest + +from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig +from transformers.testing_utils import ( + is_torch_available, + require_accelerate, + require_auto_gptq, + require_optimum, + require_torch_gpu, + require_torch_multi_gpu, + slow, +) + + +if is_torch_available(): + import torch + + +class GPTQConfigTest(unittest.TestCase): + def test_bits(self): + with self.assertRaises(ValueError): + GPTQConfig(bits="") + GPTQConfig(bits=1) + GPTQConfig(bits=2) + GPTQConfig(bits=4) + + def test_dataset(self): + with self.assertRaises(ValueError): + GPTQConfig(bits=2, dataset="auto_gpt") + GPTQConfig(bits=2, dataset="c4") + GPTQConfig(bits=2, dataset="ptb-new") + + def test_damp_percent(self): + with self.assertRaises(ValueError): + GPTQConfig(bits=2, damp_percent=10) + GPTQConfig(bits=2, damp_percent=-1) + GPTQConfig(bits=2, damp_percent="0") + GPTQConfig(bits=2, damp_percent=0.01) + + def test_to_dict(self): + quantization_config = GPTQConfig(bits=2) + quantization_config.to_dict() + + def test_from_dict(self): + dict = {"bits": 2} + quantization_config = GPTQConfig.from_dict(dict) + self.assertEqual(dict["bits"], quantization_config.bits) + + @require_optimum + def test_optimum_config(self): + from optimum.gptq import GPTQQuantizer + + config = GPTQConfig(bits=2) + optimum_config = GPTQQuantizer.from_dict(config.to_dict()) + self.assertEqual(optimum_config.bits, config.bits) + new_config = GPTQConfig.from_dict(optimum_config.to_dict()) + self.assertEqual(optimum_config.bits, new_config.bits) + + +@slow +@require_optimum +@require_auto_gptq +@require_torch_gpu +class GPTQTest(unittest.TestCase): + model_name = "bigscience/bloom-560m" + + input_text = "Hello my name is" + + EXPECTED_OUTPUTS = set() + EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I") + EXPECTED_OUTPUTS.add("Hello my name is John and I am a very good looking man.") + EXPECTED_OUTPUTS.add("Hello my name is Alyson and I am a professional photographer") + + # this seems a little small considering that we are doing 4bit quant but we have a small model and ww don't quantize the embeddings + EXPECTED_RELATIVE_DIFFERENCE = 1.664253062 + + bits = 4 + group_size = 128 + desc_act = False + disable_exllama = True + + dataset = [ + "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." + ] + + device_map = None + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + cls.model_fp16 = AutoModelForCausalLM.from_pretrained( + cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map + ) + cls.mem_fp16 = cls.model_fp16.get_memory_footprint() + + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True) + + quantization_config = GPTQConfig( + bits=cls.bits, + dataset=cls.dataset, + tokenizer=cls.tokenizer, + group_size=cls.group_size, + desc_act=cls.desc_act, + disable_exllama=cls.disable_exllama, + ) + + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, + torch_dtype=torch.float16, + device_map=cls.device_map, + quantization_config=quantization_config, + ) + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model + """ + + mem_quantized = self.quantized_model.get_memory_footprint() + + self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE) + + def test_quantized_layers_class(self): + """ + Simple test to check if the model conversion has been done correctly by checking on + the class type of the linear layers of the converted models + """ + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + + QuantLinear = dynamically_import_QuantLinear( + use_triton=False, + desc_act=self.desc_act, + group_size=self.group_size, + bits=self.bits, + disable_exllama=self.disable_exllama, + ) + self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear) + + def check_inference_correctness(self, model): + r""" + Test the generation quality of the quantized model and see that we are matching the expected output. + Given that we are operating on small numbers + the testing model is relatively small, we might not get + the same output across GPUs. So we'll generate few tokens (5-10) and check their output. + """ + # Check that inference pass works on the model + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + # Check the exactness of the results + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + # Get the generation + self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + def test_generate_quality(self): + """ + Simple test to check the quality of the model by comapring the the generated tokens with the expected tokens + """ + if self.device_map is None: + self.check_inference_correctness(self.quantized_model.to(0)) + else: + self.check_inference_correctness(self.quantized_model) + + def test_serialization(self): + """ + Test the serialization of the model and the loading of the quantized weights works + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + if self.disable_exllama: + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname).to(0) + else: + # we need to put it directly to the gpu. Otherwise, we won't be able to initialize the exllama kernel + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": 0}) + self.check_inference_correctness(quantized_model_from_saved) + + @require_accelerate + def test_serialization_big_model_inference(self): + """ + Test the serialization of the model and the loading of the quantized weights with big model inference + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto") + self.check_inference_correctness(quantized_model_from_saved) + + def test_change_loading_attributes(self): + """ + Test the serialization of the model and the loading of the quantized weights works with another config file + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + if self.disable_exllama: + self.assertEqual(self.quantized_model.config.quantization_config.disable_exllama, True) + # we need to put it directly to the gpu. Otherwise, we won't be able to initialize the exllama kernel + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( + tmpdirname, quantization_config=GPTQConfig(disable_exllama=False, bits=6), device_map={"": 0} + ) + self.assertEqual(quantized_model_from_saved.config.quantization_config.disable_exllama, False) + self.assertEqual(quantized_model_from_saved.config.quantization_config.bits, self.bits) + self.check_inference_correctness(quantized_model_from_saved) + + +@require_accelerate +@require_torch_multi_gpu +class GPTQTestDeviceMap(GPTQTest): + device_map = "auto" + + +@require_accelerate +@require_torch_multi_gpu +class GPTQTestDeviceMapExllama(GPTQTest): + device_map = "auto" + disable_exllama = False + + +# fail when run all together +@pytest.mark.skip +@require_accelerate +@require_torch_multi_gpu +class GPTQTestDeviceMapCPUOffload(GPTQTest): + device_map = { + "transformer.word_embeddings": 0, + "transformer.word_embeddings_layernorm": 0, + "lm_head": 0, + "transformer.h.0": 0, + "transformer.h.1": 0, + "transformer.h.2": 0, + "transformer.h.3": 0, + "transformer.h.4": 0, + "transformer.h.5": 0, + "transformer.h.6": 0, + "transformer.h.7": 0, + "transformer.h.8": 0, + "transformer.h.9": 0, + "transformer.h.10": 1, + "transformer.h.11": 1, + "transformer.h.12": 1, + "transformer.h.13": 1, + "transformer.h.14": 1, + "transformer.h.15": 1, + "transformer.h.16": 1, + "transformer.h.17": 0, + "transformer.h.18": "cpu", + "transformer.h.19": "cpu", + "transformer.h.20": "cpu", + "transformer.h.21": "cpu", + "transformer.h.22": "cpu", + "transformer.h.23": 1, + "transformer.ln_f": 0, + }