Skip to content

Commit

Permalink
map ipex op_name w/ pt op_name (#1740)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored Apr 23, 2024
1 parent e87c95f commit 855c10c
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 18 deletions.
6 changes: 2 additions & 4 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,9 @@ def to_config_mapping(
if self.global_config is not None:
config_mapping[(op_name, op_type)] = global_config
if op_type in op_type_config_dict:
config_mapping[(op_name, op_type)] = op_name_config_dict[op_type]
config_mapping[(op_name, op_type)] = op_type_config_dict[op_type]
for op_name_pattern in op_name_config_dict:
if isinstance(op_name, str) and re.match(op_name_pattern, op_name):
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
elif op_name_pattern == op_name: # TODO: map ipex opname to stock pt op_name
if re.match(op_name_pattern, op_name):
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
return config_mapping

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
"""
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."

_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs)
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs)

# check smoothquant folding value
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
Expand Down Expand Up @@ -121,7 +121,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(tune_cfg)
dump_model_op_stats(tune_cfg["op"])
return model


Expand Down Expand Up @@ -185,7 +185,7 @@ def qdq_quantize(
with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(tune_cfg)
dump_model_op_stats(tune_cfg["op"])
return model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
Returns:
A quantized model.
"""
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs)
cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) # update json file in ipex_config_path
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs)
# update json file in ipex_config_path; map ipex op_name to pt op_name
user_cfg = cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()

# Check save_qconf_summary part is a workaround for IPEX bug.
Expand Down Expand Up @@ -82,7 +83,7 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(tune_cfg)
dump_model_op_stats(user_cfg)
return model


Expand Down
38 changes: 32 additions & 6 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import os
import re
from collections import OrderedDict
from typing import Dict, List, Union

import torch
Expand Down Expand Up @@ -66,9 +67,10 @@
def cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name): # pragma: no cover
assert cfgs is not None, "No configure for IPEX int8 model..."
op_infos = copy.deepcopy(op_infos_from_cfgs)
cfgs = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name)
cfgs, user_cfg = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name)
with open(ipex_config_path, "w") as write_f:
json.dump(cfgs, write_f, indent=4)
return user_cfg


def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name): # pragma: no cover
Expand All @@ -83,6 +85,15 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
Returns:
cfgs (dict): updated configs.
"""
tmp_user_cfg = OrderedDict()
for op in user_cfg: # map ipex op_name to pt op_name
for i, op_name in enumerate(op):
for ops, _ in op_infos_from_cfgs.items():
if "fqn" in op_infos_from_cfgs[ops].keys() and op_infos_from_cfgs[ops]["fqn"] == op_name:
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
break
user_cfg = tmp_user_cfg
for op_name in user_cfg:
inc_op_cfg = user_cfg[op_name]
for i, name in enumerate(op_name[0]):
Expand Down Expand Up @@ -142,7 +153,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
else:
pass
cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg
return cfgs
return cfgs, user_cfg


def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
Expand Down Expand Up @@ -212,6 +223,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
cfgs (dict): dict of configuration
"""
quantizable_ops = []
op_name_info = []
# group ops by position for transform-based model
detector = TransformerBasedModelBlockPatternDetector(model)
detect_result = detector.detect_block()
Expand Down Expand Up @@ -277,17 +289,30 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
if ipex_op_type in unify_op_type_mapping_ipex:
quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type]))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
if "class" in ipex_op_type: # "<class 'torch.nn.modules.activation.ReLU'>"
op_type = ipex_op_type.split("'")[1]
op_name_info.append((module_fqn, eval(op_type)))
elif "method" in ipex_op_type: # "<method 'add' of 'torch._C._TensorBase' objects>"
method = ipex_op_type.split("'")[1]
op_type = getattr(
torch._C._TensorBase if ipex_ver.release < Version("2.2") else torch._C.TensorBase, method
)
op_name_info.append((module_fqn, op_type))
else:
op_name_info.append((module_fqn, op_type))
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
if re.match(pattern, ipex_op_type):
re_flag = True
quantizable_ops.append((tuple(name), unify_op_type))
map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn
op_name_info.append((module_fqn, ipex_op_type))
break
if not re_flag:
quantizable_ops.append((tuple(name), ipex_op_type))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
op_name_info.append((module_fqn, ipex_op_type))
else:
op_type = ""
for op_name in name:
Expand All @@ -302,14 +327,15 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
_op_cfg_id = name[0][2]
module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"]
map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn
op_name_info.append((module_fqn, op_type))

logger.debug("Map op name to fqn: ")
logger.debug(map_op_name_to_fqn)
logger.info("Attention Blocks : ")
logger.info(attention_block)
logger.info("FFN Blocks : ")
logger.info(ffn_blocks)
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, op_name_info


def simple_inference(q_model, example_inputs, iterations=1):
Expand All @@ -323,16 +349,16 @@ def simple_inference(q_model, example_inputs, iterations=1):
q_model(example_inputs)


def dump_model_op_stats(tune_cfg):
def dump_model_op_stats(user_cfg):
"""This is a function to dump quantizable ops of model to user.
Args:
tune_cfg (dict): quantization config
user_cfg (dict): quantization config
Returns:
None
"""
res = dict()
for k, v in tune_cfg["op"].items():
for k, v in user_cfg.items():
op_type_list = k[-1].split("><")
op_type = ""
for op in op_type_list:
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively

model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
return model_info

@classmethod
Expand Down Expand Up @@ -923,7 +923,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively

model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
model_info, _, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
return model_info

@classmethod
Expand Down
15 changes: 15 additions & 0 deletions test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ def test_static_quant_default(self):
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_static_quant_fallback(self):
fp32_model = copy.deepcopy(self.fp32_model)
quant_config = get_default_static_config()
example_inputs = self.input
# fallback by op_type
quant_config.set_local(torch.nn.modules.linear.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

# fallback by op_name
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.parametrize(
"act_sym, act_algo",
Expand Down

0 comments on commit 855c10c

Please sign in to comment.