Skip to content

Commit

Permalink
[SmoothQuant] make alpha search space a config argument (#1392)
Browse files Browse the repository at this point in the history
Enhance SmoothQuant with tunable alpha search space

Signed-off-by: Lu, Yintong <[email protected]>
  • Loading branch information
yintong-lu authored Nov 16, 2023
1 parent 163e32a commit f9663d0
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 2 deletions.
6 changes: 6 additions & 0 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def smooth_quant(
scales_per_op=True,
record_max_info=False,
weight_clip=True,
auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"},
default_alpha=0.5,
):
"""Get augmented model with smooth quant.
Expand All @@ -187,6 +189,10 @@ def smooth_quant(
scales_per_op (bool): True, each op will have an individual scale, mainly for accuracy
False, ops with the same input will share a scale, mainly for performance
record_max_info (bool): False, whether record the scale information
weight_clip: Whether to clip weight when calculating scales; by default it is on.
auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning.
By default the search space is 0.0-1.0 with step_size 0.1.
default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5.
Returns:
model: A modified onnx model
Expand Down
14 changes: 13 additions & 1 deletion neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,8 @@ def smooth_quant(
force_re_smooth=False,
record_max_info=False,
weight_clip=True,
auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"},
default_alpha=0.5,
):
"""Convert the model by smooth quant.
Expand All @@ -1776,6 +1778,10 @@ def smooth_quant(
scales_per_op: True, each op will have an individual scale, mainly for accuracy
False, ops with the same input will share a scale, mainly for performance
record_max_info: whether record the max info in model for alpha tuning.
weight_clip: Whether to clip weight when calculating scales; by default it is on.
auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning.
By default the search space is 0.0-1.0 with step_size 0.1.
default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5.
Returns:
model: A modified fp32 model, inplace=True.
Expand Down Expand Up @@ -1806,7 +1812,13 @@ def smooth_quant(
if scales_per_op is not None:
kwargs["scales_per_op"] = scales_per_op
model._model = self.sq.transform(
alpha=alpha, folding=folding, calib_iter=calib_iter, weight_clip=weight_clip, **kwargs
alpha=alpha,
folding=folding,
calib_iter=calib_iter,
weight_clip=weight_clip,
default_alpha=default_alpha,
auto_alpha_args=auto_alpha_args,
**kwargs,
)
if self.sq.record_max_info:
model.sq_max_info = self.sq.max_value_info
Expand Down
6 changes: 6 additions & 0 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,8 @@ def smooth_quant(
scales_per_op=True,
record_max_info=False,
weight_clip=True,
auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"},
default_alpha=0.5,
):
"""Convert the model by smooth quant.
Expand All @@ -1838,6 +1840,10 @@ def smooth_quant(
scales_per_op: True, each op will have an individual scale, mainly for accuracy
False, ops with the same input will share a scale, mainly for performance
record_max_info: whether record the max info in model for alpha tuning.
weight_clip: Whether to clip weight when calculating scales; by default it is on.
auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning.
By default the search space is 0.0-1.0 with step_size 0.1.
default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5.
Returns:
model: A smoothed Tensorflow model
Expand Down
13 changes: 12 additions & 1 deletion neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, tra
self.absorb_to_layer = {}
self.adjust_alpha_space = False
self.weight_clip = True
self.default_alpha = 0.5

def _get_device(self):
"""Get the model device
Expand Down Expand Up @@ -851,6 +852,7 @@ def _auto_tune_alpha(
default_alpha = alpha_space[len(alpha_space) // 2]
if 0.5 in alpha_space:
default_alpha = 0.5
default_alpha = self.default_alpha
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, default_alpha, tuning=True
)
Expand Down Expand Up @@ -950,6 +952,7 @@ def transform(
calib_iter=100,
auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"},
weight_clip=True,
default_alpha=0.5,
):
"""The main entry of smooth quant
:param alpha: Alpha value to balance the quantization difficulty of activation and weight, please refer
Expand All @@ -959,8 +962,14 @@ def transform(
:param op_types: The op typed to be smooth quantized
:param scales_per_op: Not supported now
:param calib_iter: Data size for calibration
:param weight_clip: Whether to clip weight_max when calculating scales.
:param auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning.
By default the search space is 0.0-1.0 with step_size 0.1.
:param default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5.
:return: A FP32 model with the same architecture as the orig model but with different weight which will be
benefit to quantization."""
benefit to quantization.
"""
if not isinstance(self.model, torch.nn.Module):
logger.warning("smooth quant is ignored since the model is not a torch module")
return self.model
Expand All @@ -976,6 +985,8 @@ def transform(
alpha = numpy.clip(alpha, 0.0, 1.0)

self.weight_clip = weight_clip
self.default_alpha = default_alpha
self.auto_alpha_args = auto_alpha_args
self.recover()
need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter)
with torch.no_grad():
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/algorithm/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, alpha=0.5):
self.scales_per_op = None
self.tune_cfg = None
self.weight_clip = None
self.auto_alpha_args = None
self.default_alpha = None

def __call__(self, origin_model, q_model, adaptor, dataloader, calib_iter):
"""Return the processed model via SmoothQuant algorithm.
Expand Down Expand Up @@ -82,6 +84,8 @@ def __call__(self, origin_model, q_model, adaptor, dataloader, calib_iter):
kwargs["folding"] = self.folding
kwargs["record_max_info"] = True
kwargs["weight_clip"] = self.weight_clip
kwargs["auto_alpha_args"] = self.auto_alpha_args
kwargs["default_alpha"] = self.default_alpha
q_model = adaptor.smooth_quant(
origin_model,
dataloader,
Expand Down
6 changes: 6 additions & 0 deletions neural_compressor/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,12 @@ def set_param_for_pre_tuning_algos(self, algo_scheduler, config, fp32_model) ->
sq_algo.weight_clip = smooth_quant_args.get(
"weight_clip", True
) # make weight_clipping a default_on option.
sq_algo.auto_alpha_args = smooth_quant_args.get(
"auto_alpha_args", {"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"}
) # default alpha search space parameters.
sq_algo.default_alpha = smooth_quant_args.get(
"default_alpha", 0.5
) # default value for alpha in auto-tuning
logger.debug(f"Set smooth quant with alpha {sq_algo.alpha} as the pre-tuning algo.")
algo_scheduler.append_algorithm("pre_quantization", sq_algo)

Expand Down
30 changes: 30 additions & 0 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,36 @@ def forward(self, x):
sq.transform(alpha="auto", calib_iter=1, folding=True, weight_clip=False)
assert sq.weight_clip is False

@classmethod
def test_sq_auto_alpha_arg(self):
class Model(torch.nn.Module):
device = torch.device("cpu")

def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(3, 4)
self.norm = LlamaRMSNorm(4)
self.fc2 = torch.nn.Linear(4, 3)

def forward(self, x):
out = self.fc1(x)
out = self.norm(out)
out = self.fc2(out)
return out

model = Model()

sq = TorchSmoothQuant(model, self.linear_dl)
sq.transform(
alpha="auto",
calib_iter=1,
folding=False,
auto_alpha_args={"alpha_min": 0.5, "alpha_max": 0.9, "alpha_step": 0.1, "shared_criterion": "mean"},
default_alpha=0.7,
)
assert sq.default_alpha == 0.7
assert sq.auto_alpha_args["alpha_min"] == 0.5


if __name__ == "__main__":
unittest.main()

0 comments on commit f9663d0

Please sign in to comment.