Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smoothquant refactor for 3.x API #1792

Merged
merged 36 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e2b5ced
Smoothquant refactor for 3.x API
violetch24 May 13, 2024
aa15c3f
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 13, 2024
5ffa853
modify smoothquant ut
violetch24 May 15, 2024
3b09ba1
Update utility.py
violetch24 May 15, 2024
fe9a810
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 15, 2024
4824301
modify sq example
violetch24 May 15, 2024
19fbf86
minor fix
violetch24 May 16, 2024
968b12e
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 16, 2024
5c5ccd1
minor fix
violetch24 May 16, 2024
987ba49
modify ut
violetch24 May 16, 2024
5e8b7ab
update requirements
violetch24 May 16, 2024
b04a9a0
update requirements
violetch24 May 17, 2024
1ef4415
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
55dfaed
modify ut for coverage
violetch24 May 17, 2024
9b8628b
minor fix
violetch24 May 17, 2024
276b029
Update smooth_quant.py
violetch24 May 17, 2024
605bcb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2024
f08aaae
minor fix
violetch24 May 17, 2024
5a40e16
minor fix
violetch24 May 17, 2024
f0f5b58
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
18b2990
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
a98e202
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
cd805d0
code fix
violetch24 May 17, 2024
ca311ca
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
cf9096e
Update requirements.txt
violetch24 May 17, 2024
e365ce7
Update requirements.txt
violetch24 May 17, 2024
ce214bd
Update run_clm_no_trainer.py
violetch24 May 17, 2024
b79e74c
modify ut
violetch24 May 17, 2024
a67daff
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 19, 2024
757ce29
ut coverage
violetch24 May 19, 2024
e058667
minor fix
violetch24 May 19, 2024
718f163
minor fix
violetch24 May 19, 2024
28fbb09
remove overrides
violetch24 May 20, 2024
2e93927
ut for 2.x and 3.x API
violetch24 May 20, 2024
1c0f4aa
Update test_smooth_quant.py
violetch24 May 20, 2024
531e8d9
Update test_smooth_quant.py
violetch24 May 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -361,20 +361,12 @@ def run_fn(model):

from utils import get_example_inputs
example_inputs = get_example_inputs(user_model, calib_dataloader)
if args.sq:
# currently, smooth quant only support quantize API
# TODO: support prepare/convert API for smooth quant
from neural_compressor.torch.quantization import quantize

user_model = quantize(
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
)
else:
from neural_compressor.torch.quantization import prepare, convert

user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)
from neural_compressor.torch.quantization import prepare, convert
user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)

user_model.save(args.output_dir)


Expand Down
14 changes: 12 additions & 2 deletions neural_compressor/torch/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Optional

import torch
Expand Down Expand Up @@ -111,5 +111,15 @@ def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any):
elif mode == Mode.CONVERT:
model = self.convert(model, *args, **kwargs)
elif mode == Mode.QUANTIZE:
model = self.quantize(model, *args, **kwargs)
if not isinstance(self.quant_config, dict):
user_cfg = copy.deepcopy(self.quant_config).to_dict()
else:
user_cfg = copy.deepcopy(self.quant_config)
if "recipe_cfgs" in user_cfg: # keep quantize API for smoothquant
run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
model = self.quantize(model, self.quant_config, run_fn, example_inputs, inplace)
else:
model = self.quantize(model, *args, **kwargs)
return model
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
# limitations under the License.

from .utility import *
from .smooth_quant import smooth_quantize
from .smooth_quant import SmoothQuantQuantizer
from .save_load import save, load, recover_model_from_json
264 changes: 190 additions & 74 deletions neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@

try:
import intel_extension_for_pytorch as ipex
except:
except: # pragma: no cover
assert False, "Please install IPEX for smooth quantization."

from collections import OrderedDict
from types import MethodType

from packaging.version import Version

from neural_compressor.torch.algorithms import Quantizer

