From ee6bc284b1e65c37b2f46857da9cd315019b208a Mon Sep 17 00:00:00 2001 From: yintong-lu <108845308+yintong-lu@users.noreply.github.com> Date: Wed, 13 Dec 2023 09:17:50 +0800 Subject: [PATCH] Lyt/blockwise (#1441) * [Algo] blockwise tuning Signed-off-by: Lu, Yintong * [Algo] code update Signed-off-by: Lu, Yintong * [Algo] sq argument update Signed-off-by: Lu, Yintong * [Algo] log update Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] code update Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] fix bugs Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] log update Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] enable blockwise on Llama models Signed-off-by: Lu, Yintong * [Algo] enable blockwise on Llama models Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] code update Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] format code Signed-off-by: Lu, Yintong * [Algo] fix bug Signed-off-by: Lu, Yintong * [Algo] add ut Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] fix format issue Signed-off-by: Lu, Yintong * [Algo] log update Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] move do_blockwise arg Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] fix bug Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] fix bug Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] fix bug Signed-off-by: Lu, Yintong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Algo] fix bug Signed-off-by: Lu, Yintong * [Algo] fix bug Signed-off-by: Lu, Yintong --------- Signed-off-by: Lu, Yintong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- neural_compressor/adaptor/onnxrt.py | 9 +- neural_compressor/adaptor/pytorch.py | 9 +- neural_compressor/adaptor/tensorflow.py | 9 +- .../adaptor/torch_utils/smooth_quant.py | 275 +++++++++++++++--- neural_compressor/adaptor/torch_utils/util.py | 2 + neural_compressor/strategy/strategy.py | 11 +- test/algorithm/test_smooth_quant.py | 27 ++ 7 files changed, 300 insertions(+), 42 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 13dd168ed02..9a735fa43aa 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -175,7 +175,13 @@ def smooth_quant( scales_per_op=True, record_max_info=False, weight_clip=True, - auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"}, + auto_alpha_args={ + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "do_blockwise": False, + }, default_alpha=0.5, ): """Get augmented model with smooth quant. @@ -194,6 +200,7 @@ def smooth_quant( weight_clip: Whether to clip weight when calculating scales; by default it is on. auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. By default the search space is 0.0-1.0 with step_size 0.1. + do_blockwise: Whether to do blockwise auto-tuning. default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. Returns: diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index a3cb3f1ea09..8839444bc2e 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -1737,7 +1737,13 @@ def smooth_quant( force_re_smooth=False, record_max_info=False, weight_clip=True, - auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"}, + auto_alpha_args={ + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "do_blockwise": False, + }, default_alpha=0.5, ): """Convert the model by smooth quant. @@ -1756,6 +1762,7 @@ def smooth_quant( weight_clip: Whether to clip weight when calculating scales; by default it is on. auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. By default the search space is 0.0-1.0 with step_size 0.1. + do_blockwise determines whether to do blockwise auto-tuning. default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. Returns: diff --git a/neural_compressor/adaptor/tensorflow.py b/neural_compressor/adaptor/tensorflow.py index c5d7731e3d3..212c233a530 100644 --- a/neural_compressor/adaptor/tensorflow.py +++ b/neural_compressor/adaptor/tensorflow.py @@ -1833,7 +1833,13 @@ def smooth_quant( scales_per_op=True, record_max_info=False, weight_clip=True, - auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"}, + auto_alpha_args={ + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "do_blockwise": False, + }, default_alpha=0.5, ): """Convert the model by smooth quant. @@ -1852,6 +1858,7 @@ def smooth_quant( weight_clip: Whether to clip weight when calculating scales; by default it is on. auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. By default the search space is 0.0-1.0 with step_size 0.1. + do_blockwise: Whether to do blockwise auto-tuning. default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. Returns: diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 6aa295f40c5..0cf32183084 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -263,6 +263,7 @@ def __init__(self, layer, input_min, input_max, save_q_input=False): self.weight_scale = None self.input_scale = None self.save_q_input = save_q_input + self.do_blockwise = False def enable_quant(self): self.quant = True @@ -289,12 +290,25 @@ def q_dq_forward(self, x, input_scale, weight_scale): output = layer_copy(x) return output + def q_dq_forward_blockwise(self, x, input_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if input_scale is None: + x = quant_dequant_x(x, self.input_min, self.input_max) + else: + x = input_scale * x + x = quant_dequant_x(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + output = layer_copy(x) + return output + def forward(self, x): if self.quant: # self.q_input = x * scale ##save the q_input if self.save_q_input: self.q_input = x - output = self.q_dq_forward(x, self.input_scale, self.weight_scale) + if not self.do_blockwise: + output = self.q_dq_forward(x, self.input_scale, self.weight_scale) + else: + output = self.q_dq_forward_blockwise(x, self.input_scale) else: output = self.orig_layer(x) @@ -351,6 +365,10 @@ def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, tra self._save_scale = False self.weight_scale_dict = {} + self.do_blockwise = False + self.block_inputs = {} + self.block_outputs = {} + def _get_device(self): """Get the model device :return:Model device.""" @@ -462,6 +480,15 @@ def _reshape_scale_for_weight(self, layer, scale): return scale + def get_blocks(self): + block_names = [] + for n, m in self.model.named_modules(): + if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: + for nn, mm in m.named_children(): + block_name = n + "." + nn + block_names.append(block_name) + return block_names + def _reshape_scale_for_input(self, layer, scale): """Reshape the scale for input feature in channel :param layer: @@ -764,44 +791,125 @@ def _update_scales_for_auto(self, absorb_scales, weight_scales): weight_scale = self._reshape_scale_for_weight(layer, weight_scale) layer.update_scale(input_scale, weight_scale) ##FIXME + def _add_blockwise_observer(self, block_modules): + """ + :param block_modules: the block modules which the observer will insert to + :return: + """ + self.blockwise_hook_handles = [] + for key in block_modules.keys(): + hook_func = self._save_blockwise_hook(key) + hook_handle = block_modules[key].register_forward_hook(hook_func) + self.blockwise_hook_handles.append(hook_handle) + + def _save_blockwise_hook(self, name): + """A forward hook to save inputs/outputs of a block + :param name: the block name + :return: A hook function.""" + + def save_blockwise_hook(module, inputs, outputs): + self.block_inputs[name] = inputs[0] + self.block_outputs[name] = outputs[0] + + return save_blockwise_hook + def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes): self._change_qdq_for_auto(enable=False) + module_names = self._get_sq_layer_names() + + if self.do_blockwise: + block_modules = {} + for key in self.block_names: + block_modules[key] = get_module(self.model, key) + self._add_blockwise_observer(block_modules) forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output - module_names = self._get_sq_layer_names() + fp32_output = {} - for name in module_names: - module = get_module(self.model, name) - fp32_output[name] = module.output - module.output = None + if not self.do_blockwise: + for name in module_names: + module = get_module(self.model, name) + fp32_output[name] = module.output + module.output = None + else: + for block_name in self.block_names: + fp32_output[block_name] = self.block_outputs[block_name] self._change_qdq_for_auto(enable=True) absorb_input_scales, weight_scales = self._cal_scales( self.absorb_to_layer, input_maxes, orig_best_alpha, tuning=True ) self._update_scales_for_auto(absorb_input_scales, weight_scales) forward_wrapper(self.model, input, self.device) ##save quant_input + for mod_name in module_names: # save fp32 values + mod = get_module(self.model, mod_name) + if mod_name in self.fp32_output_val: + self.fp32_output_val[mod_name].append(torch.norm(mod.output)) + else: + self.fp32_output_val[mod_name] = [torch.norm(mod.output)] + del mod + loss_alphas = {} - for name in module_names: - module = get_module(self.model, name) - loss = self._get_auto_loss(fp32_output[name], module.output) - cur_alpha = orig_best_alpha - if isinstance(orig_best_alpha, dict): - cur_alpha = orig_best_alpha[name] - key_name = str(cur_alpha) - loss_alphas[name] = {key_name: loss} + if not self.do_blockwise: + for name in module_names: + module = get_module(self.model, name) + loss = self._get_auto_loss(fp32_output[name], module.output) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[name] + key_name = str(cur_alpha) + loss_alphas[name] = {key_name: loss} + else: + for block_name in self.block_names: + block = get_module(self.model, block_name) + loss = self._get_auto_loss(fp32_output[block_name], self.block_outputs[block_name]) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[self.block_to_module[block_name][0]] + key_name = str(cur_alpha) + loss_alphas[block_name] = {key_name: loss} # for name in module_names: # loss_alphas[name]={} for alpha in alpha_space: absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha, tuning=True) self._update_scales_for_auto(absorb_input_scales, weight_scales) - for name in module_names: - losses = loss_alphas[name] - if str(alpha) in losses.keys(): - continue - module = get_module(self.model, name) - output = module.q_dq_forward(module.q_input, module.input_scale, module.weight_scale) - loss = self._get_auto_loss(fp32_output[name], output) - loss_alphas[name][str(alpha)] = loss + if not self.do_blockwise: + for name in module_names: + losses = loss_alphas[name] + if str(alpha) in losses.keys(): + continue + module = get_module(self.model, name) + output = module.q_dq_forward(module.q_input, module.input_scale, module.weight_scale) + loss = self._get_auto_loss(fp32_output[name], output) + loss_alphas[name][str(alpha)] = loss + else: + for block_name in self.block_names: + losses = loss_alphas[block_name] + if str(alpha) in losses.keys(): + continue + block = get_module(self.model, block_name) + block_copy = copy.deepcopy(block) + for name in self.block_to_module[block_name]: + if name == block_name and len(self.block_to_module[block_name]) == 1: + module, module_copy = block, block_copy + else: + module = get_module(block, name) + module_copy = copy.deepcopy(module) + if module.weight_scale is not None: + module_copy.orig_layer.weight *= module.weight_scale + q_dq_weight = quant_dequant_w(module_copy.orig_layer) + module_copy.orig_layer.weight.data.copy_(q_dq_weight) + module_copy.do_blockwise = True + if not (name == block_name and len(self.block_to_module[block_name]) == 1): + set_module(block_copy, name, module_copy) + try: + output = block_copy(self.block_inputs[block_name])[0] + except: # Llama model decoder_layer forward requires position_id + position_ids = torch.arange(self.block_inputs[block_name].size()[1]) + position_ids = position_ids.view(self.block_inputs[block_name].size()[0], -1) + output = block_copy(self.block_inputs[block_name], position_ids=position_ids)[0] + loss = self._get_auto_loss(fp32_output[block_name], output) + loss_alphas[block_name][str(alpha)] = loss + del block_copy # release memory return loss_alphas def _get_best_alpha(self, absorb_to_layer, loss_alphas, shared_criterion): @@ -845,7 +953,14 @@ def dict_to_list(dic): return best_alpha def _auto_tune_alpha( - self, input_maxes, calib_sample_num=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min" + self, + input_maxes, + calib_sample_num=32, + alpha_min=0.3, + alpha_max=0.7, + alpha_step=0.05, + shared_criterion="min", + do_blockwise=False, ): """Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly. @@ -887,6 +1002,7 @@ def _auto_tune_alpha( # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha tune_cnt = 4 multiply_factor = calib_sample_num // tune_cnt if calib_sample_num >= tune_cnt else calib_sample_num + self.fp32_output_val = {} best_alphas = default_alpha if not self.dataloader: @@ -905,13 +1021,25 @@ def _auto_tune_alpha( best_alphas_per_module[layer_name] = best_alphas_per_module[key] loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) - if loss_alphas == {}: - loss_alphas = loss_tmp + if self.do_blockwise: + if loss_alphas == {}: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + loss_alphas[key] = loss_tmp[block_name] + else: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] else: - for key in loss_alphas.keys(): - cur_loss = loss_alphas[key] - for alpha_key in cur_loss.keys(): - cur_loss[alpha_key] += loss_tmp[key][alpha_key] + if loss_alphas == {}: + loss_alphas = loss_tmp + else: + for key in loss_alphas.keys(): + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[key][alpha_key] total_cnt += self.dataloader.batch_size tmp_cnt += self.dataloader.batch_size if tmp_cnt // multiply_factor >= 1: @@ -939,13 +1067,25 @@ def _auto_tune_alpha( best_alphas_per_module[layer_name] = best_alphas_per_module[key] loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) - if loss_alphas == {}: - loss_alphas = loss_tmp + if self.do_blockwise: + if loss_alphas == {}: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + loss_alphas[key] = loss_tmp[block_name] + else: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] else: - for key in loss_alphas.keys(): - cur_loss = loss_alphas[key] - for alpha_key in cur_loss.keys(): - cur_loss[alpha_key] += loss_tmp[key][alpha_key] + if loss_alphas == {}: + loss_alphas = loss_tmp + else: + for key in loss_alphas.keys(): + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[key][alpha_key] total_cnt += self.dataloader.batch_size tmp_cnt += self.dataloader.batch_size if tmp_cnt // multiply_factor >= 1: @@ -966,6 +1106,33 @@ def _auto_tune_alpha( best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) for key in best_alphas.keys(): logger.info(f"Final alpha {key}:{best_alphas[key]}") + max_op, max_ratio, max_key = "", 0, "" + ratio_info = {} + for key in self.absorb_to_layer: + for op_name in self.absorb_to_layer[key]: + fp32_norm, loss_ = ( + torch.sum(torch.stack(self.fp32_output_val[op_name])), + loss_alphas[op_name][str(best_alphas[key])], + ) + ratio = loss_ / fp32_norm + max_op = op_name if ratio > max_ratio else max_op + max_key = key if ratio > max_ratio else max_key + max_ratio = max(ratio, max_ratio) + ratio_info[op_name] = ratio + logger.debug( + f"final loss: {op_name}: {loss_}; @alpha {best_alphas[key]}; \ + fp32_output norm: {fp32_norm}; ratio: {ratio}" + ) + import operator + + ratio_info = dict(sorted(ratio_info.items(), key=operator.itemgetter(1), reverse=True)) + for key in list(ratio_info.keys()): + logger.debug(f"sorted opname-ratio: {key}: {ratio_info[key]}") + if max_op != "": + logger.debug( + f"max loss: {max_op}: {loss_alphas[max_op][str(best_alphas[max_key])]} @alpha {best_alphas[max_key]}\ + fp32_output norm: {torch.sum(torch.stack(self.fp32_output_val[max_op]))}; ratio: {max_ratio}" + ) self._qdq_model_unwrapper_for_auto() logger.info("auto tuning done") return best_alphas @@ -978,7 +1145,13 @@ def transform( op_types=[torch.nn.Linear, torch.nn.Conv2d], scales_per_op=False, calib_iter=100, - auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"}, + auto_alpha_args={ + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "do_blockwise": False, + }, weight_clip=True, default_alpha=0.5, ): @@ -994,10 +1167,18 @@ def transform( :param auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. By default the search space is 0.0-1.0 with step_size 0.1. + do_blockwise: Whether to do blockwise auto-tuning. :param default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. :return: A FP32 model with the same architecture as the orig model but with different weight which will be benefit to quantization. """ + if isinstance(auto_alpha_args, dict): + self.do_blockwise = auto_alpha_args.get("do_blockwise", False) + else: + self.do_blockwise = False + if self.do_blockwise: + self.block_names = self.get_blocks() + logger.info("Blockwise auto-tuning will be performed") if not isinstance(self.model, torch.nn.Module): logger.warning("smooth quant is ignored since the model is not a torch module") return self.model @@ -1055,6 +1236,23 @@ def transform( ) return self.model + if self.do_blockwise: + module_names = self._get_sq_layer_names() + block_names, self.block_to_module = self.block_names, {} + for block in block_names: + self.block_to_module[block] = [] + for module in module_names: + checked = False + for block in block_names: + if block + "." in module: + self.block_to_module[block].append(module) + checked = True + if not checked: + self.block_to_module[module] = [module] + self.block_names = list(self.block_to_module.keys()) + logger.info(f"Blockwise auto-tuning: {len(self.block_names)} blocks found") + logger.debug(f"Blockwise auto-tuning blocks info: {self.block_to_module}") + input_maxes_abs = self._calibrate(self.absorb_to_layer, calib_iter, percentile) # Check if input_maxes match self.absorb_to_layer @@ -1181,7 +1379,10 @@ def _trace(self, op_types, skip_unsupported_layers=True): tg = GraphTrace() self._get_example_input() absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( - self.traced_model, self.example_inputs, op_types, skip_unsupported_layers=skip_unsupported_layers + self.traced_model, + self.example_inputs, + op_types, + skip_unsupported_layers=skip_unsupported_layers, ) if not skip_unsupported_layers: return absorb_to_layer diff --git a/neural_compressor/adaptor/torch_utils/util.py b/neural_compressor/adaptor/torch_utils/util.py index d9e10679998..d758b349c7f 100644 --- a/neural_compressor/adaptor/torch_utils/util.py +++ b/neural_compressor/adaptor/torch_utils/util.py @@ -324,6 +324,8 @@ def check_cfg_and_qconfig( # to int8 ipex_op_cfg = op_infos_from_cfgs[name] input_tensor_infos = ipex_op_cfg["input_tensor_infos"] + if op_name[1] == "Linear" or op_name[1] == "Linear&add": # record op_name for possible op-wise fallback + logger.debug(f"ipex_op_cfg['fqn'] - op_name {ipex_op_cfg['fqn']} {op_name}") for index, input_tensor_info in enumerate(input_tensor_infos): if "force_dtype" not in input_tensor_info.keys(): continue diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index 96c1dd34085..95691a38142 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -952,8 +952,15 @@ def set_param_for_pre_tuning_algos(self, algo_scheduler, config, fp32_model) -> "weight_clip", True ) # make weight_clipping a default_on option. sq_algo.auto_alpha_args = smooth_quant_args.get( - "auto_alpha_args", {"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"} - ) # default alpha search space parameters. + "auto_alpha_args", + { + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "do_blockwise": False, + }, + ) # default alpha search space parameters. By default, do_blockwise is set to False. sq_algo.default_alpha = smooth_quant_args.get( "default_alpha", 0.5 ) # default value for alpha in auto-tuning diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index f97fee392a9..08003161662 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -1548,5 +1548,32 @@ def forward(self, x): assert sq.auto_alpha_args["alpha_min"] == 0.5 +class TestAlphaAutoLinearBlockwise(unittest.TestCase): + @classmethod + def test_sq_linear_Blockwise_auto(self): + model = transformers.AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torchscript=True, + ) + sq = TorchSmoothQuant(model, LLMCalibDataloader()) + sq.transform( + alpha="auto", + calib_iter=1, + folding=False, + auto_alpha_args={ + "alpha_min": 0.45, + "alpha_max": 0.55, + "alpha_step": 0.01, + "shared_criterion": "mean", + "do_blockwise": True, + }, + ) + for i in range(12): + op_name1 = "model.decoder.layers." + str(i) + ".self_attn.out_proj" + op_name2 = "model.decoder.layers." + str(i) + ".fc1" + assert sq.alpha_per_layer[op_name1] == sq.alpha_per_layer[op_name2] + assert len(sq.block_names) == 13 + + if __name__ == "__main__": unittest.main()