Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ integration #25062

Merged
merged 50 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d963b97
GTPQ integration
SunMarc Jul 24, 2023
93f0d84
Add tests for gptq
SunMarc Jul 24, 2023
380baea
support for more quantization model
SunMarc Jul 25, 2023
810d537
fix style
SunMarc Jul 25, 2023
c3f5248
typo
SunMarc Jul 25, 2023
fc70ef4
fix method
SunMarc Jul 25, 2023
6a04bb8
Update src/transformers/modeling_utils.py
SunMarc Jul 25, 2023
271dab6
add dataclass and fix quantization_method
SunMarc Jul 25, 2023
992881e
fix doc
SunMarc Jul 26, 2023
3c2d940
Update tests/quantization/gptq/test_gptq.py
SunMarc Jul 26, 2023
9bbb336
Apply suggestions from code review
SunMarc Jul 26, 2023
0134c79
modify dataclass
SunMarc Jul 26, 2023
a2a7f5d
add gtpqconfig import
SunMarc Jul 26, 2023
70e1416
fix typo
SunMarc Jul 26, 2023
0e2014b
fix tests
SunMarc Jul 26, 2023
69e3c88
remove dataset as req arg
SunMarc Jul 26, 2023
cb46d75
remove tokenizer import
SunMarc Jul 26, 2023
9a3cafd
add offload cpu quantization test
SunMarc Jul 26, 2023
27e9b79
fix check dataset
SunMarc Jul 26, 2023
f47ecb4
modify dockerfile
SunMarc Jul 26, 2023
19d05d3
protect trainer
SunMarc Jul 26, 2023
76dffe2
style
SunMarc Jul 26, 2023
0f61037
test for config
SunMarc Jul 26, 2023
b0eccd5
add more log
SunMarc Jul 27, 2023
2e7a025
overwrite torch_dtype
SunMarc Jul 27, 2023
a07126a
draft doc
SunMarc Jul 27, 2023
c9d3f26
modify quantization_config docstring
SunMarc Jul 31, 2023
ecce1da
fix class name in docstring
SunMarc Jul 31, 2023
2226184
Apply suggestions from code review
SunMarc Jul 31, 2023
eff99cb
more warning
SunMarc Jul 31, 2023
159cf87
fix 8bit kwargs tests
SunMarc Jul 31, 2023
98db723
peft compatibility
SunMarc Jul 31, 2023
0144760
remove var
SunMarc Aug 1, 2023
fd8d70c
fix is_gptq_quantized
SunMarc Aug 1, 2023
0f96fb2
Merge branch 'main' into gptq_integration
SunMarc Aug 1, 2023
be19916
remove is_gptq_quantized
SunMarc Aug 2, 2023
9e8f487
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 2, 2023
4b4336e
fix wrap
SunMarc Aug 2, 2023
42d0049
Update src/transformers/modeling_utils.py
SunMarc Aug 8, 2023
a9658e2
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 8, 2023
62aa293
add exllama
SunMarc Aug 9, 2023
39137eb
skip test
SunMarc Aug 9, 2023
f23ce7e
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 9, 2023
0b0633b
overwrite float16
SunMarc Aug 9, 2023
c3c4a16
style
SunMarc Aug 9, 2023
a45b5b0
fix skip test
SunMarc Aug 9, 2023
69c8fce
Apply suggestions from code review
SunMarc Aug 10, 2023
bf98799
fix docsting formatting
SunMarc Aug 10, 2023
7adf9cb
add doc
SunMarc Aug 10, 2023
c93d1d0
better test
SunMarc Aug 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/acc
# Add bitsandbytes for mixed int8 testing
RUN python3 -m pip install --no-cache-dir bitsandbytes

# For bettertransformer
RUN python3 -m pip install --no-cache-dir optimum
# Add auto-gptq for gtpq quantization testing
RUN python3 -m pip install --no-cache-dir auto-gptq

# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum

# For video model testing
RUN python3 -m pip install --no-cache-dir decord av==9.2.0
Expand Down
127 changes: 123 additions & 4 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,128 @@ rendered properly in your Markdown viewer.

# Quantize 🤗 Transformers models

## `AutoGPTQ` Integration

🤗 Transformers has integrated `optimum` API to perform GPTQ quantization on language models. You can load and quantize your model in 8,6,4 or even 2 bits without a big drop of performance and faster inference speed! This is supported by most GPU hardwares.

