Skip to content

Commit

Permalink
Support NF4/FP4 data type in weight-only (#1185)
Browse files Browse the repository at this point in the history
* support NF4/FP4 data type in weight-only RTN & AWQ algo, allow tuning dtype and compressing nf4/fp4 mode

Signed-off-by: Xin He <[email protected]>

---------

Signed-off-by: Xin He <[email protected]>
  • Loading branch information
xin3he authored Aug 26, 2023
1 parent ffe47d9 commit 3d11b5e
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 78 deletions.
12 changes: 11 additions & 1 deletion .azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2716,4 +2716,14 @@ xgb
xgboost
hpo
HPO
arange
arange
nf
Dettmers
Qlora
llms
NormalFloat
QLoRA
TimDettmers
bitsandbytes
bnb
ccedc
7 changes: 6 additions & 1 deletion docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ There are many excellent works for weight only quantization to improve its accur
### **Quantization Capability**:
| Config | Capability |
| :---: | :---:|
| dtype | ['int', 'nf4', 'fp4'] |
| bits | [1-8] |
| group_size | [-1, 1-N] |
| scheme | ['asym', 'sym'] |
| algorithm | ['RTN', 'AWQ', 'GPTQ'] |

Notes: 4-bit NormalFloat(NF4) is proposed in QLoRA[5]. 'fp4' includes [fp4_e2m1](../../neural_compressor/adaptor/torch_utils/weight_only.py#L37) and [fp4_e2m1_bnb](https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L735). By default, fp4 refers to fp4_e2m1_bnb.

**RTN arguments**:
| rtn_args | default value | comments |
|:----------:|:-------------:|:-------------------------------------------------------------------:|
Expand Down Expand Up @@ -95,7 +98,7 @@ conf = PostTrainingQuantConfig(
},
recipes={
# 'gptq_args':{'percdamp': 0.01, 'actorder':True, 'block_size': 128, 'nsamples': 128, 'use_full_length': False},
'awq_args':{'auto_scale': True, 'mse_range': True, 'n_blocks': 5},
# 'awq_args':{'auto_scale': True, 'mse_range': True},
},
)
q_model = quantization.fit(model, conf, eval_func=eval_func)
Expand All @@ -119,3 +122,5 @@ The saved_results folder contains two files: `best_model.pt` and `qconfig.json`,
[3]. Lin, Ji, et al. "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration." arXiv preprint arXiv:2306.00978 (2023).

[4]. Frantar, Elias, et al. "Gptq: Accurate post-training quantization for generative pre-trained transformers." arXiv preprint arXiv:2210.17323 (2022).

[5]. Dettmers, Tim, et al. "Qlora: Efficient finetuning of quantized llms." arXiv preprint arXiv:2305.14314 (2023).
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def eval_func_for_nc(model_tuned):
if model_args.int8:
from neural_compressor.utils.pytorch import load
new_model = load(
os.path.abspath(os.path.expanduser(training_args.output_dir)), model)
os.path.abspath(os.path.expanduser(training_args.output_dir)), model, weight_only=True)
else:
new_model = model

Expand Down
12 changes: 11 additions & 1 deletion neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4349,6 +4349,14 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None):
if config['weight']['dtype'] == 'fp32':
continue
else:
dtype = config['weight']['dtype']
if dtype in ['nf4', 'fp4', 'fp4_e2m1_bnb', 'fp4_e2m1']:
config['weight']['bits'] = 4
config['weight']['scheme'] = 'sym'
elif dtype in ['int4']:
config['weight']['bits'] = 4
elif dtype in ['int8']:
config['weight']['bits'] = 8
algorithm = config['weight']['algorithm']
all_algo.add(algorithm)
if len(all_algo):
Expand Down Expand Up @@ -4385,15 +4393,17 @@ def rtn_quantize(self, model, tune_cfg):
if config['weight']['dtype'] == 'fp32':
continue
else:
dtype = config['weight']['dtype']
num_bits = config['weight']['bits']
group_size = config['weight']['group_size']
scheme = config['weight']['scheme']
group_size = config['weight']['group_size']
algorithm = config['weight']['algorithm']
if algorithm != 'RTN':
continue
m = fetch_module(model, op_name)
m = rtn_quantize(m, num_bits, group_size, scheme,
return_int=False,
data_type=dtype,
sym_full_range=sym_full_range,
mse_range=mse_range)
set_module(model, op_name, m)
Expand Down
3 changes: 1 addition & 2 deletions neural_compressor/adaptor/pytorch_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@
weight_only_integer: &cap_weight_only_integer {
'Linear': &cap_weight_only_integer_linear { # only Linear now
'weight': {
'dtype': ['int'], # no need to care uint
'dtype': ['int', 'int4', 'nf4', 'fp4', 'fp4_e2m1_bnb', 'fp4_e2m1'],
'bits': [4, 1, 2, 3, 5, 6, 7, 8], # [1-8], # 4
# group_size=-1 means per-channel, others means per-group
'group_size': [32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], # [1-inf], # 32
Expand All @@ -273,7 +273,6 @@
'dtype': ['fp32'],
},
},
'Conv2d': *cap_weight_only_integer_linear,
}


