Skip to content

Commit

Permalink
Add TorchAOHfQuantizer
Browse files Browse the repository at this point in the history
Summary:
Enable loading torchao quantized model in huggingface.

Test Plan:
local test

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Aug 1, 2024
1 parent 811a9ca commit a2b7a42
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
title: FBGEMM_FP8
- local: quantization/optimum
title: Optimum
- local: quantization/torchao
title: TorchAO
- local: quantization/contribute
title: Contribute new quantization method
title: Quantization Methods
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

[[autodoc]] FbgemmFp8Config

## TorchAOConfig

[[autodoc]] TorchAOConfig

38 changes: 38 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# TorchAO

[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training.

Before you begin, make sure the following libraries are installed with their latest version:

```bash
pip install --upgrade torch torchao-nightly
```


```py
from transformers import TorchAOConfig, AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = TorchAOConfig("int4_weight_only", group_size=128)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@
"GPTQConfig",
"HqqConfig",
"QuantoConfig",
"TorchAOConfig",
],
}

Expand Down Expand Up @@ -5673,6 +5674,7 @@
GPTQConfig,
HqqConfig,
QuantoConfig,
TorchAOConfig,
)

try:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
TorchAOConfig,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
Expand All @@ -36,6 +37,7 @@
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAOHfQuantizer


AUTO_QUANTIZER_MAPPING = {
Expand All @@ -48,6 +50,7 @@
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAOHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -60,6 +63,7 @@
"quanto": QuantoConfig,
"hqq": HqqConfig,
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAOConfig,
}


Expand Down
157 changes: 157 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from .base import HfQuantizer
from .quantizers_utils import get_module_from_name


if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from typing import Any, Dict, List

from ..utils import is_torch_available, is_torchao_available, logging


if is_torch_available():
import torch

if is_torchao_available():
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)

logger = logging.get_logger(__name__)


# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
parent = model
for m in module_tree:
parent = parent._modules[m]
return parent


class TorchAOHfQuantizer(HfQuantizer):
"""
Quantizer for torchao: https://github.com/pytorch/ao/
"""

requires_parameters_quantization = True
requires_calibration = False
required_packages = ["torchao"]

def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.torch_dtype = None

def validate_environment(self, device_map, **kwargs):
if not is_torchao_available():
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")

if self.torch_dtype is None:
if "torch_dtype" in kwargs:
self.torch_dtype = kwargs["torch_dtype"]
else:
self.torch_dtype = torch.float32
logger.info("Setting torch_dtype to torch.float32 as the default value since it was not specified.")

def update_torch_dtype(self, torch_dtype):
if torch_dtype is None:
torch_dtype = torch.float32
return torch_dtype

def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
module, tensor_name = get_module_from_name(model, param_name)

return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")

def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
):
"""
Each nn.Linear layer is processsed here.
We first check if the corresponding module state_dict contains already torchao quantized parameters.
If not, we create a temp linear layer with the module state_dict params and use it for quantization
"""
module, tensor_name = get_module_from_name(model, param_name)

layer_name = param_name.replace(".weight", "").replace(".bias", "")
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]

# Step 0: set module state_dict
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}

# Step 1: populate module with weight/bias from module state dict
for key in module_state_dict:
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))

# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.

_STR_TO_METHOD = {
"int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}


module = module.to(dtype=self.torch_dtype, device=target_device)
setattr(parent_module, node, module)

if self.quantization_config is not None:
assert self.quantization_config.quant_type in _STR_TO_METHOD.keys(), f"Requested quantization type: {self.quantization_config.quant_type} is not supported yet, please add support in TorchAOHfQuantizer."
quantize_(module, _STR_TO_METHOD[self.quantization_config.quant_type](**self.quantization_config.kwargs))
setattr(parent_module, node, module)

torch.cuda.empty_cache()

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
"""No process required for torchao quantized model
"""
return

def _process_model_after_weight_loading(self, model):
"""No process required for torchao quantized model
"""
return

@property
def is_serializable(self):
return False

@property
def is_trainable(self):
# torchao does not have official support for QAT (Quantization Aware Training or PEFT yet.)
# TODO: if this is supported in the future, do a version check here.
return False
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio")
_torchao_available = _is_package_available("torchao")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
Expand Down Expand Up @@ -1044,6 +1045,9 @@ def is_nltk_available():
def is_torchaudio_available():
return _torchaudio_available

def is_torchao_available():
return _torchao_available


def is_speech_available():
# For now this depends on torchaudio but the exact dependency might evolve in the future.
Expand Down
32 changes: 32 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum):
EETQ = "eetq"
HQQ = "hqq"
FBGEMM_FP8 = "fbgemm_fp8"
TORCHAO = "torchao"


class AWQLinearVersion(str, Enum):
Expand Down Expand Up @@ -1079,3 +1080,34 @@ def get_loading_attributes(self):
loading_attibutes = ["activation_scale_ub"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict

@dataclass
class TorchAOConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques
Currently exposing 3 APIs:
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight
more API examples can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
Example::
quantization_config = TorchAOConfig("int4_weight_only", group_size=32)
# int4_weight_only quant is only working with `torch.bfloat16` dtype right now
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
"""
def __init__(
self,
quant_type: str,
**kwargs,
):
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.kwargs = kwargs


def __repr__(self):
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"

0 comments on commit a2b7a42

Please sign in to comment.