from .utility import (
TorchSmoothQuant,
cfg_to_qconfig,
Expand All @@ -41,88 +46,199 @@
ipex_ver = get_ipex_version()


def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
"""Execute the quantize process on the specified model.
class SmoothQuantQuantizer(Quantizer):
def __init__(self, quant_config: OrderedDict = {}):
"""Init a SmoothQuantQuantizer object.

Args:
model: a float model to be quantized.
tune_cfg: quantization config for ops.
run_fn: a calibration function for calibrating the model.
example_inputs: used to trace torch model.
inplace: whether to carry out model transformations in-place.
Args:
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
"""
super().__init__(quant_config)

Returns:
A quantized model.
"""
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."
def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
"""Prepares a given model for quantization.

Args:
model: A float model to be quantized.
example_inputs: Used to trace torch model.
inplace: Whether to carry out model transformations in-place. Defaults to True.

Returns:
A prepared model.
"""
assert example_inputs is not None, "Please provide example_inputs for smooth quantization."
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."

# Note: we should make sure smoothquant is only executed once with inplacing fp32 model.
if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: # pragma: no cover
logger.info("The model is already optimized by SmoothQuant algorithm, skip it.")
return model

cfgs, op_infos_from_cfgs, output_tensor_id_op_name = (
model.cfgs,
model.op_infos_from_cfgs,
model.output_tensor_id_op_name,
)

# Update json file in ipex_config_path
cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()

# check smoothquant alpha and act_algo value
recipe_cfgs = self.quant_config.get("recipe_cfgs", None)
alpha = recipe_cfgs["smooth_quant_args"]["alpha"]
for op, _ in self.quant_config["op"].items():
act_algo = self.quant_config["op"][op]["activation"]["algorithm"]

_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs)
# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute.
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
from torch.ao.quantization.observer import MinMaxObserver

# check smoothquant folding value
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]:
if recipe_cfgs["smooth_quant_args"]["folding"] is None:
if ipex_ver.release < Version("2.1").release:
folding = True
if ipex_ver.release >= Version("2.1.1").release:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=alpha, act_observer=MinMaxObserver
violetch24 marked this conversation as resolved.
Show resolved Hide resolved
)
else: # pragma: no cover
if act_algo == "minmax":
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=alpha, act_observer=MinMaxObserver()
)
logger.warning(
"The int8 model accuracy will be close to 0 with MinMaxobserver, "
+ "the suggested IPEX version is higher or equal than 2.1.100+cpu."
)
else:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=alpha)

if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
else:
folding = False
else:
folding = recipe_cfgs["smooth_quant_args"]["folding"]
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)

# Note: we should make sure smoothquant is only executed once with inplacing fp32 model.
if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized:
logger.info("The model is already optimized by SmoothQuant algorithm, skip it.")
cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=True)
model.load_qconf_summary(qconf_summary=ipex_config_path)
return model

sq = TorchSmoothQuant(model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True)
model = sq.transform(
alpha=recipe_cfgs["smooth_quant_args"]["alpha"],
folding=folding,
auto_alpha_args=recipe_cfgs["smooth_quant_args"]["auto_alpha_args"],
scale_sharing=recipe_cfgs["smooth_quant_args"]["scale_sharing"],
)

# Update model parameter when smoothquant folding = False
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding:
return qdq_quantize(
model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq
)
def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
"""Converts a prepared model to a quantized model.

# Update model parameter when smoothquant folding = True
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding:
_apply_pre_optimization(model, tune_cfg, sq)
model.eval()
Args:
model: The prepared model to be converted.
example_inputs: Used to trace torch model.
inplace: Whether to carry out model transformations in-place. Defaults to True.

# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
static_qconfig = ipex.quantization.default_static_qconfig_mapping
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
Returns:
A quantized model.
"""
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(self.quant_config["op"])

from neural_compressor.torch.algorithms.smooth_quant import save

logger.info("Smooth quantization done.")
model.ori_save = model.save
model.save = MethodType(save, model)
return model