To learn more about the the quantization model, check out:
- the [GPTQ](https://arxiv.org/pdf/2210.17323.pdf) paper
<!-- - the `optimum` [guide]() on GPTQ quantization -->
- the [`AutoGPTQ`](https://github.com/PanQiWei/AutoGPTQ) library used as the backend

### Requirements

You need to have the following requirements installed to run the code below:

- Install latest `AutoGPTQ` library
`pip install auto-gptq`

- Install latest `optimum` from source
`pip install git+https://github.com/huggingface/optimum.git`

- Install latest `transformers` from source
`pip install git+https://github.com/huggingface/transformers.git`

- Install latest `accelerate` library
`pip install --upgrade accelerate`
GPTQ integration supports for now only text models and you may encounter unexpected behaviour for vision, speech or multi-modal models.

SunMarc marked this conversation as resolved.
Show resolved Hide resolved
### Load and quantize a model

SunMarc marked this conversation as resolved.
Show resolved Hide resolved
GPTQ is a quantization method that requires weights calibration before using the quantized models. If you want to quantize transformers model from scratch, it might take some time before producing the quantized model (~10 min on a Google colab for `facebook/opt-350m` model.

Hence, there are two different scenarios where you want to use GPTQ-quantized models. The first use case would be to load models that has been already quantized by other users that are available on the Hub, the second use case would be to quantize your model from scratch and save it or push it on the Hub so that other users can also use it.
#### GPTQ Configuration

In order to load and quantize a model, you need to create a [`GPTQConfig`]. You need to pass the number of `bits`, a `dataset` in order to calibrate the quantization and the `tokenizer` of the model in order prepare the dataset.

```python
model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
gptq_config = GPTQConfig(bits=4, dataset = "c4", tokenizer=tokenizer)
```

Note that you can pass your own dataset as a list of string. However, it is highly recommended to use the dataset from the GPTQ paper.
```python
dataset = ["auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
quantization = GPTQConfig(bits=4, dataset = dataset, tokenizer=tokenizer)
```

#### Quantization

You can quantize a model by using `from_pretrained` and setting the `quantization_config`.

```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=gptq_config)
```
Note that you will need a GPU to quantize a model. We will put the model in the cpu and move the modules back and forth to the gpu in order to quantize them.

If you want to maximize your gpus usage while using cpu offload, you can set `device_map = "auto"`.
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config)
```
Note that disk offload is not supported. Furthermore, if you are out of memory because of the dataset, you may have to pass `max_memory` in `from_pretained`. Checkout this [guide](https://huggingface.co/docs/accelerate/usage_guides/big_modeling#designing-a-device-map) to learn more about `device_map` and `max_memory`.

<Tip warning={true}>
GPTQ quantization only works for text model for now. Futhermore, the quantization process can a lot of time depending on one's hardware (175B model = 4 gpu hours using NVIDIA A100). Please check on the hub if there is not a GPTQ quantized version of the model. If not, you can submit a demand on github.
</Tip>

### Push quantized model to 🤗 Hub

You can push the quantized model like any 🤗 model to Hub with `push_to_hub`:

```python
quantized_model.push_to_hub("opt-125m-gptq")
tokenizer.push_to_hub("opt-125m-gptq")
```

If you want to save your quantized model on your local machine, you can also do it with `save_pretrained`:
```python
quantized_model.save_pretrained("opt-125m-gptq")
tokenizer.save_pretrained("opt-125m-gptq")
```

Note that if you have quantized your model with a `device_map`, make sure to move the entire model to one of your gpus or the `cpu` before saving it.
```python
quantized_model.to("cpu")
quantized_model.save_pretrained("opt-125m-gptq")
```

### Load a quantized model from the 🤗 Hub

You can load a quantized model from the Hub by using `from_pretrained`.
Make sure that the pushed weights are quantized, by checking that the attribute `quantization_config` is present in the model configuration object.

```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq")
```
Note that in this case, you don't need to specify the `quantization_config`. It will look for the `quantization_config` and prepare the model
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
before loading the quantized weights. However, you need to make sure that `optimum` and `auto-gptq` are installed.

If you want to load a model faster and without allocating more memory than needed, the `device_map` argument also works with quantized model. Make sure that you have `accelerate` library installed.
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto")
```

### Exllama kernels for faster inference

For 4-bit model, you can use the exllama kernels in order to a faster inference speed. You just need to pass `disable_exllama=False` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernel. Furthermore, you need to have the entire model on gpus.

```py
import torch
gptq_config = GPTQConfig(bits=4, disable_exllama=False)
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config)
```

Note that only 4-bit models are supported for now

### GPTQConfig
[[autodoc]] GPTQConfig

## `bitsandbytes` Integration

🤗 Transformers is closely integrated with most used modules on `bitsandbytes`. You can load your model in 8-bit precision with few lines of code.
Expand Down Expand Up @@ -215,7 +337,7 @@ This section is intended to advanced users, that want to explore what it is poss

One of the advanced use case of this is being able to load a model and dispatch the weights between `CPU` and `GPU`. Note that the weights that will be dispatched on CPU **will not** be converted in 8-bit, thus kept in `float32`. This feature is intended for users that want to fit a very large model and dispatch the model between GPU and CPU.

First, load a `BitsAndBytesConfig` from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`:
First, load a [`BitsAndBytesConfig`] from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
Expand Down Expand Up @@ -297,10 +419,7 @@ This enables fine-tuning large models such as `flan-t5-large` or `facebook/opt-6
Note that you don't need to pass `device_map` when loading the model for training. It will automatically load your model on your GPU. You can also set the device map to a specific device if needed (e.g. `cuda:0`, `0`, `torch.device('cuda:0')`). Please note that `device_map=auto` should be used for inference only.

### BitsAndBytesConfig

SunMarc marked this conversation as resolved.
Show resolved Hide resolved
[[autodoc]] BitsAndBytesConfig


SunMarc marked this conversation as resolved.
Show resolved Hide resolved
## Quantization with 🤗 `optimum`

Please have a look at [Optimum documentation](https://huggingface.co/docs/optimum/index) to learn more about quantization methods that are supported by `optimum` and see if these are applicable for your use case.
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@
"logging",
],
"utils.bitsandbytes": [],
"utils.quantization_config": ["BitsAndBytesConfig"],
"utils.quantization_config": ["BitsAndBytesConfig", "GPTQConfig"],
}

# sentencepiece-backed objects
Expand Down Expand Up @@ -4703,7 +4703,7 @@
)

# bitsandbytes config
from .utils.quantization_config import BitsAndBytesConfig
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig

try:
if not is_sentencepiece_available():
Expand Down
Loading