diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 2ca2ef6b9bf..068e53a6ddd 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -1388,19 +1388,21 @@ def qdq_quantize(self, model, tune_cfg): assert not q_model._smoothquant_optimized, \ "The model is already optimized by smoothquant, cannot apply new alpha." alpha = tune_cfg['recipe_cfgs']['smooth_quant_args']['alpha'] - for op_name, info in sq_max_info.items(): + for _, info in sq_max_info.items(): if alpha == 'auto': alpha = info['alpha'] + absorbed_layer = info['absorbed_layer'] input_minmax = info['input_minmax'] weight_max = info['weight_max'] abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) input_power = torch.pow(abs_input_max, alpha) weight_power = torch.pow(weight_max, 1 - alpha) scale = torch.clip(input_power / weight_power, min=1e-5) - module = fetch_module(q_model, op_name) - new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha) - set_module(q_model, op_name, new_module) - logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}") + for op_name in absorbed_layer: + module = fetch_module(q_model, op_name) + new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha) + set_module(q_model, op_name, new_module) + logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}") smoothquant_op_info = {'sq_linear': {}, 'qdq_linear': []} stats_result['SQLinearWrapper'] = {'INT8(QDQ)': 0, 'BF16': 0, 'FP32': 0} @@ -3117,27 +3119,29 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func): from .torch_utils.model_wrapper import SQLinearWrapper from .torch_utils.util import fetch_module alpha = tune_cfg['recipe_cfgs']['smooth_quant_args']['alpha'] - for op_name, info in sq_max_info.items(): + for _, info in sq_max_info.items(): if alpha == 'auto': alpha = info['alpha'] + absorbed_layer = info['absorbed_layer'] input_minmax = info['input_minmax'] weight_max = info['weight_max'] abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) input_power = torch.pow(abs_input_max, alpha) weight_power = torch.pow(weight_max, 1 - alpha) scale = torch.clip(input_power / weight_power, min=1e-5) - module = fetch_module(q_model._model, op_name) - new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha) - weight_scale = new_module._get_weight_scale() - smoothquant_scale_info[op_name] = { - 'alpha': new_module.alpha, - 'input_scale_for_mul': new_module.input_scale, - 'input_scale_after_mul': new_module.scale, - 'input_zero_point_after_mul': new_module.zero_point, - 'input_dtype': new_module.dtype, - 'weight_scale_after_mul': weight_scale, - } - logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}") + for op_name in absorbed_layer: + module = fetch_module(q_model._model, op_name) + new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha) + weight_scale = new_module._get_weight_scale() + smoothquant_scale_info[op_name] = { + 'alpha': new_module.alpha, + 'input_scale_for_mul': new_module.input_scale, + 'input_scale_after_mul': new_module.scale, + 'input_zero_point_after_mul': new_module.zero_point, + 'input_dtype': new_module.dtype, + 'weight_scale_after_mul': weight_scale, + } + logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}") # Check save_qconf_summary part is a workaroud for IPEX bug. # Sometimes the prepared model from get_op_capablitiy loss this attribute diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 2593ad7f9a8..5f7140fd9a3 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -706,6 +706,15 @@ def transform(self, alpha=0.5, folding=False, percentile=99.999, op_types=['Line if need_calibration: ##avoid multiple calibaration during tuning if the only difference is alpha if self.insert_mul: self.self_absorb_layers = self._get_all_layer_names() # TODO: only support linear now. + # fetch modules with the same input + group_modules = self._trace(op_types, skip_unsupported_layers=False) + for k, v in group_modules.items(): + # use one input for qkv + for i in v: + if i in self.self_absorb_layers: + self.self_absorb_layers.pop(i) + self.self_absorb_layers[v[0]] = v + logger.debug(f"self_absorb_layers:{self.self_absorb_layers}") if self.allow_absorb: self.absorb_to_layer, no_absorb_layers = self._trace( op_types) ##TODO we need to insert mul layer for no_absorb_layers later @@ -836,7 +845,7 @@ def _get_example_input(self): return self.example_inputs - def _trace(self, op_types): + def _trace(self, op_types, skip_unsupported_layers=True): """ Try the model to find the layers which can be smooth quantized. :param op_types: The op types to be smooth quantized @@ -846,7 +855,12 @@ def _trace(self, op_types): """ tg = GraphTrace() self._get_example_input() - absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(self.traced_model, self.example_inputs, op_types) + absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( + self.traced_model, self.example_inputs, op_types, + skip_unsupported_layers=skip_unsupported_layers + ) + if not skip_unsupported_layers: + return absorb_to_layer if absorb_to_layer == None and no_absorb_layers == None: logger.warning("sorry, could not trace the model, smooth quant is skipped") logger.warning("if you are using huggingface model," @@ -994,7 +1008,7 @@ def mapping_torch_module_to_aten(self, op_types): res = list(set(res)) return res - def get_absorb_to_layer(self, model, example_input, op_types): + def get_absorb_to_layer(self, model, example_input, op_types, skip_unsupported_layers=True): traced_model = self.trace(model, example_input) if traced_model == None: return None, None @@ -1019,7 +1033,8 @@ def get_absorb_to_layer(self, model, example_input, op_types): absorb_to_layer[absorb_name].append(layer_name) else: absorb_to_layer[absorb_name] = [layer_name] - absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) + if skip_unsupported_layers: + absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) return absorb_to_layer, no_absorb_layers def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers): diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index 294d36a1edd..e5181307540 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -12,6 +12,7 @@ from neural_compressor.data import Datasets, DATALOADERS from neural_compressor.data.dataloaders.pytorch_dataloader import PyTorchDataLoader from neural_compressor.adaptor.torch_utils.smooth_quant import TorchSmoothQuant +from neural_compressor.adaptor.torch_utils.model_wrapper import SQLinearWrapper import logging logger = logging.getLogger("neural_compressor") @@ -22,6 +23,31 @@ TEST_IPEX = False +class DemoModel(torch.nn.Module): + def __init__(self): + super(DemoModel, self).__init__() + self.fc1 = torch.nn.Linear(3, 4) + self.fc2 = torch.nn.Linear(4, 3) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + return out + +class DemoCalibDataloader: + def __init__(self): + self.batch_size = 1 + def __iter__(self): + yield torch.randn([1, 3]) + + +class LLMCalibDataloader: + def __init__(self): + self.batch_size = 1 + def __iter__(self): + yield torch.ones([1, 3], dtype=torch.long) + + class TestSqDepthwiseConv(unittest.TestCase): @classmethod def setUpClass(self): @@ -579,7 +605,6 @@ def forward(self, x): sq = TorchSmoothQuant(model, self.linear_dl) sq.transform(alpha=0.5, calib_iter=1) # By default, folding=False - from neural_compressor.adaptor.torch_utils.model_wrapper import SQLinearWrapper assert isinstance(sq.model.fc1, SQLinearWrapper) def test_sq_quant(self): @@ -617,7 +642,6 @@ def calib_func(model): calib_dataloader=CalibDataloader(), eval_func=lambda x: 0.1, ) - from neural_compressor.adaptor.torch_utils.model_wrapper import SQLinearWrapper assert isinstance(q_model.model.fc1, SQLinearWrapper) q_model.save('saved_result') @@ -642,6 +666,7 @@ def calib_func(model): # with calib_func conf = PostTrainingQuantConfig( + example_inputs=input_ids, recipes={"smooth_quant": True, "smooth_quant_args": {'alpha': 'auto', 'folding': False}} ) @@ -748,7 +773,17 @@ def forward(self, x): sq = TorchSmoothQuant(model, self.linear_dl) sq.transform(alpha='auto', calib_iter=1, folding=True) #the layernorm could not used for sq-absorb because it outputs to an add op. - assert len(sq.absorb_to_layer) == 0 + assert len(sq.absorb_to_layer) == 0 + + def test_sq_no_skip_op_auto(self): + model = transformers.AutoModelForCausalLM.from_pretrained( + 'facebook/opt-125m', torchscript=True, + ) + sq = TorchSmoothQuant(model, LLMCalibDataloader()) + sq.transform(alpha='auto', calib_iter=0, folding=False) + # folding=False will absorb all Linears with mul, kqv will use same input. + assert len(sq.absorb_to_layer['model.decoder.layers.2.self_attn.q_proj']) == 3 + class TestSqSkipOp_attn(unittest.TestCase): @classmethod @@ -801,30 +836,6 @@ def forward(self, x): assert len(sq.absorb_to_layer) == 0 -class DemoModel(torch.nn.Module): - def __init__(self): - super(DemoModel, self).__init__() - self.fc1 = torch.nn.Linear(3, 4) - self.fc2 = torch.nn.Linear(4, 3) - - def forward(self, x): - out = self.fc1(x) - out = self.fc2(out) - return out - -class DemoCalibDataloader: - def __init__(self): - self.batch_size = 1 - def __iter__(self): - yield torch.randn([1, 3]) - - -class LLMCalibDataloader: - def __init__(self): - self.batch_size = 1 - def __iter__(self): - yield torch.ones([1, 3], dtype=torch.long) - class TestTuneSqAlpha(unittest.TestCase): @classmethod def setUpClass(self):