Expand Down
14 changes: 11 additions & 3 deletions neural_compressor/adaptor/torch_utils/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def _get_act_scale(input_val):
class ActAwareWeightQuant:
"""Implementation of Activation-aware Weight quantization (AWQ) algo."""
def __init__(self, model, example_inputs=None, calib_func=None, dataloader=None, n_samples=128,
bits=4, group_size=32, scheme='asym', sym_full_range=False, weight_config={},):
data_type='int', bits=4, group_size=32, scheme='asym', sym_full_range=False,
weight_config={},):
self.example_inputs = example_inputs
if example_inputs is None:
assert dataloader is not None, "datalaoder or example_inputs is required."
Expand All @@ -103,6 +104,7 @@ def __init__(self, model, example_inputs=None, calib_func=None, dataloader=None,
# Step 2: get block list and block prefix, number
self.block_prefix, self.block_num = get_block_prefix(model)
self.block_list = fetch_module(model, self.block_prefix)
self.data_type = data_type
self.bits = bits
self.group_size = group_size
self.scheme = scheme
Expand Down Expand Up @@ -188,11 +190,13 @@ def search_scale(self, block, block_name, module_list, input_values):
for module_tuple in module_list:
# Step 1: Initailize quantization configuration.
if module_tuple[0] in self.weight_config:
cur_dtype = self.weight_config[module_tuple[0]]['dtype']
cur_bits = self.weight_config[module_tuple[0]]['bits']
cur_group_size = self.weight_config[module_tuple[0]]['group_size']
cur_scheme = self.weight_config[module_tuple[0]]['scheme']
else:
cur_bits, cur_group_size, cur_scheme = self.bits, self.group_size, self.scheme
cur_dtype, cur_bits, cur_group_size, cur_scheme = \
self.data_type, self.bits, self.group_size, self.scheme
if cur_bits < 0:
continue
logger.info(f"[SCALE] Processing module: {module_tuple}")
Expand Down Expand Up @@ -231,6 +235,7 @@ def search_scale(self, block, block_name, module_list, input_values):
module.weight.data = module.weight.data.mul(scales.view(1, -1))
module.weight.data = quant_weight(
module.weight.data,
data_type=cur_dtype,
num_bits=cur_bits,
group_size=cur_group_size,
scheme=cur_scheme,
Expand Down Expand Up @@ -310,11 +315,13 @@ def search_clip(self, block_name, module_list, input_values):
for module_name in module_tuple:
# Step 1: Initailize quantization configuration.
if module_name in self.weight_config:
cur_dtype = self.weight_config[module_name]['dtype']
cur_bits = self.weight_config[module_name]['bits']
cur_group_size = self.weight_config[module_name]['group_size']
cur_scheme = self.weight_config[module_name]['scheme']
else:
cur_bits, cur_group_size, cur_scheme = self.bits, self.group_size, self.scheme
cur_dtype, cur_bits, cur_group_size, cur_scheme = \
self.data_type, self.bits, self.group_size, self.scheme
if cur_bits < 0:
continue
logger.info(f"[CLIP] Processing module: {module_name}")
Expand All @@ -335,6 +342,7 @@ def search_clip(self, block_name, module_list, input_values):
# MulLinear can also work with @weight.setter
module.weight.data = quant_weight(
module.weight.data,
data_type=cur_dtype,
num_bits=cur_bits,
group_size=cur_group_size,
scheme=cur_scheme,
Expand Down
15 changes: 14 additions & 1 deletion neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,18 @@ def _wrapper_qdq_linear(tmp_model, module_name_list=[]):

class WeightOnlyLinear(torch.nn.Module):
def __init__(self, in_features, out_features, bits, groupsize,
zp=False, bias=False, scale_dtype=torch.float32,
dtype='int', zp=False, bias=False, scale_dtype=torch.float32,
compression_dtype=torch.int32, compression_dim=1,
gptq_perm=False, device='cpu'):
super().__init__()
self.dtype = dtype
if 'int' not in self.dtype: # for nf4, fp4
from neural_compressor.adaptor.torch_utils.weight_only import FLOAT_MAPPING, INT_MAPPING
float_list = FLOAT_MAPPING[self.dtype]
int_list = INT_MAPPING[self.dtype]
self.int2float_mapping = {}
for k, v in zip(int_list, float_list):
self.int2float_mapping[k] = v
self.device = device
self.in_features = in_features
self.out_features = out_features
Expand Down Expand Up @@ -346,6 +354,11 @@ def recover(self):
weight[:, index] = tmp.type(weight_dtype)
if self.compression_dim == 0:
weight = weight.T
if 'int' not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
new_weight += torch.where(weight == k, v, 0)
weight = new_weight
# unpack zero_point
if hasattr(self, 'packed_zp'):
zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp
Expand Down
Loading

0 comments on commit 3d11b5e

Please sign in to comment.