Skip to content

Commit

Permalink
[SmoothQuant] make weight_clipping a default_on option (#1386)
Browse files Browse the repository at this point in the history
Signed-off-by: Lu, Yintong <[email protected]>
  • Loading branch information
yintong-lu authored Nov 16, 2023
1 parent 35086eb commit 1f4aec3
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 4 deletions.
1 change: 1 addition & 0 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def smooth_quant(
op_types=["MatMul", "Gemm", "Conv", "FusedConv"],
scales_per_op=True,
record_max_info=False,
weight_clip=True,
):
"""Get augmented model with smooth quant.
Expand Down
17 changes: 13 additions & 4 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,7 @@ def smooth_quant(
scales_per_op=None,
force_re_smooth=False,
record_max_info=False,
weight_clip=True,
):
"""Convert the model by smooth quant.
Expand Down Expand Up @@ -1804,7 +1805,9 @@ def smooth_quant(
kwargs["percentile"] = percentile
if scales_per_op is not None:
kwargs["scales_per_op"] = scales_per_op
model._model = self.sq.transform(alpha=alpha, folding=folding, calib_iter=calib_iter, **kwargs)
model._model = self.sq.transform(
alpha=alpha, folding=folding, calib_iter=calib_iter, weight_clip=weight_clip, **kwargs
)
if self.sq.record_max_info:
model.sq_max_info = self.sq.max_value_info
return model
Expand Down Expand Up @@ -1833,7 +1836,9 @@ def _apply_pre_optimization(self, model, tune_cfg, recover=False):
absorb_layer = op_name
absorbed_layer = info["absorbed_layer"]
input_minmax = info["input_minmax"]
weight_max = info["weight_max"].clamp(min=1e-5)
weight_max = info["weight_max"]
if self.sq.weight_clip:
weight_max = weight_max.clamp(min=1e-5)
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
input_power = torch.pow(abs_input_max, alpha)
weight_power = torch.pow(weight_max, 1 - alpha)
Expand Down Expand Up @@ -1877,7 +1882,9 @@ def qdq_quantize(self, model, tune_cfg):
alpha = info["alpha"]
absorbed_layer = info["absorbed_layer"]
input_minmax = info["input_minmax"]
weight_max = info["weight_max"].clamp(min=1e-5)
weight_max = info["weight_max"]
if self.sq.weight_clip:
weight_max = weight_max.clamp(min=1e-5)
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
input_power = torch.pow(abs_input_max, alpha)
weight_power = torch.pow(weight_max, 1 - alpha)
Expand Down Expand Up @@ -3279,7 +3286,9 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
absorbed_layer = info["absorbed_layer"]
input_minmax = info["input_minmax"]
# for peft model,lora_B weights is 0.
weight_max = info["weight_max"].clamp(min=1e-5)
weight_max = info["weight_max"]
if self.sq.weight_clip:
weight_max = weight_max.clamp(min=1e-5)
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
input_power = torch.pow(abs_input_max, alpha)
weight_power = torch.pow(weight_max, 1 - alpha)
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,7 @@ def smooth_quant(
op_types=["MatMul", "Conv2D"],
scales_per_op=True,
record_max_info=False,
weight_clip=True,
):
"""Convert the model by smooth quant.
Expand Down
5 changes: 5 additions & 0 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, tra
self.self_absorb_layers = {}
self.absorb_to_layer = {}
self.adjust_alpha_space = False
self.weight_clip = True

def _get_device(self):
"""Get the model device
Expand Down Expand Up @@ -577,6 +578,8 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
weights.append(weight)

weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0]
if self.weight_clip:
weight_max_per_channel = weight_max_per_channel.clamp(min=1e-5)
if self.record_max_info and not tuning:
# the input of layers with same absorb layer is the same.
input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]]
Expand Down Expand Up @@ -946,6 +949,7 @@ def transform(
scales_per_op=False,
calib_iter=100,
auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"},
weight_clip=True,
):
"""The main entry of smooth quant
:param alpha: Alpha value to balance the quantization difficulty of activation and weight, please refer
Expand All @@ -971,6 +975,7 @@ def transform(

alpha = numpy.clip(alpha, 0.0, 1.0)

self.weight_clip = weight_clip
self.recover()
need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter)
with torch.no_grad():
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/algorithm/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, alpha=0.5):
self.op_types = None
self.scales_per_op = None
self.tune_cfg = None
self.weight_clip = None

def __call__(self, origin_model, q_model, adaptor, dataloader, calib_iter):
"""Return the processed model via SmoothQuant algorithm.
Expand Down Expand Up @@ -80,6 +81,7 @@ def __call__(self, origin_model, q_model, adaptor, dataloader, calib_iter):
kwargs["scales_per_op"] = self.scales_per_op
kwargs["folding"] = self.folding
kwargs["record_max_info"] = True
kwargs["weight_clip"] = self.weight_clip
q_model = adaptor.smooth_quant(
origin_model,
dataloader,
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,9 @@ def set_param_for_pre_tuning_algos(self, algo_scheduler, config, fp32_model) ->
if self.framework == "pytorch_ipex":
smooth_quant_args["folding"] = None # will reset it to True if IPEX version < 2.1.
sq_algo.folding = smooth_quant_args["folding"]
sq_algo.weight_clip = smooth_quant_args.get(
"weight_clip", True
) # make weight_clipping a default_on option.
logger.debug(f"Set smooth quant with alpha {sq_algo.alpha} as the pre-tuning algo.")
algo_scheduler.append_algorithm("pre_quantization", sq_algo)

Expand Down
36 changes: 36 additions & 0 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,5 +1427,41 @@ def calib_func(model):
out2 = q_model.model(example_input)[0]


class TestInputConfig(unittest.TestCase):
@classmethod
def setUpClass(self):
class RandDataloader:
def __init__(self):
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3))

self.linear_dl = RandDataloader()

@classmethod
def test_sq_weight_clipping(self):
class Model(torch.nn.Module):
device = torch.device("cpu")

def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(3, 4)
self.norm = LlamaRMSNorm(4)
self.fc2 = torch.nn.Linear(4, 3)

def forward(self, x):
out = self.fc1(x)
out = self.norm(out)
out = self.fc2(out)
return out

model = Model()

sq = TorchSmoothQuant(model, self.linear_dl)
sq.transform(alpha="auto", calib_iter=1, folding=True, weight_clip=False)
assert sq.weight_clip is False


if __name__ == "__main__":
unittest.main()

0 comments on commit 1f4aec3

Please sign in to comment.