From 1f4aec3c7149d93568106326e98a91933826475e Mon Sep 17 00:00:00 2001 From: yintong-lu <108845308+yintong-lu@users.noreply.github.com> Date: Thu, 16 Nov 2023 14:43:47 +0800 Subject: [PATCH] [SmoothQuant] make weight_clipping a default_on option (#1386) Signed-off-by: Lu, Yintong --- neural_compressor/adaptor/onnxrt.py | 1 + neural_compressor/adaptor/pytorch.py | 17 ++++++--- neural_compressor/adaptor/tensorflow.py | 1 + .../adaptor/torch_utils/smooth_quant.py | 5 +++ neural_compressor/algorithm/smooth_quant.py | 2 ++ neural_compressor/strategy/strategy.py | 3 ++ test/algorithm/test_smooth_quant.py | 36 +++++++++++++++++++ 7 files changed, 61 insertions(+), 4 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 037c88a3b15..d81c05a92cc 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -172,6 +172,7 @@ def smooth_quant( op_types=["MatMul", "Gemm", "Conv", "FusedConv"], scales_per_op=True, record_max_info=False, + weight_clip=True, ): """Get augmented model with smooth quant. diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 3d28437e849..e1cd20a8998 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -1761,6 +1761,7 @@ def smooth_quant( scales_per_op=None, force_re_smooth=False, record_max_info=False, + weight_clip=True, ): """Convert the model by smooth quant. @@ -1804,7 +1805,9 @@ def smooth_quant( kwargs["percentile"] = percentile if scales_per_op is not None: kwargs["scales_per_op"] = scales_per_op - model._model = self.sq.transform(alpha=alpha, folding=folding, calib_iter=calib_iter, **kwargs) + model._model = self.sq.transform( + alpha=alpha, folding=folding, calib_iter=calib_iter, weight_clip=weight_clip, **kwargs + ) if self.sq.record_max_info: model.sq_max_info = self.sq.max_value_info return model @@ -1833,7 +1836,9 @@ def _apply_pre_optimization(self, model, tune_cfg, recover=False): absorb_layer = op_name absorbed_layer = info["absorbed_layer"] input_minmax = info["input_minmax"] - weight_max = info["weight_max"].clamp(min=1e-5) + weight_max = info["weight_max"] + if self.sq.weight_clip: + weight_max = weight_max.clamp(min=1e-5) abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) input_power = torch.pow(abs_input_max, alpha) weight_power = torch.pow(weight_max, 1 - alpha) @@ -1877,7 +1882,9 @@ def qdq_quantize(self, model, tune_cfg): alpha = info["alpha"] absorbed_layer = info["absorbed_layer"] input_minmax = info["input_minmax"] - weight_max = info["weight_max"].clamp(min=1e-5) + weight_max = info["weight_max"] + if self.sq.weight_clip: + weight_max = weight_max.clamp(min=1e-5) abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) input_power = torch.pow(abs_input_max, alpha) weight_power = torch.pow(weight_max, 1 - alpha) @@ -3279,7 +3286,9 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func): absorbed_layer = info["absorbed_layer"] input_minmax = info["input_minmax"] # for peft model,lora_B weights is 0. - weight_max = info["weight_max"].clamp(min=1e-5) + weight_max = info["weight_max"] + if self.sq.weight_clip: + weight_max = weight_max.clamp(min=1e-5) abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) input_power = torch.pow(abs_input_max, alpha) weight_power = torch.pow(weight_max, 1 - alpha) diff --git a/neural_compressor/adaptor/tensorflow.py b/neural_compressor/adaptor/tensorflow.py index b28cf65175f..338fbfe82e8 100644 --- a/neural_compressor/adaptor/tensorflow.py +++ b/neural_compressor/adaptor/tensorflow.py @@ -1822,6 +1822,7 @@ def smooth_quant( op_types=["MatMul", "Conv2D"], scales_per_op=True, record_max_info=False, + weight_clip=True, ): """Convert the model by smooth quant. diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index b91b95e8563..d2b15cae0f0 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -330,6 +330,7 @@ def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, tra self.self_absorb_layers = {} self.absorb_to_layer = {} self.adjust_alpha_space = False + self.weight_clip = True def _get_device(self): """Get the model device @@ -577,6 +578,8 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False): weights.append(weight) weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0] + if self.weight_clip: + weight_max_per_channel = weight_max_per_channel.clamp(min=1e-5) if self.record_max_info and not tuning: # the input of layers with same absorb layer is the same. input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] @@ -946,6 +949,7 @@ def transform( scales_per_op=False, calib_iter=100, auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"}, + weight_clip=True, ): """The main entry of smooth quant :param alpha: Alpha value to balance the quantization difficulty of activation and weight, please refer @@ -971,6 +975,7 @@ def transform( alpha = numpy.clip(alpha, 0.0, 1.0) + self.weight_clip = weight_clip self.recover() need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter) with torch.no_grad(): diff --git a/neural_compressor/algorithm/smooth_quant.py b/neural_compressor/algorithm/smooth_quant.py index faffca4c2e7..89848e545c6 100644 --- a/neural_compressor/algorithm/smooth_quant.py +++ b/neural_compressor/algorithm/smooth_quant.py @@ -52,6 +52,7 @@ def __init__(self, alpha=0.5): self.op_types = None self.scales_per_op = None self.tune_cfg = None + self.weight_clip = None def __call__(self, origin_model, q_model, adaptor, dataloader, calib_iter): """Return the processed model via SmoothQuant algorithm. @@ -80,6 +81,7 @@ def __call__(self, origin_model, q_model, adaptor, dataloader, calib_iter): kwargs["scales_per_op"] = self.scales_per_op kwargs["folding"] = self.folding kwargs["record_max_info"] = True + kwargs["weight_clip"] = self.weight_clip q_model = adaptor.smooth_quant( origin_model, dataloader, diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index a81b3d16b17..c07f6552051 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -948,6 +948,9 @@ def set_param_for_pre_tuning_algos(self, algo_scheduler, config, fp32_model) -> if self.framework == "pytorch_ipex": smooth_quant_args["folding"] = None # will reset it to True if IPEX version < 2.1. sq_algo.folding = smooth_quant_args["folding"] + sq_algo.weight_clip = smooth_quant_args.get( + "weight_clip", True + ) # make weight_clipping a default_on option. logger.debug(f"Set smooth quant with alpha {sq_algo.alpha} as the pre-tuning algo.") algo_scheduler.append_algorithm("pre_quantization", sq_algo) diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index 3f0cb63a4d9..c4bfc646bae 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -1427,5 +1427,41 @@ def calib_func(model): out2 = q_model.model(example_input)[0] +class TestInputConfig(unittest.TestCase): + @classmethod + def setUpClass(self): + class RandDataloader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + yield torch.rand((1, 3)) + + self.linear_dl = RandDataloader() + + @classmethod + def test_sq_weight_clipping(self): + class Model(torch.nn.Module): + device = torch.device("cpu") + + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(3, 4) + self.norm = LlamaRMSNorm(4) + self.fc2 = torch.nn.Linear(4, 3) + + def forward(self, x): + out = self.fc1(x) + out = self.norm(out) + out = self.fc2(out) + return out + + model = Model() + + sq = TorchSmoothQuant(model, self.linear_dl) + sq.transform(alpha="auto", calib_iter=1, folding=True, weight_clip=False) + assert sq.weight_clip is False + + if __name__ == "__main__": unittest.main()