-
Notifications
You must be signed in to change notification settings - Fork 27.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Enable loading torchao quantized model in huggingface. Test Plan: local test Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
1 parent
811a9ca
commit 55d94b7
Showing
10 changed files
with
258 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# TorchAO | ||
|
||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training. | ||
|
||
Before you begin, make sure the following libraries are installed with their latest version: | ||
|
||
```bash | ||
pip install --upgrade torch torchao | ||
``` | ||
|
||
|
||
```py | ||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_name = "meta-llama/Meta-Llama-3-8B" | ||
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight | ||
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques | ||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128) | ||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
input_text = "What are we having for dinner?" | ||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") | ||
|
||
output = quantized_model.generate(**input_ids, max_new_tokens=10) | ||
print(tokenizer.decode(output[0], skip_special_tokens=True)) | ||
``` | ||
|
||
torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from .base import HfQuantizer | ||
from .quantizers_utils import get_module_from_name | ||
|
||
|
||
if TYPE_CHECKING: | ||
from ..modeling_utils import PreTrainedModel | ||
|
||
from typing import Any, Dict, List | ||
|
||
from ..utils import is_torch_available, is_torchao_available, logging | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
if is_torchao_available(): | ||
from torchao.quantization import ( | ||
quantize_, | ||
) | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
# Finds the parent of a node module named "name" | ||
def find_parent(model, name): | ||
module_tree = name.split(".")[:-1] | ||
parent = model | ||
for m in module_tree: | ||
parent = parent._modules[m] | ||
return parent | ||
|
||
|
||
class TorchAoHfQuantizer(HfQuantizer): | ||
""" | ||
Quantizer for torchao: https://github.com/pytorch/ao/ | ||
""" | ||
|
||
requires_parameters_quantization = True | ||
requires_calibration = False | ||
required_packages = ["torchao"] | ||
|
||
def __init__(self, quantization_config, **kwargs): | ||
super().__init__(quantization_config, **kwargs) | ||
self.torch_dtype = None | ||
|
||
def validate_environment(self, device_map, **kwargs): | ||
if not is_torchao_available(): | ||
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)") | ||
|
||
if self.torch_dtype is None: | ||
if "torch_dtype" in kwargs: | ||
self.torch_dtype = kwargs["torch_dtype"] | ||
else: | ||
self.torch_dtype = torch.float32 | ||
logger.info("Setting torch_dtype to torch.float32 as the default value since it was not specified.") | ||
|
||
def check_quantized_param( | ||
self, | ||
model: "PreTrainedModel", | ||
param_value: "torch.Tensor", | ||
param_name: str, | ||
state_dict: Dict[str, Any], | ||
**kwargs, | ||
) -> bool: | ||
module, tensor_name = get_module_from_name(model, param_name) | ||
|
||
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") | ||
|
||
def create_quantized_param( | ||
self, | ||
model: "PreTrainedModel", | ||
param_value: "torch.Tensor", | ||
param_name: str, | ||
target_device: "torch.device", | ||
state_dict: Dict[str, Any], | ||
unexpected_keys: List[str], | ||
): | ||
""" | ||
Each nn.Linear layer is processsed here. | ||
We first check if the corresponding module state_dict contains already torchao quantized parameters. | ||
If not, we create a temp linear layer with the module state_dict params and use it for quantization | ||
""" | ||
module, tensor_name = get_module_from_name(model, param_name) | ||
|
||
layer_name = param_name.replace(".weight", "").replace(".bias", "") | ||
parent_module = find_parent(model, layer_name) | ||
node = layer_name.split(".")[-1] | ||
|
||
# Step 0: set module state_dict | ||
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key} | ||
|
||
# Step 1: populate module with weight/bias from module state dict | ||
for key in module_state_dict: | ||
setattr(module, key, torch.nn.Parameter(module_state_dict[key])) | ||
|
||
# Step 2: Update the module using the `quantize_` API from TorchAO | ||
|
||
module = module.to(dtype=self.torch_dtype, device=target_device) | ||
quantize_(module, self.quantization_config.get_apply_tensor_subclass()) | ||
setattr(parent_module, node, module) | ||
|
||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): | ||
"""No process required for torchao quantized model""" | ||
return | ||
|
||
def _process_model_after_weight_loading(self, model): | ||
"""No process required for torchao quantized model""" | ||
return | ||
|
||
@property | ||
def is_serializable(self): | ||
return False | ||
|
||
@property | ||
def is_trainable(self): | ||
# torchao does not have official support for QAT (Quantization Aware Training) | ||
# but torchao support nf4/PEFT, but it is not integrated yet | ||
# TODO: if this is supported in the future, do a version check here. | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters