diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/core.py b/neural_compressor/torch/algorithms/weight_only/hqq/core.py index f13e2410432..041e173671d 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/core.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/core.py @@ -19,7 +19,7 @@ # NOTICE: the original `Quantizer` has been modified to `HQQTensorHandle` # and `QTensor` to decouple the data structure and the quantization logic. -from typing import Any, Dict, Tuple +from typing import Any, Dict, Mapping, Tuple import torch @@ -278,3 +278,61 @@ def from_float( # !!! Delete the float explicitly to save memory del float_module return new_mod + + def state_dict(self, *args, **kwargs): # nn.Module override compatible + state_dict = self.q_weight.to_state_dict() + if self.bias is not None: + state_dict["bias"] = self.bias + if "destination" in kwargs and "prefix" in kwargs: + for key, value in state_dict.items(): + kwargs["destination"][kwargs["prefix"] + key] = value + return state_dict + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + all_expected_keys = ["val", "scale_quantized", "zero_quantized", "meta_info"] + if self.bias is not None: + all_expected_keys.append("bias") + + for key in all_expected_keys: + if prefix + key not in state_dict: + missing_keys.append(key) + if missing_keys: + return # Can't load weights if either weight or meta is missing + + cur_state_dict = {} + for key in all_expected_keys: + cur_state_dict[key] = state_dict.pop(prefix + key) + + unexpected_keys += state_dict.keys() + self._assign_state_dict(cur_state_dict, strict) + + def _assign_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + _scale_quantized = state_dict["scale_quantized"] + _zero_quantized = state_dict["zero_quantized"] + scale_state = state_dict["meta_info"]["scale"] + zero_state = state_dict["meta_info"]["zero"] + if _scale_quantized: + scale = HQQTensorHandle._create_q_tensor(scale_state["val"], scale_state["meta_info"]) + else: + scale = state_dict["meta_info"]["scale"] + if _zero_quantized: + zero = HQQTensorHandle._create_q_tensor(zero_state["val"], zero_state["meta_info"]) + else: + zero = state_dict["meta_info"]["zero"] + meta = state_dict["meta_info"] + meta["scale"] = scale + meta["zero"] = zero + self.q_weight = HQQTensorHandle._create_q_tensor(state_dict["val"], meta) + if self.bias is not None: + self.bias = state_dict["bias"] + self.quantized = True + return self diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py b/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py index 950fa2243dc..f1fbd5bce3a 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py @@ -115,3 +115,19 @@ def half(self): if self.zero is not None: self.zero = self.zero.half() return self + + def to_state_dict(self): + state = {} + state["val"] = self.val + state["meta_info"] = self.meta_info.to_dict() + state["scale_quantized"] = self.is_scale_quantized() + state["zero_quantized"] = self.is_zero_quantized() + if self.is_scale_quantized(): + state["meta_info"]["scale"] = self.scale.to_state_dict() + else: + state["meta_info"]["scale"] = self.scale + if self.is_zero_quantized(): + state["meta_info"]["zero"] = self.zero.to_state_dict() + else: + state["meta_info"]["zero"] = self.zero + return state diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 231502c32c6..4a6e6a0d488 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -124,7 +124,6 @@ def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path): with open(qconfig_file_path, "r") as file: self.quantization_config = json.load(file) - model = self._build_woq_model() model.load_state_dict(qweights, assign=True) model.eval() @@ -157,8 +156,19 @@ def load_hf_format_woq_model(self): return model + def _is_hqq_model(self): + for name, module in self.original_model.named_modules(): + pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))" + for q_config_key, q_config_value in self.quantization_config.items(): + if re.search(pattern, q_config_key): + if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "hqq": + return True + def _build_woq_model(self): """Build weight-only quantization model.""" + if self._is_hqq_model(): + return self._build_hqq_model() + from neural_compressor.torch.utils import set_module from .modules import MulLinear @@ -228,6 +238,23 @@ def _build_woq_model(self): woq_model = self.original_model return woq_model + def _build_hqq_model(self): + """Replace quantized Linear with HQQLinear.""" + from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear + from neural_compressor.torch.utils import set_module + + for name, module in self.original_model.named_modules(): + if isinstance(module, torch.nn.Linear): + loaded_state_dict_keys_set = set(self.loaded_state_dict_keys) + if name + ".val" not in loaded_state_dict_keys_set: + continue + new_module = HQQLinear( + in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None + ) + set_module(self.original_model, name, new_module) + woq_model = self.original_model + return woq_model + def _get_model_class_and_config(self): from transformers import AutoConfig, AutoModelForCausalLM from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 856961af532..b8a1e3b9202 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -517,11 +517,14 @@ def hqq_entry( **kwargs, ) -> torch.nn.Module: from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer + from neural_compressor.torch.algorithms.weight_only.save_load import save logger.info("Quantize model with the HQQ algorithm.") quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping) model = quantizer.execute(model, mode=mode) + model.qconfig = configs_mapping + model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) dump_model_op_stats(mode, configs_mapping) diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index d20f828659d..584b853014f 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -22,6 +22,7 @@ AWQConfig, FP8Config, GPTQConfig, + HQQConfig, RTNConfig, TEQConfig, ) @@ -89,7 +90,9 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu" # select load function config_object = config_mapping[next(iter(config_mapping))] - if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ + if isinstance( + config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig, HQQConfig) + ): # WOQ from neural_compressor.torch.algorithms import weight_only return weight_only.load(model_name_or_path, original_model, format=LoadFormat.DEFAULT) diff --git a/test/3x/torch/quantization/weight_only/test_hqq.py b/test/3x/torch/quantization/weight_only/test_hqq.py index 1d68a553859..d6e0352c312 100644 --- a/test/3x/torch/quantization/weight_only/test_hqq.py +++ b/test/3x/torch/quantization/weight_only/test_hqq.py @@ -1,4 +1,6 @@ +import copy import os +import time from copy import deepcopy import pytest @@ -6,6 +8,7 @@ import transformers from transformers import AutoModelForCausalLM +from neural_compressor.common import options from neural_compressor.common.utils import logger from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear @@ -93,6 +96,27 @@ def test_hqq_quant(self, force_use_cpu, force_not_half): q_label_1.eq(q_label_2) ), "The results of calling `convert` + `prepare` and calling `quantize` should be equal." + def test_hqq_load_save(self, force_use_cpu, force_not_half): + + hqq_global_option.use_half = False + fp32_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-OPTForCausalLM") + example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long, device="cpu") + # test_default_config + quant_config = get_default_hqq_config() + + # prepare + convert API + model = prepare(deepcopy(fp32_model), quant_config) + qmodel = convert(model) + qmodel_out_ref = model(example_inputs)[0] + save_path = options.workspace + f"/_hqq_model_{time.time()}.pth" + qmodel.save(save_path) + from neural_compressor.torch.quantization import load + + # loading compressed model + loaded_model = load(save_path, copy.deepcopy(fp32_model)) + loaded_model_out = loaded_model(example_inputs)[0] + assert torch.allclose(qmodel_out_ref, loaded_model_out), "Unexpected result. Please double check." + def test_hqq_fallback(self, force_use_cpu, force_not_half): class ToyModel(torch.nn.Module): @@ -181,3 +205,57 @@ def test_hqq_module( scale_quant_group_size=scale_quant_group_size, device=torch.device(device_name), ) + + @pytest.mark.parametrize( + "nbits, group_size, quant_zero, quant_scale, scale_quant_group_size", + [ + (4, 64, True, False, 128), + (4, 64, False, False, 128), + (4, 64, True, True, 128), + (4, 64, False, True, 128), + (8, 64, True, False, 128), + ], + ) + def test_hqq_linear_save_and_load( + self, + nbits, + group_size, + quant_zero, + quant_scale, + scale_quant_group_size, + ): + hqq_global_option.use_half = False + # Parse config + weight_qconfig = QTensorConfig( + nbits=nbits, + channel_wise=True, + group_size=group_size, + optimize=True, + round_zero=True if nbits == 4 else False, + ) + zero_qconfig = None + if quant_zero: + zero_qconfig = QTensorConfig(nbits=8, channel_wise=False, group_size=None, optimize=False) + scale_qconfig = None + if quant_scale: + scale_qconfig = QTensorConfig(nbits=8, channel_wise=True, group_size=scale_quant_group_size, optimize=False) + hqq_quant_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig) + # Create HQQ Linear + bs = 4 + in_features = 64 + out_features = 128 + float_linear = torch.nn.Linear(in_features=in_features, out_features=out_features) + float_linear.to(device) + float_linear_copy = deepcopy(float_linear) + input = torch.randn(bs, in_features, device=device) + hqq_linear = HQQLinear.from_float(float_linear_copy, quant_config=hqq_quant_config) + out_ref = hqq_linear(input) + state_dict = hqq_linear.state_dict() + hqq_module_path = options.workspace + f"/_hqq_linear_{time.time()}.pth" + torch.save(state_dict, hqq_module_path) + reload_state_dict = torch.load(hqq_module_path) + new_float = torch.nn.Linear(in_features=in_features, out_features=out_features) + new_hqq_linear = HQQLinear.from_float(new_float, quant_config=hqq_quant_config) + new_hqq_linear.load_state_dict(reload_state_dict) + out = new_hqq_linear(input) + assert torch.equal(out_ref, out), f"out_ref: {out_ref}, out: {out}"