Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AQLM quantizer support #28928

Merged
merged 24 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
# Add einops for additional model testing
RUN python3 -m pip install --no-cache-dir einops

# Add aqlm for quantization testing
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.0

# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl

Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

</Tip>

## AqlmConfig

[[autodoc]] AqlmConfig

## AwqConfig

[[autodoc]] AwqConfig
Expand Down
27 changes: 27 additions & 0 deletions docs/source/en/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,33 @@ Interested in adding a new quantization method to Transformers? Read the [HfQuan

</Tip>

## AQLM



Try AQLM on [Google Colab](https://colab.research.google.com/drive/1-xZmBRXT5Fm3Ghn4Mwa2KRypORXb855X?usp=sharing)!

Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and take advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes.

Inference support for AQLM is realised in the `aqlm` library. Make sure to install it to run the models:
BlackSamorez marked this conversation as resolved.
Show resolved Hide resolved
```bash
pip install aqlm[gpu,cpu]
```

The library provides efficient kernels for both GPU and CPU inference.

The instructions on how to quantize models yourself, as well as all the relevant code can be found in the corresponding GitHub [repository](https://github.com/Vahe1994/AQLM).

### AQLM configurations

AQLM quantization setpus vary mainly on the number of codebooks used as well as codebook sizes in bits. The most popular setups, as well as inference kernels they support are:

| Number of codebooks | Codebook size, bits | Notation | Performance | Speedup | Fast GPU inference | Fast CPU inference |
|---------------------|---------------------|----------|-------------|-------------|--------------------|--------------------|
| 1 | 16 | 1x16 | Best | Up to ~1.3x | ✅ | ❌ |
| 2 | 8 | 2x8 | OK | Up to ~3.0x | ✅ | ❌ |
| K | 8 | Kx8 | Good | Up to ~4.0x | ❌ | ✅ |

## AWQ

<Tip>
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,7 @@
"is_vision_available",
"logging",
],
"utils.quantization_config": ["AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
}

# sentencepiece-backed objects
Expand Down Expand Up @@ -5832,7 +5832,7 @@
)

# bitsandbytes config
from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig

try:
if not is_sentencepiece_available():
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

from ..models.auto.configuration_auto import AutoConfig
from ..utils.quantization_config import (
AqlmConfig,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
QuantizationConfigMixin,
QuantizationMethod,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
Expand All @@ -33,13 +35,15 @@
"bitsandbytes_4bit": Bnb4BitHfQuantizer,
"bitsandbytes_8bit": Bnb8BitHfQuantizer,
"gptq": GptqHfQuantizer,
"aqlm": AqlmHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
"awq": AwqConfig,
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
}


Expand Down
144 changes: 144 additions & 0 deletions src/transformers/quantizers/quantizer_aqlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional

from .base import HfQuantizer


if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin


if is_torch_available():
import torch
from torch import nn

logger = logging.get_logger(__name__)


class AqlmHfQuantizer(HfQuantizer):
"""
Quantizer of the AQLM method. Enables the loading of prequantized models.
"""

requires_calibration = True
required_packages = ["aqlm"]
optimum_quantizer = None

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config

def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError("Using `aqlm` quantization requires Accelerate: `pip install accelerate`")

if not is_aqlm_available():
raise ImportError("Using `aqlm` quantization requires AQLM: `pip install aqlm[gpu,cpu]`")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
if torch.cuda.is_available():
torch_dtype = torch.float16
logger.info(
"CUDA available. Assuming AQLM inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually."
)
else:
torch_dtype = torch.float32
logger.info(
"CUDA is unavailable. Assuming AQLM inference on CPU and loading the model in `torch.float32`. To overwrite it, set `torch_dtype` manually."
)
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
return torch_dtype

def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
**kwargs,
):
_replace_with_aqlm_linear(
model,
linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize,
quantization_config=self.quantization_config,
)
model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = False
return model

@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False

@property
def is_serializable(self):
return True


