diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 740bb4b0719c61..8e91437c38b8f8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -161,6 +161,8 @@ title: FBGEMM_FP8 - local: quantization/optimum title: Optimum + - local: quantization/torchao + title: TorchAO - local: quantization/contribute title: Contribute new quantization method title: Quantization Methods diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index fc5808415cbe5f..ce1cff27c9949c 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -61,3 +61,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide. [[autodoc]] FbgemmFp8Config +## TorchAOConfig + +[[autodoc]] TorchAOConfig + diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md new file mode 100644 index 00000000000000..40761c6d0b5274 --- /dev/null +++ b/docs/source/en/quantization/torchao.md @@ -0,0 +1,38 @@ + + +# TorchAO + +[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training. + +Before you begin, make sure the following libraries are installed with their latest version: + +```bash +pip install --upgrade torch torchao-nightly +``` + + +```py +from transformers import TorchAOConfig, AutoModelForCausalLM, AutoTokenizer + +model_name = "meta-llama/Meta-Llama-3-8B" +quantization_config = TorchAOConfig("int4_weight_only", group_size=128) +quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_text = "What are we having for dinner?" +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +output = quantized_model.generate(**input_ids, max_new_tokens=10) +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` + +torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9108367f35b321..0d39bb39c1374f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -939,6 +939,7 @@ "GPTQConfig", "HqqConfig", "QuantoConfig", + "TorchAOConfig", ], } @@ -5673,6 +5674,7 @@ GPTQConfig, HqqConfig, QuantoConfig, + TorchAOConfig, ) try: diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 40aa86fc37c733..895af5c3fa742e 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -26,6 +26,7 @@ QuantizationConfigMixin, QuantizationMethod, QuantoConfig, + TorchAOConfig, ) from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_awq import AwqQuantizer @@ -36,6 +37,7 @@ from .quantizer_gptq import GptqHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer +from .quantizer_torchao import TorchAOHfQuantizer AUTO_QUANTIZER_MAPPING = { @@ -48,6 +50,7 @@ "eetq": EetqHfQuantizer, "hqq": HqqHfQuantizer, "fbgemm_fp8": FbgemmFp8HfQuantizer, + "torchao": TorchAOHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -60,6 +63,7 @@ "quanto": QuantoConfig, "hqq": HqqConfig, "fbgemm_fp8": FbgemmFp8Config, + "torchao": TorchAOConfig, } diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py new file mode 100644 index 00000000000000..e139e32b821bda --- /dev/null +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -0,0 +1,157 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from .base import HfQuantizer +from .quantizers_utils import get_module_from_name + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from typing import Any, Dict, List + +from ..utils import is_torch_available, is_torchao_available, logging + + +if is_torch_available(): + import torch + +if is_torchao_available(): + from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, + ) + +logger = logging.get_logger(__name__) + + +# Finds the parent of a node module named "name" +def find_parent(model, name): + module_tree = name.split(".")[:-1] + parent = model + for m in module_tree: + parent = parent._modules[m] + return parent + + +class TorchAOHfQuantizer(HfQuantizer): + """ + Quantizer for torchao: https://github.com/pytorch/ao/ + """ + + requires_parameters_quantization = True + requires_calibration = False + required_packages = ["torchao"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.torch_dtype = None + + def validate_environment(self, device_map, **kwargs): + if not is_torchao_available(): + raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)") + + if self.torch_dtype is None: + if "torch_dtype" in kwargs: + self.torch_dtype = kwargs["torch_dtype"] + else: + self.torch_dtype = torch.float32 + logger.info("Setting torch_dtype to torch.float32 as the default value since it was not specified.") + + def update_torch_dtype(self, torch_dtype): + if torch_dtype is None: + torch_dtype = torch.float32 + return torch_dtype + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + module, tensor_name = get_module_from_name(model, param_name) + + return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: List[str], + ): + """ + Each nn.Linear layer is processsed here. + We first check if the corresponding module state_dict contains already torchao quantized parameters. + If not, we create a temp linear layer with the module state_dict params and use it for quantization + """ + module, tensor_name = get_module_from_name(model, param_name) + + layer_name = param_name.replace(".weight", "").replace(".bias", "") + parent_module = find_parent(model, layer_name) + node = layer_name.split(".")[-1] + + # Step 0: set module state_dict + module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key} + + # Step 1: populate module with weight/bias from module state dict + for key in module_state_dict: + setattr(module, key, torch.nn.Parameter(module_state_dict[key])) + + # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module + # directly doesn't work. + + _STR_TO_METHOD = { + "int4_weight_only": int4_weight_only, + "int8_weight_only": int8_weight_only, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + } + + + module = module.to(dtype=self.torch_dtype, device=target_device) + setattr(parent_module, node, module) + + if self.quantization_config is not None: + assert self.quantization_config.quant_type in _STR_TO_METHOD.keys(), f"Requested quantization type: {self.quantization_config.quant_type} is not supported yet, please add support in TorchAOHfQuantizer." + quantize_(module, _STR_TO_METHOD[self.quantization_config.quant_type](**self.quantization_config.kwargs)) + setattr(parent_module, node, module) + + torch.cuda.empty_cache() + + def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): + """No process required for torchao quantized model + """ + return + + def _process_model_after_weight_loading(self, model): + """No process required for torchao quantized model + """ + return + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self): + # torchao does not have official support for QAT (Quantization Aware Training or PEFT yet.) + # TODO: if this is supported in the future, do a version check here. + return False diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index efe473a6cdeda2..c98837594a3e1d 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -208,6 +208,7 @@ is_torch_tpu_available, is_torch_xla_available, is_torch_xpu_available, + is_torchao_available, is_torchaudio_available, is_torchdistx_available, is_torchdynamo_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 168d8b5d9c98af..b631463978ce3e 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -171,6 +171,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _timm_available = _is_package_available("timm") _tokenizers_available = _is_package_available("tokenizers") _torchaudio_available = _is_package_available("torchaudio") +_torchao_available = _is_package_available("torchao") _torchdistx_available = _is_package_available("torchdistx") _torchvision_available = _is_package_available("torchvision") _mlx_available = _is_package_available("mlx") @@ -1044,6 +1045,9 @@ def is_nltk_available(): def is_torchaudio_available(): return _torchaudio_available +def is_torchao_available(): + return _torchao_available + def is_speech_available(): # For now this depends on torchaudio but the exact dependency might evolve in the future. diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 5de8307c3bd79b..e64dfcd268b74c 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum): EETQ = "eetq" HQQ = "hqq" FBGEMM_FP8 = "fbgemm_fp8" + TORCHAO = "torchao" class AWQLinearVersion(str, Enum): @@ -1079,3 +1080,34 @@ def get_loading_attributes(self): loading_attibutes = ["activation_scale_ub"] loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} return loading_attibutes_dict + +@dataclass +class TorchAOConfig(QuantizationConfigMixin): + """This is a config class for torchao quantization/sparsity techniques + + Currently exposing 3 APIs: + int4_weight_only, + int8_weight_only, + int8_dynamic_activation_int8_weight + + + more API examples can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques + + Example:: + quantization_config = TorchAOConfig("int4_weight_only", group_size=32) + # int4_weight_only quant is only working with `torch.bfloat16` dtype right now + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) + + """ + def __init__( + self, + quant_type: str, + **kwargs, + ): + self.quant_method = QuantizationMethod.TORCHAO + self.quant_type = quant_type + self.kwargs = kwargs + + + def __repr__(self): + return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"