Skip to content

Commit

Permalink
GPTQ integration (#25062)
Browse files Browse the repository at this point in the history
* GTPQ integration

* Add tests for gptq

* support for more quantization model

* fix style

* typo

* fix method

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <[email protected]>

* add dataclass and fix quantization_method

* fix doc

* Update tests/quantization/gptq/test_gptq.py

Co-authored-by: Younes Belkada <[email protected]>

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* modify dataclass

* add gtpqconfig import

* fix typo

* fix tests

* remove dataset as req arg

* remove tokenizer import

* add offload cpu quantization test

* fix check dataset

* modify dockerfile

* protect trainer

* style

* test for config

* add more log

* overwrite torch_dtype

* draft doc

* modify quantization_config docstring

* fix class name in docstring

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* more warning

* fix 8bit kwargs tests

* peft compatibility

* remove var

* fix is_gptq_quantized

* remove is_gptq_quantized

* fix wrap

* Update src/transformers/modeling_utils.py

Co-authored-by: Younes Belkada <[email protected]>

* add exllama

* skip test

* overwrite float16

* style

* fix skip test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>

* fix docsting formatting

* add doc

* better test

---------

Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
3 people authored Aug 10, 2023
1 parent 3470012 commit 55db70c
Show file tree
Hide file tree
Showing 16 changed files with 750 additions and 103 deletions.
7 changes: 5 additions & 2 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/acc
# Add bitsandbytes for mixed int8 testing
RUN python3 -m pip install --no-cache-dir bitsandbytes

# For bettertransformer
RUN python3 -m pip install --no-cache-dir optimum
# Add auto-gptq for gtpq quantization testing
RUN python3 -m pip install --no-cache-dir auto-gptq

# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum

# For video model testing
RUN python3 -m pip install --no-cache-dir decord av==9.2.0
Expand Down
133 changes: 132 additions & 1 deletion docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,137 @@ rendered properly in your Markdown viewer.

# Quantize 🤗 Transformers models

## `AutoGPTQ` Integration

🤗 Transformers has integrated `optimum` API to perform GPTQ quantization on language models. You can load and quantize your model in 8,6,4 or even 2 bits without a big drop of performance and faster inference speed! This is supported by most GPU hardwares.

To learn more about the the quantization model, check out:
- the [GPTQ](https://arxiv.org/pdf/2210.17323.pdf) paper
<!-- - the `optimum` [guide]() on GPTQ quantization -->
- the [`AutoGPTQ`](https://github.com/PanQiWei/AutoGPTQ) library used as the backend

### Requirements

You need to have the following requirements installed to run the code below:

- Install latest `AutoGPTQ` library
`pip install auto-gptq`

- Install latest `optimum` from source
`pip install git+https://github.com/huggingface/optimum.git`

- Install latest `transformers` from source
`pip install git+https://github.com/huggingface/transformers.git`

- Install latest `accelerate` library
`pip install --upgrade accelerate`
GPTQ integration supports for now only text models and you may encounter unexpected behaviour for vision, speech or multi-modal models.

### Load and quantize a model

GPTQ is a quantization method that requires weights calibration before using the quantized models. If you want to quantize transformers model from scratch, it might take some time before producing the quantized model (~10 min on a Google colab for `facebook/opt-350m` model.

Hence, there are two different scenarios where you want to use GPTQ-quantized models. The first use case would be to load models that has been already quantized by other users that are available on the Hub, the second use case would be to quantize your model from scratch and save it or push it on the Hub so that other users can also use it.
#### GPTQ Configuration

In order to load and quantize a model, you need to create a [`GPTQConfig`]. You need to pass the number of `bits`, a `dataset` in order to calibrate the quantization and the `tokenizer` of the model in order prepare the dataset.

```python
model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
gptq_config = GPTQConfig(bits=4, dataset = "c4", tokenizer=tokenizer)
```

Note that you can pass your own dataset as a list of string. However, it is highly recommended to use the dataset from the GPTQ paper.
```python
dataset = ["auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
quantization = GPTQConfig(bits=4, dataset = dataset, tokenizer=tokenizer)
```

#### Quantization

You can quantize a model by using `from_pretrained` and setting the `quantization_config`.

```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=gptq_config)
```
Note that you will need a GPU to quantize a model. We will put the model in the cpu and move the modules back and forth to the gpu in order to quantize them.

If you want to maximize your gpus usage while using cpu offload, you can set `device_map = "auto"`.
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config)
```
Note that disk offload is not supported. Furthermore, if you are out of memory because of the dataset, you may have to pass `max_memory` in `from_pretained`. Checkout this [guide](https://huggingface.co/docs/accelerate/usage_guides/big_modeling#designing-a-device-map) to learn more about `device_map` and `max_memory`.

<Tip warning={true}>
GPTQ quantization only works for text model for now. Futhermore, the quantization process can a lot of time depending on one's hardware (175B model = 4 gpu hours using NVIDIA A100). Please check on the hub if there is not a GPTQ quantized version of the model. If not, you can submit a demand on github.
</Tip>

### Push quantized model to 🤗 Hub

You can push the quantized model like any 🤗 model to Hub with `push_to_hub`. The quantization config will be saved and pushed along the model.

```python
quantized_model.push_to_hub("opt-125m-gptq")
tokenizer.push_to_hub("opt-125m-gptq")
```

If you want to save your quantized model on your local machine, you can also do it with `save_pretrained`:
```python
quantized_model.save_pretrained("opt-125m-gptq")
tokenizer.save_pretrained("opt-125m-gptq")
```

Note that if you have quantized your model with a `device_map`, make sure to move the entire model to one of your gpus or the `cpu` before saving it.
```python
quantized_model.to("cpu")
quantized_model.save_pretrained("opt-125m-gptq")
```

### Load a quantized model from the 🤗 Hub

You can load a quantized model from the Hub by using `from_pretrained`.
Make sure that the pushed weights are quantized, by checking that the attribute `quantization_config` is present in the model configuration object.

```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq")
```

If you want to load a model faster and without allocating more memory than needed, the `device_map` argument also works with quantized model. Make sure that you have `accelerate` library installed.
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto")
```

### Exllama kernels for faster inference

For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `disable_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels.

```py
import torch
gptq_config = GPTQConfig(bits=4, disable_exllama=False)
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config)
```

Note that only 4-bit models are supported for now. Furthermore, it is recommended to deactivate the exllama kernels if you are finetuning a quantized model with peft.

#### Fine-tune a quantized model

With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ.
Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details.

### Example demo

Check out the Google Colab [notebook](https://colab.research.google.com/drive/1_TIrmuKOFhuRRiTWN94iLKUFu6ZX4ceb?usp=sharing) to learn how to quantize your model with GPTQ and how finetune the quantized model with peft.

### GPTQConfig

[[autodoc]] GPTQConfig


## `bitsandbytes` Integration

🤗 Transformers is closely integrated with most used modules on `bitsandbytes`. You can load your model in 8-bit precision with few lines of code.
Expand Down Expand Up @@ -215,7 +346,7 @@ This section is intended to advanced users, that want to explore what it is poss

One of the advanced use case of this is being able to load a model and dispatch the weights between `CPU` and `GPU`. Note that the weights that will be dispatched on CPU **will not** be converted in 8-bit, thus kept in `float32`. This feature is intended for users that want to fit a very large model and dispatch the model between GPU and CPU.

First, load a `BitsAndBytesConfig` from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`:
First, load a [`BitsAndBytesConfig`] from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@
"logging",
],
"utils.bitsandbytes": [],
"utils.quantization_config": ["BitsAndBytesConfig"],
"utils.quantization_config": ["BitsAndBytesConfig", "GPTQConfig"],
}

# sentencepiece-backed objects
Expand Down Expand Up @@ -4703,7 +4703,7 @@
)

# bitsandbytes config
from .utils.quantization_config import BitsAndBytesConfig
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig

try:
if not is_sentencepiece_available():
Expand Down
Loading

0 comments on commit 55db70c

Please sign in to comment.