Skip to content

Commit

Permalink
fix tune_cfg issue for 3.x static quant (#1718)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored Apr 16, 2024
1 parent 137fa3a commit ba16504
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 476 deletions.
290 changes: 2 additions & 288 deletions neural_compressor/torch/algorithms/smooth_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,125 +16,28 @@
import json
import os
import re
import subprocess
from collections import UserDict

import cpuinfo
import intel_extension_for_pytorch as ipex
import numpy
import psutil
import torch
import tqdm
from packaging.version import Version

from neural_compressor.torch.algorithms.static_quant import (
TransformerBasedModelBlockPatternDetector,
dump_model_op_stats,
get_quantizable_ops_from_cfgs,
generate_activation_observer,
get_quantizable_ops_recursively,
ipex_config_path,
paser_cfgs,
simple_inference,
unify_op_type_mapping_ipex,
)
from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger

version = get_torch_version()
ipex_ver = get_ipex_version()


def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
"""This is a helper method to generate an activation observer.
Args:
scheme (str): Quantization scheme to be used.
algorithm (str): What algorithm for computing the quantization parameters based on.
Returns:
An observer.
"""
kl_activation_observer = {
"name": "HistogramObserver",
"bins": 2048,
"upsample_rate": 128,
"dtype": "torch.quint8",
"qscheme": "torch.per_tensor_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
}
minmax_activation_observer = {
"name": "MinMaxObserver",
"dtype": "torch.quint8",
"qscheme": "torch.per_tensor_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
}
smoothquant_kl_activation_observer = {
"name": "SmoothQuantActivationObserver",
"smooth_quant_enabled": smooth_quant_enable,
"dtype": "torch.quint8",
"qscheme": "torch.per_tensor_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
"alpha": 0.5,
"act_observer": kl_activation_observer,
"act_ic_observer": {
"name": "PerChannelMinMaxObserver",
"ch_axis": -1,
"dtype": "torch.quint8",
"qscheme": "torch.per_channel_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
},
}
smoothquant_minmax_activation_observer = {
"name": "SmoothQuantActivationObserver",
"smooth_quant_enabled": smooth_quant_enable,
"dtype": "torch.quint8",
"qscheme": "torch.per_tensor_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
"alpha": 0.5,
"act_observer": minmax_activation_observer,
"act_ic_observer": {
"name": "PerChannelMinMaxObserver",
"ch_axis": -1,
"dtype": "torch.quint8",
"qscheme": "torch.per_channel_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
},
}
REDUCE_RANGE = False if CpuInfo().vnni else True
if REDUCE_RANGE:
minmax_activation_observer["reduce_range"] = REDUCE_RANGE
kl_activation_observer["reduce_range"] = REDUCE_RANGE
if scheme == "sym":
minmax_activation_observer["qscheme"] = "torch.per_tensor_symmetric"
minmax_activation_observer["dtype"] = "torch.qint8"
minmax_activation_observer["quant_min"] = -128
minmax_activation_observer["quant_max"] = 127
kl_activation_observer["qscheme"] = "torch.per_tensor_symmetric"
kl_activation_observer["dtype"] = "torch.qint8"
kl_activation_observer["quant_min"] = -128
kl_activation_observer["quant_max"] = 127
if smooth_quant and smooth_quant_enable:
if algorithm == "kl":
return smoothquant_kl_activation_observer
if algorithm == "minmax":
return smoothquant_minmax_activation_observer
else:
if algorithm == "kl":
return kl_activation_observer
if algorithm == "minmax":
return minmax_activation_observer


def check_cfg_and_qconfig(
tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name, smooth_quant=False
): # pragma: no cover
Expand Down Expand Up @@ -223,131 +126,6 @@ def cfg_to_qconfig(
return None


def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
"""Get all quantizable ops from model.
Args:
model (object): input model
example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model.
Returns:
quantizable_ops (list): list of tuples of op_name and op_type.
cfgs (dict): dict of configuration
"""
quantizable_ops = []
# group ops by position for transform-based model
detector = TransformerBasedModelBlockPatternDetector(model)
detect_result = detector.detect_block()
attention_block = detect_result.get("attention_blocks", None)
ffn_blocks = detect_result.get("ffn_blocks", None)
logger.info(f"Attention Blocks: {len(attention_block)}")
logger.info(f"FFN Blocks: {len(ffn_blocks)}")
if not os.path.exists(ipex_config_path):
assert isinstance(model, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module"

if hasattr(model, "save_qconf_summary"): # pragma: no cover
os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True)
model.save_qconf_summary(qconf_summary=ipex_config_path)
else:
model.eval()

# create a quantization config file for intel pytorch extension model
os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True)
assert example_inputs is not None, "IPEX need q_dataloader or example_inputs to prepare the model"
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if ipex_ver.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
from torch.ao.quantization import QConfigMapping

static_qconfig = QConfigMapping().set_global(qconfig)
else:
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)

