diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index dce79df7d69..77b48eb4698 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -847,6 +847,12 @@ def __init__(self, framework_specific_info): self.fp32_results = [] self.fp32_preds_as_label = False + if self.version.release >= Version("1.8").release: + static_quant_mapping = tq.quantization_mappings.get_default_static_quant_module_mappings() + self.fused_op_list = \ + [static_quant_mapping[key] for key in static_quant_mapping if "intrinsic." in str(key)] + self.fused_dict = {} + def calib_func(self, model, dataloader, tmp_iterations, conf=None): try: for idx, (input, label) in enumerate(dataloader): @@ -1229,926 +1235,1113 @@ def _combine_capability(self, bf16_ops, q_capability): q_capability['optypewise'][bf16_op[1]] = [bf16_config, fp32_config] return q_capability - def is_fused_module(self, module): - """This is a helper function for `_propagate_qconfig_helper` to detecte - if this module is fused. + def get_fused_list(self, model): + """This is a helper function to get fused op list. Args: - module (object): input module + model (object): input model Returns: - (bool): is fused or not + dict of op list """ - op_type = str(type(module)) - if 'fused' in op_type: - return True - else: - return False + fused_dict = {} + for op_name, child in model.named_modules(): + if type(child) in self.fused_op_list: + in_fused_loop = False + is_fused_module = False + type_name = str(child).split("(")[0] + prefix_index = op_name.rfind(".") + fp32_int8_ops = [] + for fp32_op_name, module in self.pre_optimized_model.model.named_modules(): + fp32_type_name = str(module).split("(")[0] + prefix_fp32_index = fp32_op_name.rfind(".") + if not is_fused_module: + is_fused_module = self.is_fused_module(module) + if is_fused_module: + in_fused_loop = True + continue + if is_fused_module and in_fused_loop: + if op_name == fp32_op_name[: fp32_op_name.rfind(".")]: + fp32_int8_ops.append(fp32_op_name) + continue + else: + is_fused_module =False + in_fused_loop = False + elif op_name == fp32_op_name and not in_fused_loop: + in_fused_loop = True + fp32_int8_ops.append(fp32_op_name) + elif in_fused_loop and \ + op_name[: prefix_index if prefix_index > -1 else 0] == \ + fp32_op_name[: prefix_fp32_index if prefix_fp32_index > -1 else 0]: + if "BatchNorm" in str(type(module)): + fp32_int8_ops.append(fp32_op_name) + continue + elif fp32_type_name in type_name.split(".")[-1][-len(fp32_type_name) - 2:]: + fp32_int8_ops.append(fp32_op_name) + in_fused_loop = False + break + else: + in_fused_loop = False + break + elif in_fused_loop: + in_fused_loop = False + break + if len(fp32_int8_ops) > 1: + fused_dict.update({op_name: fp32_int8_ops}) + return fused_dict - def calculate_hessian_trace(self, - fp32_model, - dataloader, - q_model, - criterion, - enable_act=False - ): - """Calculate hessian trace. + def diagnosis_helper(self, fp32_model, int8_model, tune_cfg=None, save_path=None): + """This is a helper function to diagnosis. Args: - fp32_model: The original fp32 model. - criterion: The loss function for calculate the hessian trace. # loss = criterion(output, target) - dataloader: The dataloader for calculate the gradient. - q_model: The INT8 AMAP model. - enable_act: Enabling quantization error or not. + fp32_model (object): Fp32 model (original) + int8_model (object): Quantized model + tune_cfg (dict): Quantization config + save_path (Path): The path to save min/max value of op outputs - Return: - hessian_trace(Dict[Tuple, float]), key: (op_name, op_type); value: hessian trace. + Returns: + Op name list for inspecting, tuning configuration """ - from .torch_utils.hawq_metric import hawq_top - op_to_traces = hawq_top(fp32_model=fp32_model, - dataloader=dataloader, - q_model=q_model, - criterion=criterion, - enable_act=enable_act) - return op_to_traces - - def smooth_quant(self, model, dataloader, calib_iter, alpha=0.5, folding=False, - percentile=None, op_types=None, scales_per_op=None, force_re_smooth=False, - record_max_info=False): - """ convert the model by smooth quant. + exclude_list = ["QuantStub", "DeQuantStub", "BatchNorm2d", "Sequential"] + optype_list = torch.quantization.get_default_qconfig_propagation_list() + supported_optype = [] + for optype in optype_list: + op_type = str(optype).rstrip('\'>').split('.')[-1] + if "intrinsic." not in str(optype) and op_type not in exclude_list: + supported_optype.append(optype) + inspect_node_list = [] + for name, child in fp32_model.model.named_modules(): + op_type = type(child) + if op_type in supported_optype: + inspect_node_list.append(name) + return inspect_node_list, tune_cfg - Args: - model: origin FP32 model - dataloader: calib dataloader - calib_iter: calib iters - alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ - folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant - percentile:Percentile of calibration to remove outliers, not supported now - op_types: The op types whose input tensor will be dumped - scales_per_op: True, each op will have an individual scale, mainly for accuracy - False, ops with the same input will share a scale, mainly for performance - record_max_info: whether record the max info in model for alpha tuning. + def inspect_tensor(self, + model, + dataloader, + op_list=None, + iteration_list=None, + inspect_type='activation', + save_to_disk=False, + save_path=None, + quantization_cfg=None): + assert self.version.release >= Version("1.8").release, "Inspect_tensor only support torch 1.8 or above!" + from neural_compressor.utils.utility import dump_data_to_local + from torch import dequantize + is_quantized = model.is_quantized + op_list_ = [] + fp32_int8_map = {} + for op_name in op_list: + op_list_.append(op_name) + for key in self.fused_dict: + if op_name in self.fused_dict[key]: + op_list_.pop() + fp32_int8_map[op_name] = \ + {'activation': self.fused_dict[key][-1], 'weight': self.fused_dict[key][0]} + if not is_quantized: + op_list_.append(self.fused_dict[key][-1]) + elif key not in op_list_: + op_list_.append(key) + break - Returns: - model: A modified fp32 model, inplace=True. - """ - # Note: we should make sure smoothquant is only executed once with inplacing fp32 model. - if hasattr(model._model, '_smoothquant_optimized') and model._model._smoothquant_optimized: - logger.info("The model is already optimized by SmoothQuant algorithm, skip it.") - return model - if self.__class__.__name__ == 'PyTorch_IPEXAdaptor' and self.version.release < \ - Version("2.1").release: - if folding is None: - folding = True - logger.info( - "IPEX version >= 2.1 is required for SmoothQuant folding=False, reset folding=True.") + assert min(iteration_list) > 0, \ + "Iteration number should great zero, 1 means first iteration." + iterations = max(iteration_list) if iteration_list is not None else -1 + new_model = self._pre_eval_hook(model, op_list=op_list_, iteration_list=iteration_list) + self.evaluate(new_model, dataloader, iteration=iterations) + observer_dict = {} + ret = {} + if inspect_type == 'activation' or inspect_type == 'all': + if self.version.release >= Version("2.0.0").release: + from torch.quantization.quantize import _get_observer_dict as get_observer_dict else: - assert folding, "IPEX version >= 2.1 is required for SmoothQuant folding=False." + from torch.quantization import get_observer_dict + ret['activation'] = [] + get_observer_dict(new_model.model, observer_dict) + if iteration_list is None: + iteration_list = [1] + for i in iteration_list: + summary = OrderedDict() + for key in observer_dict: + if isinstance(observer_dict[key], torch.nn.modules.linear.Identity): + continue + op_name = key.replace(".activation_post_process", "") + value = observer_dict[key].get_tensor_value()[i] + if op_name in op_list: + if type(value) is list: + summary[op_name] = {} + for index in range(len(value)): + summary[op_name].update({ + op_name + ".output" + str(index): + dequantize(value[index]).numpy() + if value[index].is_quantized else value[index].numpy() + }) + else: + summary[op_name] = { + op_name + ".output0": + dequantize(value).numpy() if value.is_quantized else value.numpy() + } + else: + if bool(self.fused_dict): + if is_quantized: + for a in fp32_int8_map: + if op_name == a: + tensor_name = fp32_int8_map[a]['weight'] + if type(value) is list: + summary[tensor_name] = {} + for index in range(len(value)): + summary[tensor_name].update({ + tensor_name + ".output" + str(index): + dequantize(value[index]).numpy() + if value[index].is_quantized else + value[index].numpy() + }) + else: + summary[tensor_name] = { + tensor_name + ".output0": + dequantize(value).numpy() + if value.is_quantized else value.numpy() + } + else: + for a in fp32_int8_map: # pragma: no cover + if op_name == fp32_int8_map[a]['activation']: + tensor_name = fp32_int8_map[a]['weight'] + if type(value) is list: + summary[tensor_name] = {} + for index in range(len(value)): + summary[tensor_name].update({ + tensor_name + ".output" + str(index): + dequantize(value[index]).numpy() + if value[index].is_quantized else + value[index].numpy() + }) + else: + summary[tensor_name] = { + tensor_name + ".output0": + dequantize(value).numpy() + if value.is_quantized else value.numpy() + } - if not hasattr(self, 'sq') or force_re_smooth: - from .torch_utils.smooth_quant import TorchSmoothQuant - self.sq = TorchSmoothQuant(model._model, dataloader=dataloader, - example_inputs=self.example_inputs, q_func=self.q_func) - kwargs = {} ## different backends may have different default values - self.sq.record_max_info = record_max_info # whether record the max info of input and weight. - if op_types != None: - kwargs["op_types"] = op_types - if percentile != None: - kwargs['percentile'] = percentile - if scales_per_op != None: - kwargs['scales_per_op'] = scales_per_op - model._model = self.sq.transform( - alpha=alpha, - folding=folding, - calib_iter=calib_iter, - **kwargs - ) - if self.sq.record_max_info: - model.sq_max_info = self.sq.max_value_info - return model + ret['activation'].append(summary) - def _apply_pre_optimization(self, model, tune_cfg, recover=False): - """update model parameters based on tune_cfg. + if inspect_type == 'weight' or inspect_type == 'all': + ret['weight'] = {} + state_dict = new_model._model.state_dict() - Args: - model (torch.nn.Module): smoothquant optimized model. - tune_cfg (dict): optimization config. - recover (dict): recover pre-optimization change. + for key in state_dict: + if not isinstance(state_dict[key], torch.Tensor): + continue + if 'weight' not in key and 'bias' not in key: + continue - Returns: - model: pre-optimized model. - """ - q_model = model._model - sq_max_info = model.sq_max_info - if sq_max_info: - from .torch_utils.smooth_quant import TorchSmoothQuant - tsq = TorchSmoothQuant(q_model, None) - alpha = tune_cfg['recipe_cfgs']['smooth_quant_args']['alpha'] - for op_name, info in sq_max_info.items(): - if alpha == 'auto': - alpha = info['alpha'] - absorb_layer = op_name - absorbed_layer = info['absorbed_layer'] - input_minmax = info['input_minmax'] - weight_max = info['weight_max'] - abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) - input_power = torch.pow(abs_input_max, alpha) - weight_power = torch.pow(weight_max, 1 - alpha) - scale = torch.clip(input_power / weight_power, min=1e-5) - with torch.no_grad(): - if recover: - scale = 1.0 / scale - for layer in absorbed_layer: - tsq._scale_layer_weight(layer, scale) - tsq._absorb_scales(absorb_layer, 1.0/scale) - logger.debug(f"Current smoothquant scale of {op_name} is {scale}, alpha is {alpha}") + op = key[:key.rfind('.')] + op = op.replace('._packed_params', '') - def qdq_quantize(self, model, tune_cfg): - """insert quant, dequant pairs before linear to simulate quantization. + if op in op_list: + if op in ret['weight']: + ret['weight'][op].update({ + key: + dequantize(state_dict[key]).numpy() + if state_dict[key].is_quantized else state_dict[key].detach().numpy() + }) + else: + ret['weight'][op] = { + key: + dequantize(state_dict[key]).numpy() + if state_dict[key].is_quantized else state_dict[key].detach().numpy() + } + else: + if bool(self.fused_dict): + if is_quantized: + for a in fp32_int8_map: + if op == a: + tensor_name = fp32_int8_map[a]['weight'] + if tensor_name in ret['weight']: + ret['weight'][tensor_name].update({ + key: + dequantize(state_dict[key]).numpy() + if state_dict[key].is_quantized else + state_dict[key].detach().numpy() + }) + else: + ret['weight'][tensor_name] = \ + {key: dequantize(state_dict[key]).numpy() + if state_dict[key].is_quantized else + state_dict[key].detach().numpy()} + break + else: + ret['weight'] = None - Args: - model (torch.nn.Module): smoothquant optimized model. - tune_cfg (dict): quantization config. + if save_to_disk: + if not save_path: + save_path = self.workspace_path + dump_data_to_local(ret, save_path, 'inspect_result.pkl') - Returns: - model: qdq quantized model. - """ - q_model = model._model - from .torch_utils.util import fetch_module, set_module - from .torch_utils.model_wrapper import QDQLinear, SQLinearWrapper - smoothquant_scale_info = {} - fallback_op_name_list = [] - stats_result = {} - for (op_name, op_type), qconfig in tune_cfg['op'].items(): - if op_type == 'Linear' and qconfig['weight']['dtype'] != 'int8': - fallback_op_name_list.append(op_name) + return ret - sq_max_info = model.sq_max_info - if sq_max_info: - assert not q_model._smoothquant_optimized, \ - "The model is already optimized by smoothquant, cannot apply new alpha." - for _, info in sq_max_info.items(): - alpha = info['alpha'] - absorbed_layer = info['absorbed_layer'] - input_minmax = info['input_minmax'] - weight_max = info['weight_max'] - abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) - input_power = torch.pow(abs_input_max, alpha) - weight_power = torch.pow(weight_max, 1 - alpha) - scale = torch.clip(input_power / weight_power, min=1e-5) - for op_name in absorbed_layer: - module = fetch_module(q_model, op_name) - new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha) - set_module(q_model, op_name, new_module) - logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}") + def _pre_eval_hook(self, model, op_list=None, iteration_list=None): + """The function is used to do some preprocession before evaluation phase. + Here, it used to add hook for dump output tensor for quantizable ops. - smoothquant_op_info = {'sq_linear': {}, 'qdq_linear': []} - stats_result['SQLinearWrapper'] = {'INT8(QDQ)': 0, 'BF16': 0, 'FP32': 0} - for name, module in q_model.named_modules(): - if isinstance(module, SQLinearWrapper): - smoothquant_op_info['sq_linear'][name] = module.input_scale - if name not in fallback_op_name_list: - smoothquant_scale_info[name] = { - 'input_scale_for_mul': module.input_scale, - 'quant_scale': module.scale, - 'quant_zero_point': module.zero_point, - 'quant_dtype': module.dtype, - } - smoothquant_op_info['qdq_linear'].append(name+'.sq_linear') - new_module = QDQLinear(module.sq_linear, module.scale, module.zero_point, module.dtype) - set_module(q_model, name+'.sq_linear', new_module) - stats_result['SQLinearWrapper']['INT8(QDQ)'] += 1 - else: - stats_result['SQLinearWrapper']['FP32'] += 1 + Args: + model (object): input model - tune_cfg['recipe_cfgs']['smoothquant_op_info'] = smoothquant_op_info - model._model = q_model - model.q_config = copy.deepcopy(tune_cfg) - field_names=["Op Type", "Total", "INT8", "BF16", "FP32"] - output_data = [[ - op_type, sum(stats_result[op_type].values()), stats_result[op_type]['INT8(QDQ)'], - stats_result[op_type]['BF16'], stats_result[op_type]['FP32']] - for op_type in stats_result.keys()] - Statistics(output_data, - header='Mixed Precision Statistics', - field_names=field_names).print_stat() + Returns: + model (object): model with hook + """ + from abc import ABCMeta - return model + def _with_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories. + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. -unify_op_type_mapping = { - "ConvReLU2d": "Conv2d", - "ConvReLU3d": "Conv3d", - "LinearReLU": "Linear", - "ConvBn2d": "Conv2d", - "ConvBnReLU2d": "Conv2d" -} + Example:: + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + class _PartialWrapper(object): + def __init__(self, p): + self.p = p -@adaptor_registry -class PyTorchAdaptor(TemplateAdaptor): - """Adaptor of PyTorch framework, all PyTorch API is in this class. + def __call__(self, *args, **keywords): + return self.p(*args, **keywords) - Args: - framework_specific_info (dict): dictionary of tuning configure from yaml file. - """ - def __init__(self, framework_specific_info): - super(PyTorchAdaptor, self).__init__(framework_specific_info) - """ - # Map for swapping float module to quantized ones, - # and this dictionary will change with different PoTorch versions - DEFAULT_MODULE_MAPPING = { - nn.Linear: nnq.Linear, - nn.ReLU: nnq.ReLU, - nn.ReLU6: nnq.ReLU6, - nn.Conv2d: nnq.Conv2d, - nn.Conv3d: nnq.Conv3d, - QuantStub: nnq.Quantize, - DeQuantStub: nnq.DeQuantize, - # Wrapper Modules: - nnq.FloatFunctional: nnq.QFunctional, - # Intrinsic modules: - nni.ConvReLU2d: nniq.ConvReLU2d, - nni.ConvReLU3d: nniq.ConvReLU3d, - nni.LinearReLU: nniq.LinearReLU, - nniqat.ConvReLU2d: nniq.ConvReLU2d, - nniqat.LinearReLU: nniq.LinearReLU, - nniqat.ConvBn2d: nnq.Conv2d, - nniqat.ConvBnReLU2d: nniq.ConvReLU2d, - # QAT modules: - nnqat.Linear: nnq.Linear, - nnqat.Conv2d: nnq.Conv2d, - } - """ + def __repr__(self): + return self.p.__repr__() - self.tune_cfg = None - if self.device == "cpu": - query_config_file = "pytorch_cpu.yaml" - elif self.device == "gpu": - query_config_file = "pytorch_gpu.yaml" - else: # pragma: no cover - assert False, "Unsupport this device {}".format(self.device) - self.query_handler = PyTorchQuery( - local_config_file=os.path.join(os.path.dirname(__file__), query_config_file)) + with_args = _with_args - self.white_list = get_torch_white_list(self.approach) + r = _PartialWrapper(partial(cls_or_self, **kwargs)) + return r - # for tensorboard - self.dump_times = 0 - self.fused_dict = {} + ABC = ABCMeta(str("ABC"), (object, ), {}) # compatible with Python 2 *and* 3: - self.optype_statistics = None + class _RecordingObserver(ABC, torch.nn.Module): + """The module is mainly for debug and records the tensor values during runtime. - @dump_elapsed_time("Pass quantize model") - def quantize(self, tune_cfg, model, dataloader, q_func=None): - """Execute the quantize process on the specified model. + Args: + iteration_list (list, optional): indexs of iteration which to dump tensor. + """ + def __init__(self, iteration_list=None, **kwargs): + super(_RecordingObserver, self).__init__(**kwargs) + self.output_tensors_dict = OrderedDict() + self.current_iter = 1 + self.iteration_list = iteration_list - Args: - tune_cfg (dict): quantization config. - model (object): model need to do quantization. - dataloader (object): calibration dataset. - q_func (objext, optional): training function for quantization aware training mode. + def forward(self, x): + if (self.iteration_list is None and self.current_iter == 1) or \ + (self.iteration_list is not None and + self.current_iter in self.iteration_list): + if type(x) is tuple or type(x) is list: + self.output_tensors_dict[self.current_iter] = \ + [i.to("cpu") if i.device != 'cpu' else i.clone() for i in x] + else: + self.output_tensors_dict[self.current_iter] = \ + x.to("cpu") if x.device != "cpu" else x.clone() + self.current_iter += 1 + return x - Returns: - (object): quantized model - """ - assert isinstance(model._model, torch.nn.Module), \ - "The model passed in is not the instance of torch.nn.Module" - if self.performance_only: - q_model = model - else: - try: - q_model = copy.deepcopy(model) - except Exception as e: # pragma: no cover - logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format( - repr(e))) - q_model = model + @torch.jit.export + def get_tensor_value(self): + return self.output_tensors_dict - # For smoothquant optimized model - recipe_cfgs = tune_cfg.get('recipe_cfgs', None) - if recipe_cfgs and recipe_cfgs.get('smooth_quant', False) \ - and not recipe_cfgs['smooth_quant_args']['folding'] \ - and self.approach != 'post_training_dynamic_quant': - return self.qdq_quantize(q_model, tune_cfg) + with_args = classmethod(_with_args) - if recipe_cfgs and recipe_cfgs.get('smooth_quant', False) \ - and recipe_cfgs['smooth_quant_args']['folding']: - self._apply_pre_optimization(q_model, tune_cfg) + def _observer_forward_hook(module, input, output): + """Forward hook that calls observer on the output - # For tensorboard display - self.tune_cfg = tune_cfg - self.tune_cfg["approach"] = self.approach - self.tune_cfg["reduce_range"] = REDUCE_RANGE - self.tune_cfg["framework"] = "pytorch" - op_cfgs = _cfg_to_qconfig(tune_cfg, self.approach) - self.tune_cfg['bf16_ops_list'] = op_cfgs['bf16_ops_list'] - del op_cfgs['bf16_ops_list'] - gc.collect() + Args: + module (object): input module + input (object): module input + output (object): module output - if self.version.release < Version("2.0.0").release: - from torch.quantization.quantize import add_observer_ - else: - from torch.quantization.quantize import _add_observer_ as add_observer_ + Returns: + module output tensor (object) + """ + return module.activation_post_process(output) - if self.approach == 'quant_aware_training': - q_model._model.train() - else: - q_model._model.eval() - if self.version.release < Version("1.7.0").release or \ - self.approach != 'quant_aware_training': - _propagate_qconfig(q_model._model, op_cfgs, approach=self.approach) - # sanity check common API misusage - if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model._model.modules()): - logger.warn("None of the submodule got qconfig applied. Make sure you " - "passed correct configuration through `qconfig_dict` or " - "by assigning the `.qconfig` attribute directly on submodules.") + def _add_observer_(module, op_list=None, prefix=""): + """Add observer for the leaf child of the module. - if self.approach in ['post_training_static_quant', 'post_training_auto_quant']: - add_observer_(q_model._model) - if q_func is None: - iterations = tune_cfg.get('calib_iteration', 1) - self.model_calibration(q_model._model, - dataloader, - iterations, - calib_sampling_size=tune_cfg.get('calib_sampling_size', 1)) - else: - q_func(q_model._model) - elif self.approach == 'quant_aware_training': - if self.version.release >= Version("1.7.0").release: - _propagate_qconfig(q_model._model, op_cfgs, is_qat_convert=True) - torch.quantization.convert(q_model._model, - mapping=self.q_mapping, - inplace=True, - remove_qconfig=False) - _propagate_qconfig(q_model._model, op_cfgs) - add_observer_(q_model._model, self.white_list, - set(self.q_mapping.values())) - else: # pragma: no cover - add_observer_(q_model._model) - torch.quantization.convert(q_model._model, self.q_mapping, inplace=True) - # q_func can be created by neural_compressor internal or passed by user. It's critical to - # distinguish how q_func is passed since neural_compressor built-in functions accept neural_compressor - # model and user defined func should accept framework model. - q_model._model = q_func( - q_model if getattr(q_func, 'builtin', None) else q_model._model) - assert q_model._model is not None, "Please return a trained model in train function!" - q_model._model.eval() + This function insert observer module to all leaf child module that + has a valid qconfig attribute. - if self.approach == 'quant_aware_training': - torch.quantization.convert(q_model._model, inplace=True) - else: - torch.quantization.convert(q_model._model, mapping=self.q_mapping, inplace=True) + Args: + module (object): input module with qconfig attributes for all the leaf modules that + we want to dump tensor + op_list (list, optional): list of ops which to be dumped in module + prefix (string): name of module - if len(self.tune_cfg['bf16_ops_list']) > 0 and \ - (self.version.release >= Version("1.11.0").release) and \ - (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover - q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg) + Returns: + None, module is modified inplace with added observer modules and forward_hooks + """ + for name, child in module.named_children(): + op_name = name if prefix == "" else prefix + "." + name + if isinstance(child, torch.nn.quantized.FloatFunctional) and \ + (op_list is None or op_name in op_list): + if hasattr(child, 'qconfig') and child.qconfig is not None and ( + op_list is None or op_name in op_list): + child.activation_post_process = \ + child.qconfig.activation() + elif hasattr(child, 'qconfig') and child.qconfig is not None and \ + (op_list is None or op_name in op_list): + # observer and hook will be gone after we swap the module + child.add_module('activation_post_process', child.qconfig.activation()) + child.register_forward_hook(_observer_forward_hook) + else: + _add_observer_(child, op_list, op_name) - q_model.q_config = copy.deepcopy(self.tune_cfg) - if self.approach != 'post_training_dynamic_quant': - self._get_scale_zeropoint(q_model._model, q_model.q_config) - q_model.is_quantized = True + def _propagate_qconfig_helper(module, + qconfig_dict, + white_list=None, + qconfig_parent=None, + prefix='', + fused=False): + """This is a helper function for `propagate_qconfig_` - self._dump_model_op_stats(q_model._model, q_model.q_config) - torch_utils.util.get_embedding_contiguous(q_model._model) - return q_model + Args: + module (object): input module + qconfig_dict (dictionary): dictionary that maps from name of submodule to + quantization configuration + white_list (list, optional): list of quantizable modules + qconfig_parent (object, optional): config of parent module, we will fallback to + this config when there is no specified config + for current module + prefix (string, optional): corresponding prefix of the current module, + used as key in qconfig_dict + fused (bool, optional): Indicates whether the module is fused or not - def evaluate(self, - model, - dataloader, - postprocess=None, - metrics=None, - measurer=None, - iteration=-1, - tensorboard=False, - fp32_baseline=False): - """Execute the evaluate process on the specified model. + Return: + None, module is modified inplace with qconfig attached + """ + module.qconfig = qconfig_parent + if hasattr(module, '_modules'): + for name, child in module.named_children(): + module_prefix = prefix + '.' + name if prefix else name + _propagate_qconfig_helper(child, qconfig_dict, white_list, qconfig_parent, + module_prefix) - Args: - model (object): model to run evaluation. - dataloader (object): evaluation dataset. - postprocess (object, optional): process function after evaluation. - metrics (list, optional): list of metric function. - measurer (object, optional): measurer function. - iteration (int, optional): number of iterations to evaluate. - tensorboard (bool, optional): dump output tensor to tensorboard summary files. - fp32_baseline (boolen, optional): only for compare_label=False pipeline + def _prepare(model, inplace=True, op_list=[], white_list=None): + """The model will be attached with observer or fake quant modules, and qconfig + will be propagated. - Returns: - (object): accuracy - """ - self.is_baseline = fp32_baseline - if tensorboard: - model = self._pre_eval_hook(model) + Args: + model (object): input model to be modified in-place + inplace (bool, optional): carry out model transformations in-place, + the original module is mutated + op_list (list, optional): list of ops which to be dumped in module + white_list (list, optional): list of quantizable modules - model_ = model._model - assert isinstance( - model_, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module" - model_.eval() - if self.device == "cpu": - model_.to("cpu") - elif self.device == "gpu": - if self.is_baseline: - model_.to("dpcpp") + Returns: + model (object): model with qconfig + """ + if not inplace: + model = copy.deepcopy(model) + _propagate_qconfig_helper(model, + qconfig_dict={}, + white_list=white_list, + qconfig_parent=model.qconfig) + # sanity check common API misusage + if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): # pragma: no cover + logger.warn("None of the submodule got qconfig applied. Make sure you " + "passed correct configuration through `qconfig_dict` or " + "by assigning the `.qconfig` attribute directly on submodules") + _add_observer_(model, op_list=op_list) + return model - if metrics: - self.fp32_preds_as_label = any([hasattr(metric, "compare_label") and \ - not metric.compare_label for metric in metrics]) - acc = self.model_eval(model_, dataloader, postprocess, metrics, measurer, iteration) + model = model if model.is_quantized else copy.deepcopy(model) + model._model.qconfig = torch.quantization.QConfig( + weight=torch.quantization.default_debug_observer, + activation=_RecordingObserver.with_args(iteration_list=iteration_list)) + _prepare(model._model, op_list=op_list) - if tensorboard: - self._post_eval_hook(model, accuracy=acc) - return acc if not isinstance(acc, list) or len(acc) > 1 else acc[0] + return model - def _pre_hook_for_qat(self, dataloader=None): - # self.model._model is needed here. - self.model._model.qconfig = torch.quantization.QConfig( - activation=torch.quantization.FakeQuantize.with_args(dtype=torch.quint8, - qscheme=torch.per_tensor_affine, - reduce_range=REDUCE_RANGE), - weight=torch.quantization.default_weight_fake_quant) - self.non_quant_dict = self.get_non_quant_modules(self.model.kwargs) - quantizable_ops = [] - self._get_quantizable_ops_recursively(self.model._model, '', quantizable_ops) - bf16_ops = [] - if self.version.release >= Version("1.11.0").release and self.use_bf16 and \ - (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover - self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16") - self._get_bf16_ops_recursively(self.model._model, '', bf16_ops) - bf16_ops_list = [(op) for op in bf16_ops if op not in quantizable_ops] - self.model.model.training = True - torch.quantization.prepare_qat(self.model._model, inplace=True) + def is_fused_module(self, module): + """This is a helper function for `_propagate_qconfig_helper` to detecte + if this module is fused. - # This is a flag for reloading - self.model.q_config = { - 'is_oneshot': True, - 'framework': 'pytorch', - 'reduce_range': REDUCE_RANGE, - 'approach': 'quant_aware_training', - 'bf16_ops_list': bf16_ops_list, - } + Args: + module (object): input module - def _post_hook_for_qat(self): - torch.quantization.convert(self.model._model, inplace=True) - if self.model.q_config is not None and len(self.model.q_config['bf16_ops_list']) > 0 and \ - self.version.release >= Version("1.11.0").release and self.use_bf16 and \ - (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover - self.model._model = torch_utils.bf16_convert.Convert(self.model._model, self.model.q_config) + Returns: + (bool): is fused or not + """ + op_type = str(type(module)) + if 'fused' in op_type: + return True + else: + return False - def _pre_hook_for_hvd(self, dataloader=None): - # TODO: lazy init here - hvd.init() - hvd.broadcast_parameters(self.model._model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(self.optimizer, root_rank=0) - self.optimizer = hvd.DistributedOptimizer( - self.optimizer, named_parameters=self.model._model.named_parameters()) + def calculate_hessian_trace(self, + fp32_model, + dataloader, + q_model, + criterion, + enable_act=False + ): + """Calculate hessian trace. - def train(self, model, dataloader, optimizer_tuple, criterion_tuple, hooks, **kwargs): - """Execute the train process on the specified model. + Args: + fp32_model: The original fp32 model. + criterion: The loss function for calculate the hessian trace. # loss = criterion(output, target) + dataloader: The dataloader for calculate the gradient. + q_model: The INT8 AMAP model. + enable_act: Enabling quantization error or not. + + Return: + hessian_trace(Dict[Tuple, float]), key: (op_name, op_type); value: hessian trace. + """ + from .torch_utils.hawq_metric import hawq_top + op_to_traces = hawq_top(fp32_model=fp32_model, + dataloader=dataloader, + q_model=q_model, + criterion=criterion, + enable_act=enable_act) + return op_to_traces + + def smooth_quant(self, model, dataloader, calib_iter, alpha=0.5, folding=False, + percentile=None, op_types=None, scales_per_op=None, force_re_smooth=False, + record_max_info=False): + """ convert the model by smooth quant. Args: - model (object): model to run evaluation. - dataloader (object): training dataset. - optimizer (tuple): It is a tuple of (cls, parameters) for optimizer. - criterion (tuple): It is a tuple of (cls, parameters) for criterion. - kwargs (dict, optional): other parameters. + model: origin FP32 model + dataloader: calib dataloader + calib_iter: calib iters + alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ + folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant + percentile:Percentile of calibration to remove outliers, not supported now + op_types: The op types whose input tensor will be dumped + scales_per_op: True, each op will have an individual scale, mainly for accuracy + False, ops with the same input will share a scale, mainly for performance + record_max_info: whether record the max info in model for alpha tuning. Returns: - None + model: A modified fp32 model, inplace=True. """ - model_ = model._model - device = "cuda:0" if self.device != "GPU" and torch.cuda.is_available() else self.device - # self.model is set to neural_compressor model here to hold the inplace change in FWK model. - self.model = model - optimizer = optimizer_tuple[0](model_.parameters(), **optimizer_tuple[1]) - self.optimizer = optimizer - criterion = criterion_tuple[0](**criterion_tuple[1]) - start_epochs = kwargs['kwargs']['start_epoch'] - end_epochs = kwargs['kwargs']['end_epoch'] - iters = kwargs['kwargs']['iteration'] - if hooks is not None: - on_train_begin = hooks['on_train_begin'] - on_train_end = hooks['on_train_end'] - on_epoch_begin = hooks['on_epoch_begin'] - on_epoch_end = hooks['on_epoch_end'] - on_step_begin = hooks['on_step_begin'] - on_step_end = hooks['on_step_end'] - on_after_compute_loss = hooks['on_after_compute_loss'] - on_before_optimizer_step = hooks['on_before_optimizer_step'] - if hooks is not None: - on_train_begin() - for nepoch in range(start_epochs, end_epochs): - model_.to(device) - model_.train() - cnt = 0 - if hooks is not None: - on_epoch_begin(nepoch) - if getattr(dataloader, 'distributed', False) \ - or isinstance(dataloader.sampler, \ - torch.utils.data.distributed.DistributedSampler): - dataloader.sampler.set_epoch(nepoch) - for image, target in dataloader: - # TODO: to support adjust lr with epoch - target = target.to(device) - if hooks is not None: - on_step_begin(cnt) - print('.', end='', flush=True) - cnt += 1 - output = pytorch_forward_wrapper(model_, image, device=device) - loss = criterion(output, target) - if hooks is not None: - loss = on_after_compute_loss(image, output, loss) - self.optimizer.zero_grad() - loss.backward() - if hooks is not None: - on_before_optimizer_step() - self.optimizer.step() - if hooks is not None: - on_step_end() - if cnt >= iters: - break - if hooks is not None: - on_epoch_end() + # Note: we should make sure smoothquant is only executed once with inplacing fp32 model. + if hasattr(model._model, '_smoothquant_optimized') and model._model._smoothquant_optimized: + logger.info("The model is already optimized by SmoothQuant algorithm, skip it.") + return model + if self.__class__.__name__ == 'PyTorch_IPEXAdaptor' and self.version.release < \ + Version("2.1").release: + if folding is None: + folding = True + logger.info( + "IPEX version >= 2.1 is required for SmoothQuant folding=False, reset folding=True.") + else: + assert folding, "IPEX version >= 2.1 is required for SmoothQuant folding=False." - if device != self.device: # pragma: no cover - model_.to(self.device) + if not hasattr(self, 'sq') or force_re_smooth: + from .torch_utils.smooth_quant import TorchSmoothQuant + self.sq = TorchSmoothQuant(model._model, dataloader=dataloader, + example_inputs=self.example_inputs, q_func=self.q_func) + kwargs = {} ## different backends may have different default values + self.sq.record_max_info = record_max_info # whether record the max info of input and weight. + if op_types != None: + kwargs["op_types"] = op_types + if percentile != None: + kwargs['percentile'] = percentile + if scales_per_op != None: + kwargs['scales_per_op'] = scales_per_op + model._model = self.sq.transform( + alpha=alpha, + folding=folding, + calib_iter=calib_iter, + **kwargs + ) + if self.sq.record_max_info: + model.sq_max_info = self.sq.max_value_info + return model - if hooks is not None: - on_train_end() + def _apply_pre_optimization(self, model, tune_cfg, recover=False): + """update model parameters based on tune_cfg. - return model_ + Args: + model (torch.nn.Module): smoothquant optimized model. + tune_cfg (dict): optimization config. + recover (dict): recover pre-optimization change. + + Returns: + model: pre-optimized model. + """ + q_model = model._model + sq_max_info = model.sq_max_info + if sq_max_info: + from .torch_utils.smooth_quant import TorchSmoothQuant + tsq = TorchSmoothQuant(q_model, None) + alpha = tune_cfg['recipe_cfgs']['smooth_quant_args']['alpha'] + for op_name, info in sq_max_info.items(): + if alpha == 'auto': + alpha = info['alpha'] + absorb_layer = op_name + absorbed_layer = info['absorbed_layer'] + input_minmax = info['input_minmax'] + weight_max = info['weight_max'] + abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) + input_power = torch.pow(abs_input_max, alpha) + weight_power = torch.pow(weight_max, 1 - alpha) + scale = torch.clip(input_power / weight_power, min=1e-5) + with torch.no_grad(): + if recover: + scale = 1.0 / scale + for layer in absorbed_layer: + tsq._scale_layer_weight(layer, scale) + tsq._absorb_scales(absorb_layer, 1.0/scale) + logger.debug(f"Current smoothquant scale of {op_name} is {scale}, alpha is {alpha}") + + def qdq_quantize(self, model, tune_cfg): + """insert quant, dequant pairs before linear to simulate quantization. - def _dump_model_op_stats(self, model, tune_cfg): - """This is a function to dump quantizable ops of model to user. Args: - model (object): input model - tune_cfg (dict): quantization config + model (torch.nn.Module): smoothquant optimized model. + tune_cfg (dict): quantization config. + Returns: - None + model: qdq quantized model. """ - res = {} - ignore_log = False - modules = dict(model.named_modules()) - # fetch quantizable ops supported in Neural Compressor from tune_cfg - for key in tune_cfg['op']: - op_name = key[0] - op_type = str(type(modules[op_name])).rstrip('\'>').split('.')[-1] - if op_type == 'BF16ModuleWrapper': # pragma: no cover - op_type = str(type(modules[op_name].module)).rstrip('\'>').split('.')[-1] - if op_type == 'DequantQuantWrapper': - op_type = str(type(modules[op_name].module)).rstrip('\'>').split('.')[-1] - if 'Functional' in op_type: - op_type = op_name.split('.')[-1] - if op_type not in res.keys(): - res[op_type] = {'INT8': 0, 'BF16': 0, 'FP32': 0} - value = tune_cfg['op'][key] - # Special cases: QuantStub, Embedding - if ('weight' in value and value['weight']['dtype'] == 'fp32') or \ - ('weight' not in value and value['activation']['dtype'] == 'fp32'): - res[op_type]['FP32'] += 1 - elif value['activation']['dtype'] == 'bf16': # pragma: no cover - res[op_type]['BF16'] += 1 - else: - res[op_type]['INT8'] += 1 - # fetch other quantizable ops supported in PyTorch from model - for name, child in modules.items(): - op_type = str(type(child)).rstrip('\'>').split('.')[-1] - if tune_cfg['approach'] != 'post_training_dynamic_quant': - if op_type == 'DeQuantize': - if op_type not in res.keys(): - res[op_type] = {'INT8': 0, 'BF16': 0, 'FP32': 0} - res[op_type]['INT8'] += 1 - if op_type in self.non_quant_dict['skipped_module_classes']: - ignore_log = True - if op_type not in res.keys(): - res[op_type] = {'INT8': 0, 'BF16': 0, 'FP32': 0} - res[op_type]['FP32'] += 1 - # show results to users - if ignore_log: - logger.info("Ignore LayerNorm, InstanceNorm3d and Embedding quantizable ops" \ - " due to accuracy issue in PyTorch.") + q_model = model._model + from .torch_utils.util import fetch_module, set_module + from .torch_utils.model_wrapper import QDQLinear, SQLinearWrapper + smoothquant_scale_info = {} + fallback_op_name_list = [] + stats_result = {} + for (op_name, op_type), qconfig in tune_cfg['op'].items(): + if op_type == 'Linear' and qconfig['weight']['dtype'] != 'int8': + fallback_op_name_list.append(op_name) + + sq_max_info = model.sq_max_info + if sq_max_info: + assert not q_model._smoothquant_optimized, \ + "The model is already optimized by smoothquant, cannot apply new alpha." + for _, info in sq_max_info.items(): + alpha = info['alpha'] + absorbed_layer = info['absorbed_layer'] + input_minmax = info['input_minmax'] + weight_max = info['weight_max'] + abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) + input_power = torch.pow(abs_input_max, alpha) + weight_power = torch.pow(weight_max, 1 - alpha) + scale = torch.clip(input_power / weight_power, min=1e-5) + for op_name in absorbed_layer: + module = fetch_module(q_model, op_name) + new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha) + set_module(q_model, op_name, new_module) + logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}") + + smoothquant_op_info = {'sq_linear': {}, 'qdq_linear': []} + stats_result['SQLinearWrapper'] = {'INT8(QDQ)': 0, 'BF16': 0, 'FP32': 0} + for name, module in q_model.named_modules(): + if isinstance(module, SQLinearWrapper): + smoothquant_op_info['sq_linear'][name] = module.input_scale + if name not in fallback_op_name_list: + smoothquant_scale_info[name] = { + 'input_scale_for_mul': module.input_scale, + 'quant_scale': module.scale, + 'quant_zero_point': module.zero_point, + 'quant_dtype': module.dtype, + } + smoothquant_op_info['qdq_linear'].append(name+'.sq_linear') + new_module = QDQLinear(module.sq_linear, module.scale, module.zero_point, module.dtype) + set_module(q_model, name+'.sq_linear', new_module) + stats_result['SQLinearWrapper']['INT8(QDQ)'] += 1 + else: + stats_result['SQLinearWrapper']['FP32'] += 1 + tune_cfg['recipe_cfgs']['smoothquant_op_info'] = smoothquant_op_info + model._model = q_model + model.q_config = copy.deepcopy(tune_cfg) field_names=["Op Type", "Total", "INT8", "BF16", "FP32"] output_data = [[ - op_type, sum(res[op_type].values()), - res[op_type]['INT8'], res[op_type]['BF16'], res[op_type]['FP32']] - for op_type in res.keys()] - + op_type, sum(stats_result[op_type].values()), stats_result[op_type]['INT8(QDQ)'], + stats_result[op_type]['BF16'], stats_result[op_type]['FP32']] + for op_type in stats_result.keys()] Statistics(output_data, header='Mixed Precision Statistics', field_names=field_names).print_stat() - self.optype_statistics = field_names, output_data + return model - def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): - """This is a helper function for `query_fw_capability`, - and it will get all quantizable ops from model. - Args: - model (object): input model - prefix (string): prefix of op name - quantizable_ops (list): list of quantizable ops from model include op name and type. +unify_op_type_mapping = { + "ConvReLU2d": "Conv2d", + "ConvReLU3d": "Conv3d", + "LinearReLU": "Linear", + "ConvBn2d": "Conv2d", + "ConvBnReLU2d": "Conv2d" +} - Returns: - None - """ - module_dict = dict(model.named_modules()) - for op_name, child in model.named_modules(): - if self.is_fused_module(child): - for name, _ in child.named_children(): - module_prefix = op_name + '.' + name - if module_prefix in module_dict: - module_dict.pop(module_prefix) # remove sub-modules of fused modules - if op_name in self.fused_dict: - self.fused_dict[op_name] = [self.fused_dict[op_name], module_prefix] - else: - self.fused_dict[op_name] = module_prefix - for op_name, child in module_dict.items(): - # there is accuracy issue in quantized LayerNorm op in pytorch <1.8.1, - # so remove it here - if op_name in self.non_quant_dict['skipped_module_names'] or \ - str(child.__class__.__name__) in \ - self.non_quant_dict['skipped_module_classes']: - continue - if type(child) in self.white_list and type(child) != torch.nn.Sequential and \ - type(child) != torch.quantization.stubs.DeQuantStub: - quantizable_ops.append( - (op_name, unify_op_type_mapping[str(child.__class__.__name__)] - if str(child.__class__.__name__) in unify_op_type_mapping else str( - child.__class__.__name__))) +@adaptor_registry +class PyTorchAdaptor(TemplateAdaptor): + """Adaptor of PyTorch framework, all PyTorch API is in this class. - def _get_scale_zeropoint(self, model, tune_cfg): - """get activation scale and zero_point for converted model. + Args: + framework_specific_info (dict): dictionary of tuning configure from yaml file. + """ + def __init__(self, framework_specific_info): + super(PyTorchAdaptor, self).__init__(framework_specific_info) + """ + # Map for swapping float module to quantized ones, + # and this dictionary will change with different PoTorch versions + DEFAULT_MODULE_MAPPING = { + nn.Linear: nnq.Linear, + nn.ReLU: nnq.ReLU, + nn.ReLU6: nnq.ReLU6, + nn.Conv2d: nnq.Conv2d, + nn.Conv3d: nnq.Conv3d, + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + # Wrapper Modules: + nnq.FloatFunctional: nnq.QFunctional, + # Intrinsic modules: + nni.ConvReLU2d: nniq.ConvReLU2d, + nni.ConvReLU3d: nniq.ConvReLU3d, + nni.LinearReLU: nniq.LinearReLU, + nniqat.ConvReLU2d: nniq.ConvReLU2d, + nniqat.LinearReLU: nniq.LinearReLU, + nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBnReLU2d: nniq.ConvReLU2d, + # QAT modules: + nnqat.Linear: nnq.Linear, + nnqat.Conv2d: nnq.Conv2d, + } + """ - Args: - model (dir): Int8 model converted from fp32 model. - scale and zero_point is set with calibration for each module - tune_cfg (object): This file saves scale and zero_point of \ - output activation of each quantized module. + self.tune_cfg = None + if self.device == "cpu": + query_config_file = "pytorch_cpu.yaml" + elif self.device == "gpu": + query_config_file = "pytorch_gpu.yaml" + else: # pragma: no cover + assert False, "Unsupport this device {}".format(self.device) + self.query_handler = PyTorchQuery( + local_config_file=os.path.join(os.path.dirname(__file__), query_config_file)) - Returns: - None - """ - modules = dict(model.named_modules()) - for key, value in tune_cfg['op'].items(): - if hasattr(modules[key[0]], 'scale'): - value['activation']['scale'] = float(modules[key[0]].scale) - if hasattr(modules[key[0]], 'zero_point'): - value['activation']['zero_point'] = int(modules[key[0]].zero_point) + self.white_list = get_torch_white_list(self.approach) - def _pre_eval_hook(self, model, op_list=None, iteration_list=None): - """The function is used to do some preprocession before evaluation phase. - Here, it used to add hook for dump output tensor for quantizable ops. + # for tensorboard + self.dump_times = 0 + + self.optype_statistics = None + + @dump_elapsed_time("Pass quantize model") + def quantize(self, tune_cfg, model, dataloader, q_func=None): + """Execute the quantize process on the specified model. Args: - model (object): input model + tune_cfg (dict): quantization config. + model (object): model need to do quantization. + dataloader (object): calibration dataset. + q_func (objext, optional): training function for quantization aware training mode. Returns: - model (object): model with hook + (object): quantized model """ - from abc import ABCMeta + assert isinstance(model._model, torch.nn.Module), \ + "The model passed in is not the instance of torch.nn.Module" + if self.performance_only: + q_model = model + else: + try: + q_model = copy.deepcopy(model) + except Exception as e: # pragma: no cover + logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format( + repr(e))) + q_model = model - def _with_args(cls_or_self, **kwargs): - r"""Wrapper that allows creation of class factories. + # For smoothquant optimized model + recipe_cfgs = tune_cfg.get('recipe_cfgs', None) + if recipe_cfgs and recipe_cfgs.get('smooth_quant', False) \ + and not recipe_cfgs['smooth_quant_args']['folding'] \ + and self.approach != 'post_training_dynamic_quant': + return self.qdq_quantize(q_model, tune_cfg) - This can be useful when there is a need to create classes with the same - constructor arguments, but different instances. + if recipe_cfgs and recipe_cfgs.get('smooth_quant', False) \ + and recipe_cfgs['smooth_quant_args']['folding']: + self._apply_pre_optimization(q_model, tune_cfg) - Example:: + # For tensorboard display + self.tune_cfg = tune_cfg + self.tune_cfg["approach"] = self.approach + self.tune_cfg["reduce_range"] = REDUCE_RANGE + self.tune_cfg["framework"] = "pytorch" + op_cfgs = _cfg_to_qconfig(tune_cfg, self.approach) + self.tune_cfg['bf16_ops_list'] = op_cfgs['bf16_ops_list'] + del op_cfgs['bf16_ops_list'] + gc.collect() - >>> Foo.with_args = classmethod(_with_args) - >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) - >>> foo_instance1 = foo_builder() - >>> foo_instance2 = foo_builder() - >>> id(foo_instance1) == id(foo_instance2) - False - """ - class _PartialWrapper(object): - def __init__(self, p): - self.p = p + if self.version.release < Version("2.0.0").release: + from torch.quantization.quantize import add_observer_ + else: + from torch.quantization.quantize import _add_observer_ as add_observer_ - def __call__(self, *args, **keywords): - return self.p(*args, **keywords) + if self.approach == 'quant_aware_training': + q_model._model.train() + else: + q_model._model.eval() + if self.version.release < Version("1.7.0").release or \ + self.approach != 'quant_aware_training': + _propagate_qconfig(q_model._model, op_cfgs, approach=self.approach) + # sanity check common API misusage + if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model._model.modules()): + logger.warn("None of the submodule got qconfig applied. Make sure you " + "passed correct configuration through `qconfig_dict` or " + "by assigning the `.qconfig` attribute directly on submodules.") - def __repr__(self): - return self.p.__repr__() + if self.approach in ['post_training_static_quant', 'post_training_auto_quant']: + add_observer_(q_model._model) + if q_func is None: + iterations = tune_cfg.get('calib_iteration', 1) + self.model_calibration(q_model._model, + dataloader, + iterations, + calib_sampling_size=tune_cfg.get('calib_sampling_size', 1)) + else: + q_func(q_model._model) + elif self.approach == 'quant_aware_training': + if self.version.release >= Version("1.7.0").release: + _propagate_qconfig(q_model._model, op_cfgs, is_qat_convert=True) + torch.quantization.convert(q_model._model, + mapping=self.q_mapping, + inplace=True, + remove_qconfig=False) + _propagate_qconfig(q_model._model, op_cfgs) + add_observer_(q_model._model, self.white_list, + set(self.q_mapping.values())) + else: # pragma: no cover + add_observer_(q_model._model) + torch.quantization.convert(q_model._model, self.q_mapping, inplace=True) + # q_func can be created by neural_compressor internal or passed by user. It's critical to + # distinguish how q_func is passed since neural_compressor built-in functions accept neural_compressor + # model and user defined func should accept framework model. + q_model._model = q_func( + q_model if getattr(q_func, 'builtin', None) else q_model._model) + assert q_model._model is not None, "Please return a trained model in train function!" + q_model._model.eval() - with_args = _with_args + if self.approach == 'quant_aware_training': + torch.quantization.convert(q_model._model, inplace=True) + else: + torch.quantization.convert(q_model._model, mapping=self.q_mapping, inplace=True) - r = _PartialWrapper(partial(cls_or_self, **kwargs)) - return r + if len(self.tune_cfg['bf16_ops_list']) > 0 and \ + (self.version.release >= Version("1.11.0").release) and \ + (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover + q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg) - ABC = ABCMeta(str("ABC"), (object, ), {}) # compatible with Python 2 *and* 3: + self.fused_dict = self.get_fused_list(q_model.model) + q_model.q_config = copy.deepcopy(self.tune_cfg) + if self.approach != 'post_training_dynamic_quant': + self._get_scale_zeropoint(q_model._model, q_model.q_config) + q_model.is_quantized = True - class _RecordingObserver(ABC, torch.nn.Module): - """The module is mainly for debug and records the tensor values during runtime. + self._dump_model_op_stats(q_model._model, q_model.q_config) + torch_utils.util.get_embedding_contiguous(q_model._model) + return q_model - Args: - iteration_list (list, optional): indexs of iteration which to dump tensor. - """ - def __init__(self, iteration_list=None, **kwargs): - super(_RecordingObserver, self).__init__(**kwargs) - self.output_tensors_dict = OrderedDict() - self.current_iter = 1 - self.iteration_list = iteration_list + def evaluate(self, + model, + dataloader, + postprocess=None, + metrics=None, + measurer=None, + iteration=-1, + tensorboard=False, + fp32_baseline=False): + """Execute the evaluate process on the specified model. - def forward(self, x): - if (self.iteration_list is None and self.current_iter == 1) or \ - (self.iteration_list is not None and - self.current_iter in self.iteration_list): - if type(x) is tuple or type(x) is list: - self.output_tensors_dict[self.current_iter] = \ - [i.to("cpu") if i.device != 'cpu' else i.clone() for i in x] - else: - self.output_tensors_dict[self.current_iter] = \ - x.to("cpu") if x.device != "cpu" else x.clone() - self.current_iter += 1 - return x + Args: + model (object): model to run evaluation. + dataloader (object): evaluation dataset. + postprocess (object, optional): process function after evaluation. + metrics (list, optional): list of metric function. + measurer (object, optional): measurer function. + iteration (int, optional): number of iterations to evaluate. + tensorboard (bool, optional): dump output tensor to tensorboard summary files. + fp32_baseline (boolen, optional): only for compare_label=False pipeline - @torch.jit.export - def get_tensor_value(self): - return self.output_tensors_dict + Returns: + (object): accuracy + """ + self.is_baseline = fp32_baseline + if tensorboard: + model = self._pre_eval_hook(model) - with_args = classmethod(_with_args) + model_ = model._model + assert isinstance( + model_, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module" + model_.eval() + if self.device == "cpu": + model_.to("cpu") + elif self.device == "gpu": + if self.is_baseline: + model_.to("dpcpp") - def _observer_forward_hook(module, input, output): - """Forward hook that calls observer on the output + if metrics: + self.fp32_preds_as_label = any([hasattr(metric, "compare_label") and \ + not metric.compare_label for metric in metrics]) + acc = self.model_eval(model_, dataloader, postprocess, metrics, measurer, iteration) - Args: - module (object): input module - input (object): module input - output (object): module output + if tensorboard: + self._post_eval_hook(model, accuracy=acc) + return acc if not isinstance(acc, list) or len(acc) > 1 else acc[0] - Returns: - module output tensor (object) - """ - return module.activation_post_process(output) + def _pre_hook_for_qat(self, dataloader=None): + # self.model._model is needed here. + self.model._model.qconfig = torch.quantization.QConfig( + activation=torch.quantization.FakeQuantize.with_args(dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=REDUCE_RANGE), + weight=torch.quantization.default_weight_fake_quant) + self.non_quant_dict = self.get_non_quant_modules(self.model.kwargs) + quantizable_ops = [] + self._get_quantizable_ops_recursively(self.model._model, '', quantizable_ops) + bf16_ops = [] + if self.version.release >= Version("1.11.0").release and self.use_bf16 and \ + (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover + self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16") + self._get_bf16_ops_recursively(self.model._model, '', bf16_ops) + bf16_ops_list = [(op) for op in bf16_ops if op not in quantizable_ops] + self.model.model.training = True + torch.quantization.prepare_qat(self.model._model, inplace=True) - def _add_observer_(module, op_list=None, prefix=""): - """Add observer for the leaf child of the module. + # This is a flag for reloading + self.model.q_config = { + 'is_oneshot': True, + 'framework': 'pytorch', + 'reduce_range': REDUCE_RANGE, + 'approach': 'quant_aware_training', + 'bf16_ops_list': bf16_ops_list, + } - This function insert observer module to all leaf child module that - has a valid qconfig attribute. + def _post_hook_for_qat(self): + torch.quantization.convert(self.model._model, inplace=True) + if self.model.q_config is not None and len(self.model.q_config['bf16_ops_list']) > 0 and \ + self.version.release >= Version("1.11.0").release and self.use_bf16 and \ + (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover + self.model._model = torch_utils.bf16_convert.Convert(self.model._model, self.model.q_config) - Args: - module (object): input module with qconfig attributes for all the leaf modules that - we want to dump tensor - op_list (list, optional): list of ops which to be dumped in module - prefix (string): name of module + def _pre_hook_for_hvd(self, dataloader=None): + # TODO: lazy init here + hvd.init() + hvd.broadcast_parameters(self.model._model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(self.optimizer, root_rank=0) + self.optimizer = hvd.DistributedOptimizer( + self.optimizer, named_parameters=self.model._model.named_parameters()) - Returns: - None, module is modified inplace with added observer modules and forward_hooks - """ - for name, child in module.named_children(): - op_name = name if prefix == "" else prefix + "." + name - if isinstance(child, torch.nn.quantized.FloatFunctional) and \ - (op_list is None or op_name in op_list): - if hasattr(child, 'qconfig') and child.qconfig is not None and ( - op_list is None or op_name in op_list): - child.activation_post_process = \ - child.qconfig.activation() - elif hasattr(child, 'qconfig') and child.qconfig is not None and \ - (op_list is None or op_name in op_list): - # observer and hook will be gone after we swap the module - child.add_module('activation_post_process', child.qconfig.activation()) - child.register_forward_hook(_observer_forward_hook) - else: - _add_observer_(child, op_list, op_name) + def train(self, model, dataloader, optimizer_tuple, criterion_tuple, hooks, **kwargs): + """Execute the train process on the specified model. - def _propagate_qconfig_helper(module, - qconfig_dict, - white_list=None, - qconfig_parent=None, - prefix='', - fused=False): - """This is a helper function for `propagate_qconfig_` + Args: + model (object): model to run evaluation. + dataloader (object): training dataset. + optimizer (tuple): It is a tuple of (cls, parameters) for optimizer. + criterion (tuple): It is a tuple of (cls, parameters) for criterion. + kwargs (dict, optional): other parameters. - Args: - module (object): input module - qconfig_dict (dictionary): dictionary that maps from name of submodule to - quantization configuration - white_list (list, optional): list of quantizable modules - qconfig_parent (object, optional): config of parent module, we will fallback to - this config when there is no specified config - for current module - prefix (string, optional): corresponding prefix of the current module, - used as key in qconfig_dict - fused (bool, optional): Indicates whether the module is fused or not + Returns: + None + """ + model_ = model._model + device = "cuda:0" if self.device != "GPU" and torch.cuda.is_available() else self.device + # self.model is set to neural_compressor model here to hold the inplace change in FWK model. + self.model = model + optimizer = optimizer_tuple[0](model_.parameters(), **optimizer_tuple[1]) + self.optimizer = optimizer + criterion = criterion_tuple[0](**criterion_tuple[1]) + start_epochs = kwargs['kwargs']['start_epoch'] + end_epochs = kwargs['kwargs']['end_epoch'] + iters = kwargs['kwargs']['iteration'] + if hooks is not None: + on_train_begin = hooks['on_train_begin'] + on_train_end = hooks['on_train_end'] + on_epoch_begin = hooks['on_epoch_begin'] + on_epoch_end = hooks['on_epoch_end'] + on_step_begin = hooks['on_step_begin'] + on_step_end = hooks['on_step_end'] + on_after_compute_loss = hooks['on_after_compute_loss'] + on_before_optimizer_step = hooks['on_before_optimizer_step'] + if hooks is not None: + on_train_begin() + for nepoch in range(start_epochs, end_epochs): + model_.to(device) + model_.train() + cnt = 0 + if hooks is not None: + on_epoch_begin(nepoch) + if getattr(dataloader, 'distributed', False) \ + or isinstance(dataloader.sampler, \ + torch.utils.data.distributed.DistributedSampler): + dataloader.sampler.set_epoch(nepoch) + for image, target in dataloader: + # TODO: to support adjust lr with epoch + target = target.to(device) + if hooks is not None: + on_step_begin(cnt) + print('.', end='', flush=True) + cnt += 1 + output = pytorch_forward_wrapper(model_, image, device=device) + loss = criterion(output, target) + if hooks is not None: + loss = on_after_compute_loss(image, output, loss) + self.optimizer.zero_grad() + loss.backward() + if hooks is not None: + on_before_optimizer_step() + self.optimizer.step() + if hooks is not None: + on_step_end() + if cnt >= iters: + break + if hooks is not None: + on_epoch_end() - Return: - None, module is modified inplace with qconfig attached - """ - if white_list is None: - white_list = \ - torch.quantization.default_mappings.DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST \ - if self.version.release < Version("1.7.0").release else \ - torch.quantization.quantization_mappings.get_qconfig_propagation_list() - - if type(module) in white_list and type(module) != torch.nn.Sequential: - module.qconfig = qconfig_parent - else: - module.qconfig = None - if hasattr(module, '_modules'): - for name, child in module.named_children(): - module_prefix = prefix + '.' + name if prefix else name - _propagate_qconfig_helper(child, qconfig_dict, white_list, qconfig_parent, - module_prefix) + if device != self.device: # pragma: no cover + model_.to(self.device) - def _prepare(model, inplace=True, op_list=[], white_list=None): - """The model will be attached with observer or fake quant modules, and qconfig - will be propagated. + if hooks is not None: + on_train_end() - Args: - model (object): input model to be modified in-place - inplace (bool, optional): carry out model transformations in-place, - the original module is mutated - op_list (list, optional): list of ops which to be dumped in module - white_list (list, optional): list of quantizable modules + return model_ - Returns: - model (object): model with qconfig - """ - if not inplace: - model = copy.deepcopy(model) - _propagate_qconfig_helper(model, - qconfig_dict={}, - white_list=white_list, - qconfig_parent=model.qconfig) - # sanity check common API misusage - if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): # pragma: no cover - logger.warn("None of the submodule got qconfig applied. Make sure you " - "passed correct configuration through `qconfig_dict` or " - "by assigning the `.qconfig` attribute directly on submodules") - _add_observer_(model, op_list=op_list) - return model + def _dump_model_op_stats(self, model, tune_cfg): + """This is a function to dump quantizable ops of model to user. + Args: + model (object): input model + tune_cfg (dict): quantization config + Returns: + None + """ + res = {} + ignore_log = False + modules = dict(model.named_modules()) + # fetch quantizable ops supported in Neural Compressor from tune_cfg + for key in tune_cfg['op']: + op_name = key[0] + op_type = str(type(modules[op_name])).rstrip('\'>').split('.')[-1] + if op_type == 'BF16ModuleWrapper': # pragma: no cover + op_type = str(type(modules[op_name].module)).rstrip('\'>').split('.')[-1] + if op_type == 'DequantQuantWrapper': + op_type = str(type(modules[op_name].module)).rstrip('\'>').split('.')[-1] + if 'Functional' in op_type: + op_type = op_name.split('.')[-1] + if op_type not in res.keys(): + res[op_type] = {'INT8': 0, 'BF16': 0, 'FP32': 0} + value = tune_cfg['op'][key] + # Special cases: QuantStub, Embedding + if ('weight' in value and value['weight']['dtype'] == 'fp32') or \ + ('weight' not in value and value['activation']['dtype'] == 'fp32'): + res[op_type]['FP32'] += 1 + elif value['activation']['dtype'] == 'bf16': # pragma: no cover + res[op_type]['BF16'] += 1 + else: + res[op_type]['INT8'] += 1 + # fetch other quantizable ops supported in PyTorch from model + for name, child in modules.items(): + op_type = str(type(child)).rstrip('\'>').split('.')[-1] + if tune_cfg['approach'] != 'post_training_dynamic_quant': + if op_type == 'DeQuantize': + if op_type not in res.keys(): + res[op_type] = {'INT8': 0, 'BF16': 0, 'FP32': 0} + res[op_type]['INT8'] += 1 + if op_type in self.non_quant_dict['skipped_module_classes']: + ignore_log = True + if op_type not in res.keys(): + res[op_type] = {'INT8': 0, 'BF16': 0, 'FP32': 0} + res[op_type]['FP32'] += 1 + # show results to users + if ignore_log: + logger.info("Ignore LayerNorm, InstanceNorm3d and Embedding quantizable ops" \ + " due to accuracy issue in PyTorch.") - # create properties - if self.version.release < Version("1.7.0").release: # pragma: no cover - white_list = self.white_list | \ - (set(torch.quantization.default_mappings.DEFAULT_MODULE_MAPPING.values()) | - set(torch.quantization.default_mappings.DEFAULT_QAT_MODULE_MAPPING.values()) | - set(torch.quantization.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING.values())) - elif self.version.release < Version("1.8.0").release: # pragma: no cover - white_list = torch.quantization.get_compare_output_module_list() - else: - white_list = torch.quantization.get_default_compare_output_module_list() + field_names=["Op Type", "Total", "INT8", "BF16", "FP32"] + output_data = [[ + op_type, sum(res[op_type].values()), + res[op_type]['INT8'], res[op_type]['BF16'], res[op_type]['FP32']] + for op_type in res.keys()] - model = model if model.is_quantized else copy.deepcopy(model) - model._model.qconfig = torch.quantization.QConfig( - weight=torch.quantization.default_debug_observer, - activation=_RecordingObserver.with_args(iteration_list=iteration_list)) - _prepare(model._model, op_list=op_list, white_list=white_list) + Statistics(output_data, + header='Mixed Precision Statistics', + field_names=field_names).print_stat() + self.optype_statistics = field_names, output_data - return model - def is_fused_child(self, op_name): - """This is a helper function for `_post_eval_hook` + def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): + """This is a helper function for `query_fw_capability`, + and it will get all quantizable ops from model. Args: - op_name (string): op name + model (object): input model + prefix (string): prefix of op name + quantizable_ops (list): list of quantizable ops from model include op name and type. Returns: - (bool): if this op is fused - + None """ - op = op_name[:op_name.rfind('.')] - if op in self.fused_dict and op_name[op_name.rfind('.') + 1:].isdigit(): - return True - else: - return False + module_dict = dict(model.named_modules()) + for op_name, child in model.named_modules(): + if self.is_fused_module(child): + for name, _ in child.named_children(): + module_prefix = op_name + '.' + name + if module_prefix in module_dict: + module_dict.pop(module_prefix) # remove sub-modules of fused modules + for op_name, child in module_dict.items(): + # there is accuracy issue in quantized LayerNorm op in pytorch <1.8.1, + # so remove it here + if op_name in self.non_quant_dict['skipped_module_names'] or \ + str(child.__class__.__name__) in \ + self.non_quant_dict['skipped_module_classes']: + continue + if type(child) in self.white_list and type(child) != torch.nn.Sequential and \ + type(child) != torch.quantization.stubs.DeQuantStub: + quantizable_ops.append( + (op_name, unify_op_type_mapping[str(child.__class__.__name__)] + if str(child.__class__.__name__) in unify_op_type_mapping else str( + child.__class__.__name__))) - def is_fused_op(self, op_name): - """This is a helper function for `_post_eval_hook` + def _get_scale_zeropoint(self, model, tune_cfg): + """get activation scale and zero_point for converted model. Args: - op_name (string): op name + model (dir): Int8 model converted from fp32 model. + scale and zero_point is set with calibration for each module + tune_cfg (object): This file saves scale and zero_point of \ + output activation of each quantized module. Returns: - (bool): if this op is fused - + None """ - op = op_name[:op_name.rfind('.')] - if op in self.fused_dict: - return True - else: - return False + modules = dict(model.named_modules()) + for key, value in tune_cfg['op'].items(): + if hasattr(modules[key[0]], 'scale'): + value['activation']['scale'] = float(modules[key[0]].scale) + if hasattr(modules[key[0]], 'zero_point'): + value['activation']['zero_point'] = int(modules[key[0]].zero_point) - def is_last_fused_child(self, op_name): + def is_fused_child(self, op_name): """This is a helper function for `_post_eval_hook` Args: op_name (string): op name Returns: - (bool): if this op is last fused op + (bool): if this op is fused """ - op = op_name[:op_name.rfind('.')] - if op_name in self.fused_dict[op][-1]: - return True - else: - return False + for key in self.fused_dict: + if op_name in self.fused_dict[key]: + return True + return False + def _post_eval_hook(self, model, **args): """The function is used to do some post process after complete evaluation. @@ -2200,20 +2393,17 @@ def _post_eval_hook(self, model, **args): for key in observer_dict: if isinstance(observer_dict[key], torch.nn.modules.linear.Identity): continue - op_name = key.strip(".activation_post_process") + op_name = key.replace(".activation_post_process", "") summary[op_name + ".output"] = observer_dict[key].get_tensor_value() for iter in summary[op_name + ".output"]: # Only collect last fused child output op = op_name - if self.is_fused_child(op_name) == True and \ - self.is_last_fused_child(op_name) == True: - op = op_name[:op_name.rfind('.')] + if op_name in self.fused_dict: + op = self.fused_dict[op_name][0] else: - if self.is_fused_child(op_name) == True and \ - self.is_last_fused_child(op_name) == False: - continue - else: - op = op_name + for key in self.fused_dict: + if op_name in self.fused_dict[key]: + op = op_name if summary[op_name + ".output"][iter].is_quantized: writer.add_histogram(op + "/Output/int8", @@ -2225,7 +2415,6 @@ def _post_eval_hook(self, model, **args): for key in state_dict: if not isinstance(state_dict[key], torch.Tensor): continue - op = key[:key.rfind('.')] if self.is_fused_child(op) is True: # fused child tensorboard tag will be merge @@ -2252,171 +2441,6 @@ def _post_eval_hook(self, model, **args): def save(self, model, path=None): pass - def inspect_tensor(self, - model, - dataloader, - op_list=None, - iteration_list=None, - inspect_type='activation', - save_to_disk=False): - if self.version.release >= Version("1.8.0").release: - from torch.fx import GraphModule - if type(model._model) == GraphModule: # pragma: no cover - assert False, "Inspect_tensor didn't support fx graph model now!" - from torch import dequantize - import numpy as np - is_quantized = model.is_quantized - op_list_ = [] - fp32_int8_map = {} - for op_name in op_list: - op_list_.append(op_name) - for key in self.fused_dict: - if op_name in self.fused_dict[key]: - fp32_int8_map[op_name] = \ - {'activation': self.fused_dict[key][-1], 'weight': key} - if is_quantized: - op_list_.append(key) - op_list_.remove(op_name) - else: - op_list_.append(self.fused_dict[key][-1]) - - new_model = model if is_quantized else copy.deepcopy(model) - - assert min(iteration_list) > 0, \ - "Iteration number should great zero, 1 means first iteration." - iterations = max(iteration_list) if iteration_list is not None else -1 - new_model = self._pre_eval_hook(new_model, op_list=op_list_, iteration_list=iteration_list) - self.evaluate(new_model, dataloader, iteration=iterations) - observer_dict = {} - ret = {} - if inspect_type == 'activation' or inspect_type == 'all': - if self.version.release >= Version("2.0.0").release: - from torch.quantization.quantize import _get_observer_dict as get_observer_dict - else: - from torch.quantization import get_observer_dict - ret['activation'] = [] - get_observer_dict(new_model._model, observer_dict) - if iteration_list is None: - iteration_list = [1] - for i in iteration_list: - summary = OrderedDict() - for key in observer_dict: - if isinstance(observer_dict[key], torch.nn.modules.linear.Identity): - continue - op_name = key.replace(".activation_post_process", "") - value = observer_dict[key].get_tensor_value()[i] - if op_name in op_list: - if type(value) is list: - summary[op_name] = {} - for index in range(len(value)): - summary[op_name].update({ - op_name + ".output" + str(index): - dequantize(value[index]).numpy() - if value[index].is_quantized else value[index].numpy() - }) - else: - summary[op_name] = { - op_name + ".output0": - dequantize(value).numpy() if value.is_quantized else value.numpy() - } - else: - if bool(self.fused_dict): - if is_quantized: - for a in fp32_int8_map: - if op_name == fp32_int8_map[a]['weight']: - if type(value) is list: - summary[a] = {} - for index in range(len(value)): - summary[a].update({ - op_name + ".output" + str(index): - dequantize(value[index]).numpy() - if value[index].is_quantized else - value[index].numpy() - }) - else: - summary[a] = { - op_name + ".output0": - dequantize(value).numpy() - if value.is_quantized else value.numpy() - } - else: - for a in fp32_int8_map: # pragma: no cover - if op_name == fp32_int8_map[a]['activation']: - if type(value) is list: - summary[a] = {} - for index in range(len(value)): - summary[a].update({ - op_name + ".output" + str(index): - dequantize(value[index]).numpy() - if value[index].is_quantized else - value[index].numpy() - }) - else: - summary[a] = { - op_name + ".output0": - dequantize(value).numpy() - if value.is_quantized else value.numpy() - } - - if save_to_disk: - dump_dir = os.path.join(self.workspace_path, 'dump_tensor') - os.makedirs(dump_dir, exist_ok=True) - np.savez(os.path.join(dump_dir, 'activation_iter{}.npz'.format(i)), **summary) - - ret['activation'].append(summary) - - if inspect_type == 'weight' or inspect_type == 'all': - ret['weight'] = {} - state_dict = new_model._model.state_dict() - - for key in state_dict: - if not isinstance(state_dict[key], torch.Tensor): - continue - if 'weight' not in key and 'bias' not in key: - continue - - op = key[:key.rfind('.')] - op = op.replace('._packed_params', '') - - if op in op_list: - if op in ret['weight']: - ret['weight'][op].update({ - key: - dequantize(state_dict[key]).numpy() - if state_dict[key].is_quantized else state_dict[key].detach().numpy() - }) - else: - ret['weight'][op] = { - key: - dequantize(state_dict[key]).numpy() - if state_dict[key].is_quantized else state_dict[key].detach().numpy() - } - else: - if bool(self.fused_dict): - if is_quantized: - for a in fp32_int8_map: - if op == fp32_int8_map[a]['weight']: - if a in ret['weight']: - ret['weight'][a].update({ - key: - dequantize(state_dict[key]).numpy() - if state_dict[key].is_quantized else - state_dict[key].detach().numpy() - }) - else: - ret['weight'][a] = \ - {key: dequantize(state_dict[key]).numpy() - if state_dict[key].is_quantized else - state_dict[key].detach().numpy()} - break - - if save_to_disk: - np.savez(os.path.join(dump_dir, 'weight.npz'), **ret['weight']) - else: - ret['weight'] = None - - return ret - def set_tensor(self, model, tensor_dict): state_dict = model._model.state_dict() tensor_name = None @@ -2427,7 +2451,12 @@ def set_tensor(self, model, tensor_dict): weight_bias = key[end + 1:] for op in self.fused_dict: if op_name in self.fused_dict[op]: - state_op_name = op + if model.is_quantized: + state_op_name = op + else: + state_op_name = self.fused_dict[op][0] + # elif op_name in self.fused_dict[op]: + # state_op_name = op if state_op_name is None: state_op_name = op_name for state_dict_key in state_dict.keys(): @@ -3469,6 +3498,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg) + self.fused_dict = self.get_fused_list(q_model.model) + q_model.is_quantized = True q_model.q_config = copy.deepcopy(self.tune_cfg) if self.approach != 'post_training_dynamic_quant': self._get_scale_zeropoint(q_model._model, q_model.q_config) @@ -4602,7 +4633,6 @@ def _dump_model_op_stats(self, model, tune_cfg): field_names=field_names).print_stat() self.optype_statistics = field_names, output_data - def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): """This is a helper function for `query_fw_capability`, and it will get all quantizable ops from model. diff --git a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py index 71e411d44cf..2ef90aec963 100644 --- a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py +++ b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py @@ -2,6 +2,7 @@ import neural_compressor.adaptor.pytorch as nc_torch import numpy as np import os +import pickle import shutil import torch import torch.nn as nn @@ -708,16 +709,17 @@ def test_tensor_dump_and_set(self): quantizer.strategy.adaptor.inspect_tensor( model, dataloader, op_list=['conv1.0', 'layer1.0.conv1.0'], iteration_list=[1, 2], inspect_type='all', save_to_disk=True) - load_array = lambda *a, **k: np.load(*a, allow_pickle=True, **k) - a = load_array('saved/dump_tensor/activation_iter1.npz') - w = load_array('saved/dump_tensor/weight.npz') + with open('saved/inspect_result.pkl', 'rb') as fp: + tensor_dict = pickle.load(fp) + a = tensor_dict["activation"][0] + w = tensor_dict["weight"] if PT_VERSION >= Version("1.8.0").release: - self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] == - a['conv1.0'].item()['conv1.0.output0'].shape[1]) + self.assertTrue(w['conv1.0']['conv1.0.weight'].shape[0] == + a['conv1.0']['conv1.0.output0'].shape[1]) else: - self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] == - a['conv1.0'].item()['conv1.1.output0'].shape[1]) - data = np.random.random(w['conv1.0'].item()['conv1.0.weight'].shape).astype(np.float32) + self.assertTrue(w['conv1.0']['conv1.0.weight'].shape[0] == + a['conv1.0']['conv1.1.output0'].shape[1]) + data = np.random.random(w['conv1.0']['conv1.0.weight'].shape).astype(np.float32) quantizer.strategy.adaptor.set_tensor(q_model, {'conv1.0.weight': data}) changed_tensor = q_model.get_weight('conv1.weight') scales = changed_tensor.q_per_channel_scales() @@ -1114,5 +1116,37 @@ def test_symbolic_trace(self): traced_model_qat = symbolic_trace(model_origin, is_qat=True) self.assertTrue(isinstance(traced_model_qat.sub, torch.fx.graph_module.GraphModule)) + def test_tensor_dump(self): + model = resnet18() + model = MODELS['pytorch'](model) + quantizer = Quantization('fx_ptq_yaml.yaml') + dataset = quantizer.dataset('dummy', (100, 3, 224, 224), label=True) + dataloader = common.DataLoader(dataset) + dataloader = common._generate_common_dataloader(dataloader, 'pytorch') + quantizer.eval_dataloader = dataloader + quantizer.calib_dataloader = dataloader + quantizer.model = model.model + q_model = quantizer.fit() + op_list, _ = quantizer.strategy.adaptor.diagnosis_helper(model, q_model, None) + quantizer.strategy.adaptor.inspect_tensor( + model, dataloader, op_list=op_list, + iteration_list=[1], inspect_type='all', save_to_disk=True) + with open('saved/inspect_result.pkl', 'rb') as fp: + tensor_dict = pickle.load(fp) + a = tensor_dict["activation"][0] + w = tensor_dict["weight"] + self.assertTrue(w['conv1']['conv1.weight'].shape[0] == + a['conv1']['conv1.output0'].shape[1]) + quantizer.strategy.adaptor.inspect_tensor( + q_model, dataloader, op_list=['conv1', 'layer2.0.downsample.0'], + iteration_list=[1, 2], inspect_type='all', save_to_disk=True) + with open('saved/inspect_result.pkl', 'rb') as fp: + tensor_dict = pickle.load(fp) + a = tensor_dict["activation"][0] + w = tensor_dict["weight"] + self.assertTrue(w['layer2.0.downsample.0']['layer2.0.downsample.0.weight'].shape[0] == + a['layer2.0.downsample.0']['layer2.0.downsample.0.output0'].shape[1]) + + if __name__ == "__main__": unittest.main()