From 85931587d6fb9fd10d16e5c750dc5fdc519bda73 Mon Sep 17 00:00:00 2001 From: xinhe Date: Tue, 18 Jul 2023 17:26:52 +0800 Subject: [PATCH] support full range for 4 bit sym -8 (#1083) Signed-off-by: Xin He --- neural_compressor/adaptor/pytorch.py | 12 ++- .../adaptor/torch_utils/weight_only.py | 97 +++++++++++++------ neural_compressor/config.py | 7 ++ .../test_weight_only_adaptor.py | 19 ++++ 4 files changed, 107 insertions(+), 28 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 6878c4c05a4..85784256db2 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4534,6 +4534,10 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None): def rtn_quantize(self, model, tune_cfg): logger.debug("quantizing with the round-to-nearest algorithm") + if 'rtn_args' in self.recipes: + full_range = self.recipes['rtn_args'].get('full_range', False) + else: + full_range=False from .torch_utils.weight_only import rtn_quantize from .torch_utils.util import fetch_module for key, config in tune_cfg['op'].items(): @@ -4548,7 +4552,8 @@ def rtn_quantize(self, model, tune_cfg): if algorithm != 'RTN': continue m = fetch_module(model, op_name) - m = rtn_quantize(m, num_bits, group_size, scheme, return_int=False) + m = rtn_quantize(m, num_bits, group_size, scheme, + return_int=False, full_range=full_range) set_module(model, op_name, m) return model @@ -4645,6 +4650,10 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func): n_blocks = self.recipes['awq_args'].get('n_blocks', 5) else: auto_scale, mse_range = True, True + if 'rtn_args' in self.recipes: + full_range = self.recipes['rtn_args'].get('full_range', False) + else: + full_range=False calib_sampling_size = tune_cfg.get('calib_sampling_size', 1) model = awq_quantize( model, @@ -4657,6 +4666,7 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func): calib_func=calib_func, n_blocks=n_blocks, return_int=False, + full_range=full_range, ) return model diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index 44dff7e43eb..1fb900d6be0 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -60,7 +60,7 @@ def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False): return scale * (q - zp) -def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False): +def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_range=False): """Quant and dequant tensor with sym schema. Args: @@ -69,6 +69,11 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False): quantile (float, optional): percentile of clip. Defaults to 1.0. return_int (bool, optional): Choose return fp32 or int8/uint8 data. Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + For example: 4 bit + scale = amax / 8 if full_range else amax / 7 + If True, scale = -scale if abs(min)> abs(max) else scale + Defaults to False. Returns: output: qdq weight @@ -79,12 +84,20 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False): if num_bits == 1: maxq = torch.tensor(2 ** (num_bits - 1)) minq = torch.tensor(2 ** (num_bits - 1) - 1) - - wmax = torch.abs(weight).max(1)[0] + max_val = torch.max(weight, 1)[0] + min_val = torch.min(weight, 1)[0] + flip_flag = torch.abs(min_val) > torch.abs(max_val) + wmax = torch.max(torch.abs(max_val), torch.abs(min_val)) wmax = wmax * quantile tmp = (wmax == 0) wmax[tmp] = +1 - scale = wmax / ((maxq - minq) / 2) + if full_range: + # use -8, 8 to make sure amax is not changed after fake quant + scale = wmax / (-minq) + tmp = scale * flip_flag.int() + scale -= 2*tmp # set negetive scale with flip_flag + else: + scale = wmax / maxq scale.unsqueeze_(dim=-1) q = torch.clamp(torch.round(weight / scale), minq, maxq) if return_int: @@ -92,7 +105,8 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False): return scale * q -def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, return_int=False): +def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, + return_int=False, full_range=False): """Quant and dequant tensor per channel. Args: @@ -101,18 +115,20 @@ def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, return_int=False): quantile (float, optional): percentile of clip. Defaults to 1.0. return_int (bool, optional): Choose return fp32 or int8/uint8 data. Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). Returns: output: qdq weight """ assert num_bits > 0, "num_bits should be larger than 0" if scheme == "sym": - return qdq_weight_sym(weight, num_bits, quantile, return_int) + return qdq_weight_sym(weight, num_bits, quantile, return_int, full_range) else: return qdq_weight_asym(weight, num_bits, quantile, return_int) -def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, return_int=False): +def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, + return_int=False, full_range=False): """Quant and dequant tensor with group size. Args: @@ -123,26 +139,32 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, quantile (float, optional): percentile of clip. Defaults to 1.0. return_int (bool, optional): Choose return fp32 or int8/uint8 data. Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). Returns: output: qdq weight. """ if group_size == -1 or weight.shape[1] < group_size: - return qdq_weight_actor(weight, num_bits, scheme=scheme, quantile=quantile, return_int=return_int) + return qdq_weight_actor(weight, num_bits, scheme=scheme, quantile=quantile, + return_int=return_int, full_range=full_range) orig_shape = weight.shape if weight.shape[1] % group_size == 0: weight = weight.reshape(-1, group_size) if return_int: weight, scale, zp = qdq_weight_actor( - weight, num_bits, scheme=scheme, quantile=quantile, return_int=True) + weight, num_bits, scheme=scheme, quantile=quantile, + return_int=True, full_range=full_range + ) weight = weight.reshape(orig_shape) scale = scale.reshape(orig_shape[0], -1) if zp is not None: zp = zp.reshape(orig_shape[0], -1) return weight, scale, zp else: - weight = qdq_weight_actor(weight, num_bits, scheme=scheme, quantile=quantile) + weight = qdq_weight_actor( + weight, num_bits, scheme=scheme, quantile=quantile, full_range=full_range + ) return weight.reshape(orig_shape) else: split_index = weight.shape[1] // group_size * group_size @@ -150,14 +172,20 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, weight1 = weight1.reshape(-1, group_size) if return_int: weight1, scale1, zp1 = qdq_weight_actor( - weight1, num_bits, scheme=scheme, quantile=quantile, return_int=True) + weight1, num_bits, scheme=scheme, + quantile=quantile, return_int=True, full_range=full_range + ) else: - weight1 = qdq_weight_actor(weight1, num_bits, scheme=scheme, quantile=quantile) + weight1 = qdq_weight_actor( + weight1, num_bits, scheme=scheme, quantile=quantile, full_range=full_range + ) weight1 = weight1.reshape(orig_shape[0], split_index) weight2 = weight[:, split_index:] if return_int: weight2, scale2, zp2 = qdq_weight_actor( - weight2, num_bits, scheme=scheme, quantile=quantile, return_int=True) + weight2, num_bits, scheme=scheme, + quantile=quantile, return_int=True, full_range=full_range + ) weight = torch.cat([weight1, weight2], dim=1) scale = torch.cat([scale1, scale2], dim=0) if zp2 is not None: @@ -169,13 +197,16 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, zp = zp.reshape(orig_shape[0], -1) return weight, scale, zp else: - weight2 = qdq_weight_actor(weight2, num_bits, scheme=scheme, quantile=quantile) + weight2 = qdq_weight_actor( + weight2, num_bits, scheme=scheme, + quantile=quantile, full_range=full_range + ) weight = torch.cat([weight1, weight2], dim=1) return weight def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym", - quantile=1.0, weight_config={}, return_int=False): + quantile=1.0, weight_config={}, return_int=False, full_range=False): """Quant the model with round to nearst method. Args: @@ -196,12 +227,12 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym", } return_int (bool, optional): Choose return fp32 or int32 model. Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). Returns: model: fake quantized torch module """ assert isinstance(model, torch.nn.Module), "only support torch module" - assert num_bits > 0, "bit for weight only should large than zero!" supported_layers = ['Linear'] for n, m in model.named_modules(): if m.__class__.__name__ not in supported_layers: @@ -211,21 +242,23 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym", group_size = weight_config[n]['group_size'] scheme = weight_config[n]['scheme'] quantile = weight_config[n].get('quantile', 1.0) - else: - # skip when n is not in weight_config - if weight_config != {}: - continue logger.debug(f"RTN quantized module:{n, m}") - logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + \ - f"scheme={scheme}, quantile={quantile}") + if scheme == 'sym': + logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + \ + f"scheme={scheme}, quantile={quantile}, full_range={full_range}") + else: + logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + \ + f"scheme={scheme}, quantile={quantile}") if num_bits <= 0: logger.info(f"skip {n}") continue weight = m.weight if return_int: from .model_wrapper import WeightOnlyLinear - int_weight, scale, zp = quant_weight(weight, num_bits, group_size, - scheme, quantile, return_int=True) + int_weight, scale, zp = quant_weight( + weight, num_bits, group_size, scheme, + quantile, return_int=True, full_range=full_range + ) new_module = WeightOnlyLinear( m.in_features, m.out_features, num_bits, group_size ) @@ -235,7 +268,9 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym", else: set_module(model, n, new_module) else: - q_weight = quant_weight(weight, num_bits, group_size, scheme, quantile) + q_weight = quant_weight( + weight, num_bits, group_size, scheme, quantile, full_range=full_range + ) m.weight.data.copy_(q_weight) return model @@ -353,7 +388,8 @@ def _update_input_with_scale(args, kwargs, scales): @torch.no_grad() def awq_quantize(model, weight_config={}, absorb_dict={}, dataloader=None, n_samples=128, - auto_scale=True, mse_range=True, calib_func=None, n_blocks=5, return_int=False): + auto_scale=True, mse_range=True, calib_func=None, n_blocks=5, + return_int=False, full_range=False): """Quant the model with Activation-aware Weight quantization(AWQ) method. Args: @@ -382,6 +418,7 @@ def awq_quantize(model, weight_config={}, absorb_dict={}, dataloader=None, n_sam n_blocks: split model into block number to avoid OOM. return_int (bool, optional): Choose return fp32 or int32 model. Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). Returns: model: fake quantized model @@ -603,6 +640,12 @@ def forward(self, *args, **kwargs): # apply quantization and clip logger.info("Quantizing the AWQ optimized fp32 model") - model = rtn_quantize(model, weight_config=weight_config, return_int=return_int) + model = rtn_quantize( + model, + num_bits=-1, + weight_config=weight_config, + return_int=return_int, + full_range=full_range, + ) logger.info("AWQ quantization is done.") return model diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 60771d9c7ba..88aa08e7073 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -847,6 +847,12 @@ def smooth_quant_args(val=None): else: return {} + def rtn_args(val=None): + if val is not None: + return _check_value("rtn_args", val, dict) + else: + return {} + def awq_args(val=None): if val is not None: return _check_value("awq_args", val, dict) @@ -932,6 +938,7 @@ def dedicated_qdq_pair(val=None): "add_qdq_pair_to_weight": add_qdq_pair_to_weight, "optypes_to_exclude_output_quant": optypes_to_exclude_output_quant, "dedicated_qdq_pair": dedicated_qdq_pair, + "rtn_args": rtn_args, "awq_args": awq_args, "gptq_args": gptq_args, } diff --git a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py index b30352e0077..4ed11dc3025 100644 --- a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py @@ -86,6 +86,25 @@ def test_RTN_quant(self): # sym has clip issue for [-8, 7], set a big atol. self.assertTrue(torch.all(torch.isclose(out3, out2, atol=1e-1))) + model = Model() + out1 = model(input) + + conf = PostTrainingQuantConfig( + approach='weight_only', + recipes={ + # By default, full range is False and 4 bit sym will only use range [-7,7]. + 'rtn_args': {'full_range': True} + } + ) + q_model = quantization.fit(model, conf) + out2 = q_model(input) + self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) + self.assertFalse(torch.all(out1 == out2)) + q_model.convert(weight_only=True) + out3 = q_model(input) + # sym has clip issue for [-8, 7], set a big atol. + self.assertTrue(torch.all(torch.isclose(out3, out2, atol=1e-1))) + model = Model() out1 = model(input) conf = PostTrainingQuantConfig(