Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smoothquant refactor for 3.x API #1792

Merged
merged 36 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e2b5ced
Smoothquant refactor for 3.x API
violetch24 May 13, 2024
aa15c3f
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 13, 2024
5ffa853
modify smoothquant ut
violetch24 May 15, 2024
3b09ba1
Update utility.py
violetch24 May 15, 2024
fe9a810
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 15, 2024
4824301
modify sq example
violetch24 May 15, 2024
19fbf86
minor fix
violetch24 May 16, 2024
968b12e
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 16, 2024
5c5ccd1
minor fix
violetch24 May 16, 2024
987ba49
modify ut
violetch24 May 16, 2024
5e8b7ab
update requirements
violetch24 May 16, 2024
b04a9a0
update requirements
violetch24 May 17, 2024
1ef4415
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
55dfaed
modify ut for coverage
violetch24 May 17, 2024
9b8628b
minor fix
violetch24 May 17, 2024
276b029
Update smooth_quant.py
violetch24 May 17, 2024
605bcb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2024
f08aaae
minor fix
violetch24 May 17, 2024
5a40e16
minor fix
violetch24 May 17, 2024
f0f5b58
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
18b2990
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
a98e202
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
cd805d0
code fix
violetch24 May 17, 2024
ca311ca
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
cf9096e
Update requirements.txt
violetch24 May 17, 2024
e365ce7
Update requirements.txt
violetch24 May 17, 2024
ce214bd
Update run_clm_no_trainer.py
violetch24 May 17, 2024
b79e74c
modify ut
violetch24 May 17, 2024
a67daff
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 19, 2024
757ce29
ut coverage
violetch24 May 19, 2024
e058667
minor fix
violetch24 May 19, 2024
718f163
minor fix
violetch24 May 19, 2024
28fbb09
remove overrides
violetch24 May 20, 2024
2e93927
ut for 2.x and 3.x API
violetch24 May 20, 2024
1c0f4aa
Update test_smooth_quant.py
violetch24 May 20, 2024
531e8d9
Update test_smooth_quant.py
violetch24 May 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -361,20 +361,12 @@ def run_fn(model):

from utils import get_example_inputs
example_inputs = get_example_inputs(user_model, calib_dataloader)
if args.sq:
# currently, smooth quant only support quantize API
# TODO: support prepare/convert API for smooth quant
from neural_compressor.torch.quantization import quantize

user_model = quantize(
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
)
else:
from neural_compressor.torch.quantization import prepare, convert

user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)
from neural_compressor.torch.quantization import prepare, convert
user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)

user_model.save(args.output_dir)


Expand Down
14 changes: 12 additions & 2 deletions neural_compressor/torch/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Optional

import torch
Expand Down Expand Up @@ -111,5 +111,15 @@ def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any):
elif mode == Mode.CONVERT:
model = self.convert(model, *args, **kwargs)
elif mode == Mode.QUANTIZE:
model = self.quantize(model, *args, **kwargs)
if not isinstance(self.quant_config, dict):
user_cfg = copy.deepcopy(self.quant_config).to_dict()
else:
user_cfg = copy.deepcopy(self.quant_config)
if "recipe_cfgs" in user_cfg: # keep quantize API for smoothquant
run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
model = self.quantize(model, self.quant_config, run_fn, example_inputs, inplace)
else:
model = self.quantize(model, *args, **kwargs)
return model
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
# limitations under the License.

from .utility import *
from .smooth_quant import smooth_quantize
from .smooth_quant import SmoothQuantQuantizer
from .save_load import save, load, recover_model_from_json
268 changes: 193 additions & 75 deletions neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py

Large diffs are not rendered by default.

132 changes: 125 additions & 7 deletions neural_compressor/torch/algorithms/smooth_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,136 @@
TransformerBasedModelBlockPatternDetector,
dump_model_op_stats,
generate_activation_observer,
get_quantizable_ops_recursively,
get_quantizable_ops_from_cfgs,
ipex_config_path,
parse_cfgs,
simple_inference,
unify_op_type_mapping_ipex,
)
from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger

version = get_torch_version()
ipex_ver = get_ipex_version()


def get_quantizable_ops_recursively(model, alpha, act_algo, example_inputs, inplace=True): # pragma: no cover
"""Get all quantizable ops from model.

Args:
model (object): input model
example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model.
alpha (float|str): smoothquant alpha.
act_algo (str): activation algorithm, minmax or kl.

Returns:
quantizable_ops (list): list of tuples of op_name and op_type.
cfgs (dict): dict of configuration
"""
quantizable_ops = []
# group ops by position for transform-based model
detector = TransformerBasedModelBlockPatternDetector(model)
detect_result = detector.detect_block()
attention_block = detect_result.get("attention_blocks", None)
ffn_blocks = detect_result.get("ffn_blocks", None)
logger.info(f"Attention Blocks: {len(attention_block)}")
logger.info(f"FFN Blocks: {len(ffn_blocks)}")
if not os.path.exists(ipex_config_path):
assert isinstance(model, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module"

if hasattr(model, "save_qconf_summary"):
os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True)
model.save_qconf_summary(qconf_summary=ipex_config_path)
else: # pragma: no cover
model.eval()

