diff --git a/neural_compressor/torch/algorithms/smooth_quant/utility.py b/neural_compressor/torch/algorithms/smooth_quant/utility.py index ceb2657b89a..3448d705ea7 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/utility.py +++ b/neural_compressor/torch/algorithms/smooth_quant/utility.py @@ -16,13 +16,10 @@ import json import os import re -import subprocess from collections import UserDict -import cpuinfo import intel_extension_for_pytorch as ipex import numpy -import psutil import torch import tqdm from packaging.version import Version @@ -30,11 +27,10 @@ from neural_compressor.torch.algorithms.static_quant import ( TransformerBasedModelBlockPatternDetector, dump_model_op_stats, - get_quantizable_ops_from_cfgs, + generate_activation_observer, + get_quantizable_ops_recursively, ipex_config_path, - paser_cfgs, simple_inference, - unify_op_type_mapping_ipex, ) from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger @@ -42,99 +38,6 @@ ipex_ver = get_ipex_version() -def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover - """This is a helper method to generate an activation observer. - - Args: - scheme (str): Quantization scheme to be used. - algorithm (str): What algorithm for computing the quantization parameters based on. - - Returns: - An observer. - """ - kl_activation_observer = { - "name": "HistogramObserver", - "bins": 2048, - "upsample_rate": 128, - "dtype": "torch.quint8", - "qscheme": "torch.per_tensor_affine", - "reduce_range": False, - "quant_min": 0, - "quant_max": 255, - } - minmax_activation_observer = { - "name": "MinMaxObserver", - "dtype": "torch.quint8", - "qscheme": "torch.per_tensor_affine", - "reduce_range": False, - "quant_min": 0, - "quant_max": 255, - } - smoothquant_kl_activation_observer = { - "name": "SmoothQuantActivationObserver", - "smooth_quant_enabled": smooth_quant_enable, - "dtype": "torch.quint8", - "qscheme": "torch.per_tensor_affine", - "reduce_range": False, - "quant_min": 0, - "quant_max": 255, - "alpha": 0.5, - "act_observer": kl_activation_observer, - "act_ic_observer": { - "name": "PerChannelMinMaxObserver", - "ch_axis": -1, - "dtype": "torch.quint8", - "qscheme": "torch.per_channel_affine", - "reduce_range": False, - "quant_min": 0, - "quant_max": 255, - }, - } - smoothquant_minmax_activation_observer = { - "name": "SmoothQuantActivationObserver", - "smooth_quant_enabled": smooth_quant_enable, - "dtype": "torch.quint8", - "qscheme": "torch.per_tensor_affine", - "reduce_range": False, - "quant_min": 0, - "quant_max": 255, - "alpha": 0.5, - "act_observer": minmax_activation_observer, - "act_ic_observer": { - "name": "PerChannelMinMaxObserver", - "ch_axis": -1, - "dtype": "torch.quint8", - "qscheme": "torch.per_channel_affine", - "reduce_range": False, - "quant_min": 0, - "quant_max": 255, - }, - } - REDUCE_RANGE = False if CpuInfo().vnni else True - if REDUCE_RANGE: - minmax_activation_observer["reduce_range"] = REDUCE_RANGE - kl_activation_observer["reduce_range"] = REDUCE_RANGE - if scheme == "sym": - minmax_activation_observer["qscheme"] = "torch.per_tensor_symmetric" - minmax_activation_observer["dtype"] = "torch.qint8" - minmax_activation_observer["quant_min"] = -128 - minmax_activation_observer["quant_max"] = 127 - kl_activation_observer["qscheme"] = "torch.per_tensor_symmetric" - kl_activation_observer["dtype"] = "torch.qint8" - kl_activation_observer["quant_min"] = -128 - kl_activation_observer["quant_max"] = 127 - if smooth_quant and smooth_quant_enable: - if algorithm == "kl": - return smoothquant_kl_activation_observer - if algorithm == "minmax": - return smoothquant_minmax_activation_observer - else: - if algorithm == "kl": - return kl_activation_observer - if algorithm == "minmax": - return minmax_activation_observer - - def check_cfg_and_qconfig( tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name, smooth_quant=False ): # pragma: no cover @@ -223,131 +126,6 @@ def cfg_to_qconfig( return None -def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover - """Get all quantizable ops from model. - - Args: - model (object): input model - example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. - Returns: - quantizable_ops (list): list of tuples of op_name and op_type. - cfgs (dict): dict of configuration - """ - quantizable_ops = [] - # group ops by position for transform-based model - detector = TransformerBasedModelBlockPatternDetector(model) - detect_result = detector.detect_block() - attention_block = detect_result.get("attention_blocks", None) - ffn_blocks = detect_result.get("ffn_blocks", None) - logger.info(f"Attention Blocks: {len(attention_block)}") - logger.info(f"FFN Blocks: {len(ffn_blocks)}") - if not os.path.exists(ipex_config_path): - assert isinstance(model, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module" - - if hasattr(model, "save_qconf_summary"): # pragma: no cover - os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) - model.save_qconf_summary(qconf_summary=ipex_config_path) - else: - model.eval() - - # create a quantization config file for intel pytorch extension model - os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) - assert example_inputs is not None, "IPEX need q_dataloader or example_inputs to prepare the model" - from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig - - if ipex_ver.release >= Version("2.1").release: - # HistogramObserver will cause a performance issue. - # static_qconfig = ipex.quantization.default_static_qconfig_mapping - qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), - ) - from torch.ao.quantization import QConfigMapping - - static_qconfig = QConfigMapping().set_global(qconfig) - else: - static_qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), - ) - - if isinstance(example_inputs, dict): - model = ipex.quantization.prepare(model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=True) - else: - model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=True) - simple_inference(model, example_inputs, iterations=1) - model.save_qconf_summary(qconf_summary=ipex_config_path) - - map_op_name_to_fqn = {} - with open(ipex_config_path, "r") as f: - cfgs = json.load(f) - if ipex_ver.release < Version("1.12.0").release: # pragma: no cover - for op_cfg in cfgs: - if op_cfg["name"] in unify_op_type_mapping_ipex: - quantizable_ops.append((op_cfg["id"], unify_op_type_mapping_ipex[op_cfg["name"]])) - else: - re_flag = False - for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): - if re.match(pattern, op_cfg["name"]): - re_flag = True - quantizable_ops.append((op_cfg["id"], unify_op_type)) - break - if not re_flag: - quantizable_ops.append((op_cfg["id"], op_cfg["name"])) - else: - ( - ops_name, - op_infos_from_cfgs, - input_tensor_id_op_name, - output_tensor_id_op_name, - ) = paser_cfgs(cfgs) - quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name) - for name in quantizable_op_names: - # name : list - if len(name) == 1: - module_key = name[0][0] - op_cfg_id = name[0][2] - ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] - module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None) - - if ipex_op_type in unify_op_type_mapping_ipex: - quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) - map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn - else: - re_flag = False - for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): - if re.match(pattern, ipex_op_type): - re_flag = True - quantizable_ops.append((tuple(name), unify_op_type)) - map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn - break - if not re_flag: - quantizable_ops.append((tuple(name), ipex_op_type)) - map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn - else: - op_type = "" - for op_name in name: - module_key = op_name[0] - op_cfg_id = op_name[2] - single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] - if single_op_type in unify_op_type_mapping_ipex: - single_op_type = unify_op_type_mapping_ipex[single_op_type] - op_type += "&" + single_op_type if op_type else single_op_type - quantizable_ops.append((tuple(name), op_type)) - _module_key = name[0][0] - _op_cfg_id = name[0][2] - module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] - map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn - - logger.debug("Map op name to fqn: ") - logger.debug(map_op_name_to_fqn) - logger.info("Attention Blocks : ") - logger.info(attention_block) - logger.info("FFN Blocks : ") - logger.info(ffn_blocks) - return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name - - def get_parent(node, all_parents=False): # pragma: no cover if node.inputs() is None: return None @@ -2275,67 +2053,3 @@ def forward(self, x): output = self.orig_layer(x) self.output = output return output - - -class CpuInfo(object): # pragma: no cover - """Get CPU Info.""" - - def __init__(self): - """Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket.""" - self._bf16 = False - self._vnni = False - info = cpuinfo.get_cpu_info() - if "arch" in info and "X86" in info["arch"]: - cpuid = cpuinfo.CPUID() - max_extension_support = cpuid.get_max_extension_support() - if max_extension_support >= 7: - ecx = cpuid._run_asm( - b"\x31\xC9", # xor ecx, ecx - b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\x89\xC8" b"\xC3", # mov eax, 7 # cpuid # mov ax, cx # ret - ) - self._vnni = bool(ecx & (1 << 11)) - eax = cpuid._run_asm( - b"\xB9\x01\x00\x00\x00", # mov ecx, 1 - b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret - ) - self._bf16 = bool(eax & (1 << 5)) - if "arch" in info and "ARM" in info["arch"]: # pragma: no cover - self._sockets = 1 - else: - self._sockets = self.get_number_of_sockets() - self._cores = psutil.cpu_count(logical=False) - self._cores_per_socket = int(self._cores / self._sockets) - - @property - def bf16(self): - """Get whether it is bf16.""" - return self._bf16 - - @property - def vnni(self): - """Get whether it is vnni.""" - return self._vnni - - @property - def cores_per_socket(self): - """Get the cores per socket.""" - return self._cores_per_socket - - def get_number_of_sockets(self) -> int: - """Get number of sockets in platform.""" - cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l" - if psutil.WINDOWS: - cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"' - - with subprocess.Popen( - args=cmd, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=False, - ) as proc: - proc.wait() - if proc.stdout: - for line in proc.stdout: - return int(line.decode("utf-8", errors="ignore").strip()) - return 0 diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index b3dccdafb00..626d0f60a2e 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -51,50 +51,38 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): Returns: A quantized model. """ + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs) + cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) # update json file in ipex_config_path model.eval() - if ipex_ver.release >= Version("1.12.0").release: - # Check save_qconf_summary part is a workaround for IPEX bug. - # Sometimes the prepared model from get_op_capablitiy loss this attribute - if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): - from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig - - if ipex_ver.release >= Version("2.1").release: - static_qconfig = ipex.quantization.default_static_qconfig_mapping - else: - static_qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), - ) - if isinstance(example_inputs, dict): - model = ipex.quantization.prepare( - model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace - ) - else: - model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) - - model.load_qconf_summary(qconf_summary=ipex_config_path) - run_fn(model) - model.save_qconf_summary(qconf_summary=ipex_config_path) - model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) - - else: # pragma: no cover - # for IPEX version < 1.12 - _, cfgs, default_cfgs, fuse_ops = get_quantizable_ops_recursively(model, example_inputs) - qscheme = cfg_to_qconfig(tune_cfg, cfgs, default_cfgs, fuse_ops) - ipex_conf = ipex.quantization.QuantConf( - configure_file=ipex_config_path, qscheme=qscheme - ) # pylint: disable=E1101 - run_fn(model) - ipex_conf.save(ipex_config_path) - ipex_conf = ipex.quantization.QuantConf(ipex_config_path) # pylint: disable=E1101 - model = ipex.quantization.convert(model, ipex_conf, example_inputs, inplace=True) # pylint: disable=E1121 + # Check save_qconf_summary part is a workaround for IPEX bug. + # Sometimes the prepared model from get_op_capablitiy loss this attribute + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): + from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig + + if ipex_ver.release >= Version("2.1").release: + static_qconfig = ipex.quantization.default_static_qconfig_mapping + else: + static_qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), + ) + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare( + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + ) + else: + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) + + model.load_qconf_summary(qconf_summary=ipex_config_path) + run_fn(model) + model.save_qconf_summary(qconf_summary=ipex_config_path) + model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - if ipex_ver.release >= Version("1.12.0").release: - dump_model_op_stats(tune_cfg) + dump_model_op_stats(tune_cfg) return model diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index cdfd3cb72d0..dd073f50aab 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -27,7 +27,7 @@ except: pass -from neural_compressor.common.utils import DEFAULT_WORKSPACE +from neural_compressor.common.utils import DEFAULT_WORKSPACE, CpuInfo from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger version = get_torch_version() @@ -63,57 +63,142 @@ ] -def cfg_to_qconfig(tune_cfg, cfgs, default_cfgs, fuse_ops): # pragma: no cover +def cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name): # pragma: no cover assert cfgs is not None, "No configure for IPEX int8 model..." - for key in tune_cfg["op"]: - try: - scheme = tune_cfg["op"][key]["activation"]["scheme"] - except: - scheme = "asym" - if scheme not in ["asym", "sym"]: - scheme = "asym" - break - for key in tune_cfg["op"]: - value = tune_cfg["op"][key] - pattern = get_pattern(key, fuse_ops) - assert isinstance(value, dict) - assert "activation" in value - if value["activation"]["dtype"] == "fp32": - if "weight" in value: - assert value["weight"]["dtype"] == "fp32" - for op_cfg in cfgs: - if op_cfg["id"] == key[0]: - if key[1] in ["relu_", "add_"]: - continue - num_inputs = len(op_cfg["inputs_quantized"]) - num_outputs = len(op_cfg["outputs_quantized"]) - for i_num in range(num_inputs): - op_cfg["inputs_quantized"][i_num] = False - for o_num in range(num_outputs): - op_cfg["outputs_quantized"][o_num] = False - if pattern: - if pattern[1] in ["relu_", "add_"]: - continue - tune_cfg["op"][pattern]["activation"]["dtype"] = "fp32" - if "weight" in tune_cfg["op"][pattern]: - tune_cfg["op"][pattern]["weight"]["dtype"] = "fp32" - else: - for op_cfg in cfgs: - if op_cfg["id"] == key[0]: - if key[1] in ["relu_", "add_"]: - continue - num_inputs = len(op_cfg["inputs_quantized"]) - num_outputs = len(op_cfg["outputs_quantized"]) - for i_num in range(num_inputs): - op_cfg["inputs_quantized"][i_num] = default_cfgs[key[0]]["inputs_quantized"][i_num] - for o_num in range(num_outputs): - op_cfg["outputs_quantized"][o_num] = default_cfgs[key[0]]["outputs_quantized"][o_num] + op_infos = copy.deepcopy(op_infos_from_cfgs) + cfgs = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name) with open(ipex_config_path, "w") as write_f: - json.dump(cfgs, write_f) - if scheme == "asym": - return torch.per_tensor_affine + json.dump(cfgs, write_f, indent=4) + + +def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name): # pragma: no cover + """Check configs and quantization configs. + + Args: + user_cfg (dict): quantization configuration for ops. + cfgs (dict): configs loaded from ipex config path. + op_infos_from_cfgs (dict): dict containing configs that have been parsed for each op. + output_tensor_ids_op_name (dict): dict containing op names corresponding to 'op_infos_from_cfgs'. + + Returns: + cfgs (dict): updated configs. + """ + for op_name in user_cfg: + inc_op_cfg = user_cfg[op_name] + for i, name in enumerate(op_name[0]): + # to int8 + ipex_op_cfg = op_infos_from_cfgs[name] + input_tensor_infos = ipex_op_cfg["input_tensor_infos"] + if op_name[1] == "Linear" or op_name[1] == "Linear&add": # record op_name for possible op-wise fallback + logger.debug(f"ipex_op_cfg['fqn'] - op_name {ipex_op_cfg['fqn']} {op_name}") + for index, input_tensor_info in enumerate(input_tensor_infos): + if "force_dtype" not in input_tensor_info.keys(): + continue + if ( + input_tensor_info["force_dtype"] == "torch.qint8" + or input_tensor_info["force_dtype"] == "torch.quint8" + ): + # int8 -> int8 + if inc_op_cfg["weight"]["dtype"] == "int8": + inc_scheme = inc_op_cfg["activation"]["scheme"] + inc_algorithm = inc_op_cfg["activation"]["algorithm"] + ipex_op_cfg["input_tensor_infos"] = input_tensor_infos + if ( + "op_type" in ipex_op_cfg + and ipex_op_cfg["op_type"] == "" + ): + smooth_quant_enable = True + else: + smooth_quant_enable = False + activation_observer = generate_activation_observer( + inc_scheme, inc_algorithm, smooth_quant=False, smooth_quant_enable=smooth_quant_enable + ) + if inc_scheme == "sym": + input_tensor_infos[index]["force_dtype"] = "torch.qint8" + if inc_scheme == "asym": + input_tensor_infos[index]["force_dtype"] = "torch.quint8" + ipex_op_cfg["activation_observer"] = activation_observer + # int8 -> fp32 + else: + input_tensor_infos[index]["force_dtype"] = "torch.float32" + # modify pre_op output inf_dtype + if i == 0: + input_tensor_id = input_tensor_info["id"] + input_tensor_dtype = input_tensor_info["force_dtype"] + if input_tensor_id in output_tensor_ids_op_name.keys(): + pre_op_name = output_tensor_ids_op_name[input_tensor_id] + pre_op_module = pre_op_name[0][0] + pre_op_state = pre_op_name[0][1] + pre_op_index = pre_op_name[0][2] + pre_op_infos = cfgs[pre_op_module][pre_op_state][pre_op_index] + pre_op_output_infos = pre_op_infos["output_tensor_infos"] + for index, pre_op_output in enumerate(pre_op_output_infos): + if pre_op_output["id"] == input_tensor_id: + pre_op_output_infos[index]["inf_dtype"] = input_tensor_dtype + else: + pass + pre_op_infos["output_tensor_infos"] = pre_op_output_infos + cfgs[pre_op_module][pre_op_state][pre_op_index] = pre_op_infos + else: + pass + cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg + return cfgs + + +def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover + """This is a helper method to generate a dict containing activation observer info. + + Args: + scheme (str): Quantization scheme to be used. + algorithm (str): What algorithm for computing the quantization parameters based on. + + Returns: + A dict containing observer info.zs + """ + from intel_extension_for_pytorch.quantization._smooth_quant import SmoothQuantActivationObserver + from intel_extension_for_pytorch.quantization._utils import _get_observer_setting + from torch.quantization import HistogramObserver, MinMaxObserver + + kl_activation_observer = _get_observer_setting(HistogramObserver(reduce_range=False)) + minmax_activation_observer = _get_observer_setting( + MinMaxObserver(qscheme=torch.per_tensor_affine, dtype=torch.quint8) + ) + smoothquant_kl_activation_observer = _get_observer_setting( + SmoothQuantActivationObserver( + reduce_range=False, + smooth_quant_enabled=smooth_quant_enable, + ) + ) + smoothquant_minmax_activation_observer = _get_observer_setting( + SmoothQuantActivationObserver( + reduce_range=False, + smooth_quant_enabled=smooth_quant_enable, + ) + ) + + REDUCE_RANGE = False if CpuInfo().vnni else True + if REDUCE_RANGE: + minmax_activation_observer["reduce_range"] = REDUCE_RANGE + kl_activation_observer["reduce_range"] = REDUCE_RANGE + if scheme == "sym": + minmax_activation_observer["qscheme"] = "torch.per_tensor_symmetric" + minmax_activation_observer["dtype"] = "torch.qint8" + minmax_activation_observer["quant_min"] = -128 + minmax_activation_observer["quant_max"] = 127 + kl_activation_observer["qscheme"] = "torch.per_tensor_symmetric" + kl_activation_observer["dtype"] = "torch.qint8" + kl_activation_observer["quant_min"] = -128 + kl_activation_observer["quant_max"] = 127 + if smooth_quant and smooth_quant_enable: + if algorithm == "kl": + return smoothquant_kl_activation_observer + if algorithm == "minmax": + return smoothquant_minmax_activation_observer else: - return torch.per_tensor_symmetric + if algorithm == "kl": + return kl_activation_observer + if algorithm == "minmax": + return minmax_activation_observer def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover @@ -174,67 +259,49 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover map_op_name_to_fqn = {} with open(ipex_config_path, "r") as f: cfgs = json.load(f) - default_cfgs = {} - fuse_ops = [] - if ipex_ver.release < Version("1.12.0").release: # pragma: no cover - default_cfgs = copy.deepcopy(cfgs) - fuse_ops = get_fuse_ops(cfgs) - for op_cfg in cfgs: - if op_cfg["name"] in unify_op_type_mapping_ipex: - quantizable_ops.append((op_cfg["id"], unify_op_type_mapping_ipex[op_cfg["name"]])) + ( + ops_name, + op_infos_from_cfgs, + input_tensor_id_op_name, + output_tensor_id_op_name, + ) = paser_cfgs(cfgs) + quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name) + for name in quantizable_op_names: + # name : list + if len(name) == 1: + module_key = name[0][0] + op_cfg_id = name[0][2] + ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] + module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None) + + if ipex_op_type in unify_op_type_mapping_ipex: + quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) + map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn else: re_flag = False for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): - if re.match(pattern, op_cfg["name"]): + if re.match(pattern, ipex_op_type): re_flag = True - quantizable_ops.append((op_cfg["id"], unify_op_type)) + quantizable_ops.append((tuple(name), unify_op_type)) + map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn break if not re_flag: - quantizable_ops.append((op_cfg["id"], op_cfg["name"])) - else: - ( - ops_name, - op_infos_from_cfgs, - input_tensor_id_op_name, - output_tensor_id_op_name, - ) = paser_cfgs(cfgs) - quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name) - for name in quantizable_op_names: - # name : list - if len(name) == 1: - module_key = name[0][0] - op_cfg_id = name[0][2] - ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] - module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None) - - if ipex_op_type in unify_op_type_mapping_ipex: - quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) + quantizable_ops.append((tuple(name), ipex_op_type)) map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn - else: - re_flag = False - for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): - if re.match(pattern, ipex_op_type): - re_flag = True - quantizable_ops.append((tuple(name), unify_op_type)) - map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn - break - if not re_flag: - quantizable_ops.append((tuple(name), ipex_op_type)) - map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn - else: - op_type = "" - for op_name in name: - module_key = op_name[0] - op_cfg_id = op_name[2] - single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] - if single_op_type in unify_op_type_mapping_ipex: - single_op_type = unify_op_type_mapping_ipex[single_op_type] - op_type += "&" + single_op_type if op_type else single_op_type - quantizable_ops.append((tuple(name), op_type)) - _module_key = name[0][0] - _op_cfg_id = name[0][2] - module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] - map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn + else: + op_type = "" + for op_name in name: + module_key = op_name[0] + op_cfg_id = op_name[2] + single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] + if single_op_type in unify_op_type_mapping_ipex: + single_op_type = unify_op_type_mapping_ipex[single_op_type] + op_type += "&" + single_op_type if op_type else single_op_type + quantizable_ops.append((tuple(name), op_type)) + _module_key = name[0][0] + _op_cfg_id = name[0][2] + module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] + map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn logger.debug("Map op name to fqn: ") logger.debug(map_op_name_to_fqn) @@ -242,7 +309,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover logger.info(attention_block) logger.info("FFN Blocks : ") logger.info(ffn_blocks) - return quantizable_ops, cfgs, default_cfgs, fuse_ops + return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name def simple_inference(q_model, example_inputs, iterations=1): @@ -309,42 +376,6 @@ def dump_model_op_stats(tune_cfg): ).print_stat() -def get_fuse_ops(default_cfgs): # pragma: no cover - elt_wise = ["relu", "sigmoid", "gelu"] - inplace_ops = ["relu_", "add_"] - op_patterns = [] - num_ops = len(default_cfgs) - for cur_id in range(num_ops): - cur_op = default_cfgs[cur_id]["name"] - if cur_op == "dropout": - continue - inputs = default_cfgs[cur_id]["inputs_flow"] - num_input = len(inputs) - pre_ops = {} - for i_num in range(num_input): - inp = inputs[i_num] - for pre_id in range(cur_id): - pre_op = default_cfgs[pre_id]["name"] - pre_out = default_cfgs[pre_id]["outputs_flow"] - num_out = len(pre_out) - for o_num in range(num_out): - if pre_out[o_num] == inp: - if cur_op in inplace_ops and (pre_op in ["conv2d", "conv3d", "linear"]): - op_patterns.append([(pre_id, pre_op), (cur_id, cur_op)]) - if cur_op in elt_wise and (pre_op in ["conv2d", "conv3d", "linear", "add"]): - op_patterns.append([(pre_id, pre_op), (cur_id, cur_op)]) - if cur_op == "add": - pre_ops[i_num] = [pre_id, pre_op] - if len(pre_ops) > 0: - for key, value in pre_ops.items(): - if ( - value[1] in ["conv2d", "conv3d", "linear"] - and default_cfgs[cur_id]["inputs_quantized"][key] is False - ): - op_patterns.append([(value[0], value[1]), (cur_id, cur_op)]) - return op_patterns - - def get_depth(d) -> int: """Query the depth of the dict.""" if isinstance(d, dict): @@ -491,16 +522,6 @@ def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids return quantizable_ops -def get_pattern(fallback_op, fuse_ops): # pragma: no cover - for fuse_pattern in fuse_ops: - if fuse_pattern[0] == fallback_op: - if fuse_pattern[1] in ["relu_", "add_"]: - return None - else: - return fuse_pattern[1] - return None - - class Statistics: # pragma: no cover """The statistics printer."""