Skip to content

Commit

Permalink
Add TorchAOHfQuantizer
Browse files Browse the repository at this point in the history
Summary:
Enable loading torchao quantized model in huggingface.

Test Plan:
local test

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Aug 2, 2024
1 parent 811a9ca commit 55d94b7
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
title: FBGEMM_FP8
- local: quantization/optimum
title: Optimum
- local: quantization/torchao
title: TorchAO
- local: quantization/contribute
title: Contribute new quantization method
title: Quantization Methods
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

[[autodoc]] FbgemmFp8Config

## TorchAOConfig

[[autodoc]] TorchAoConfig

2 changes: 1 addition & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ Use the table below to help you decide which quantization method to use.
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |

| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
40 changes: 40 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# TorchAO

[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training.

Before you begin, make sure the following libraries are installed with their latest version:

```bash
pip install --upgrade torch torchao
```


```py
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@
"GPTQConfig",
"HqqConfig",
"QuantoConfig",
"TorchAoConfig",
],
}

Expand Down Expand Up @@ -5673,6 +5674,7 @@
GPTQConfig,
HqqConfig,
QuantoConfig,
TorchAoConfig,
)

try:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
TorchAoConfig,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
Expand All @@ -36,6 +37,7 @@
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer


AUTO_QUANTIZER_MAPPING = {
Expand All @@ -48,6 +50,7 @@
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -60,6 +63,7 @@
"quanto": QuantoConfig,
"hqq": HqqConfig,
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
}


Expand Down
134 changes: 134 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from .base import HfQuantizer
from .quantizers_utils import get_module_from_name


if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from typing import Any, Dict, List

from ..utils import is_torch_available, is_torchao_available, logging


if is_torch_available():
import torch

if is_torchao_available():
from torchao.quantization import (
quantize_,
)

logger = logging.get_logger(__name__)


# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
parent = model
for m in module_tree:
parent = parent._modules[m]
return parent


class TorchAoHfQuantizer(HfQuantizer):
"""
Quantizer for torchao: https://github.com/pytorch/ao/
"""

requires_parameters_quantization = True
requires_calibration = False
required_packages = ["torchao"]

def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.torch_dtype = None

def validate_environment(self, device_map, **kwargs):
if not is_torchao_available():
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")

if self.torch_dtype is None:
if "torch_dtype" in kwargs:
self.torch_dtype = kwargs["torch_dtype"]
else:
self.torch_dtype = torch.float32
logger.info("Setting torch_dtype to torch.float32 as the default value since it was not specified.")

def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
module, tensor_name = get_module_from_name(model, param_name)

return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")

def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
):
"""
Each nn.Linear layer is processsed here.
We first check if the corresponding module state_dict contains already torchao quantized parameters.
If not, we create a temp linear layer with the module state_dict params and use it for quantization
"""
module, tensor_name = get_module_from_name(model, param_name)

layer_name = param_name.replace(".weight", "").replace(".bias", "")
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]

# Step 0: set module state_dict
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}

# Step 1: populate module with weight/bias from module state dict
for key in module_state_dict:
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))

# Step 2: Update the module using the `quantize_` API from TorchAO

module = module.to(dtype=self.torch_dtype, device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
setattr(parent_module, node, module)

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
"""No process required for torchao quantized model"""
return

def _process_model_after_weight_loading(self, model):
"""No process required for torchao quantized model"""
return

@property
def is_serializable(self):
return False

@property
def is_trainable(self):
# torchao does not have official support for QAT (Quantization Aware Training)
# but torchao support nf4/PEFT, but it is not integrated yet
# TODO: if this is supported in the future, do a version check here.
return False
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio")
_torchao_available = _is_package_available("torchao")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
Expand Down Expand Up @@ -1045,6 +1046,10 @@ def is_torchaudio_available():
return _torchaudio_available


def is_torchao_available():
return _torchao_available


def is_speech_available():
# For now this depends on torchaudio but the exact dependency might evolve in the future.
return _torchaudio_available
Expand Down
67 changes: 65 additions & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@

from packaging import version

from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging


if is_torch_available():
import torch


logger = logging.get_logger(__name__)


Expand All @@ -43,6 +42,7 @@ class QuantizationMethod(str, Enum):
EETQ = "eetq"
HQQ = "hqq"
FBGEMM_FP8 = "fbgemm_fp8"
TORCHAO = "torchao"


class AWQLinearVersion(str, Enum):
Expand Down Expand Up @@ -1079,3 +1079,66 @@ def get_loading_attributes(self):
loading_attibutes = ["activation_scale_ub"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict


@dataclass
class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques
Args:
quant_type (str): the type of quantization we want to use, currently supporting: int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
kwargs: the keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments:
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
Example:
```python
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
```
"""

def __init__(
self,
quant_type: str,
**kwargs,
):
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.kwargs = kwargs
self._STR_TO_METHOD = {}
if is_torchao_available():
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)

self._STR_TO_METHOD = {
"int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)

self.post_init()

def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if self.quant_type not in self._STR_TO_METHOD.keys():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
)

def get_apply_tensor_subclass(self):
return self._STR_TO_METHOD[self.quant_type](**self.kwargs)

def __repr__(self):
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"

0 comments on commit 55d94b7

Please sign in to comment.