# create a quantization config file for intel pytorch extension model
os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True)
assert example_inputs is not None, "IPEX need q_dataloader or example_inputs to prepare the model"

from torch.ao.quantization import MinMaxObserver

if ipex_ver.release >= Version("2.1.1").release:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=alpha, act_observer=MinMaxObserver
)
else: # pragma: no cover
if act_algo == "minmax":
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=alpha, act_observer=MinMaxObserver()
)
logger.warning(
"The int8 model accuracy will be close to 0 with MinMaxobserver, "
+ "the suggested IPEX version is higher or equal than 2.1.100+cpu."
)
else:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=alpha)

if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
else:
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)

simple_inference(model, example_inputs, iterations=1)
model.save_qconf_summary(qconf_summary=ipex_config_path)

map_op_name_to_fqn = {}
with open(ipex_config_path, "r") as f:
cfgs = json.load(f)
(
ops_name,
op_infos_from_cfgs,
input_tensor_id_op_name,
output_tensor_id_op_name,
) = parse_cfgs(cfgs)
quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name)
for name in quantizable_op_names:
# name : list
if len(name) == 1:
module_key = name[0][0]
op_cfg_id = name[0][2]
ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"]
module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None)

if ipex_op_type in unify_op_type_mapping_ipex:
quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type]))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
if re.match(pattern, ipex_op_type):
re_flag = True
quantizable_ops.append((tuple(name), unify_op_type))
map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn
break
if not re_flag:
quantizable_ops.append((tuple(name), ipex_op_type))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
else: # pragma: no cover
op_type = ""
for op_name in name:
module_key = op_name[0]
op_cfg_id = op_name[2]
single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"]
if single_op_type in unify_op_type_mapping_ipex:
single_op_type = unify_op_type_mapping_ipex[single_op_type]
op_type += "&" + single_op_type if op_type else single_op_type
quantizable_ops.append((tuple(name), op_type))
_module_key = name[0][0]
_op_cfg_id = name[0][2]
module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"]
map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn

logger.debug("Map op name to fqn: ")
logger.debug(map_op_name_to_fqn)
logger.info("Attention Blocks : ")
logger.info(attention_block)
logger.info("FFN Blocks : ")
logger.info(ffn_blocks)
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name


def check_cfg_and_qconfig(
tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name, smooth_quant=False
): # pragma: no cover
Expand Down Expand Up @@ -539,8 +659,6 @@ def calibrate(self, calib_iter, op_types=[torch.nn.Conv2d, torch.nn.Linear]): #


class GraphTrace: # pragma: no cover
""""""

def __init__(self):
self.supported_torch_module_to_aten = {
"Linear": "aten::linear",
Expand Down Expand Up @@ -729,7 +847,7 @@ def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers):


@register_autotune("version1")
class AutoAlpha:
class AutoAlpha: # pragma: no cover
def __init__(
self,
model,
Expand Down Expand Up @@ -1354,7 +1472,7 @@ def _auto_tune_alpha_blockwise(self):
return best_alphas


class TorchSmoothQuant:
class TorchSmoothQuant: # pragma: no cover
"""Fake input channel quantization, for more details please refer to
[1] SmoothQuant: Accurate and Efficient
Post-Training Quantization for Large Language Models
Expand Down Expand Up @@ -1929,7 +2047,7 @@ def _trace(self, op_types, skip_unsupported_layers=True):
return absorb_to_layer, no_absorb_layers


class SQLinearWrapper(torch.nn.Module):
class SQLinearWrapper(torch.nn.Module): # pragma: no cover
def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8):
super().__init__()
self.register_buffer("input_scale", input_scale)
Expand Down Expand Up @@ -1990,7 +2108,7 @@ def _recover_sq_linear(self):
self.sq_linear.weight *= scale


class WrapperLayer(torch.nn.Module):
class WrapperLayer(torch.nn.Module): # pragma: no cover
def __init__(self, layer, input_min, input_max, save_q_input=False):
super(WrapperLayer, self).__init__()
self.add_module("orig_layer", layer) # set orig_layer in get/set_module
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
op_infos_from_cfgs,
input_tensor_id_op_name,
output_tensor_id_op_name,
) = paser_cfgs(cfgs)
) = parse_cfgs(cfgs)
quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name)
for name in quantizable_op_names:
# name : list
Expand Down Expand Up @@ -426,7 +426,7 @@ def get_element_under_depth(d, ops_lst):
ops_lst.append(d)


def paser_cfgs(cfgs): # pragma: no cover
def parse_cfgs(cfgs): # pragma: no cover
"""Parse configs.

Args:
Expand Down
26 changes: 13 additions & 13 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,14 @@ def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode,
@register_algo(name=SMOOTH_QUANT)
@torch.no_grad()
def smooth_quant_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], SmoothQuantConfig], *args, **kwargs
model: torch.nn.Module,
configs_mapping: Dict[Tuple[str, callable], SmoothQuantConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs,
) -> torch.nn.Module:
logger.info("Quantize model with the smooth quant algorithm.")
from neural_compressor.torch.algorithms.smooth_quant import save, smooth_quantize
from neural_compressor.torch.algorithms.smooth_quant import SmoothQuantQuantizer, TorchSmoothQuant

