From 2562f29842e3eac4a28d11ca4502376375b893bf Mon Sep 17 00:00:00 2001 From: xinhe Date: Tue, 11 Jul 2023 16:27:46 +0800 Subject: [PATCH] integrate and enhance AWQ algorithm (#1052) * support AWQ algorithm Signed-off-by: Xin He * enhance input Catcher and loss module collection Signed-off-by: Xin He * combine scale and clip Signed-off-by: Xin He * fix bug Signed-off-by: Xin He * fix accuracy issue Signed-off-by: Xin He * fix fallback bug Signed-off-by: Xin He * enhance UT Signed-off-by: Xin He * fix ut Signed-off-by: Xin He * enhance log Signed-off-by: Xin He * add assert Signed-off-by: Xin He * add n_blocks args Signed-off-by: Xin He * enhance log Signed-off-by: Xin He * fix bug Signed-off-by: Xin He * add document Signed-off-by: Xin He * fix spell Signed-off-by: Xin He --------- Signed-off-by: Xin He --- .../scripts/codeScan/pyspelling/inc_dict.txt | 3 + docs/source/quantization_weight_only.md | 46 +- neural_compressor/adaptor/pytorch.py | 109 +++- neural_compressor/adaptor/pytorch_cpu.yaml | 2 +- .../adaptor/torch_utils/smooth_quant.py | 48 +- .../adaptor/torch_utils/weight_only.py | 507 ++++++++++++++++-- .../test_weight_only_adaptor.py | 78 ++- .../test_weight_only_quantization.py | 107 +++- 8 files changed, 768 insertions(+), 132 deletions(-) diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt index f223b4ffdab..e201580dd53 100644 --- a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt +++ b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt @@ -2658,6 +2658,9 @@ RTN awq gptq percdamp +Frantar +Ji +mose DeQuantize FakeQuant FrameworkModel diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index 0d05f89bb15..f75fc950a48 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -18,28 +18,38 @@ Text generation: The most famous application of LLMs is text generation, which Besides, as mentioned in many papers[1][2], activation quantization is the main reason to cause the accuracy drop. So for text generation task, weight only quantization is a preferred option in most cases. +Theoretically, round-to-nearest (RTN) is the mose straightforward way to quantize weight using scale maps. However, when the number of bits is small (e.g. 3), the MSE loss is larger than expected. A group size is introduced to reduce elements using the same scale to improve accuracy. -## Supported Framework Model Matrix +There are many excellent works for weight only quantization to improve its accuracy performance, such as AWQ[3], GPTQ[4]. Neural compressor integrates these popular algorithms in time to help customers leverage them and deploy them to their own tasks. -| Framework | Weight-only | -| :---: | :---:| -| PyTorch | ✔ | -| ONNX | WIP | +## Supported Framework Model Matrix +| Algorithms/Framework | PyTorch | ONNX | +|:--------------:|:----------:|:----------:| +| RTN | ✔ | ✔ | +| AWQ | ✔ | stay tuned | +| GPTQ | stay tuned | stay tuned | ## Examples - -The quantization capability of weight-only approach is as follows: +### **Quantization Capability**: | Config | Capability | | :---: | :---:| | bits | [1-8] | | group_size | [-1, 1-N] | | scheme | ['asym', 'sym'] | -| algorithm | ['RTN', ] | +| algorithm | ['RTN', 'AWQ'] | + +**AWQ arguments**: +| awq_args | default value | comments | +|:----------:|:-------------:|:-------------------------------------------------------------------:| +| auto_scale | True | Whether search for best scales based on activation distribution | +| mse_range | True | Whether search for the best clip range from range [0.89, 1.0, 0.01] | +| n_blocks | 5 | Split the model into n blocks for AWQ search to avoid out-of-memory | + **Note**: `group_size=-1` indicates the per-channel quantization per output channel. `group_size=[1-N]` indicates splitting the input channel elements per group_size. -The use case code is as follows: +### **User code**: ```python conf = PostTrainingQuantConfig( approach='weight_only', @@ -53,11 +63,11 @@ conf = PostTrainingQuantConfig( }, }, }, - ### AWQ and GPTQ is WIP - # recipes={ - # 'gptq_args':{'percdamp': 0.01}, - # 'awq_args':{'alpha': 'auto', 'clip': True}, - # }, + ### GPTQ is WIP + recipes={ + # 'gptq_args':{'percdamp': 0.01}, + 'awq_args':{'auto_scale': True, 'mse_range': True, 'n_blocks': 5}, + }, ) q_model = quantization.fit(model, conf, eval_func=eval_func) q_model.save('saved_results') @@ -67,6 +77,10 @@ The saved_results folder contains two files: `best_model.pt` and `weight_config. ## Reference -[1]Xiao, Guangxuan, et al. "Smoothquant: Accurate and efficient post-training quantization for large language models." arXiv preprint arXiv:2211.10438 (2022). +[1]. Xiao, Guangxuan, et al. "Smoothquant: Accurate and efficient post-training quantization for large language models." arXiv preprint arXiv:2211.10438 (2022). + +[2]. Wei, Xiuying, et al. "Outlier suppression: Pushing the limit of low-bit transformer language models." arXiv preprint arXiv:2209.13325 (2022). + +[3]. Lin, Ji, et al. "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration." arXiv preprint arXiv:2306.00978 (2023). -[2]Wei, Xiuying, et al. "Outlier suppression: Pushing the limit of low-bit transformer language models." arXiv preprint arXiv:2209.13325 (2022). +[4]. Frantar, Elias, et al. "Gptq: Accurate post-training quantization for generative pre-trained transformers." arXiv preprint arXiv:2210.17323 (2022). \ No newline at end of file diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 9bce77d3a83..298299dbaf3 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4480,14 +4480,14 @@ def __init__(self, framework_specific_info): self.optype_statistics = None @dump_elapsed_time("Pass quantize model") - def quantize(self, tune_cfg, model, dataloader, q_func=None): + def quantize(self, tune_cfg, model, dataloader, calib_func=None): """Execute the quantize process on the specified model. Args: tune_cfg (dict): quantization config. model (object): model need to do quantization. dataloader (object): calibration dataset. - q_func (objext, optional): training function for quantization aware training mode. + calib_func (objext, optional): calibration function for ease-of-use. Returns: (object): quantized model @@ -4510,9 +4510,22 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): self.tune_cfg["framework"] = "pytorch" assert self.approach=='post_training_weight_only', "Please make sure the approach is weight_only" - q_model._model = self.rtn_quantize(q_model._model, tune_cfg) - q_model._model = self.gptq_quantize(q_model._model, tune_cfg, dataloader) - q_model._model = self.awq_quantize(q_model._model, tune_cfg, dataloader) + all_algo = set() + for key, config in tune_cfg['op'].items(): + op_name, op_type = key + if config['weight']['dtype'] == 'fp32': + continue + else: + algorithm = config['weight']['algorithm'] + all_algo.add(algorithm) + + if 'GPTQ' in all_algo: + q_model._model = self.gptq_quantize(q_model._model, tune_cfg, dataloader) + + if 'AWQ' in all_algo: # includes RTN in AWQ + q_model._model = self.awq_quantize(q_model._model, tune_cfg, dataloader, calib_func) + elif 'RTN' in all_algo: + q_model._model = self.rtn_quantize(q_model._model, tune_cfg) q_model.q_config = copy.deepcopy(self.tune_cfg) q_model.is_quantized = True @@ -4546,13 +4559,89 @@ def gptq_quantize(self, model, tune_cfg, dataloader): # TODO: implementation return model - def awq_quantize(self, model, tune_cfg, dataloader): + def awq_quantize(self, model, tune_cfg, dataloader, calib_func): logger.debug("quantizing with the AWQ algorithm") - # set default value if has args in recipes, else we use function + from .torch_utils.weight_only import awq_quantize + # get example inputs if not provided. + if self.example_inputs is None: + if dataloader is None: + assert False, "Please provide dataloader or example_inputs for AWQ algorithm." + try: + for idx, (input, label) in enumerate(dataloader): + self.example_inputs = input + break + except: + for idx, input in enumerate(dataloader): + self.example_inputs = input + break + + # get modules that can be absorbed. + from .torch_utils.smooth_quant import GraphTrace + tg = GraphTrace() + supported_layers = ['Linear'] + absorb_to_layer, _ = tg.get_absorb_to_layer(model, self.example_inputs, supported_layers) + if absorb_to_layer is None or absorb_to_layer == {}: + logger.warning('No absorb layer is detected, skip AWQ algorithm') + return model + + # got flipped dict from absorb_to_layer dict + flipped_dict = {} + for k, v in absorb_to_layer.items(): + for m in v: + flipped_dict[m] = {'absorb_layer': k} + + # check tune_cfg to skip layers without AWQ config + skipped_op_name_set = set() + for key, config in tune_cfg['op'].items(): + op_name, op_type = key + if config['weight']['dtype'] == 'fp32': + if op_name in flipped_dict: + absorb_to_layer.pop(flipped_dict[op_name]['absorb_layer']) + continue + else: + if op_name in flipped_dict: + flipped_dict[op_name]['bits'] = config['weight']['bits'] + flipped_dict[op_name]['group_size'] = config['weight']['group_size'] + flipped_dict[op_name]['scheme'] = config['weight']['scheme'] + algorithm = config['weight']['algorithm'] + if algorithm != 'AWQ': + if op_name in flipped_dict: + absorb_to_layer.pop(flipped_dict[op_name]['absorb_layer']) + else: + skipped_op_name_set.add(op_name) + if skipped_op_name_set: + logger.info("{} is skipped by AWQ algorithm".format(skipped_op_name_set)) + + # collect AWQ config from tune_cfg for quantization. + weight_config = {} + if len(absorb_to_layer) == 0: + logger.warning('No absorb layer needs AWQ algorithim, skip it') + else: + logger.debug("**absorb layer**: **absorbed layers**") + for k, v in absorb_to_layer.items(): + logger.debug(f"{k}: {v}") + for m in v: + weight_config[m] = flipped_dict[m] + logger.info("Absorbed layers with the same absorb layer use the same config") + if 'awq_args' in self.recipes: - alpha = self.recipes['awq_args'].get('alpha', 'auto') - # AWQ(model, dataloader, w_bit, group_size, alpha='auto', clip=True) - # TODO: implementation + auto_scale = self.recipes['awq_args'].get('auto_scale', True) + mse_range = self.recipes['awq_args'].get('mse_range', True) + n_blocks = self.recipes['awq_args'].get('n_blocks', 5) + else: + auto_scale, mse_range = True, True + calib_sampling_size = tune_cfg.get('calib_sampling_size', 1) + model = awq_quantize( + model, + weight_config=weight_config, + absorb_dict=absorb_to_layer, + dataloader=dataloader, + n_samples=calib_sampling_size, + auto_scale=auto_scale, + mse_range=mse_range, + calib_func=calib_func, + n_blocks=n_blocks, + ) return model def _dump_model_op_stats(self, model, tune_cfg): diff --git a/neural_compressor/adaptor/pytorch_cpu.yaml b/neural_compressor/adaptor/pytorch_cpu.yaml index d62bd63d44c..f4c0416ec8a 100644 --- a/neural_compressor/adaptor/pytorch_cpu.yaml +++ b/neural_compressor/adaptor/pytorch_cpu.yaml @@ -267,7 +267,7 @@ # group_size=-1 means per-channel, others means per-group 'group_size': [32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], # [1-inf], # 32 'scheme': ['sym', 'asym'], # sym, no ZP - 'algorithm': ['RTN'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order + 'algorithm': ['RTN', 'AWQ', 'GPTQ'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order }, 'activation': { 'dtype': ['fp32'], diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index c33b89cc4ed..00a383e4714 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -698,21 +698,7 @@ def transform(self, alpha=0.5, folding=False, percentile=99.999, op_types=['Line self.absorb_to_layer, no_absorb_layers = self._trace( op_types) ##TODO we need to insert mul layer for no_absorb_layers later if self.absorb_to_layer == None and no_absorb_layers == None: - logger.warning("sorry, could not trace the model, smooth quant is skipped") - logger.warning("if you are using huggingface model," - "you could set torchscript to True " - "when loading the model or set the return_dict to False") return self.model - elif self.absorb_to_layer == {}: - logger.warning("could not find any layer to be absorbed") - else: - to_absorb_cnt = 0 - for key, item in self.absorb_to_layer.items(): - to_absorb_cnt += len(item) - - logger.info( - f" {to_absorb_cnt} out of {to_absorb_cnt + len(no_absorb_layers)} " - f"layers could be absorbed in smooth quant") # remove self.self_absorb_layers if it exists in self.absorb_to_layer for k, v in self.absorb_to_layer.items(): @@ -832,6 +818,20 @@ def _trace(self, op_types): tg = GraphTrace() self._get_example_input() absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(self.traced_model, self.example_inputs, op_types) + if absorb_to_layer == None and no_absorb_layers == None: + logger.warning("sorry, could not trace the model, smooth quant is skipped") + logger.warning("if you are using huggingface model," + "you could set torchscript to True " + "when loading the model or set the return_dict to False") + elif absorb_to_layer == {}: + logger.warning("could not find any layer to be absorbed") + else: + to_absorb_cnt = 0 + for key, item in absorb_to_layer.items(): + to_absorb_cnt += len(item) + logger.info( + f" {to_absorb_cnt} out of {to_absorb_cnt + len(no_absorb_layers)} " + f"layers could be absorbed in smooth quant") return absorb_to_layer, no_absorb_layers @@ -884,8 +884,9 @@ def trace(self, model, dummy_input): try: traced_model = torch.jit.trace(model, example_kwarg_inputs=dict(dummy_input), strict=False) traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) - except: - pass + except Exception as e: + logger.warning(e) + logger.info("Jit trace in GraphTrace failed, absorb layer detection is skipped") else: try: traced_model = torch.jit.trace(model, dummy_input, strict=False) @@ -894,8 +895,9 @@ def trace(self, model, dummy_input): try: traced_model = torch.jit.trace(model, dummy_input[0], strict=False) traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) - except: - pass + except Exception as e: + logger.warning(e) + logger.info("Jit trace in GraphTrace failed, absorb layer detection is skipped") return traced_model def get_nodes(self, traced_model, op_types=['Linear']): @@ -970,7 +972,8 @@ def get_absorb_to_layer(self, model, example_input, op_types): no_absorb_layers = [] for index, absorb in enumerate(nodes_prev_absorb): if absorb == None: - no_absorb_layers.append(nodes[index]) + no_absorb_layers.append( + '.'.join(nodes[index].scopeName().split('/')[-1].split('.')[1:])) continue node = nodes[index] layer_name = '.'.join(node.scopeName().split('/')[-1].split('.')[1:]) @@ -981,17 +984,17 @@ def get_absorb_to_layer(self, model, example_input, op_types): absorb_to_layer[absorb_name].append(layer_name) else: absorb_to_layer[absorb_name] = [layer_name] - absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer) + absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) return absorb_to_layer, no_absorb_layers - def remove_unsupported_layers(self, model, absorb_to_layer): + def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers): res = {} for key in absorb_to_layer.keys(): - absorb_layer = get_module(model, key) layer_type = absorb_layer.__class__.__name__ if layer_type not in self.supported_torch_module_to_aten.keys(): + no_absorb_layers.extend(absorb_to_layer[key]) continue supported = True for layer_name in absorb_to_layer[key]: @@ -999,6 +1002,7 @@ def remove_unsupported_layers(self, model, absorb_to_layer): layer_type = layer.__class__.__name__ if layer_type not in self.supported_torch_module_to_aten.keys(): supported = False + no_absorb_layers.extend(absorb_to_layer[key]) break if supported: res[key] = absorb_to_layer[key] diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index cd802ace957..022319b446f 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -1,3 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 MIT HAN Lab +# This source code is licensed under the MIT license +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import OrderedDict from ...utils import logger from ...utils.utility import LazyImport @@ -5,16 +27,23 @@ torch = LazyImport("torch") -def qdq_weight_asym(weight, num_bits=4): - """quant and dequant tensor with asym schema - :param weight: input weight - :param num_bits: num_bits - :return: qdq weight +def qdq_weight_asym(weight, num_bits=4, quantile=1.0): + """Quant and dequant tensor with asym schema. + + Args: + weight: input weight + num_bits (int, optional): num_bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: qdq weight """ maxq = torch.tensor(2 ** num_bits - 1) zeros = torch.zeros(weight.shape[0], device=weight.device) wmin = torch.minimum(weight.min(1)[0], zeros) wmax = torch.maximum(weight.max(1)[0], zeros) + wmin = wmin * quantile + wmax = wmax * quantile tmp = (wmin == 0) & (wmax == 0) wmin[tmp] = -1 wmax[tmp] = +1 @@ -26,11 +55,16 @@ def qdq_weight_asym(weight, num_bits=4): return scale * (q - zp) -def qdq_weight_sym(weight, num_bits=4): - """quant and dequant tensor with sym schema - :param weight: input weight - :param num_bits: num_bits - :return: qdq weight +def qdq_weight_sym(weight, num_bits=4, quantile=1.0): + """Quant and dequant tensor with sym schema. + + Args: + weight : input weight + num_bits (int, optional): num_bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: qdq weight """ # assert num_bits > 1, "symmetric scheme only supports num_bits > 1" maxq = torch.tensor(2 ** (num_bits - 1) - 1).to(weight.device) @@ -40,6 +74,7 @@ def qdq_weight_sym(weight, num_bits=4): minq = torch.tensor(2 ** (num_bits - 1) - 1) wmax = torch.abs(weight).max(1)[0] + wmax = wmax * quantile tmp = (wmax == 0) wmax[tmp] = +1 scale = wmax / ((maxq - minq) / 2) @@ -48,56 +83,80 @@ def qdq_weight_sym(weight, num_bits=4): return scale * q -def qdq_weight_actor(weight, num_bits, scheme): - """quant and dequant tensor per channel - :param weight: input weight - :param num_bits: num_bits - :param scheme: sym or asym - :return: qdq weight +def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0): + """Quant and dequant tensor per channel. + + Args: + weight : input weight + num_bits (int, optional): num_bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: qdq weight """ assert num_bits > 0, "num_bits should be larger than 0" if scheme == "sym": - return qdq_weight_sym(weight, num_bits) + return qdq_weight_sym(weight, num_bits, quantile) else: - return qdq_weight_asym(weight, num_bits) - -def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym"): - """quant and dequant tensor with group size - :param weight: input weight - :param num_bits: num_bits - :param group_size: how many elements share one scale/zp - :param scheme: sym or asym - :return: qdq weight + return qdq_weight_asym(weight, num_bits, quantile) + + +def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0): + """Quant and dequant tensor with group size. + + Args: + weight: input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to -1. + scheme (str, optional): sym or asym. Defaults to "asym". + quantile (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: qdq weight. """ if group_size == -1 or weight.shape[1] < group_size: - return qdq_weight_actor(weight, num_bits, scheme=scheme) + return qdq_weight_actor(weight, num_bits, scheme=scheme, quantile=quantile) orig_shape = weight.shape if weight.shape[1] % group_size == 0: weight = weight.reshape(-1, group_size) - weight = qdq_weight_actor(weight, num_bits, scheme=scheme) + weight = qdq_weight_actor(weight, num_bits, scheme=scheme, quantile=quantile) weight = weight.reshape(orig_shape) return weight else: split_index = weight.shape[1] // group_size * group_size weight1 = weight[:, :split_index] weight1 = weight1.reshape(-1, group_size) - weight1 = qdq_weight_actor(weight1, num_bits, scheme=scheme) + weight1 = qdq_weight_actor(weight1, num_bits, scheme=scheme, quantile=quantile) weight1 = weight1.reshape(orig_shape[0], split_index) weight2 = weight[:, split_index:] - weight2 = qdq_weight_actor(weight2, num_bits, scheme=scheme) + weight2 = qdq_weight_actor(weight2, num_bits, scheme=scheme, quantile=quantile) weight = torch.cat([weight1, weight2], dim=1) return weight -def rtn_quantize(model, num_bits, group_size=-1, scheme="asym", w_layers_config={}): - """ quant the model with round to nearst method - :param model: torch module - :param num_bits: num bits - :param group_size: how many elements share one scale/zp - :param scheme: sym or asym - :param w_layers_config: specific layer wise configirations {"layer_name":[num_bits,group_size,schema]} - :return: +def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym", quantile=1.0, weight_config={}): + """Quant the model with round to nearst method. + + Args: + model: torch module + num_bits: num bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 32. + scheme (str, optional): sym or asym. Defaults to "asym". + quantile (float, optional): percentile of clip. Defaults to 1.0. + weight_config (dict, optional): specific layer wise configirations. Defaults to {}. + For example, + weight_config={ + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym' + } + } + + Returns: + model: fake quantized torch module """ assert isinstance(model, torch.nn.Module), "only support torch module" assert num_bits > 0, "bit for weight only should large than zero!" @@ -106,12 +165,14 @@ def rtn_quantize(model, num_bits, group_size=-1, scheme="asym", w_layers_config= for n, m in model.named_modules(): if m.__class__.__name__ not in supported_layers: continue - if n in w_layers_config: # pragma: no cover - num_bits = w_layers_config[n][0] - group_size = w_layers_config[n][1] - scheme = w_layers_config[n][2] + if n in weight_config: # pragma: no cover + num_bits = weight_config[n]['bits'] + group_size = weight_config[n]['group_size'] + scheme = weight_config[n]['scheme'] + quantile = weight_config[n].get('quantile', 1.0) logger.debug(f"RTN quantized module:{n, m}") - logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, scheme={scheme}") + logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + \ + f"scheme={scheme}, quantile={quantile}") if num_bits <= 0: logger.info(f"skip {n}") continue @@ -122,9 +183,367 @@ def rtn_quantize(model, num_bits, group_size=-1, scheme="asym", w_layers_config= weight = weight.reshape(weight.shape[0], -1) else: weight = m.weight - q_weight = quant_weight(weight, num_bits, group_size, scheme) + q_weight = quant_weight(weight, num_bits, group_size, scheme, quantile) if m.__class__.__name__ == "Conv2d": q_weight = q_weight.reshape(orig_shape[1], orig_shape[0], orig_shape[2], orig_shape[3]) q_weight = q_weight.permute(1, 0, 2, 3) m.weight.data.copy_(q_weight) return model + + +def get_module_input_output(model, module_hook_config={}, dataloader=None, iters=-1, + calib_func=None): + """A help function to get input and output tensor of modules in module_name_list. + + Args: + model: torch model. + module_hook_config (dict, optional): required module name for input/output. Defaults to {}. + For example: + module_hook_config = { + 'fc1': ['output'], + 'fc2': ['input', 'output'] + } + dataloader: dataloader for model input. + iters: iterations for inference. + calib_func: a custom inference function to replace dataloader and iters. + + Returns: + input_values, output_values: recorded input_values, output_values. + """ + input_values, output_values = {}, {} + def _save_input_output_hook(name): + """ + A forward hook to save input and output values of a module + param name: the module name + return: A hook function + """ + def save_input_hook(module, inputs): + input = inputs[0] + if name in input_values: + try: + input_values[name] = torch.cat((input_values[name], input), 0) + except Exception as e: + logger.error(e) + assert False, "Please unify the input shape for AWQ algorithm calibration." + else: + input_values[name] = input + def save_output_hook(module, inputs, outputs): + if isinstance(outputs, tuple): + outputs = outputs[0] + if name in output_values: + try: + output_values[name] = torch.cat((output_values[name], outputs), 0) + except Exception as e: + logger.error(e) + assert False, "Please unify the input shape for AWQ algorithm calibration." + else: + output_values[name] = outputs + return save_input_hook, save_output_hook + + hook_list = [] + for name, module in model.named_modules(): + if name in module_hook_config: + save_input_hook, save_output_hook = _save_input_output_hook(name) + require_list = module_hook_config[name] + if 'input' in require_list: + hook_list.append( + module.register_forward_pre_hook(save_input_hook)) + if 'output' in require_list: + hook_list.append( + module.register_forward_hook(save_output_hook)) + if calib_func: + calib_func(model) + else: + from .smooth_quant import model_forward + model_forward(model, dataloader, iters, device='cpu') + for h in hook_list: + h.remove() + return input_values, output_values + + +@torch.no_grad() +def _get_weight_scale(weight, q_group_size=-1): + org_shape = weight.shape + if q_group_size > 0: + weight = weight.view(-1, q_group_size) + scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) + scale = scale.view(org_shape) + scale = scale.mean(0) + return scale + + +@torch.no_grad() +def _get_act_scale(x): + return x.abs().view(-1, x.shape[-1]).mean(0) + + +def _update_input_with_scale(args, kwargs, scales): + new_args, new_kwargs = args, kwargs + for i, v in enumerate(args): + if isinstance(v, torch.Tensor): + try: + new_args[i] = torch.div(v, scales.view(1, 1, -1)) + except: + new_args[i] = v + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + try: + new_kwargs[k] = torch.div(v, scales.view(1, 1, -1)) + except: + new_kwargs[k] = v + return new_args, new_kwargs + + +@torch.no_grad() +def awq_quantize(model, weight_config={}, absorb_dict={}, dataloader=None, n_samples=128, + auto_scale=True, mse_range=True, calib_func=None, n_blocks=5): + """Quant the model with Activation-aware Weight quantization(AWQ) method. + + Args: + model (torch.nn.Module): torch model. + weight_config (dict, optional): contains all info required by AWQ. Defaults to {}. + For example, + weight_config={ + 'fc2': + { + # 'absorb_layer': 'fc1', + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym' + } + } + absorb_dict (dict, optional): contains all absorb info required by AWQ.. Defaults to {}. + For example, + absorb_dict = { + # 'absorb_layer': absorbed_layer + 'fc1': ['fc2', 'fc3'] + } # in this case, fc2 and fc3 need to share the same scale. + n_samples: calibration sample number. + auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True. + mse_range (bool, optional): whether enable clip for weight by checking mse. Defaults to True. + calib_func: a custom inference function to replace dataloader and iters. + n_blocks: split model into block number to avoid OOM. + + Returns: + model: fake quantized model + """ + assert isinstance(model, torch.nn.Module), "only support torch module" + # collect module names to record their input/output for loss calculation. + module_for_loss_dict = OrderedDict() + # get the upper module if absorbed modules have the same absorb module. + for absorb, absorbed in absorb_dict.items(): + # used as input for absob module + module_for_loss_dict[absorb] = set() + if len(absorbed) > 1: + split_dict = {} + split_absorb = absorb.split('.') + for ab in absorbed: + split_dict[ab] = ab.split('.') + for ab, sp_list in split_dict.items(): + for i in range(len(sp_list)): + if sp_list[i] != split_absorb[i]: + group_name = '.'.join(sp_list[:i+1]) + module_for_loss_dict[absorb].add(group_name) + break + else: + module_for_loss_dict[absorb].add(absorbed[0]) + + def calibration(module_name_list): + if calib_func: + input_values, output_values = get_module_input_output( + model, + module_name_list, + calib_func=calib_func, + ) + else: + batch_size = dataloader.batch_size + iters = int(math.ceil(n_samples / batch_size)) + if n_samples % batch_size != 0: + logger.info("calibration samples increase from {} to {} due to batch_size is {}".format( + n_samples, iters*batch_size, batch_size, + )) + input_values, output_values = get_module_input_output( + model, + module_name_list, + dataloader=dataloader, + iters=iters, + ) + return input_values, output_values + + layer_args = {} + layer_kwargs = {} + # to fetch kwargs which torch hook cannot handle + class Catcher(torch.nn.Module): + def __init__(self, module, name): + super().__init__() + self.module = module + self.module_name = name + + def forward(self, *args, **kwargs): + if self.module_name not in layer_args: + layer_args[self.module_name] = list(args) + layer_kwargs[self.module_name] = kwargs + else: + # to concat different batches + for i, v in enumerate(args): + if isinstance(v, torch.Tensor): + layer_args[self.module_name][i] = \ + torch.cat([layer_args[self.module_name][i], v], dim=0) + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + layer_kwargs[self.module_name][k] = \ + torch.cat([layer_kwargs[self.module_name][k], v], dim=0) + return self.module.forward(*args, **kwargs) + + if auto_scale or mse_range: + from .util import fetch_module, set_module + block_num = n_blocks + module_num = math.ceil(len(absorb_dict) / block_num) + logger.info(f"AWQ search splits the model into {block_num} blocks to avoid OOM, " +\ + f"each block contains {module_num} modules") + for idx, (absorb, absorbed) in enumerate(absorb_dict.items()): + # Split module_name_list to avoid OOM when recording output tensors. + if idx % module_num == 0: + layer_args = {} + layer_kwargs = {} + part_module_hook_config = {} + logger.info(f"AWQ search calibration round {idx//module_num + 1}") + for id, name in enumerate(module_for_loss_dict): + if idx <= id < (idx + module_num): + part_module_hook_config[name] = ['output'] # fetch input tensor + for i in module_for_loss_dict[name]: + # use Catcher to record input args and kwargs + tmp_module = fetch_module(model, i) + new_module = Catcher(tmp_module, i) + set_module(model, i, new_module) + part_module_hook_config[i] = ['output'] # fetch output tensor + input_values, output_values = calibration(part_module_hook_config) + # recover Catcher + for id, name in enumerate(module_for_loss_dict): + if idx <= id < (idx + module_num): + for i in module_for_loss_dict[name]: + # use Catcher to record input args and kwargs + tmp_module = fetch_module(model, i) + set_module(model, i, tmp_module.module) + + logger.info(f"Processing module: {absorb}:{absorbed}") + weight = torch.cat([fetch_module(model, _m).weight for _m in absorbed], dim=0) + w_max = _get_weight_scale( + weight, q_group_size=weight_config[absorbed[0]]['group_size']) + del weight + x_max = _get_act_scale(output_values[absorb]) + absorbed_modules = {_m: fetch_module(model, _m) for _m in absorbed} + org_stat = {_m: module.state_dict() for _m, module in absorbed_modules.items()} + org_out = {} + for loss_module_name in module_for_loss_dict[absorb]: + out = output_values[loss_module_name] + if isinstance(out, tuple): + out = out[0] + org_out[loss_module_name] = out + for loss_module_name in module_for_loss_dict[absorb]: + blockes = {_m: fetch_module(model, _m) for _m in module_for_loss_dict[absorb]} + + if auto_scale: + logger.info("Searching best scales with AWQ algorithm") + best_error = float('inf') + best_scales = None + best_scale_alpha = None + n_grid = 20 + history = [] + + for ratio in range(n_grid): + ratio = ratio * 1 / n_grid + scales = (x_max.pow(ratio) / w_max.pow(1-ratio) + ).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + for name, module in absorbed_modules.items(): + module.weight.data = module.weight.data.mul(scales.view(1, -1)) + module.weight.data = quant_weight( + module.weight.data, + num_bits=weight_config[name]['bits'], + group_size=weight_config[name]['group_size'], + scheme=weight_config[name]['scheme'], + ) / scales.view(1, -1) + + loss = 0 + for name, block in blockes.items(): + out = block(*layer_args[name], **layer_kwargs[name]) + if isinstance(out, tuple): + out = out[0] + loss += (org_out[name] - out).float().pow(2).mean().item() # float prevents overflow + history.append(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_scales = scales + best_scale_alpha = ratio + for name, module in absorbed_modules.items(): + module.load_state_dict(org_stat[name]) + + logger.debug("The loss history of different scale:{}".format(history)) + logger.debug("The best alpha for scale: {}:{}".format(absorb, best_scale_alpha)) + best_scales = best_scales.view(-1) + assert torch.isnan(best_scales).sum() == 0, best_scales + scales = best_scales.detach() + # update absorb model + absorb_module = fetch_module(model, absorb) + if len(absorb_module.weight.shape) == 1: + absorb_module.weight.div_(scales) # for LayerNorm + else: + absorb_module.weight.div_(scales.view(-1, 1)) + if absorb_module.bias is not None: + absorb_module.bias.div_(scales.view(-1)) + # update absorbed model + for name in absorbed: + absorbed_module = fetch_module(model, name) + absorbed_module.weight.mul_(scales.view(1, -1)) + + if mse_range: + logger.info("Searching the best clip range with AWQ algorithm") + best_error = float('inf') + best_clip_ratio = None + n_grid = 100 + max_shrink = 0.1 + history = [] + org_stat = {_m: module.state_dict() for _m, module in absorbed_modules.items()} + + for name, module in absorbed_modules.items(): + for i_s in range(int(max_shrink * n_grid)): + ratio = (1 - i_s / n_grid) # 1, 0.95-0.55 + module.weight.data = quant_weight( + module.weight.data, + num_bits=weight_config[name]['bits'], + group_size=weight_config[name]['group_size'], + scheme=weight_config[name]['scheme'], + quantile=ratio, + ) + + loss = 0 + for n, block in blockes.items(): + if n in name: + # preprocess input with existing scale + if auto_scale: + new_args, new_kwargs = _update_input_with_scale( + layer_args[n], layer_kwargs[n], scales) + else: + new_args, new_kwargs = layer_args[n], layer_kwargs[n] + out = block(*new_args, **new_kwargs) + if isinstance(out, tuple): + out = out[0] + loss += (org_out[n] - out).float().pow(2).mean().item() # float prevents overflow + history.append(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_clip_ratio = ratio + module.load_state_dict(org_stat[name]) + + logger.debug("The loss history of different clip range:{}".format(history)) + weight_config[name]['quantile'] = best_clip_ratio + logger.debug("The best clip ratio for {}:{}".format(name, best_clip_ratio)) + + # apply quantization and clip + logger.info("Quantizing the AWQ optimized fp32 model") + model = rtn_quantize(model, weight_config=weight_config) + logger.info("AWQ quantization is done.") + return model diff --git a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py index a6edf6a0dce..f2201437592 100644 --- a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py @@ -1,6 +1,7 @@ import shutil import torch import unittest +import transformers from neural_compressor import quantization, PostTrainingQuantConfig @@ -27,19 +28,42 @@ def eval_func(model): output = model(input) return 0.0 +class SimpleDataLoader(): + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.randn([1, 30]) + + +class LLMDataLoader(): + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.ones([1, 10], dtype=torch.long) + class TestPytorchWeightOnlyAdaptor(unittest.TestCase): approach = 'weight_only' + @classmethod + def setUpClass(self): + self.dataloader = SimpleDataLoader() + self.gptj = transformers.AutoModelForCausalLM.from_pretrained( + 'hf-internal-testing/tiny-random-GPTJForCausalLM', + torchscript=True, + ) + self.llm_dataloader = LLMDataLoader() + self.lm_input = torch.ones([1, 10], dtype=torch.long) + @classmethod def tearDownClass(self): shutil.rmtree("./saved", ignore_errors=True) shutil.rmtree("runs", ignore_errors=True) - def test_RTN_func(self): - # TODO - pass - def test_RTN_quant(self): input = torch.randn(3,30) model = Model() @@ -65,10 +89,6 @@ def test_RTN_quant(self): }, }, }, - recipes={ - 'gptq_args':{'percdamp': 0.01}, - 'awq_args':{'alpha': 'auto', 'clip': True}, - }, ) q_model = quantization.fit(model, conf, eval_func=eval_func) out2 = q_model(input) @@ -87,10 +107,6 @@ def test_RTN_quant(self): }, }, }, - recipes={ - 'gptq_args':{'percdamp': 0.01}, - 'awq_args':{'alpha': 'auto', 'clip': True}, - }, ) q_model = quantization.fit(model, conf, eval_func=eval_func) out2 = q_model(input) @@ -122,10 +138,6 @@ def test_RTN_quant(self): }, }, }, - recipes={ - 'gptq_args':{'percdamp': 0.01}, - 'awq_args':{'alpha': 'auto', 'clip': True}, - }, ) q_model = quantization.fit(model, conf, eval_func=eval_func) out2 = q_model(input) @@ -137,6 +149,40 @@ def test_RTN_quant(self): out1 = new_model(input) self.assertTrue(torch.all(out1 == out2)) + def test_AWQ_quant(self): + input = torch.randn(3,30) + model = Model() + out1 = model(input) + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, # -1 (per-channel) + 'scheme': 'sym', + 'algorithm': 'AWQ', + }, + }, + }, + op_name_dict={ + '.*lm_head':{ # re.match + "weight": { + 'dtype': 'fp32' + }, + }, + }, + recipes={ + 'awq_args':{'auto_scale': True, 'mse_range': True, 'n_blocks': 2}, + }, + ) + q_model = quantization.fit( + self.gptj, + conf, + calib_dataloader=self.llm_dataloader, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_weight_only_quantization.py b/test/quantization/test_weight_only_quantization.py index 549a8a7d112..3ad893c39d4 100644 --- a/test/quantization/test_weight_only_quantization.py +++ b/test/quantization/test_weight_only_quantization.py @@ -1,40 +1,101 @@ import unittest import copy import torch -from neural_compressor.adaptor.torch_utils.weight_only import rtn_quantize +from neural_compressor.adaptor.torch_utils.weight_only import rtn_quantize, awq_quantize +from neural_compressor.adaptor.torch_utils.smooth_quant import GraphTrace +import transformers -class TestWeightOnlyQuant(unittest.TestCase): +class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(32, 64) + self.fc2 = torch.nn.Linear(64, 32) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + return out + + +class SimpleDataLoader(): + def __init__(self): + self.batch_size = 1 + self.input = torch.randn([1, 32]) + + def __iter__(self): + yield self.input + + +class TestAWQWeightOnlyQuant(unittest.TestCase): @classmethod def setUpClass(self): - class Model(torch.nn.Module): - def __init__(self): - super(Model, self).__init__() - self.conv1 = torch.nn.Conv2d(3, 4, 2, 2) - self.act = torch.nn.ReLU6() - self.conv2 = torch.nn.Conv2d(4, 10, 3, 3) - - def forward(self, x): - out = self.conv1(x) - out = self.act(out) - out = self.conv2(out) + x - return out - self.model = Model() + self.dataloader = SimpleDataLoader() + self.example_inputs = torch.randn([1, 32]) + self.gptj = transformers.AutoModelForCausalLM.from_pretrained( + 'hf-internal-testing/tiny-random-GPTJForCausalLM', + torchscript=True, + ) + self.lm_input = torch.ones([1, 10], dtype=torch.long) - @classmethod - def tearDownClass(self): - pass + def test_trace(self): + op_types = ['Linear'] + tg = GraphTrace() + # absorb_to_layer={'absorb_layer': absorbed_layer} + absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(self.model, self.example_inputs, op_types) + self.assertTrue(len(no_absorb_layers) == 1) + absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(self.gptj, self.lm_input, op_types) + self.assertTrue(len(no_absorb_layers) == 11) + return absorb_to_layer, no_absorb_layers - def test_conv(self): + def test_rtn(self): fp32_model = copy.deepcopy(self.model) model1 = rtn_quantize(fp32_model, num_bits=3, group_size=-1) - w_layers_config = { + weight_config = { # 'op_name': (bit, group_size, sheme) - 'conv1': (8, 128, 'sym'), - 'conv2': (4, 32, 'asym') + 'fc1': { + 'bits': 8, + 'group_size': -1, + 'scheme': 'sym' + }, + 'fc2': { + 'bits': 4, + 'group_size': 32, + 'scheme': 'asym', + 'quantile': 0.95, # not required. + }, + } + model2 = rtn_quantize(fp32_model, weight_config=weight_config) + + + def test_awq(self): + fp32_model = copy.deepcopy(self.model) + weight_config = { + # 'op_name': (bit, group_size, sheme) + 'fc1': { + 'bits': 8, + 'group_size': -1, + 'scheme': 'sym' + }, + 'fc2': { + 'bits': 4, + 'group_size': 32, + 'scheme': 'asym' + }, + } + absorb_dict = { + 'fc1': ['fc2'] } - model2 = rtn_quantize(fp32_model, num_bits=3, group_size=-1, w_layers_config=w_layers_config) + model1 = awq_quantize( + fp32_model, + weight_config=weight_config, + absorb_dict=absorb_dict, + dataloader=self.dataloader, + n_samples=128, + auto_scale=True, + mse_range=True, + ) if __name__ == "__main__":