From 855c10ca37d01bd371a4b9dcd953ce735f9bdea6 Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Tue, 23 Apr 2024 14:01:51 +0800 Subject: [PATCH] map ipex op_name w/ pt op_name (#1740) Signed-off-by: Cheng, Zixuan --- neural_compressor/common/base_config.py | 6 +-- .../algorithms/smooth_quant/smooth_quant.py | 6 +-- .../algorithms/static_quant/static_quant.py | 7 ++-- .../torch/algorithms/static_quant/utility.py | 38 ++++++++++++++++--- .../torch/quantization/config.py | 4 +- .../torch/quantization/test_static_quant.py | 15 ++++++++ 6 files changed, 58 insertions(+), 18 deletions(-) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 0b7749a5c48..05e26d8b05d 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -410,11 +410,9 @@ def to_config_mapping( if self.global_config is not None: config_mapping[(op_name, op_type)] = global_config if op_type in op_type_config_dict: - config_mapping[(op_name, op_type)] = op_name_config_dict[op_type] + config_mapping[(op_name, op_type)] = op_type_config_dict[op_type] for op_name_pattern in op_name_config_dict: - if isinstance(op_name, str) and re.match(op_name_pattern, op_name): - config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern] - elif op_name_pattern == op_name: # TODO: map ipex opname to stock pt op_name + if re.match(op_name_pattern, op_name): config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern] return config_mapping diff --git a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py index bd26dcdfc3b..e49d1bfbab8 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py +++ b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py @@ -56,7 +56,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): """ assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant." - _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs) + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs) # check smoothquant folding value recipe_cfgs = tune_cfg.get("recipe_cfgs", None) @@ -121,7 +121,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - dump_model_op_stats(tune_cfg) + dump_model_op_stats(tune_cfg["op"]) return model @@ -185,7 +185,7 @@ def qdq_quantize( with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - dump_model_op_stats(tune_cfg) + dump_model_op_stats(tune_cfg["op"]) return model diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index 626d0f60a2e..2f4ed042e24 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -51,8 +51,9 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): Returns: A quantized model. """ - _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs) - cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) # update json file in ipex_config_path + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs) + # update json file in ipex_config_path; map ipex op_name to pt op_name + user_cfg = cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) model.eval() # Check save_qconf_summary part is a workaround for IPEX bug. @@ -82,7 +83,7 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - dump_model_op_stats(tune_cfg) + dump_model_op_stats(user_cfg) return model diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index dd073f50aab..4657abd46d7 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -16,6 +16,7 @@ import json import os import re +from collections import OrderedDict from typing import Dict, List, Union import torch @@ -66,9 +67,10 @@ def cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name): # pragma: no cover assert cfgs is not None, "No configure for IPEX int8 model..." op_infos = copy.deepcopy(op_infos_from_cfgs) - cfgs = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name) + cfgs, user_cfg = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name) with open(ipex_config_path, "w") as write_f: json.dump(cfgs, write_f, indent=4) + return user_cfg def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name): # pragma: no cover @@ -83,6 +85,15 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ Returns: cfgs (dict): updated configs. """ + tmp_user_cfg = OrderedDict() + for op in user_cfg: # map ipex op_name to pt op_name + for i, op_name in enumerate(op): + for ops, _ in op_infos_from_cfgs.items(): + if "fqn" in op_infos_from_cfgs[ops].keys() and op_infos_from_cfgs[ops]["fqn"] == op_name: + ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]]) + tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op] + break + user_cfg = tmp_user_cfg for op_name in user_cfg: inc_op_cfg = user_cfg[op_name] for i, name in enumerate(op_name[0]): @@ -142,7 +153,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ else: pass cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg - return cfgs + return cfgs, user_cfg def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover @@ -212,6 +223,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover cfgs (dict): dict of configuration """ quantizable_ops = [] + op_name_info = [] # group ops by position for transform-based model detector = TransformerBasedModelBlockPatternDetector(model) detect_result = detector.detect_block() @@ -277,6 +289,17 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover if ipex_op_type in unify_op_type_mapping_ipex: quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + if "class" in ipex_op_type: # "" + op_type = ipex_op_type.split("'")[1] + op_name_info.append((module_fqn, eval(op_type))) + elif "method" in ipex_op_type: # "" + method = ipex_op_type.split("'")[1] + op_type = getattr( + torch._C._TensorBase if ipex_ver.release < Version("2.2") else torch._C.TensorBase, method + ) + op_name_info.append((module_fqn, op_type)) + else: + op_name_info.append((module_fqn, op_type)) else: re_flag = False for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): @@ -284,10 +307,12 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover re_flag = True quantizable_ops.append((tuple(name), unify_op_type)) map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn + op_name_info.append((module_fqn, ipex_op_type)) break if not re_flag: quantizable_ops.append((tuple(name), ipex_op_type)) map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + op_name_info.append((module_fqn, ipex_op_type)) else: op_type = "" for op_name in name: @@ -302,6 +327,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover _op_cfg_id = name[0][2] module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn + op_name_info.append((module_fqn, op_type)) logger.debug("Map op name to fqn: ") logger.debug(map_op_name_to_fqn) @@ -309,7 +335,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover logger.info(attention_block) logger.info("FFN Blocks : ") logger.info(ffn_blocks) - return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name + return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, op_name_info def simple_inference(q_model, example_inputs, iterations=1): @@ -323,16 +349,16 @@ def simple_inference(q_model, example_inputs, iterations=1): q_model(example_inputs) -def dump_model_op_stats(tune_cfg): +def dump_model_op_stats(user_cfg): """This is a function to dump quantizable ops of model to user. Args: - tune_cfg (dict): quantization config + user_cfg (dict): quantization config Returns: None """ res = dict() - for k, v in tune_cfg["op"].items(): + for k, v in user_cfg.items(): op_type_list = k[-1].split("><") op_type = "" for op in op_type_list: diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 29a8177e9be..9de2ecb0a94 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -818,7 +818,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively - model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs) + _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info @classmethod @@ -923,7 +923,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively - model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs) + model_info, _, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info @classmethod diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index 518e2240470..493191cae04 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -49,6 +49,21 @@ def test_static_quant_default(self): q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + def test_static_quant_fallback(self): + fp32_model = copy.deepcopy(self.fp32_model) + quant_config = get_default_static_config() + example_inputs = self.input + # fallback by op_type + quant_config.set_local(torch.nn.modules.linear.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + + # fallback by op_name + quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") @pytest.mark.parametrize( "act_sym, act_algo",