# convert the user config into internal format
quant_config_mapping = {}
Expand Down Expand Up @@ -277,17 +281,13 @@ def smooth_quant_entry(
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
assert example_inputs is not None, "Please provide example_inputs for smooth quantization."
q_model = smooth_quantize(
model=model,
tune_cfg=quant_config_mapping,
run_fn=run_fn,
example_inputs=example_inputs,
inplace=inplace,
)
logger.info("Smooth quantization done.")
q_model.ori_save = q_model.save
q_model.save = MethodType(save, q_model)
return q_model
model.sq_info = TorchSmoothQuant(model, example_inputs=example_inputs, q_func=run_fn, record_max_info=True)
violetch24 marked this conversation as resolved.
Show resolved Hide resolved

quantizer = get_quantizer(model, quantizer_cls=SmoothQuantQuantizer, quant_config=quant_config_mapping)
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
postprocess_model(model, mode, quantizer)

return model


###################### AWQ Algo Entry ##################################
Expand Down
13 changes: 11 additions & 2 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,10 +1013,19 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
def get_model_info(
model: torch.nn.Module, alpha, act_algo, example_inputs, inplace=True
) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively

model_info, _, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
model_info, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(
model, alpha, act_algo, example_inputs=example_inputs, inplace=inplace
)
model.cfgs, model.op_infos_from_cfgs, model.output_tensor_id_op_name = (
cfgs,
op_infos_from_cfgs,
output_tensor_id_op_name,
)
return model_info

@classmethod
Expand Down
32 changes: 29 additions & 3 deletions neural_compressor/torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,23 @@ def quantize(
if is_ipex_available and (
isinstance(quant_config, StaticQuantConfig) or isinstance(quant_config, SmoothQuantConfig)
):
model_info = quant_config.get_model_info(q_model, example_inputs)
if isinstance(quant_config, SmoothQuantConfig):
from neural_compressor.torch.algorithms.smooth_quant import TorchSmoothQuant

sq = TorchSmoothQuant(
model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True
)
model = sq.transform(
alpha=quant_config.alpha,
folding=quant_config.folding,
auto_alpha_args=quant_config.auto_alpha_args,
scale_sharing=quant_config.scale_sharing,
)
model_info = quant_config.get_model_info(
q_model, quant_config.alpha, quant_config.act_algo, example_inputs, inplace=True
)
else:
model_info = quant_config.get_model_info(q_model, example_inputs)
else:
model_info = quant_config.get_model_info(model=q_model)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
Expand Down Expand Up @@ -122,7 +138,12 @@ def prepare(
if is_ipex_available and (
isinstance(quant_config, StaticQuantConfig) or isinstance(quant_config, SmoothQuantConfig)
):
model_info = quant_config.get_model_info(prepared_model, example_inputs)
if isinstance(quant_config, SmoothQuantConfig):
model_info = quant_config.get_model_info(
prepared_model, quant_config.alpha, quant_config.act_algo, example_inputs, inplace=True
)
else:
model_info = quant_config.get_model_info(prepared_model, example_inputs)
else:
model_info = quant_config.get_model_info(model=prepared_model)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
Expand Down Expand Up @@ -185,7 +206,12 @@ def convert(
if is_ipex_available and (
isinstance(quant_config, StaticQuantConfig) or isinstance(quant_config, SmoothQuantConfig)
):
model_info = quant_config.get_model_info(q_model, example_inputs)
if isinstance(quant_config, SmoothQuantConfig):
model_info = quant_config.get_model_info(
q_model, quant_config.alpha, quant_config.act_algo, example_inputs, inplace=True
violetch24 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
model_info = quant_config.get_model_info(q_model, example_inputs)
else:
model_info = quant_config.get_model_info(model=q_model)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
Expand Down
1 change: 1 addition & 0 deletions requirements_pt.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
auto-round
intel_extension_for_pytorch
overrides
peft==0.10.0
psutil
py-cpuinfo
Expand Down
Loading
Loading