From 641d42b2ebf873e87aa7d5bb0b2fcd518550022f Mon Sep 17 00:00:00 2001 From: xinhe Date: Wed, 16 Aug 2023 14:10:52 +0800 Subject: [PATCH] Refactor AWQ algo to enhance memory, computation and support foldng=False (#1130) Refactor AWQ algo to enhance memory, computation and support foldng=False Signed-off-by: Xin He --------- Signed-off-by: Xin He Signed-off-by: Lv, Kaokao Co-authored-by: Lv, Kaokao --- .../scripts/codeScan/pyspelling/inc_dict.txt | 5 +- docs/source/quantization_weight_only.md | 10 +- neural_compressor/adaptor/pytorch.py | 88 ++-- neural_compressor/adaptor/torch_utils/awq.py | 435 ++++++++++++++++++ .../adaptor/torch_utils/model_wrapper.py | 34 +- .../adaptor/torch_utils/smooth_quant.py | 8 +- neural_compressor/adaptor/torch_utils/teq.py | 14 +- neural_compressor/adaptor/torch_utils/util.py | 215 ++++++++- .../adaptor/torch_utils/weight_only.py | 363 ++------------- neural_compressor/model/torch_model.py | 4 +- neural_compressor/utils/pytorch.py | 47 +- .../test_weight_only_adaptor.py | 140 +++++- test/quantization/test.py | 31 ++ .../test_weight_only_quantization.py | 78 ++-- 14 files changed, 998 insertions(+), 474 deletions(-) create mode 100644 neural_compressor/adaptor/torch_utils/awq.py create mode 100644 test/quantization/test.py diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt index 8314ec2c153..b14314002c9 100644 --- a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt +++ b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt @@ -2659,9 +2659,11 @@ classDef bdf bmm AWQ +awq GPTQ +gptq RTN -awq +rtn gptq percdamp Frantar @@ -2693,6 +2695,7 @@ hostname qweight qconfig TEQ +teq WeightOnlyLinear McKinstry Migacz diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index 4140b51db27..3b2b5fa888c 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -40,12 +40,18 @@ There are many excellent works for weight only quantization to improve its accur | scheme | ['asym', 'sym'] | | algorithm | ['RTN', 'AWQ'] | +**RTN arguments**: +| rtn_args | default value | comments | +|:----------:|:-------------:|:-------------------------------------------------------------------:| +| sym_full_range | False | Whether use -2**(bits-1) in sym scheme, for example, | +| return_int | False | Whether return compressed model with int data type | + **AWQ arguments**: | awq_args | default value | comments | |:----------:|:-------------:|:-------------------------------------------------------------------:| -| auto_scale | True | Whether search for best scales based on activation distribution | +| auto_scale | True | Whether search for best scales based on activation distribution | | mse_range | True | Whether search for the best clip range from range [0.89, 1.0, 0.01] | -| n_blocks | 5 | Split the model into n blocks for AWQ search to avoid out-of-memory | +| folding | False | False will allow insert mul before linear when the scale cannot be absorbed by last layer, else won't | **Note**: `group_size=-1` indicates the per-channel quantization per output channel. `group_size=[1-N]` indicates splitting the input channel elements per group_size. diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index ddd2b4b67db..8529de2868d 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4313,6 +4313,8 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None): else: algorithm = config['weight']['algorithm'] all_algo.add(algorithm) + if len(all_algo): + logger.info(f"All algorithms to do: {all_algo}") if 'GPTQ' in all_algo: q_model._model, gptq_config = self.gptq_quantize( q_model._model, tune_cfg, dataloader @@ -4322,7 +4324,7 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None): q_model._model = self.teq_quantize(q_model._model, tune_cfg, dataloader, calib_func) if 'AWQ' in all_algo: # includes RTN in AWQ q_model._model = self.awq_quantize(q_model._model, tune_cfg, dataloader, calib_func) - elif 'RTN' in all_algo: + if 'RTN' in all_algo: q_model._model = self.rtn_quantize(q_model._model, tune_cfg) q_model.q_config = copy.deepcopy(self.tune_cfg) @@ -4331,7 +4333,7 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None): return q_model def rtn_quantize(self, model, tune_cfg): - logger.debug("quantizing with the round-to-nearest algorithm") + logger.info("quantizing with the round-to-nearest algorithm") if 'rtn_args' in self.recipes: sym_full_range = self.recipes['rtn_args'].get('sym_full_range', False) else: @@ -4357,7 +4359,7 @@ def rtn_quantize(self, model, tune_cfg): return model def gptq_quantize(self, model, tune_cfg, dataloader): - logger.debug("quantizing with the GPTQ algorithm") + logger.info("quantizing with the GPTQ algorithm") from .torch_utils.weight_only import gptq_quantize # convert tune_cfg to gptq_quantize's weight config """please refer to weight_config which can be analyzed by user-define API function weight_only.gptq_quantize @@ -4403,7 +4405,7 @@ def gptq_quantize(self, model, tune_cfg, dataloader): return model, quantization_perm def teq_quantize(self, model, tune_cfg, dataloader, calib_func): - logger.debug("quantizing with the TEQ algorithm") + logger.info("quantizing with the TEQ algorithm") from .torch_utils.weight_only import teq_quantize # get example inputs if not provided. if self.example_inputs is None: # pragma: no cover @@ -4490,90 +4492,52 @@ def teq_quantize(self, model, tune_cfg, dataloader, calib_func): return model def awq_quantize(self, model, tune_cfg, dataloader, calib_func): - logger.debug("quantizing with the AWQ algorithm") + logger.info("quantizing with the AWQ algorithm") from .torch_utils.weight_only import awq_quantize # get example inputs if not provided. if self.example_inputs is None: - if dataloader is None: - assert False, "Please provide dataloader or example_inputs for AWQ algorithm." - try: - for idx, (input, label) in enumerate(dataloader): - self.example_inputs = input - break - except: - for idx, input in enumerate(dataloader): - self.example_inputs = input - break + from neural_compressor.adaptor.torch_utils.util import get_example_input + assert dataloader is not None, "datalaoder or example_inputs is required." + self.example_inputs = get_example_input(dataloader) - # get modules that can be absorbed. - from .torch_utils.smooth_quant import GraphTrace - tg = GraphTrace() - supported_layers = ['Linear'] - absorb_to_layer, _ = tg.get_absorb_to_layer(model, self.example_inputs, supported_layers) - if absorb_to_layer is None or absorb_to_layer == {}: - logger.warning('No absorb layer is detected, skip AWQ algorithm') - return model - - # got flipped dict from absorb_to_layer dict - flipped_dict = {} - for k, v in absorb_to_layer.items(): - for m in v: - flipped_dict[m] = {'absorb_layer': k} - - # check tune_cfg to skip layers without AWQ config + # build weight_config weight_config = {} - skipped_op_name_set = set() for key, config in tune_cfg['op'].items(): op_name, op_type = key if config['weight']['dtype'] == 'fp32': - if op_name in flipped_dict: - absorb_to_layer.pop(flipped_dict[op_name]['absorb_layer']) - continue + weight_config[op_name] = { + 'bits': -1, # skip quantization + 'group_size': 128, + 'scheme': 'asym', + 'algorithm': 'RTN', + } else: - weight_config[op_name] = {} - weight_config[op_name]['bits'] = config['weight']['bits'] - weight_config[op_name]['group_size'] = config['weight']['group_size'] - weight_config[op_name]['scheme'] = config['weight']['scheme'] - if op_name in flipped_dict: - algorithm = config['weight']['algorithm'] - if algorithm != 'AWQ': - absorb_to_layer.pop(weight_config[op_name]['absorb_layer']) - else: - skipped_op_name_set.add(op_name) - if skipped_op_name_set: - logger.info("{} is skipped by AWQ algorithm".format(skipped_op_name_set)) - - # collect AWQ config from tune_cfg for quantization. - if len(absorb_to_layer) == 0: - logger.warning('No absorb layer needs AWQ algorithim, skip it') - else: - logger.debug("**absorb layer**: **absorbed layers**") - for k, v in absorb_to_layer.items(): - logger.debug(f"{k}: {v}") - logger.info("Absorbed layers with the same absorb layer use the same config") + weight_config[op_name] = config['weight'] if 'awq_args' in self.recipes: auto_scale = self.recipes['awq_args'].get('auto_scale', True) mse_range = self.recipes['awq_args'].get('mse_range', True) - n_blocks = self.recipes['awq_args'].get('n_blocks', 5) + folding = self.recipes['awq_args'].get('folding', False) else: - auto_scale, mse_range = True, True + auto_scale, mse_range, folding = True, True, False if 'rtn_args' in self.recipes: sym_full_range = self.recipes['rtn_args'].get('sym_full_range', False) + return_int = self.recipes['rtn_args'].get('return_int', False) else: - sym_full_range=False + sym_full_range, return_int = False, False calib_sampling_size = tune_cfg.get('calib_sampling_size', 1) model = awq_quantize( model, + bits=-1, # no quantize for op not in weight_config + example_inputs=self.example_inputs, weight_config=weight_config, - absorb_dict=absorb_to_layer, dataloader=dataloader, n_samples=calib_sampling_size, auto_scale=auto_scale, mse_range=mse_range, calib_func=calib_func, - n_blocks=n_blocks, - return_int=False, + folding=folding, + return_int=return_int, sym_full_range=sym_full_range, ) return model diff --git a/neural_compressor/adaptor/torch_utils/awq.py b/neural_compressor/adaptor/torch_utils/awq.py new file mode 100644 index 00000000000..6249efffb73 --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/awq.py @@ -0,0 +1,435 @@ +import torch +import copy +from neural_compressor.adaptor.torch_utils.util import ( + fetch_module, + get_example_input, + get_absorb_layers, + get_module_input_output, + get_hidden_states, + get_block_prefix +) +from .model_wrapper import MulLinear +from ...utils import logger +from .smooth_quant import model_forward, set_module +from functools import partial + + +def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}): + """Get absorbed layer per block. + + Args: + model (torch.nn.Module): input model + example_inputs: example_inputs + + Returns: + block_absorb_dict: dict of absorbed layer per block. eg. {0, [[absorbed_1, xx], [xx]], ...} + """ + block_absorb_dict = {} # record absorbed layer per block + absorb_layer_dict = {} # record absorb layers for absorbed layers + absorb_to_layer, no_absorb_layers = get_absorb_layers( + model, example_inputs, + supported_layers=['Linear'], folding=False + ) + logger.debug(f"The no absorb layers: {no_absorb_layers}") + # skip ops when algorithm is not AWQ + skip_op_set = set() + for k, v in absorb_to_layer.items(): + for vv in v: + if vv in weight_config and (weight_config[vv]['algorithm'] != 'AWQ' or \ + weight_config[vv]['bits'] == -1): + skip_op_set.add(k) + for k in no_absorb_layers: + if k in weight_config and (weight_config[k]['algorithm'] != 'AWQ' or \ + weight_config[k]['bits'] == -1): + skip_op_set.add(k) + for k in skip_op_set: + if k in absorb_to_layer: + absorb_to_layer.pop(k) + if k in no_absorb_layers: + no_absorb_layers.remove(k) + if len(skip_op_set) > 0: + logger.info(f"{skip_op_set} are skipped when running AWQ optimization") + + block_prefix, block_num = get_block_prefix(model) + for i in range(block_num): + block_absorb_dict[i] = [] + block_name = block_prefix + '.' + str(i) + '.' + for k, v in absorb_to_layer.items(): + name_list =tuple(vv for vv in v if block_name in vv) + if len(name_list) > 0: + block_absorb_dict[i].append(name_list) + absorb_layer_dict[name_list] = k + if not folding: + for k in no_absorb_layers: + if block_name in k: + name_list = tuple([k]) + block_absorb_dict[i].append(name_list) + absorb_layer_dict[name_list] = k + logger.debug(f"The absorbed layers per block: {block_absorb_dict}") + logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}") + return block_absorb_dict, absorb_layer_dict + + +@torch.no_grad() +def _get_weight_scale(weight, q_group_size=-1): + org_shape = weight.shape + if q_group_size > 0: + weight = weight.view(-1, q_group_size) + scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) + scale = scale.view(org_shape) + scale = scale.mean(0) + return scale + + +@torch.no_grad() +def _get_act_scale(input_val): + tmp = [x.abs().view(-1, x.shape[-1]) for x in input_val] + tmp = torch.cat(tmp, dim=0) + return tmp.mean(0) + + +class ActAwareWeightQuant: + """Implementation of Activation-aware Weight quantization (AWQ) algo.""" + def __init__(self, model, example_inputs=None, calib_func=None, dataloader=None, n_samples=128, + bits=4, group_size=32, scheme='asym', sym_full_range=False, weight_config={},): + self.example_inputs = example_inputs + if example_inputs is None: + assert dataloader is not None, "datalaoder or example_inputs is required." + self.example_inputs = get_example_input(dataloader) + # Step 1: get hidden states and kwargs of first block. + self.total_block_args, self.total_block_kwargs = get_hidden_states( + model, dataloader=dataloader, n_samples=n_samples, calib_func=calib_func + ) + # Step 2: get block list and block prefix, number + self.block_prefix, self.block_num = get_block_prefix(model) + self.block_list = fetch_module(model, self.block_prefix) + self.bits = bits + self.group_size = group_size + self.scheme = scheme + self.sym_full_range = sym_full_range + self.weight_config = weight_config + self.model = model + + def quantize(self, auto_scale=True, mse_range=True, folding=False, return_int=False): + """Execute AWQ quantization. + + Args: + auto_scale (bool, optional): whether search scale. Defaults to True. + mse_range (bool, optional): whether search clip range. Defaults to True. + folding (bool, optional): whether only allow update scale when it can be fold + to upper layer. Defaults to False. + return_int (bool, optional): whether return int dtype with WeightOnlyLinear. + Defaults to False. + + Returns: + model: quantized model + """ + # Step 1: get absorbed module list per block, includes self-absorption + # block_absorb_dict is split per block, includes all absorb relationship. + # absorb_layer_dict is the inverse of block_absorb_dict for all blocks + self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block( + self.model, self.example_inputs, + # for only mse_range, folding is useless. + folding = folding if auto_scale else False, + weight_config=self.weight_config, + ) + # process per block + for i, module_list in self.block_absorb_dict.items(): + logger.info(f"Processing block: {i+1}/{self.block_num}") + if len(module_list) == 0: + logger.info(f"No need to process this block.") + continue + # Step 1: fetch all input values of each linear for scale calculation + # use the first linear for QKV tuple + block_name = self.block_prefix + '.' + str(i) + block = fetch_module(self.model, block_name) + module_hook_config = { + v[0].split(block_name + '.')[1]: ['input'] for v in module_list + } + def block_calibration(model): + for args, kwargs in zip(self.total_block_args, self.total_block_kwargs): + model(*args, **kwargs) + input_values = get_module_input_output( + block, module_hook_config, calib_func=block_calibration, + ) + # Step 3: search best scale for linears in one block and apply it + if auto_scale: + scale_info = self.search_scale(block, block_name, module_list, input_values) + # Step 2: update self.total_block_args, self.total_block_kwargs for next block + out_list = self.block_inference(block) + self.update_block_input(out_list) + # Step 4: get input of next block before update scale + # weights of linear is updated by scale + if auto_scale: + self.apply_scale(scale_info) + # Step 5: search best clip range for linears in one block and save to weight_config + if mse_range: + self.search_clip(block_name, module_list, input_values) + # Step 6: apply clip range in weight_config when quantizing model weights + self.apply_quantize_with_clip(return_int) + return self.model + + def search_scale(self, block, block_name, module_list, input_values): + """Search scales per block. + + Args: + block (torch.nn.Module): a block of model + block_name (str): the block name in model. + module_list (dict): contains all linear tuple in current block, + linears in the same tuple shares scale. + input_values (dict): contains all input values of linears in current block + + Returns: + scale_info: a dict that contains input scales of linears in current block + """ + from .weight_only import quant_weight + scale_info = {} + logger.info("Searching best scales with AWQ algorithm") + for module_tuple in module_list: + # Step 1: Initailize quantization configuration. + if module_tuple[0] in self.weight_config: + cur_bits = self.weight_config[module_tuple[0]]['bits'] + cur_group_size = self.weight_config[module_tuple[0]]['group_size'] + cur_scheme = self.weight_config[module_tuple[0]]['scheme'] + else: + cur_bits, cur_group_size, cur_scheme = self.bits, self.group_size, self.scheme + if cur_bits < 0: + continue + logger.info(f"[SCALE] Processing module: {module_tuple}") + # Step 2: update module name in block + module_name_list = [i.split(block_name + '.')[1] for i in module_tuple] + # Step 3: collect w_max and x_max for scale calculation. + weight = torch.cat( + [fetch_module(block, _m).weight for _m in module_name_list], dim=0 + ) + w_max = _get_weight_scale(weight, q_group_size=cur_group_size) + del weight + input_val = input_values[module_name_list[0]]['input'] + x_max = _get_act_scale(input_val) + absorbed_modules = {_m: fetch_module(block, _m) for _m in module_name_list} + # Step 4: collect origin output for MSE and state_dict for recover. + org_stat = {_m: module.state_dict() for _m, module in absorbed_modules.items()} + if len(module_tuple) > 1: + # use block inference for multi-modules + org_out = self.block_inference(block) + else: + module = absorbed_modules[module_name_list[0]] + org_out = self.module_inference(module, input_val) + # Step 5: collect origin output for MSE and state_dict for recover. + best_error = float('inf') + best_scales = None + best_scale_alpha = None + n_grid = 20 + history = [] + # Step 6: set different alpha for scale and compare the MSE loss. + for ratio in range(n_grid): + ratio = ratio * 1 / n_grid + scales = (x_max.pow(ratio) / w_max.pow(1-ratio) + ).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + for name, module in absorbed_modules.items(): + module.weight.data = module.weight.data.mul(scales.view(1, -1)) + module.weight.data = quant_weight( + module.weight.data, + num_bits=cur_bits, + group_size=cur_group_size, + scheme=cur_scheme, + full_range=self.sym_full_range, + ) / scales.view(1, -1) + loss = 0 + if len(module_tuple) > 1: + # use block inference for multi-modules + cur_out = self.block_inference(block) + else: + module = absorbed_modules[module_name_list[0]] + cur_out = self.module_inference(module, input_val) + for out1, out2 in zip(org_out, cur_out): + loss += (out1 - out2).float().pow(2).mean().item() + history.append(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_scales = scales + best_scale_alpha = ratio + for name, module in absorbed_modules.items(): + module.load_state_dict(org_stat[name]) + # Step 7: record the best scale alpha of each module_tuple + assert best_scales is not None, "Loss is infinity! Cannot find the correct scale." + best_scales = best_scales.view(-1) + assert torch.isnan(best_scales).sum() == 0, best_scales + scales = best_scales.detach() + scale_info[module_tuple] = scales + logger.debug("The loss history of different scale:{}".format(history)) + logger.info("The best scale alpha of {}: {}".format(module_tuple, best_scale_alpha)) + return scale_info + + @torch.no_grad() + def apply_scale(self, scale_info): + """Apply scales to model. + + Args: + scale_info (dict): a dict that contains input scales of linears in current block + """ + for module_tuple, scale in scale_info.items(): + logger.debug(f"apply scale for module: {module_tuple}") + assert module_tuple in self.absorb_layer_dict, "cannot find the absorb module." + absorb_module_name = self.absorb_layer_dict[module_tuple] + absorb_module = fetch_module(self.model, absorb_module_name) + if absorb_module_name == module_tuple[0]: + # Case 1: module is self-absorption + new_module = MulLinear(absorb_module, 1.0 / scale) + new_module._update_linear() + set_module(self.model, absorb_module_name, new_module) + else: + # Case 2: scale is absorbed by other layer + if len(absorb_module.weight.shape) == 1: + absorb_module.weight.div_(scale) # for LayerNorm + else: + absorb_module.weight.div_(scale.view(-1, 1)) + # hasattr is for LlamaRMSNorm + if hasattr(absorb_module, 'bias') and absorb_module.bias is not None: + absorb_module.bias.div_(scale.view(-1)) + for name in module_tuple: + absorbed_module = fetch_module(self.model, name) + absorbed_module.weight.mul_(scale.view(1, -1)) + + def search_clip(self, block_name, module_list, input_values): + """Search best clip range of each linears in current block. + + Args: + block_name (str): block name in model. + module_list (dict): contains all linear tuple in current block, + linears in the same tuple shares scale. + input_values (dict): contains all input values of linears in current block + """ + from .weight_only import quant_weight + logger.info("Searching the best clip range with AWQ algorithm") + for module_tuple in module_list: + input_val = input_values[module_tuple[0].split(block_name + '.')[1]]['input'] + # process linear modules one by one + for module_name in module_tuple: + # Step 1: Initailize quantization configuration. + if module_name in self.weight_config: + cur_bits = self.weight_config[module_name]['bits'] + cur_group_size = self.weight_config[module_name]['group_size'] + cur_scheme = self.weight_config[module_name]['scheme'] + else: + cur_bits, cur_group_size, cur_scheme = self.bits, self.group_size, self.scheme + if cur_bits < 0: + continue + logger.info(f"[CLIP] Processing module: {module_name}") + # Step 2: update module name + module = fetch_module(self.model, module_name) + # Step 3: collect origin output for MSE and state_dict for recover. + org_stat = module.state_dict() + org_out = self.module_inference(module, input_val) + # Step 4: set different clip range for weight and compare the MSE loss. + logger.info("Searching the best clip range with AWQ algorithm") + best_error = float('inf') + best_clip_ratio = None + n_grid = 100 + max_shrink = 0.1 + history = [] + for i_s in range(int(max_shrink * n_grid)): + ratio = (1 - i_s / n_grid) # 1, 0.91-1.0 + # MulLinear can also work with @weight.setter + module.weight.data = quant_weight( + module.weight.data, + num_bits=cur_bits, + group_size=cur_group_size, + scheme=cur_scheme, + full_range=self.sym_full_range, + quantile=ratio, + ) + loss = 0 + cur_out = self.module_inference(module, input_val) + for out1, out2 in zip(org_out, cur_out): + loss += (out1 - out2).float().pow(2).mean().item() + history.append(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_clip_ratio = ratio + module.load_state_dict(org_stat) + logger.debug("The loss history of different clip range:{}".format(history)) + if module_name not in self.weight_config: + self.weight_config[module_name] = { + 'bits': cur_bits, + 'group_size': cur_group_size, + 'scheme': cur_scheme + } + self.weight_config[module_name]['quantile'] = best_clip_ratio + if isinstance(module, MulLinear): + self.weight_config[module_name+'.linear'] = self.weight_config[module_name] + self.weight_config.pop(module_name) + logger.debug("The best clip ratio for {}:{}".format(module_name, best_clip_ratio)) + + def apply_quantize_with_clip(self, return_int=False): + """Quantize model with clip range. + + Args: + return_int (bool, optional): whether return int dtype with WeightOnlyLinear. + Defaults to False. + """ + # apply quantization and clip + logger.info("Quantizing the AWQ optimized fp32 model") + from .weight_only import rtn_quantize + self.model = rtn_quantize( + self.model, + num_bits=self.bits, + group_size=self.group_size, + scheme=self.scheme, + weight_config=self.weight_config, + return_int=return_int, + sym_full_range=self.sym_full_range, + ) + logger.info("AWQ quantization is done.") + + def update_block_input(self, input_list): + """Update block input for next block inference. + + Args: + input_list (list): A list of previous block outputs to serve as input to the next block. + """ + for i, inp in enumerate(input_list): + if len(self.total_block_args[i]) > 0: + self.total_block_args[i][0] = inp + elif 'hidden_states' in self.total_block_kwargs[i]: + self.total_block_kwargs[i]['hidden_states'] = inp + else: # pragma: no cover + assert False, "cannot find hidden_states position for next block" + + def block_inference(self, model): + """Collect output of block. + + Args: + model (torch.nn.Module): input model. + + Returns: + output(list): a list of block output. + """ + total_out = [] + for args, kwargs in zip(self.total_block_args, self.total_block_kwargs): + out = model(*args, **kwargs) + if isinstance(out, tuple): # pragma: no cover + out = out[0] + total_out.append(out) + return total_out + + def module_inference(self, model, inputs): + """Collect output of module. + + Args: + model (torch.nn.Module): input model. + inputs (list): a list of module input. + + Returns: + output(list): a list of module output. + """ + total_out = [] + for inp in inputs: + out = model(inp) + if isinstance(out, tuple): # pragma: no cover + out = out[0] + total_out.append(out) + return total_out diff --git a/neural_compressor/adaptor/torch_utils/model_wrapper.py b/neural_compressor/adaptor/torch_utils/model_wrapper.py index 93480d3e6c2..1e7bc9f7a37 100644 --- a/neural_compressor/adaptor/torch_utils/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/model_wrapper.py @@ -476,27 +476,43 @@ def forward(self, x): return F.linear(x, weight_q, self.orig_layer.bias) -class TEQMulLinear(torch.nn.Module): - """ - Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input - """ +class MulLinear(torch.nn.Module): + """Linear wrapper to apply scale to input.""" - def __init__(self, module, input_scale): + def __init__(self, module, input_scale=None): """ A forward hook to save input max of a module :param module: the linear module :param input_scale: scale for input """ - super().__init__() + if input_scale is None: + input_scale = torch.empty(module.in_features) self.register_buffer('input_scale', input_scale) - self.add_module('sq_linear', module) + self.add_module('linear', module) @property def weight(self): - return self.sq_linear.weight + return self.linear.weight + + @weight.setter + def weight(self, weight): + self.linear.weight = weight def forward(self, X): X = torch.mul(X, self.input_scale) - X = self.sq_linear(X) + X = self.linear(X) return X + + def _update_linear(self): + # update linear weight with input_scale + scale = self.input_scale.view(1, self.input_scale.shape[0]) + with torch.no_grad(): + self.linear.weight /= scale + + def _recover_linear(self): + # remove mul and reset sq_linear for ipex inference + scale = self.input_scale.view(1, self.input_scale.shape[0]) + with torch.no_grad(): + self.linear.weight *= scale + diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 85746b41e3d..9771edb23f4 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -73,14 +73,14 @@ def model_forward(model, dataloader, iters, device): for idx, (input, label) in enumerate(dataloader): output = forward_wrapper(model, input, device) cnt += 1 - if cnt >= iters: + if iters != -1 and cnt >= iters: break except Exception as e: cnt = 0 for idx, input in enumerate(dataloader): output = forward_wrapper(model, input, device) cnt += 1 - if cnt >= iters: + if iters != -1 and cnt >= iters: break @@ -1208,7 +1208,7 @@ def trace(self, model, dummy_input): traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) except Exception as e: logger.warning(e) - logger.info("Jit trace in GraphTrace failed, absorb layer detection is skipped") + logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") else: try: traced_model = torch.jit.trace(model, dummy_input, strict=False) @@ -1219,7 +1219,7 @@ def trace(self, model, dummy_input): traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) except Exception as e: logger.warning(e) - logger.info("Jit trace in GraphTrace failed, absorb layer detection is skipped") + logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") if orig_device != "cpu": model = model.to(orig_device) return traced_model diff --git a/neural_compressor/adaptor/torch_utils/teq.py b/neural_compressor/adaptor/torch_utils/teq.py index 8d084c7d50c..3176929776e 100644 --- a/neural_compressor/adaptor/torch_utils/teq.py +++ b/neural_compressor/adaptor/torch_utils/teq.py @@ -28,7 +28,7 @@ from .smooth_quant import GraphTrace, get_module, set_module from .weight_only import quant_weight -from .model_wrapper import TEQLinearFakeQuant, TEQMulLinear +from .model_wrapper import TEQLinearFakeQuant, MulLinear import transformers @@ -139,12 +139,12 @@ def _absorb_scales(self, layer, scale, layer_name=""): """ # for insert mul if not self.folding: # pragma: no cover - if isinstance(layer, TEQMulLinear): - set_module(self.model, layer_name, layer.sq_linear) ##recover + if isinstance(layer, MulLinear): + set_module(self.model, layer_name, layer.linear) ##recover else: - new_module = TEQMulLinear(layer, scale) + new_module = MulLinear(layer, scale) set_module(self.model, layer_name, new_module) - self.weight_config[layer_name + ".sq_linear"] = self.weight_config[layer_name] + self.weight_config[layer_name + ".linear"] = self.weight_config[layer_name] return if isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.GroupNorm) or \ @@ -207,8 +207,8 @@ def _scale_layer_weight(self, layer, scale): ##input channel :param scale: The scale to be multiplied :return: """ - if layer.__class__.__name__ == "TEQMulLinear": - layer = layer.sq_linear + if layer.__class__.__name__ == "MulLinear": + layer = layer.linear if layer.__class__.__name__ == "TEQLinearFakeQuant": layer = layer.orig_layer diff --git a/neural_compressor/adaptor/torch_utils/util.py b/neural_compressor/adaptor/torch_utils/util.py index 72335d368b2..409e036dbdf 100644 --- a/neural_compressor/adaptor/torch_utils/util.py +++ b/neural_compressor/adaptor/torch_utils/util.py @@ -21,6 +21,7 @@ import numpy as np from collections import UserDict from packaging.version import Version +from functools import partial from ...utils import logger from ...utils.utility import LazyImport, CpuInfo @@ -962,7 +963,7 @@ def get_op_type_by_name(op_name, quantizable_ops): return pair[1] return None -def collect_weight_info(q_config): +def collect_weight_info(model, q_config): """collect weight info from q_config for dumping into qconfig.json qconfig.json example: @@ -988,12 +989,15 @@ def collect_weight_info(q_config): if config['weight']['dtype'] == 'fp32': weight_info[op_name] = {'dtype': 'fp32'} else: + # fetch module type for MulLinear + module = fetch_module(model, op_name) if level == DEBUG: weight_info[op_name] = { 'dtype': config['weight']['dtype'], 'bits': config['weight']['bits'], 'group_size': config['weight']['group_size'], 'scheme': config['weight']['scheme'], + 'module_type': str(type(module)).split('\'')[1], 'algorithm': config['weight']['algorithm'] } else: @@ -1002,5 +1006,214 @@ def collect_weight_info(q_config): 'bits': config['weight']['bits'], 'group_size': config['weight']['group_size'], 'scheme': config['weight']['scheme'], + 'module_type': str(type(module)).split('\'')[1], } return weight_info + + +def get_module_input_output(model, module_hook_config={}, dataloader=None, iters=-1, + calib_func=None, input_func=None, output_func=None): + """A help function to get input and output tensor of modules in module_name_list. + + Args: + model: torch model. + module_hook_config (dict, optional): required module name for input/output. Defaults to {}. + For example: + module_hook_config = { + 'fc1': ['output'], + 'fc2': ['input', 'output'] + } + dataloader: dataloader for model input. + iters: iterations for inference. + calib_func: a custom inference function to replace dataloader and iters. + input_func: preprocess input for less memory usage + output_func: preprocess output for less memory usage + + Returns: + total_values: recorded input_values, output_values. + for example: + {'fc1': + {'input': [], 'output': []}, + } + + """ + from collections import defaultdict + total_values = defaultdict(defaultdict) + def _save_input_output_hook(name, record_input=False, record_output=False): + """ + A forward hook to save input and output values of a module + param name: the module name + return: A hook function + """ + def _hook(module, inputs, outputs): + if record_input: + input = inputs[0] + if input_func is not None: + input = input_func(input) + if name in total_values and 'input' in total_values[name]: + total_values[name]['input'].append(input) + else: + total_values[name]['input'] = [input] + if record_output: + output = outputs[0] if isinstance(outputs, tuple) else outputs + if output_func is not None: + output = output_func(output) + if input_func is not None: + input = input_func(input) + if name in total_values and 'output' in total_values[name]: + total_values[name]['output'].append(output) + else: + total_values[name]['output'] = [output] + return _hook + + hook_list = [] + for name, module in model.named_modules(): + if name in module_hook_config: + require_list = module_hook_config[name] + logger.debug(f"required hooks {name}: {require_list}") + _hook = _save_input_output_hook( + name, + record_input='input' in require_list, + record_output='output' in require_list, + ) + require_list = module_hook_config[name] + hook_list.append( + module.register_forward_hook(_hook)) + if calib_func: + calib_func(model) + else: + from .smooth_quant import model_forward + model_forward(model, dataloader, iters, device=next(model.parameters()).device) + for h in hook_list: + h.remove() + return total_values + + +def get_absorb_layers(model, example_inputs, supported_layers=['Linear'], folding=False): + """Get absorb_to_layer and no_absorb_layer. + + Args: + model (torch.nn.Module): input model + example_inputs: example_inputs + supported_layers (list, optional): supported_layers. Defaults to ['Linear']. + folding (bool, optional): whether allow self-absorption. Defaults to False. + + Returns: + absorb_to_layer: dict of absorb_to_layer. eg. {absorb, [absorbed_1, xx]} + no_absorb_layers: list of no_absorb_layers + """ + # get modules that can be absorbed. + from .smooth_quant import GraphTrace + tg = GraphTrace() + absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( + model, example_inputs, supported_layers + ) + if absorb_to_layer is None or absorb_to_layer == {}: + absorb_to_layer = {} + logger.warning('No absorb layer is detected.') + # if no_absorb_layers is None, jit trace failed. + # collect all linears for next step + if no_absorb_layers is None: + no_absorb_layers = [] + op_types = ['Linear'] + for name, module in model.named_modules(): + for op_type in op_types: + if op_type == str(module.__class__.__name__): + no_absorb_layers.append(name) + return absorb_to_layer, no_absorb_layers + + +def get_block_prefix(model): + """get prefix and number of blockes + + Args: + model (torch.nn.Module): input model + + Returns: + block_prefix(str): block_list name in model + block_num(int): number of block in block_list + """ + module_types=[torch.nn.ModuleList] + for n, m in model.named_modules(): + if type(m) in module_types: + block_prefix = n + block_num = len(m) + logger.debug(f"block_prefix: {block_prefix}, block_num: {block_num} ") + break + assert block_num > 0, "block num should't be zero!" + return block_prefix, block_num + + +def calibration(model, dataloader=None, n_samples=128, calib_func=None): + """ Calibration with dataloader or calib_func + + Args: + model (torch.nn.Module): input model + dataloader: dataloader. Defaults to None. + n_samples (int, optional): n_samples. Defaults to 128. + calib_func: calib_func. Defaults to None. + """ + # calibration with dataloader or calib_func + if calib_func is not None: + calib_func(model) + else: + import math + from .smooth_quant import model_forward + batch_size = dataloader.batch_size + iters = int(math.ceil(n_samples / batch_size)) + if n_samples % batch_size != 0: + logger.info("calibration samples increase from {} to {} due to batch_size is {}".format( + n_samples, iters*batch_size, batch_size, + )) + model_forward(model, dataloader, iters, next(model.parameters()).device) + + +def get_hidden_states(model, dataloader=None, n_samples=128, calib_func=None): + """get the input args and kwargs of first block. + + Args: + model (torch.nn.Module): input model + dataloader (dataloader, optional): input dataloader. Defaults to None. + n_samples (int, optional): number samples from dataloader. Defaults to 128. + calib_func (func, optional): a calib func to replace dataloader. Defaults to None. + + Raises: + ValueError: to avoid inference of rest parts in model + + Returns: + total_block_args(list): a list of input args of each batch + total_block_kwargs(list): a list of input kwargs of each batch + """ + # Step 1: replace block_forward to collect block inputs and avoid entire inference + total_block_args = [] + total_block_kwargs = [] + def forward(layer, *args, **kwargs): + # update total_hidden_states, total_block_kwargs, per batch + total_block_args.append(list(args)) + total_block_kwargs.append(kwargs) + raise ValueError + + block_prefix, block_num = get_block_prefix(model) + block_list = fetch_module(model, block_prefix) + first_block = block_list[0] + block_forward_cache = first_block.forward + first_block.forward = partial(forward, first_block) + + # Step 2: replace model_forward to avoid ValueError + model_forward_cache = model.forward + def model_forward(model, *args, **kwargs): + nonlocal model_forward_cache + try: + model_forward_cache(*args, **kwargs) + except ValueError: + pass + model.forward = partial(model_forward, model) + + # Step 3: execute calibration + calibration(model, dataloader=dataloader, n_samples=n_samples, calib_func=calib_func) + logger.info("The hidden_states collection is done.") + + # Step 4: recover model and block forward + model.forward = model_forward_cache + first_block.forward = block_forward_cache + return total_block_args, total_block_kwargs diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index f9766b664f1..70bf97219de 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy import math from typing import OrderedDict from .util import set_module @@ -81,7 +82,7 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang # assert num_bits > 1, "symmetric scheme only supports num_bits > 1" maxq = torch.tensor(2 ** (num_bits - 1) - 1).to(weight.device) minq = torch.tensor(-2 ** (num_bits - 1)).to(weight.device) - if num_bits == 1: + if num_bits == 1: # pragma: no cover maxq = torch.tensor(2 ** (num_bits - 1)) minq = torch.tensor(2 ** (num_bits - 1) - 1) max_val = torch.max(weight, 1)[0] @@ -144,6 +145,8 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, Returns: output: qdq weight. """ + if num_bits <= 0: + return weight if group_size == -1 or weight.shape[1] < group_size: return qdq_weight_actor(weight, num_bits, scheme=scheme, quantile=quantile, return_int=return_int, full_range=full_range) @@ -258,7 +261,7 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym", logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + \ f"scheme={scheme}, quantile={quantile}") if num_bits <= 0: - logger.info(f"skip {name}") + logger.info(f"Skip {name}") continue weight = m.weight if return_int: @@ -298,117 +301,17 @@ def gptq_quantize(model, weight_config={}, dataloader=None, device=None): logger.info("GPTQ quantizing done.") return fp32_modified_model, gptq_config -def get_module_input_output(model, module_hook_config={}, dataloader=None, iters=-1, - calib_func=None): - """A help function to get input and output tensor of modules in module_name_list. - - Args: - model: torch model. - module_hook_config (dict, optional): required module name for input/output. Defaults to {}. - For example: - module_hook_config = { - 'fc1': ['output'], - 'fc2': ['input', 'output'] - } - dataloader: dataloader for model input. - iters: iterations for inference. - calib_func: a custom inference function to replace dataloader and iters. - - Returns: - input_values, output_values: recorded input_values, output_values. - """ - input_values, output_values = {}, {} - def _save_input_output_hook(name): - """ - A forward hook to save input and output values of a module - param name: the module name - return: A hook function - """ - def save_input_hook(module, inputs): - input = inputs[0] - if name in input_values: - try: - input_values[name] = torch.cat((input_values[name], input), 0) - except Exception as e: - logger.error(e) - assert False, "Please unify the input shape for AWQ algorithm calibration." - else: - input_values[name] = input - def save_output_hook(module, inputs, outputs): - if isinstance(outputs, tuple): - outputs = outputs[0] - if name in output_values: - try: - output_values[name] = torch.cat((output_values[name], outputs), 0) - except Exception as e: - logger.error(e) - assert False, "Please unify the input shape for AWQ algorithm calibration." - else: - output_values[name] = outputs - return save_input_hook, save_output_hook - - hook_list = [] - for name, module in model.named_modules(): - if name in module_hook_config: - save_input_hook, save_output_hook = _save_input_output_hook(name) - require_list = module_hook_config[name] - if 'input' in require_list: - hook_list.append( - module.register_forward_pre_hook(save_input_hook)) - if 'output' in require_list: - hook_list.append( - module.register_forward_hook(save_output_hook)) - if calib_func: - calib_func(model) - else: - from .smooth_quant import model_forward - model_forward(model, dataloader, iters, device='cpu') - for h in hook_list: - h.remove() - return input_values, output_values - @torch.no_grad() -def _get_weight_scale(weight, q_group_size=-1): - org_shape = weight.shape - if q_group_size > 0: - weight = weight.view(-1, q_group_size) - scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) - scale = scale.view(org_shape) - scale = scale.mean(0) - return scale - - -@torch.no_grad() -def _get_act_scale(x): - return x.abs().view(-1, x.shape[-1]).mean(0) - - -def _update_input_with_scale(args, kwargs, scales): - new_args, new_kwargs = args, kwargs - for i, v in enumerate(args): - if isinstance(v, torch.Tensor): - try: - new_args[i] = torch.div(v, scales.view(1, 1, -1)) - except: - new_args[i] = v - for k, v in kwargs.items(): - if isinstance(v, torch.Tensor): - try: - new_kwargs[k] = torch.div(v, scales.view(1, 1, -1)) - except: - new_kwargs[k] = v - return new_args, new_kwargs - - -@torch.no_grad() -def awq_quantize(model, weight_config={}, absorb_dict={}, dataloader=None, n_samples=128, - auto_scale=True, mse_range=True, calib_func=None, n_blocks=5, - return_int=False, sym_full_range=False): +def awq_quantize(model, bits=4, group_size=32, scheme='asym', weight_config={}, + example_inputs=None, dataloader=None, n_samples=128, calib_func=None, + auto_scale=True, mse_range=True, folding=False, return_int=False, + sym_full_range=False): """Quant the model with Activation-aware Weight quantization(AWQ) method. Args: model (torch.nn.Module): torch model. + example_inputs: example_inputs. weight_config (dict, optional): contains all info required by AWQ. Defaults to {}. For example, weight_config={ @@ -424,8 +327,9 @@ def awq_quantize(model, weight_config={}, absorb_dict={}, dataloader=None, n_sam For example, absorb_dict = { # 'absorb_layer': absorbed_layer - 'fc1': ['fc2', 'fc3'] - } # in this case, fc2 and fc3 need to share the same scale. + 'fc1': ['fc1', 'fc2', 'fc3'] + } # in this case, fc2 and fc3 need to share the same scale. fc1 is self absorbed. + # self absorb module will replace with MulLinear, which contains torch.mul and module. n_samples: calibration sample number. auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True. mse_range (bool, optional): whether enable clip for weight by checking mse. Defaults to True. @@ -438,233 +342,28 @@ def awq_quantize(model, weight_config={}, absorb_dict={}, dataloader=None, n_sam Returns: model: fake quantized model """ + from .awq import ActAwareWeightQuant assert isinstance(model, torch.nn.Module), "only support torch module" - # collect module names to record their input/output for loss calculation. - module_for_loss_dict = OrderedDict() - # get the upper module if absorbed modules have the same absorb module. - for absorb, absorbed in absorb_dict.items(): - # used as input for absob module - module_for_loss_dict[absorb] = set() - if len(absorbed) > 1: - split_dict = {} - split_absorb = absorb.split('.') - for ab in absorbed: - split_dict[ab] = ab.split('.') - for ab, sp_list in split_dict.items(): - for i in range(len(sp_list)): - if sp_list[i] != split_absorb[i]: - group_name = '.'.join(sp_list[:i+1]) - module_for_loss_dict[absorb].add(group_name) - break - else: - module_for_loss_dict[absorb].add(absorbed[0]) - - def calibration(module_name_list): - if calib_func: - input_values, output_values = get_module_input_output( - model, - module_name_list, - calib_func=calib_func, - ) - else: - batch_size = dataloader.batch_size - iters = int(math.ceil(n_samples / batch_size)) - if n_samples % batch_size != 0: - logger.info("calibration samples increase from {} to {} due to batch_size is {}".format( - n_samples, iters*batch_size, batch_size, - )) - input_values, output_values = get_module_input_output( - model, - module_name_list, - dataloader=dataloader, - iters=iters, - ) - return input_values, output_values - - layer_args = {} - layer_kwargs = {} - # to fetch kwargs which torch hook cannot handle - class Catcher(torch.nn.Module): - def __init__(self, module, name): - super().__init__() - self.module = module - self.module_name = name - - def forward(self, *args, **kwargs): - if self.module_name not in layer_args: - layer_args[self.module_name] = list(args) - layer_kwargs[self.module_name] = kwargs - else: - # to concat different batches - for i, v in enumerate(args): - if isinstance(v, torch.Tensor): - layer_args[self.module_name][i] = \ - torch.cat([layer_args[self.module_name][i], v], dim=0) - for k, v in kwargs.items(): - if isinstance(v, torch.Tensor): - layer_kwargs[self.module_name][k] = \ - torch.cat([layer_kwargs[self.module_name][k], v], dim=0) - return self.module.forward(*args, **kwargs) - - if auto_scale or mse_range: - from .util import fetch_module, set_module - block_num = n_blocks - module_num = math.ceil(len(absorb_dict) / block_num) - logger.info(f"AWQ search splits the model into {block_num} blocks to avoid OOM, " +\ - f"each block contains {module_num} modules") - for idx, (absorb, absorbed) in enumerate(absorb_dict.items()): - # Split module_name_list to avoid OOM when recording output tensors. - if idx % module_num == 0: - layer_args = {} - layer_kwargs = {} - part_module_hook_config = {} - logger.info(f"AWQ search calibration round {idx//module_num + 1}") - for id, name in enumerate(module_for_loss_dict): - if idx <= id < (idx + module_num): - part_module_hook_config[name] = ['output'] # fetch input tensor - for i in module_for_loss_dict[name]: - # use Catcher to record input args and kwargs - tmp_module = fetch_module(model, i) - new_module = Catcher(tmp_module, i) - set_module(model, i, new_module) - part_module_hook_config[i] = ['output'] # fetch output tensor - input_values, output_values = calibration(part_module_hook_config) - # recover Catcher - for id, name in enumerate(module_for_loss_dict): - if idx <= id < (idx + module_num): - for i in module_for_loss_dict[name]: - # use Catcher to record input args and kwargs - tmp_module = fetch_module(model, i) - set_module(model, i, tmp_module.module) - - logger.info(f"Processing module: {absorb}:{absorbed}") - weight = torch.cat([fetch_module(model, _m).weight for _m in absorbed], dim=0) - w_max = _get_weight_scale( - weight, q_group_size=weight_config[absorbed[0]]['group_size']) - del weight - x_max = _get_act_scale(output_values[absorb]) - absorbed_modules = {_m: fetch_module(model, _m) for _m in absorbed} - org_stat = {_m: module.state_dict() for _m, module in absorbed_modules.items()} - org_out = {} - for loss_module_name in module_for_loss_dict[absorb]: - out = output_values[loss_module_name] - if isinstance(out, tuple): - out = out[0] - org_out[loss_module_name] = out - for loss_module_name in module_for_loss_dict[absorb]: - blockes = {_m: fetch_module(model, _m) for _m in module_for_loss_dict[absorb]} - - if auto_scale: - logger.info("Searching best scales with AWQ algorithm") - best_error = float('inf') - best_scales = None - best_scale_alpha = None - n_grid = 20 - history = [] - - for ratio in range(n_grid): - ratio = ratio * 1 / n_grid - scales = (x_max.pow(ratio) / w_max.pow(1-ratio) - ).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - for name, module in absorbed_modules.items(): - module.weight.data = module.weight.data.mul(scales.view(1, -1)) - module.weight.data = quant_weight( - module.weight.data, - num_bits=weight_config[name]['bits'], - group_size=weight_config[name]['group_size'], - scheme=weight_config[name]['scheme'], - ) / scales.view(1, -1) - - loss = 0 - for name, block in blockes.items(): - out = block(*layer_args[name], **layer_kwargs[name]) - if isinstance(out, tuple): - out = out[0] - loss += (org_out[name] - out).float().pow(2).mean().item() # float prevents overflow - history.append(loss) - is_best = loss < best_error - if is_best: - best_error = loss - best_scales = scales - best_scale_alpha = ratio - for name, module in absorbed_modules.items(): - module.load_state_dict(org_stat[name]) - - logger.debug("The loss history of different scale:{}".format(history)) - logger.debug("The best alpha for scale: {}:{}".format(absorb, best_scale_alpha)) - assert best_scales is not None, "Loss is infinity! Cannot find the correct scale." - best_scales = best_scales.view(-1) - assert torch.isnan(best_scales).sum() == 0, best_scales - scales = best_scales.detach() - # update absorb model - absorb_module = fetch_module(model, absorb) - if len(absorb_module.weight.shape) == 1: - absorb_module.weight.div_(scales) # for LayerNorm - else: - absorb_module.weight.div_(scales.view(-1, 1)) - if absorb_module.bias is not None: - absorb_module.bias.div_(scales.view(-1)) - # update absorbed model - for name in absorbed: - absorbed_module = fetch_module(model, name) - absorbed_module.weight.mul_(scales.view(1, -1)) - - if mse_range: - logger.info("Searching the best clip range with AWQ algorithm") - best_error = float('inf') - best_clip_ratio = None - n_grid = 100 - max_shrink = 0.1 - history = [] - org_stat = {_m: module.state_dict() for _m, module in absorbed_modules.items()} - - for name, module in absorbed_modules.items(): - for i_s in range(int(max_shrink * n_grid)): - ratio = (1 - i_s / n_grid) # 1, 0.95-0.55 - module.weight.data = quant_weight( - module.weight.data, - num_bits=weight_config[name]['bits'], - group_size=weight_config[name]['group_size'], - scheme=weight_config[name]['scheme'], - quantile=ratio, - ) - - loss = 0 - for n, block in blockes.items(): - if n in name: - # preprocess input with existing scale - if auto_scale: - new_args, new_kwargs = _update_input_with_scale( - layer_args[n], layer_kwargs[n], scales) - else: - new_args, new_kwargs = layer_args[n], layer_kwargs[n] - out = block(*new_args, **new_kwargs) - if isinstance(out, tuple): - out = out[0] - loss += (org_out[n] - out).float().pow(2).mean().item() # float prevents overflow - history.append(loss) - is_best = loss < best_error - if is_best: - best_error = loss - best_clip_ratio = ratio - module.load_state_dict(org_stat[name]) - - logger.debug("The loss history of different clip range:{}".format(history)) - weight_config[name]['quantile'] = best_clip_ratio - logger.debug("The best clip ratio for {}:{}".format(name, best_clip_ratio)) - - # apply quantization and clip - logger.info("Quantizing the AWQ optimized fp32 model") - model = rtn_quantize( + awq = ActAwareWeightQuant( model, - num_bits=-1, - weight_config=weight_config, + example_inputs=example_inputs, + calib_func=calib_func, + dataloader=dataloader, + n_samples=n_samples, + bits=bits, + group_size=group_size, + scheme=scheme, + sym_full_range=sym_full_range, + weight_config=weight_config + ) + qdq_model = awq.quantize( + auto_scale=auto_scale, + mse_range=mse_range, + folding=folding, return_int=return_int, - sym_full_range=sym_full_range, ) - logger.info("AWQ quantization is done.") - return model + return qdq_model + def teq_quantize(model, weight_config={}, absorb_to_layer={}, extra_config={}, dataloader= None, calib_func=None, example_inputs=None): diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index 126140535a6..1bd97eb9961 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -316,7 +316,7 @@ def save(self, root=None): if self.q_config['approach'] == 'post_training_weight_only': from ..adaptor.torch_utils.util import collect_weight_info weight_config_path = os.path.join(root, "qconfig.json") - weight_config = collect_weight_info(self.q_config) + weight_config = collect_weight_info(self.model, self.q_config) with open(weight_config_path, 'w') as f: json.dump(weight_config, f, indent = 4) if hasattr(self, 'gptq_config') and self.gptq_config: @@ -424,7 +424,7 @@ def export_compressed_model(self, qweight_config_path=None, sym_full_range=False with open(qweight_config_path, 'r') as f: weight_config = json.load(f) else: - weight_config = collect_weight_info(self.q_config) + weight_config = collect_weight_info(self.model, self.q_config) if gptq_config_path is not None: with open(gptq_config_path, 'r') as f: gptq_config = json.load(f) diff --git a/neural_compressor/utils/pytorch.py b/neural_compressor/utils/pytorch.py index 5f0decec34b..b2d9e5ab422 100644 --- a/neural_compressor/utils/pytorch.py +++ b/neural_compressor/utils/pytorch.py @@ -28,6 +28,7 @@ import torch.quantization as tq import yaml import os +import json yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple', @@ -183,6 +184,41 @@ def _load_int8_orchestration(model, tune_cfg, stat_dict, example_inputs, **kwarg return model +def load_weight_only(checkpoint_dir, model): + """Load model in weight_only mode. + + Args: + checkpoint_dir (dir/file/dict): The folder of checkpoint. 'qconfig.json' and + 'best_model.pt' are needed in This directory. + 'checkpoint' dir is under workspace folder and + workspace folder is define in configure yaml file. + model (object): fp32 model need to do quantization. + + Returns: + (object): quantized model + """ + import neural_compressor # for eval(config['module_type']) + from neural_compressor.adaptor.torch_utils.model_wrapper import MulLinear + weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), + 'best_model.pt') + # for weight only quantized model. + weights_only_config_file = os.path.join( + os.path.abspath(os.path.expanduser(checkpoint_dir)),'qconfig.json') + with open(weights_only_config_file, 'r') as f: + weight_only_config = json.load(f) + for op_name, config in weight_only_config.items(): + if config['dtype'] == 'fp32': + continue + if eval(config['module_type']) == MulLinear: + # op should be repleced by MulLinear + module = util.fetch_module(model, op_name) + new_module = MulLinear(module) + util.set_module(model, op_name, new_module) + model.load_state_dict(torch.load(weights_file)) + logger.info('Load weight_only quantized model') + return model + + def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): """Execute the quantize process on the specified model. @@ -198,6 +234,9 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): Returns: (object): quantized model """ + weigth_only = kwargs.get('weight_only', False) + if weigth_only: + return load_weight_only(checkpoint_dir, model) if checkpoint_dir is not None: if isinstance(checkpoint_dir, dict): stat_dict = checkpoint_dir @@ -212,14 +251,6 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): try: weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), 'best_model.pt') - # for weight only quantized model. - weights_only_config_file = os.path.join( - os.path.abspath(os.path.expanduser(checkpoint_dir)),'qconfig.json') - if os.path.exists(weights_only_config_file): - model.load_state_dict(torch.load(weights_file)) - logger.info('Load weight_only quantized model') - return model - # ------------------------------- try: stat_dict = torch.jit.load(weights_file) logger.info("torch.jit.load is used to recovery the int8 model quantized by INC IPEX backend") diff --git a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py index a4e4cc4823f..cb7b12cabf0 100644 --- a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py @@ -1,4 +1,5 @@ import sys +import copy sys.path.append("./") import os import shutil @@ -7,7 +8,7 @@ import transformers from neural_compressor import quantization, PostTrainingQuantConfig -from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear +from neural_compressor.adaptor.torch_utils.model_wrapper import MulLinear, WeightOnlyLinear class Model(torch.nn.Module): @@ -61,6 +62,9 @@ def setUpClass(self): 'hf-internal-testing/tiny-random-GPTJForCausalLM', torchscript=True, ) + self.gptj_no_jit = transformers.AutoModelForCausalLM.from_pretrained( + 'hf-internal-testing/tiny-random-GPTJForCausalLM', + ) self.gptj.seqlen = 512 self.llm_dataloader = LLMDataLoader() self.lm_input = torch.ones([1, 10], dtype=torch.long) @@ -118,6 +122,10 @@ def test_RTN_quant(self): }, }, }, + recipes={ + # By default, sym_full_range is False and 4 bit sym will only use range [-7,7]. + 'rtn_args': {'return_int': True} + } ) q_model = quantization.fit(model, conf, eval_func=eval_func) out2 = q_model(input) @@ -178,7 +186,7 @@ def test_RTN_quant(self): self.assertFalse(torch.all(out1 == out2)) q_model.save('saved') from neural_compressor.utils.pytorch import load - new_model = load('saved', model) + new_model = load('saved', model, weight_only=True) out1 = new_model(input) self.assertTrue(torch.all(out1 == out2)) @@ -208,6 +216,19 @@ def test_AWQ_quant(self): }, }, op_name_dict={ + '.*3.*':{ # re.match + "weight": { + 'dtype': 'fp32' + }, + }, + '.*4.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, # -1 (per-channel) + 'scheme': 'asym', + 'algorithm': 'RTN', + }, + }, '.*lm_head':{ # re.match "weight": { 'dtype': 'fp32' @@ -215,23 +236,132 @@ def test_AWQ_quant(self): }, }, recipes={ - 'awq_args':{'auto_scale': True, 'mse_range': True, 'n_blocks': 2}, + 'awq_args':{'auto_scale': True, 'mse_range': True, 'folding': False}, }, ) + fp32_model = copy.deepcopy(self.gptj) q_model = quantization.fit( - self.gptj, + fp32_model, conf, calib_dataloader=self.llm_dataloader, ) + q_model.save('saved') input = torch.ones([1, 10], dtype=torch.long) out1 = q_model(input) + from neural_compressor.utils.pytorch import load + fp32_model = copy.deepcopy(self.gptj) + reload_model = load('saved', fp32_model, weight_only=True) + out2 = reload_model(input) q_model.export_compressed_model() - out2 = q_model(input) + out3 = q_model(input) # no idea about the gap at 1e-08, use allclose instead of out1==out2 self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) + self.assertTrue(torch.allclose(out1[0], out3[0], atol=1e-05)) self.assertTrue(isinstance(q_model.model.transformer.h[0].mlp.fc_in, WeightOnlyLinear)) self.assertTrue(isinstance(q_model.model.lm_head, torch.nn.Linear)) + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, # -1 (per-channel) + 'scheme': 'asym', + 'algorithm': 'AWQ', + }, + }, + }, + op_name_dict={ + '.*3.*':{ # re.match + "weight": { + 'dtype': 'fp32' + }, + }, + '.*4.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, # -1 (per-channel) + 'scheme': 'asym', + 'algorithm': 'RTN', + }, + }, + '.*lm_head':{ # re.match + "weight": { + 'dtype': 'fp32' + }, + }, + }, + recipes={ + 'rtn_args': {'return_int': True}, + 'awq_args':{'auto_scale': True, 'mse_range': True, 'folding': False}, + }, + ) + fp32_model = copy.deepcopy(self.gptj) + q_model = quantization.fit( + fp32_model, + conf, + calib_dataloader=self.llm_dataloader, + ) + self.assertTrue(isinstance(q_model.model.transformer.h[0].mlp.fc_out, MulLinear)) + self.assertTrue(isinstance(q_model.model.transformer.h[3].mlp.fc_out, torch.nn.Linear)) + self.assertTrue(isinstance(q_model.model.transformer.h[4].mlp.fc_out, WeightOnlyLinear)) + + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, # -1 (per-channel) + 'scheme': 'asym', + 'algorithm': 'AWQ', + }, + }, + }, + ) + fp32_model = copy.deepcopy(self.gptj_no_jit) + q_model = quantization.fit( + fp32_model, + conf, + calib_dataloader=self.llm_dataloader, + ) + self.assertTrue(isinstance(q_model.model.transformer.h[0].mlp.fc_in, MulLinear)) + self.assertTrue(isinstance(q_model.model.transformer.h[0].mlp.fc_out, MulLinear)) + + def test_AWQ_util(self): + from neural_compressor.adaptor.torch_utils.util import get_module_input_output + class DemoModel(torch.nn.Module): + def __init__(self): + super(DemoModel, self).__init__() + self.fc1 = torch.nn.Linear(3, 4) + self.fc2 = torch.nn.Linear(4, 3) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + return out + + tmp = torch.randn([3, 3]) + class DemoCalibDataloader: + def __init__(self): + self.batch_size = 1 + def __iter__(self): + for i in range(3): + yield tmp + + module_hook_config = { + 'fc1': ['output'], + 'fc2': ['input', 'output'] + } + model = DemoModel() + out = model(tmp) + values = get_module_input_output(model, module_hook_config, DemoCalibDataloader()) + self.assertTrue(torch.allclose(values['fc1']['output'][0], values['fc2']['input'][0])) + self.assertTrue(torch.allclose(values['fc2']['output'][0], out)) + + def test_GPTQ_quant(self): class gptq_inc_loader(object): def __init__(self, nsamples=32): diff --git a/test/quantization/test.py b/test/quantization/test.py new file mode 100644 index 00000000000..af70358381b --- /dev/null +++ b/test/quantization/test.py @@ -0,0 +1,31 @@ +import torch +import transformers + + +model = transformers.AutoModelForCausalLM.from_pretrained( + 'hf-internal-testing/tiny-random-GPTJForCausalLM', + torchscript=True, +) +lm_input = torch.ones([1, 10], dtype=torch.long) + + +class SimpleDataLoader(): + def __init__(self): + self.batch_size = 1 + self.input = torch.randn([1, 32]) + + def __iter__(self): + for i in range(10): + yield torch.ones([1, 10], dtype=torch.long) + + +def calib_func(model): + model(lm_input) + +from neural_compressor.adaptor.torch_utils.awq import _get_hidden_states + +hid = _get_hidden_states(model, calib_func=calib_func) +hid = _get_hidden_states(model, dataloader=SimpleDataLoader()) + +print(hid) + diff --git a/test/quantization/test_weight_only_quantization.py b/test/quantization/test_weight_only_quantization.py index f197ab75381..daf4fb9c53b 100644 --- a/test/quantization/test_weight_only_quantization.py +++ b/test/quantization/test_weight_only_quantization.py @@ -4,7 +4,9 @@ import copy import torch -from neural_compressor.adaptor.torch_utils.weight_only import rtn_quantize, awq_quantize, gptq_quantize, teq_quantize +from neural_compressor.adaptor.torch_utils.weight_only import ( + rtn_quantize, awq_quantize, gptq_quantize, teq_quantize +) from neural_compressor.adaptor.torch_utils.smooth_quant import GraphTrace from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear import transformers @@ -74,48 +76,42 @@ def test_rtn(self): model2 = rtn_quantize(fp32_model, weight_config=weight_config, return_int=True) self.assertTrue(isinstance(model2.fc1, WeightOnlyLinear)) - def test_awq(self): - fp32_model = copy.deepcopy(self.model) - weight_config = { - # 'op_name': (bit, group_size, sheme) - 'fc1': { - 'bits': 8, - 'group_size': -1, - 'scheme': 'sym' - }, - 'fc2': { - 'bits': 4, - 'group_size': 32, - 'scheme': 'asym' - }, - } - absorb_dict = { - 'fc1': ['fc2'] - } - model1 = awq_quantize( - fp32_model, - weight_config=weight_config, - absorb_dict=absorb_dict, - dataloader=self.dataloader, - n_samples=128, - auto_scale=True, - mse_range=True, - ) - self.assertTrue(isinstance(model1.fc1, torch.nn.Linear)) + example_inputs = torch.ones([1, 10], dtype=torch.long) + from neural_compressor.adaptor.torch_utils.awq import ActAwareWeightQuant + model = transformers.AutoModelForCausalLM.from_pretrained( + 'facebook/opt-125m', torchscript=True,) + class LLMCalibDataloader: + def __init__(self): + self.batch_size = 1 + def __iter__(self): + for i in range(2): + yield example_inputs + + out1 = model(example_inputs) + awq = ActAwareWeightQuant(model, dataloader=LLMCalibDataloader(), bits=8, group_size=-1) + qdq_model = awq.quantize() + out2 = qdq_model(example_inputs) + # output data is up to 4, so use big atol=0.5 + self.assertTrue(torch.allclose(out1[0], out2[0], atol=0.5)) + + def calib_func(model): + for i in range(2): + model(self.lm_input) + out1 = self.gptj(example_inputs) + awq = ActAwareWeightQuant( + self.gptj, calib_func=calib_func, example_inputs=self.lm_input, + bits=8, group_size=-1) + qdq_model = awq.quantize() + out2 = qdq_model(example_inputs) + self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-2)) + + # default awq_quantize is 4 bits, 32 group size, use big atol=1e-1 + qdq_model = awq_quantize(self.gptj, example_inputs=self.lm_input, calib_func=calib_func) + out2 = qdq_model(example_inputs) + print(out1[0], out2[0]) + self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-1)) - fp32_model = copy.deepcopy(self.model) - model2 = awq_quantize( - fp32_model, - weight_config=weight_config, - absorb_dict=absorb_dict, - dataloader=self.dataloader, - n_samples=128, - auto_scale=True, - mse_range=True, - return_int=True - ) - self.assertTrue(isinstance(model2.fc1, WeightOnlyLinear)) class TestGPTQWeightOnlyQuant(unittest.TestCase): @classmethod