if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=True)
else:
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=True)
simple_inference(model, example_inputs, iterations=1)
model.save_qconf_summary(qconf_summary=ipex_config_path)

map_op_name_to_fqn = {}
with open(ipex_config_path, "r") as f:
cfgs = json.load(f)
if ipex_ver.release < Version("1.12.0").release: # pragma: no cover
for op_cfg in cfgs:
if op_cfg["name"] in unify_op_type_mapping_ipex:
quantizable_ops.append((op_cfg["id"], unify_op_type_mapping_ipex[op_cfg["name"]]))
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
if re.match(pattern, op_cfg["name"]):
re_flag = True
quantizable_ops.append((op_cfg["id"], unify_op_type))
break
if not re_flag:
quantizable_ops.append((op_cfg["id"], op_cfg["name"]))
else:
(
ops_name,
op_infos_from_cfgs,
input_tensor_id_op_name,
output_tensor_id_op_name,
) = paser_cfgs(cfgs)
quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name)
for name in quantizable_op_names:
# name : list
if len(name) == 1:
module_key = name[0][0]
op_cfg_id = name[0][2]
ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"]
module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None)

if ipex_op_type in unify_op_type_mapping_ipex:
quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type]))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
if re.match(pattern, ipex_op_type):
re_flag = True
quantizable_ops.append((tuple(name), unify_op_type))
map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn
break
if not re_flag:
quantizable_ops.append((tuple(name), ipex_op_type))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
else:
op_type = ""
for op_name in name:
module_key = op_name[0]
op_cfg_id = op_name[2]
single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"]
if single_op_type in unify_op_type_mapping_ipex:
single_op_type = unify_op_type_mapping_ipex[single_op_type]
op_type += "&" + single_op_type if op_type else single_op_type
quantizable_ops.append((tuple(name), op_type))
_module_key = name[0][0]
_op_cfg_id = name[0][2]
module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"]
map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn

logger.debug("Map op name to fqn: ")
logger.debug(map_op_name_to_fqn)
logger.info("Attention Blocks : ")
logger.info(attention_block)
logger.info("FFN Blocks : ")
logger.info(ffn_blocks)
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name


def get_parent(node, all_parents=False): # pragma: no cover
if node.inputs() is None:
return None
Expand Down Expand Up @@ -2275,67 +2053,3 @@ def forward(self, x):
output = self.orig_layer(x)
self.output = output
return output


class CpuInfo(object): # pragma: no cover
"""Get CPU Info."""

def __init__(self):
"""Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket."""
self._bf16 = False
self._vnni = False
info = cpuinfo.get_cpu_info()
if "arch" in info and "X86" in info["arch"]:
cpuid = cpuinfo.CPUID()
max_extension_support = cpuid.get_max_extension_support()
if max_extension_support >= 7:
ecx = cpuid._run_asm(
b"\x31\xC9", # xor ecx, ecx
b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\x89\xC8" b"\xC3", # mov eax, 7 # cpuid # mov ax, cx # ret
)
self._vnni = bool(ecx & (1 << 11))
eax = cpuid._run_asm(
b"\xB9\x01\x00\x00\x00", # mov ecx, 1
b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret
)
self._bf16 = bool(eax & (1 << 5))
if "arch" in info and "ARM" in info["arch"]: # pragma: no cover
self._sockets = 1
else:
self._sockets = self.get_number_of_sockets()
self._cores = psutil.cpu_count(logical=False)
self._cores_per_socket = int(self._cores / self._sockets)

@property
def bf16(self):
"""Get whether it is bf16."""
return self._bf16

@property
def vnni(self):
"""Get whether it is vnni."""
return self._vnni

@property
def cores_per_socket(self):
"""Get the cores per socket."""
return self._cores_per_socket

def get_number_of_sockets(self) -> int:
"""Get number of sockets in platform."""
cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l"
if psutil.WINDOWS:
cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"'

with subprocess.Popen(
args=cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=False,
) as proc:
proc.wait()
if proc.stdout:
for line in proc.stdout:
return int(line.decode("utf-8", errors="ignore").strip())
return 0
Loading

0 comments on commit ba16504

Please sign in to comment.