def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args, **kwargs):
"""Execute the quantize process on the specified model.

Args:
model: a float model to be quantized.
tune_cfg: quantization config for ops.
run_fn: a calibration function for calibrating the model.
example_inputs: used to trace torch model.
inplace: whether to carry out model transformations in-place.

Returns:
A quantized model.
"""
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."

cfgs, op_infos_from_cfgs, output_tensor_id_op_name = (
model.cfgs,
model.op_infos_from_cfgs,
model.output_tensor_id_op_name,
)

# check smoothquant folding value
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]:
if recipe_cfgs["smooth_quant_args"]["folding"] is None: # pragma: no cover
if ipex_ver.release < Version("2.1").release:
folding = True
else:
folding = False
else:
folding = recipe_cfgs["smooth_quant_args"]["folding"]

# Note: we should make sure smoothquant is only executed once with inplacing fp32 model.
if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: # pragma: no cover
logger.info("The model is already optimized by SmoothQuant algorithm, skip it.")
return model

sq_info = model.sq_info

# Update model parameter when smoothquant folding = False
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding:
return qdq_quantize(
model,
tune_cfg,
run_fn,
example_inputs,
inplace,
cfgs,
op_infos_from_cfgs,
output_tensor_id_op_name,
sq_info,
)
else:
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)

model.load_qconf_summary(qconf_summary=ipex_config_path)
run_fn(model)
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
# Update model parameter when smoothquant folding = True
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding:
_apply_pre_optimization(model, tune_cfg, sq_info)

# Recover model parameter when smoothquant folding = True
if (
recipe_cfgs
and recipe_cfgs.get("smooth_quant", False)
and recipe_cfgs["smooth_quant_args"]["folding"]
and not inplace
): # pragma: no cover
_apply_pre_optimization(model, tune_cfg, sq, recover=True)
# Update json file in ipex_config_path
cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(tune_cfg["op"])
return model
# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
static_qconfig = ipex.quantization.default_static_qconfig_mapping
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
else:
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)

model.load_qconf_summary(qconf_summary=ipex_config_path)
run_fn(model)
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)

# Recover model parameter when smoothquant folding = True
if (
recipe_cfgs
and recipe_cfgs.get("smooth_quant", False)
and recipe_cfgs["smooth_quant_args"]["folding"]
and not inplace
): # pragma: no cover
_apply_pre_optimization(model, tune_cfg, sq_info, recover=True)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(tune_cfg["op"])
return model


def qdq_quantize(
Expand All @@ -133,12 +249,12 @@ def qdq_quantize(

# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
from torch.ao.quantization.observer import MinMaxObserver

if ipex_ver.release >= Version("2.1.1").release:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver)
else:
else: # pragma: no cover
if sq_minmax_init:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=0.5, act_observer=MinMaxObserver()
Expand Down Expand Up @@ -169,7 +285,7 @@ def qdq_quantize(
# IPEX may raise an error on the second iteration.
# OverflowError: cannot convert float infinity to integer
run_fn(model)
except:
except: # pragma: no cover
logger.warning(
"The calibration failed when calibrating with ipex, "
+ "using scale info from SmoothQuant for Linear and "
Expand Down Expand Up @@ -197,7 +313,7 @@ def _apply_pre_optimization(model, tune_cfg, sq, recover=False):
tsq = TorchSmoothQuant(model, None)
alpha = tune_cfg["recipe_cfgs"]["smooth_quant_args"]["alpha"]
for op_name, info in sq_max_info.items():
if alpha == "auto":
if alpha == "auto": # pragma: no cover
alpha = info["alpha"]
absorb_layer = op_name
absorbed_layer = info["absorbed_layer"]
Expand Down Expand Up @@ -237,7 +353,7 @@ def _ipex_post_quant_process(model, example_inputs, inplace=False):
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except:
except: # pragma: no cover
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
else:
Expand Down
Loading
Loading