Skip to content

Commit

Permalink
smoothquant alpha auto-tuning (#747)
Browse files Browse the repository at this point in the history
Signed-off-by: Lu, Yintong <[email protected]>
Signed-off-by: wenhuach21 <[email protected]>
Co-authored-by: Wang, Mengni <[email protected]>
Co-authored-by: yiliu30 <[email protected]>
  • Loading branch information
3 people authored Mar 29, 2023
1 parent 12ee49f commit 12c101f
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
parser.add_argument('--sq', action='store_true', default=False, help="whether to use smooth quant")
# parser.add_argument('--calib_num', type=int, default=100, help="calibration num for sq")
parser.add_argument('--model_name_or_path', type=str, default='bigscience/bloom-560m')
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--alpha', default=0.5, help="Set alpha=auto to use alpha tuning.")
parser.add_argument('--log_frequency', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--kl', action='store_true', default=False, help="whether to use kl divergence for calibration")
Expand All @@ -36,7 +36,7 @@ def evaluate(self, model):
index = 1
for input_ids, label, label_indices in tqdm(self.dataloader):
outputs = model(input_ids)
last_token_logits = outputs[0][:, label_indices, :]
last_token_logits = outputs[0][torch.arange(len(label_indices)), label_indices, :]
pred = last_token_logits.argmax(dim=-1)
total += label.size(0)
hit += (pred == label).sum().item()
Expand Down Expand Up @@ -148,7 +148,7 @@ def eval_func(model):
if args.kl:
op_type_dict = {'linear': {'activation': {'algorithm': ['kl']}}}

conf = PostTrainingQuantConfig(backend='ipex', excluded_precisions=["bf16"],
conf = PostTrainingQuantConfig(quant_level=1, backend='ipex', excluded_precisions=["bf16"],##use basic tuning
recipes=recipes,
op_type_dict=op_type_dict)

Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5,
return self.smooth_quant_model
from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment
from onnx import numpy_helper
if isinstance(alpha, str):
logger.warning(f"onnx backend only support float alpha, reset alpha to 0.5 ")
alpha = 0.5
black_nodes = []
white_nodes = []
if tune_cfg is not None:
Expand Down
207 changes: 197 additions & 10 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def model_forward(model, dataloader, iters):
for idx, input in enumerate(dataloader):
if isinstance(input, dict):
out = model(**input)
elif isinstance(input,list) or isinstance(input, tuple):
elif isinstance(input, list) or isinstance(input, tuple):
out = model(*input)
else:
out = model(input)
Expand All @@ -43,6 +43,64 @@ def model_forward(model, dataloader, iters):
break


def quant_dequant_w(m, num_bits=8, scheme='asym'):##TODO take sym as default
if isinstance(m, torch.nn.Linear):
x = m.weight
if scheme == 'sym':
q_min, q_max = -2. ** (num_bits - 1), 2. ** (num_bits - 1) - 1.
scale = torch.abs(torch.max(x, dim=1).values) / (2 ** (num_bits - 1) - 1)
else:
q_min, q_max = 0, 2. ** num_bits - 1.
scale = (torch.max(x, dim=1).values - torch.min(x, dim=1).values) / (2 ** num_bits - 1)
scale = torch.clip(scale, min=1e-5)

if scheme == 'sym':
bias = 0
else:
bias = torch.round(0 - (torch.min(x, dim=1).values) / scale)
bias = bias.unsqueeze(dim=-1)
scale = scale.unsqueeze(dim=-1)
q_x = x / scale + bias
q_x.clamp_(q_min, q_max).round_()
return (q_x - bias) * scale
elif isinstance(m, torch.nn.Conv2d):
x = m.weight
x = torch.permute(x, (0, 2, 3, 1))
x = x.reshape(-1, x.shape[-1])
if scheme == 'sym':
q_min, q_max = -2. ** (num_bits - 1), 2. ** (num_bits - 1) - 1.
scale = torch.abs(torch.max(x, dim=0).values) / (2 ** (num_bits - 1) - 1)
else:
q_min, q_max = 0, 2. ** num_bits - 1.
scale = (torch.max(x, dim=0).values - torch.min(x, dim=0).values) / (2 ** num_bits - 1)
scale = torch.clip(scale, min=1e-5)
if scheme == 'sym':
bias = 0
else:
bias = torch.round(0 - (torch.min(x, dim=0).values) / scale)
bias = bias.unsqueeze(dim=0)
scale = scale.unsqueeze(dim=0)

q_x = x / scale + bias
q_x.clamp_(q_min, q_max).round_()
q_dq_x = (q_x - bias) * scale
q_dq_x = q_dq_x.view(m.weight.shape[0], m.weight.shape[2], m.weight.shape[3], m.weight.shape[1])
q_dq_x = torch.permute(q_dq_x, (0, 3, 1, 2))
return q_dq_x
else:
logger.warning("unsupported layer type, please have a check")


def quant_dequant_x(x, num_bits=8):
q_min, q_max = 0, 2. ** num_bits - 1.
scale = (torch.max(x) - torch.min(x)) / (2 ** num_bits - 1)
scale = torch.clip(scale, min=1e-5)
bias = torch.round(0 - (torch.min(x)) / scale)
q_x = x / scale + bias
q_x.clamp_(q_min, q_max).round_()
return scale * (q_x - bias)


class TorchSmoothQuant:
"""
Fake input channel quantization, for more details please refer to
Expand All @@ -67,7 +125,11 @@ def __init__(self, model, dataloader, traced_model=None):
self.device = device
self.dtype = dtype
self.dataloader = dataloader
self.input_values = {}
self.output_values = {}
self.input_maxes = {}
self.hook_layer_names = []
self.hook_values_handles = []
self.traced_model = traced_model
if self.traced_model == None:
self.traced_model = self.model
Expand Down Expand Up @@ -115,10 +177,30 @@ def save_input_hook(module, inputs, outputs):
input = input.reshape(-1, input.shape[-1])
max_tensor = torch.max(input, dim=0)[0]
self.input_maxes[name].append(max_tensor)
# self.input_values[name] = input
# self.output_values[name] = outputs

return save_input_hook

def _add_observer(self, modules):
def _save_input_output_hook(self, name):
"""
A forward hook to save input and output values of a module
param name: the module name
return: A hook function
"""

def save_input_output_hook(module, inputs, outputs):
input = inputs[0]
# if name in self.input_values:
# self.input_values[name].append(input)
# self.output_values[name].append(outputs)
# else:
self.input_values[name] = [input]##TODO save more,like 8
self.output_values[name] = [outputs]##TODO do not save output

return save_input_output_hook

def _add_observer(self, modules, input_output_modules=None):
"""
:param modules: the modules which the observer will insert to
:return:
Expand All @@ -128,6 +210,11 @@ def _add_observer(self, modules):
hook_func = self._save_input_pc_hook(key)
hook_handle = modules[key].register_forward_hook(hook_func)
self.hook_handles.append(hook_handle)
if input_output_modules:
for key in input_output_modules.keys():
hook_func = self._save_input_output_hook(key)
hook_handle = input_output_modules[key].register_forward_hook(hook_func)
self.hook_values_handles.append(hook_handle)

def _remove_observer(self):
"""
Expand All @@ -136,6 +223,9 @@ def _remove_observer(self):
"""
for hook_handle in self.hook_handles:
hook_handle.remove()
if self.hook_values_handles:
for hook_handle in self.hook_values_handles:
hook_handle.remove()

# ##https://gist.github.com/sailfish009/28b54c8aa6398148a6358b8f03c0b611
# def percentile(t: torch.tensor, q: float):
Expand All @@ -158,7 +248,7 @@ def _remove_observer(self):
# result = t.view(-1).kthvalue(k).values.item()
# return result

def _calibrate(self, absorb_to_layer, calib_iter):
def _calibrate(self, absorb_to_layer, calib_iter, save_input_output=False):
"""
:param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer
:param calib_iter: Data size for calibration
Expand All @@ -185,8 +275,10 @@ def _calibrate(self, absorb_to_layer, calib_iter):
hook_modules[name] = module
if len(hook_modules) == 0:
return {}

self._add_observer(hook_modules)
hook_modules_input_output = {}
for name in self.hook_layer_names:
hook_modules_input_output[name] = self._get_module(name)
self._add_observer(hook_modules, hook_modules_input_output)
self._dump_min_max(calib_iter=calib_iter)
self._remove_observer()
return self.input_maxes
Expand All @@ -205,6 +297,9 @@ def _dump_min_max(self, calibration_method="min_max", calib_iter=100):
val = torch.stack(val, dim=0)
val = torch.max(torch.abs(val), dim=0)[0] ##FIXME should add abs
self.input_maxes[key] = val
for key in self.input_values.keys():
self.input_values[key] = torch.cat(self.input_values[key], dim=0) ##this may introduce memory issue
self.output_values[key] = torch.cat(self.output_values[key], dim=0)

def _reshape_in_channel_to_last(self, layer_name):
"""
Expand Down Expand Up @@ -363,7 +458,91 @@ def _check_same_hyperparameters(self, percentile, op_types,
else:
return True

def transform(self, alpha=0.5, percentile=99.999, op_types=['Linear', 'Conv2d', 'ConvTranspose2d'],
def auto_tune_alpha(self, input_maxes, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, attn_method='min'):
"""
Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.
input_maxes:
alpha_min: min value of alpha search space.
alpha_max: max value of alpha search space.
alpha_step: step size of alpha search space.
attn_method: criterion method used on attention ops; currently min, max and mean are supported.
"""
logger.info("enter auto")
import copy
alpha_scale = 100
alpha_values = list(range(round(alpha_min * alpha_scale), round((alpha_max + alpha_step) * alpha_scale),
round(alpha_step * alpha_scale)))
ans_layer2absorb, self.layer_to_absorb, ans = {}, {}, {}
## Searching optimal alphas
for idx, key in enumerate(self.absorb_to_layer):
absorb_to_layer_sample, input_max_op = {}, {}
absorb_key = key
absorb_to_layer_sample[absorb_key] = self.absorb_to_layer[absorb_key]
loss_all_layers = {}
for layer_key in self.absorb_to_layer[absorb_key]:
if layer_key not in self.layer_to_absorb.values():
if layer_key in input_maxes:
self.layer_to_absorb[absorb_key] = layer_key
layer_key_ = self.layer_to_absorb[absorb_key]
input_max_op[layer_key] = input_maxes[layer_key_]
loss_alpha = {}
for alpha in alpha_values:
alpha = alpha / alpha_scale
self.weight_scale_info, self.absorb_scales_info = self._adjust_parameters(absorb_to_layer_sample,
input_max_op, alpha)
input_of_op, output_of_op = self.input_values[layer_key], self.output_values[layer_key]
# if output_of_op.ndim == 3:
# output_of_op = output_of_op[0]
input_of_op_q = quant_dequant_x(input_of_op * self.absorb_scales_info[absorb_key])
layer = self._get_module(layer_key)
weight_qdq = quant_dequant_w(layer)
layer_cp = copy.deepcopy(layer)
layer_cp.weight.data = weight_qdq
output_of_op_q = layer_cp(input_of_op_q)
self.recover()
loss = torch.sum(torch.abs(output_of_op - output_of_op_q) ** 2)
loss_alpha[alpha] = loss
if layer_key not in ans: # Update alpha results
ans[layer_key] = alpha
else:
ans[layer_key] = alpha if loss < loss_alpha[ans[layer_key]] else ans[layer_key]
loss_all_layers[layer_key] = loss_alpha
if absorb_key not in ans_layer2absorb:
ans_layer2absorb[absorb_key] = ans[layer_key]
else:
if attn_method == 'max':
ans_layer2absorb[absorb_key] = max(ans_layer2absorb[absorb_key], ans[layer_key])
elif attn_method == 'min':
ans_layer2absorb[absorb_key] = min(ans_layer2absorb[absorb_key], ans[layer_key])
elif attn_method == 'mean':
pass
if attn_method == 'mean':
mean_loss = {}
for alpha in alpha_values:
alpha = alpha / alpha_scale
mean_loss[alpha] = 0
for key in loss_all_layers.keys():
mean_loss[alpha] += loss_all_layers[key][alpha]
min_alpha = min(mean_loss, key=mean_loss.get)
if len(loss_all_layers) > 1:
ans_layer2absorb[absorb_key] = min_alpha

for idx, key in enumerate(self.absorb_to_layer): # Adjust parameters according to optimal alphas.
absorb_to_layer_sample, input_max_op = {}, {}
absorb_key = key
absorb_to_layer_sample[absorb_key] = self.absorb_to_layer[absorb_key]
layer_key_ = self.layer_to_absorb[absorb_key]
input_max_op[layer_key_] = input_maxes[layer_key_]
if key in ans_layer2absorb:
op_weight_scale, op_absorb_scale = self._adjust_parameters(absorb_to_layer_sample, input_max_op,
alpha=ans_layer2absorb[key])
else:
op_weight_scale, op_absorb_scale = self._adjust_parameters(absorb_to_layer_sample, input_max_op)
self.weight_scale_info.update(op_weight_scale)
self.absorb_scales_info.update(op_absorb_scale)
self.input_values, self.output_values = {}, {}

def transform(self, alpha=0.5, percentile=99.999, op_types=['Linear', 'Conv2d'],
scales_per_op=False, calib_iter=100):
"""
The main entry of smooth quant
Expand All @@ -386,18 +565,26 @@ def transform(self, alpha=0.5, percentile=99.999, op_types=['Linear', 'Conv2d',
self.recover()
self.absorb_to_layer, no_absorb_layers = self._trace(
op_types) ##TODO we need to insert mul layer for no_absorb_layers later
for key in self.absorb_to_layer:
self.hook_layer_names += self.absorb_to_layer[key]
if self.absorb_to_layer == None and no_absorb_layers == None:
logger.warning("sorry, could not trace the model, smooth quant is ignored")
logger.warning("if you are using huggingface model,"
"you could set torchscript to True "
"when loading the model or set the return_dict to False")
return self.model
save_input_output = False
if alpha == "auto":
save_input_output = True

input_maxes = self._calibrate(self.absorb_to_layer, calib_iter)
input_maxes = self._calibrate(self.absorb_to_layer, calib_iter, save_input_output)

self.recover()
self.weight_scale_info, self.absorb_scales_info = self._adjust_parameters(self.absorb_to_layer, input_maxes,
alpha)
if alpha == 'auto':
self.auto_tune_alpha(input_maxes)
else:
self.weight_scale_info, self.absorb_scales_info = self._adjust_parameters(self.absorb_to_layer,
input_maxes, alpha)
return self.model

def recover(self):
Expand Down Expand Up @@ -459,7 +646,7 @@ def __init__(self):
"LayerNorm": "aten::layer_norm",
"BatchNorm2d": "aten::batch_norm",
"GroupNorm": "aten::group_norm",
"InstanceNorm2d": "instance_norm",
"InstanceNorm2d": "aten::instance_norm",
"LlamaRMSNorm": "aten::mul",
"T5LayerNorm": "aten::mul",
}
Expand Down
6 changes: 5 additions & 1 deletion neural_compressor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,11 @@ def smooth_quant_args(val=None):
_check_value("smooth_quant_args", val, dict)
for k, v in val.items():
if k == "alpha":
_check_value("alpha", v, float)
if isinstance(v, str):
assert v == "auto", "the alpha of sq only supports float and 'auto'"
else:
_check_value("alpha", v, float)

return True
else:
return {}
Expand Down
21 changes: 10 additions & 11 deletions neural_compressor/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,22 +515,21 @@ def apply_recipe_one_by_one(self, tune_cfg):
For recipes only have two options, apply the last one.
For recipes with multiple values. such as alpha of smooth quant, apply it one by one.
"""
from .utils.tuning_sampler import TuningSamplerRegistry
all_registered_samplers = TuningSamplerRegistry.sampler_dict
for recipe_name, recipe_vals in self._tuning_recipes.items():
if recipe_name in FALLBACK_RECIPES_SET and 'recipes_ops' in self.capability and \
len(self.capability['recipes_ops'].get(recipe_name, [])) > 0:
logger.info(f"Applied recipe {recipe_name} with value {recipe_vals[-1]}")
new_tune_cfg = self._fallback_ops(copy.deepcopy(tune_cfg), \
self.capability['recipes_ops'][recipe_name], self.tuning_space)
yield new_tune_cfg
if recipe_name in all_registered_samplers:
recipe_sampler = all_registered_samplers[recipe_name](tuning_space=None,
tuning_order_lst=[],
initial_op_tuning_cfg=copy.deepcopy(tune_cfg),
kwargs={recipe_name: recipe_vals})
for new_tune_cfg in recipe_sampler:
yield new_tune_cfg
if recipe_name == "smooth_quant":
sq_args = {'smooth_quant': True}
if 'recipe_cfgs' not in new_tune_cfg:
new_tune_cfg['recipe_cfgs'] = sq_args
else:
new_tune_cfg['recipe_cfgs'].update(sq_args)
new_tune_cfg['recipe_cfgs'] = sq_args
yield new_tune_cfg

def set_param_for_pre_quantization_algos(self, algo_scheduler, tune_cfg, fp32_model) -> None:
"""Set the parameter for pre-quantization algos, such as smooth quantization.
Expand All @@ -549,9 +548,9 @@ def set_param_for_pre_quantization_algos(self, algo_scheduler, tune_cfg, fp32_mo
if recipe_cfgs and recipe_cfgs.get('smooth_quant', False):
# skip assign alpha to sq first.
# set the alpha to 0.5 by default
# smooth_quant_args = recipe_cfgs.get('smooth_quant_args', {'alpha': 0.5})
smooth_quant_args = recipe_cfgs.get('smooth_quant_args', {'alpha': 0.5})
sq_algo = ALGORITHMS()['smooth_quant']
#sq_algo.alpha = smooth_quant_args['alpha']
sq_algo.alpha = smooth_quant_args['alpha']
#logger.debug(f"Set smooth quant with alpha {smooth_quant_args['alpha']} as the pre-quantization algo.")
algo_scheduler.append_algorithm('pre_quantization', sq_algo)

Expand Down
Loading

0 comments on commit 12c101f

Please sign in to comment.