def _replace_with_aqlm_linear(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can move this method and make it public under integrations/aqlm.py and import locally the method inside _process_model_before_weight_loading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

model,
linear_weights_not_to_quantize=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.

Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
from accelerate import init_empty_weights
from aqlm import QuantizedLinear

for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, nn.Linear):
# Check if the current key is not in the `linear_weights_not_to_quantize`
if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw in the config of the model you pushed on the Hub that you also included layer norm weights inside linear_weights_not_to_quantize , I think these can be excluded from the config as they are not an insitance of nn.Linear right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They certainly can be excluded. It's just that converting from a freshly quantized AQLM format it would be troublesome to check if an unquantized .weight parameter is of nn.Linear or not. So I simply included all of them just in case. That Mixtral config can, indeed, be made somewhat shorter.

with init_empty_weights():
in_features = module.in_features
out_features = module.out_features

model._modules[name] = QuantizedLinear(
in_features,
out_features,
bias=module.bias is not None,
in_group_size=quantization_config.in_group_size,
out_group_size=quantization_config.out_group_size,
num_codebooks=quantization_config.num_codebooks,
nbits_per_codebook=quantization_config.nbits_per_codebook,
)
has_been_replaced = True

# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_aqlm_linear(
module,
linear_weights_not_to_quantize,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .utils import (
is_accelerate_available,
is_apex_available,
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_bitsandbytes_available,
Expand Down Expand Up @@ -955,6 +956,13 @@ def require_apex(test_case):
return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)


def require_aqlm(test_case):
"""
Decorator marking a test that requires aqlm
"""
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)


def require_bitsandbytes(test_case):
"""
Decorator for bits and bytes (bnb) dependency
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
get_torch_version,
is_accelerate_available,
is_apex_available,
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_bitsandbytes_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[

_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_bitsandbytes_available = _is_package_available("bitsandbytes")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
Expand Down Expand Up @@ -570,6 +571,10 @@ def is_apex_available():
return _apex_available


def is_aqlm_available():
return _aqlm_available


def is_ninja_available():
r"""
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
Expand Down
61 changes: 61 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"


class AWQLinearVersion(str, Enum):
Expand Down Expand Up @@ -731,3 +732,63 @@ def get_loading_attributes(self):
loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict


@dataclass
class AqlmConfig(QuantizationConfigMixin):
"""
This is a wrapper class about `aqlm` parameters.

Args:
in_group_size (`int`):
The group size along the input dimension.
out_group_size (`int`):
The group size along the output dimension. It's recommended to always use 1.
num_codebooks (`int`):
Number of codebooks for the Additive Quantization procedure.
nbits_per_codebook (`int`):
Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook.
linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*, default to `None`):
BlackSamorez marked this conversation as resolved.
Show resolved Hide resolved
List of full paths of `nn.Linear` weight parameters that shall not be quantized.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
"""

def __init__(
self,
in_group_size: int = 8,
out_group_size: int = 1,
num_codebooks: int = 1,
nbits_per_codebook: int = 16,
linear_weights_not_to_quantize: Optional[List[str]] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AQLM
self.in_group_size = in_group_size
self.out_group_size = out_group_size
self.num_codebooks = num_codebooks
self.nbits_per_codebook = nbits_per_codebook
self.linear_weights_not_to_quantize = linear_weights_not_to_quantize

self.post_init()

def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not isinstance(self.in_group_size, int):
raise ValueError("in_group_size must be a float")
if not isinstance(self.out_group_size, int):
raise ValueError("out_group_size must be a float")
if not isinstance(self.num_codebooks, int):
raise ValueError("num_codebooks must be a float")
if not isinstance(self.nbits_per_codebook, int):
raise ValueError("nbits_per_codebook must be a float")

if self.linear_weights_not_to_quantize is not None and not isinstance(
self.linear_weights_not_to_quantize, list
):
raise ValueError("linear_weights_not_to_quantize must be a list of strings")

if self.linear_weights_not_to_quantize is None:
self.linear_weights_not_to_quantize = []
Empty file.
Loading