Skip to content

Commit

Permalink
Support xpu for ipex static quant (#1916)
Browse files Browse the repository at this point in the history
Signed-off-by: violetch24 <[email protected]>
  • Loading branch information
violetch24 authored Jul 17, 2024
1 parent a1cc618 commit 53e6ee6
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 55 deletions.
13 changes: 10 additions & 3 deletions neural_compressor/torch/algorithms/static_quant/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,16 @@ def save(model, output_dir="./saved_results"):

qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
model.ori_save(qmodel_file_path)
with open(qconfig_file_path, "w") as f:
json.dump(model.tune_cfg, f, indent=4)
device = next(model.parameters(), None).device.type if next(model.parameters(), None) else "cpu"
if device == "cpu":
model.ori_save(qmodel_file_path)
with open(qconfig_file_path, "w") as f:
json.dump(model.tune_cfg, f, indent=4)
else: # pragma: no cover
from neural_compressor.common.utils import save_config_mapping

torch.jit.save(model, qmodel_file_path)
save_config_mapping(model.qconfig, qconfig_file_path)

logger.info("Save quantized model to {}.".format(qmodel_file_path))
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))
Expand Down
108 changes: 69 additions & 39 deletions neural_compressor/torch/algorithms/static_quant/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .utility import (
CpuInfo,
cfg_to_qconfig,
dump_model_op_stats,
generate_xpu_qconfig,
get_ipex_version,
get_quantizable_ops_recursively,
ipex_config_path,
Expand All @@ -56,6 +58,7 @@ def __init__(self, quant_config: OrderedDict = {}):
"""
super().__init__(quant_config)
self.user_cfg = OrderedDict()
self.device = auto_detect_accelerator().current_device()

def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
"""Prepares a given model for quantization.
Expand All @@ -70,43 +73,61 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
"""
assert example_inputs is not None, "Please provide example_inputs for static quantization."

_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
model, example_inputs
)
# update json file in ipex_config_path; map ipex op_name to pt op_name
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()
if self.device == "cpu":
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
model, example_inputs
)
# update json file in ipex_config_path; map ipex op_name to pt op_name
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
else: # pragma: no cover
model = model.to("xpu")

use_bf16 = self.quant_config.get("use_bf16", None)
model.eval()

# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if ipex_ver.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
from torch.ao.quantization import QConfigMapping

static_qconfig = QConfigMapping().set_global(qconfig)
else:
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
# Sometimes the prepared model from get_op_capablitiy loss this attributes
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
from torch.ao.quantization import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, QConfig

if self.device != "cpu": # pragma: no cover
from torch.quantization.quantize_jit import prepare_jit

with torch.no_grad():
modelJit = torch.jit.trace(model, example_inputs)
qconfig = generate_xpu_qconfig(self.quant_config)
model = prepare_jit(modelJit, qconfig, inplace)
else:
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
if ipex_ver.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
),
)
from torch.ao.quantization import QConfigMapping

static_qconfig = QConfigMapping().set_global(qconfig)
else: # pragma: no cover
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
),
)
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
else:
model = ipex.quantization.prepare(
model, static_qconfig, example_inputs=example_inputs, inplace=inplace
)

if self.device == "cpu":
model.load_qconf_summary(qconf_summary=ipex_config_path)

model.load_qconf_summary(qconf_summary=ipex_config_path)
return model

def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
Expand All @@ -124,18 +145,27 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):

from neural_compressor.torch.algorithms.static_quant import save

model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)
if self.device != "cpu": # pragma: no cover
from torch.quantization.quantize_jit import convert_jit

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
model = convert_jit(model, inplace)
simple_inference(model, example_inputs, iterations=2)
model.qconfig = self.quant_config["op"]
dump_model_op_stats(model.qconfig)
else:
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)

dump_model_op_stats(self.user_cfg)
with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path

dump_model_op_stats(self.user_cfg)

logger.info("Static quantization done.")
model.ori_save = model.save
model.save = MethodType(save, model)

logger.info("Static quantization done.")
return model


Expand Down
41 changes: 41 additions & 0 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,47 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
return cfgs, ori_user_cfg


def generate_xpu_qconfig(tune_cfg): # pragma: no cover
# qconfig observer & config constants for ipex-xpu
from torch.ao.quantization import HistogramObserver, MinMaxObserver, QConfig

act_observer_minmax_asym = MinMaxObserver.with_args(quant_min=0, quant_max=127)
act_observer_minmax_sym = MinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127
)
act_observer_kl_asym = HistogramObserver.with_args(quant_min=0, quant_max=127)
act_observer_kl_sym = HistogramObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127
)
# no tuning for granularity due to tuning space
weight_observer_minmax_sym = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)

qconfig = {}
user_cfg = copy.deepcopy(tune_cfg["op"])
for _, cfg in user_cfg.items():
act_algo = cfg["activation"]["algorithm"]
act_sym = cfg["activation"]["scheme"]
break

if act_algo == "minmax":
if act_sym == "sym":
activation = act_observer_minmax_sym
else:
activation = act_observer_minmax_asym
else:
if act_sym == "sym":
activation = act_observer_kl_sym
else:
activation = act_observer_kl_asym

qconfig[""] = QConfig(activation=activation, weight=weight_observer_minmax_sym)

for (op_name, op_type), cfg in user_cfg.items():
if cfg["weight"]["dtype"] == "fp32":
qconfig[op_name] = None
return qconfig


def generate_activation_observer(
scheme, algorithm, smooth_quant=False, smooth_quant_enable=False, alpha=0.5
): # pragma: no cover
Expand Down
26 changes: 23 additions & 3 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,7 @@ def __init__(
act_algo: str = "minmax",
excluded_precisions: list = [],
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
model_info: Optional[List[Tuple[str, Callable]]] = None,
):
"""Init Static Quant Configs."""
super().__init__(white_list=white_list)
Expand All @@ -1107,6 +1108,7 @@ def __init__(
self.act_granularity = act_granularity
self.act_algo = act_algo
self.excluded_precisions = excluded_precisions
self.model_info = model_info
self._post_init()

@classmethod
Expand All @@ -1124,10 +1126,28 @@ def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tupl
_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
return model_info

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]: # pragma: no cover
if self.model_info:
return self.model_info
else:
white_list = torch.quantization.quantization_mappings.get_default_qconfig_propagation_list()
filter_result = []
for op_name, module in model.named_modules():
if type(module) in white_list:
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
self.model_info = filter_result
return filter_result

def get_model_info(self, model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

if is_ipex_imported():
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)
if auto_detect_accelerator().current_device() == "cpu":
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)
else:
return StaticQuantConfig.get_model_info_for_ipex_xpu(self, model)

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
Expand Down
Loading

0 comments on commit 53e6ee6

Please sign in to comment.