Skip to content

Commit

Permalink
Support static_groups options in GPTQ API (#1478)
Browse files Browse the repository at this point in the history
Signed-off-by: YIYANGCAI <[email protected]>
  • Loading branch information
YIYANGCAI authored Jan 23, 2024
1 parent ab72037 commit 1c426a0
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Notes:
| pad_max_length | 2048 | Whether to align calibration data to a fixed length. This value should not exceed model's acceptable sequence length. Please refer to model's config json to find out this value.|
| use_max_length | False | Whether to align all calibration data to fixed length, which equals to pad_max_length. |
| block_size | 128 | Execute GPTQ quantization per block, block shape = [$C_{out}$, block_size] |
| static_groups | False | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements |

**Note:** Neural compressor provides `Unsigned integer for asymmetric quantization` and `Signed integer for symmetric quantization`. Please follow the below section to compress the low bit data type for saving.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
this should align with your model config, \
and your dataset builder args: args.pad_max_length')
parser.add_argument('--gptq_debug', action='store_true', help='Whether to use debug model ')
parser.add_argument('--gptq_static_groups', action='store_true', help='Use determined group to do quantization')
# ==============code generation args===========
parser.add_argument("--code_generation", action="store_true")
parser.add_argument("--n_samples", default=200, type=int)
Expand Down Expand Up @@ -277,7 +278,8 @@ def calib_func(prepared_model):
'block_size': args.gptq_block_size,
'nsamples': args.gptq_nsamples,
'use_max_length': args.gptq_use_max_length,
'pad_max_length': args.gptq_pad_max_length
'pad_max_length': args.gptq_pad_max_length,
'static_groups': args.gptq_static_groups,
}
# GPTQ: use assistive functions to modify calib_dataloader and calib_func
# TEQ: set calib_func=None, use default training func as calib_func
Expand All @@ -293,6 +295,7 @@ def calib_func(prepared_model):

# for test on various models, keep the code of directly call gptq_quantize
if args.gptq_debug:

from neural_compressor.adaptor.torch_utils.weight_only import gptq_quantize

gptq_conf = {
Expand All @@ -301,6 +304,7 @@ def calib_func(prepared_model):
'group_size': args.woq_group_size, # -1 (per-channel)
'sym': (args.woq_scheme == "sym"),
'act_order': args.gptq_actorder,
'static_groups': args.gptq_static_groups,
}
}
q_model_gptq_debug, gptq_config = gptq_quantize(
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4709,6 +4709,7 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
"percdamp": self.recipes["gptq_args"].get("percdamp", 0.01),
"act_order": self.recipes["gptq_args"].get("act_order", False),
"block_size": self.recipes["gptq_args"].get("block_size", True),
"static_groups": self.recipes["gptq_args"].get("static_groups", False),
}
nsamples = self.recipes["gptq_args"].get("nsamples", 128)
use_max_length = self.recipes["gptq_args"].get("use_max_length", False)
Expand Down
35 changes: 30 additions & 5 deletions neural_compressor/adaptor/torch_utils/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
self.percdamp_default = 0.01
self.sym_default = False
self.act_order_default = False
self.static_groups_default = False
self.perchannel_default = True
self.mse_default = False
self.check_layer_config()
Expand Down Expand Up @@ -406,6 +407,9 @@ def check_layer_config(self):
tmp_weight_config[name]["percdamp"] = self.weight_config.get("pecdamp", self.percdamp_default)
tmp_weight_config[name]["sym"] = self.weight_config.get("sym", self.sym_default)
tmp_weight_config[name]["act_order"] = self.weight_config.get("act_order", self.act_order_default)
tmp_weight_config[name]["static_groups"] = self.weight_config.get(
"static_groups", self.static_groups_default
)
tmp_weight_config[name]["perchannel"] = self.weight_config.get("perchannel", self.perchannel_default)
tmp_weight_config[name]["mse"] = self.weight_config.get("mse", self.mse_default)
self.weight_config = tmp_weight_config
Expand All @@ -417,6 +421,9 @@ def check_layer_config(self):
self.weight_config[layer_name]["percdamp"] = config.get("pecdamp", self.percdamp_default)
self.weight_config[layer_name]["sym"] = config.get("sym", self.sym_default)
self.weight_config[layer_name]["act_order"] = config.get("act_order", self.act_order_default)
self.weight_config[layer_name]["static_groups"] = config.get(
"static_groups", self.static_groups_default
)
self.weight_config[layer_name]["perchannel"] = config.get("perchannel", self.perchannel_default)
self.weight_config[layer_name]["mse"] = config.get("mse", self.mse_default)

Expand Down Expand Up @@ -631,6 +638,7 @@ def tmp(_, inp, out):
percdamp=weight_config_this_layer["percdamp"],
groupsize=weight_config_this_layer["group_size"],
act_order=weight_config_this_layer["act_order"],
static_groups=weight_config_this_layer["static_groups"],
)
if self.layer_wise:
from ..torch_utils.layer_wise_quant.utils import (
Expand Down Expand Up @@ -745,7 +753,7 @@ def add_batch(self, inp, out):
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix

def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False):
def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, static_groups=False):
# W = self.layer.weight.data.clone()
weight_shape, weight_dtype = W.shape, W.data.dtype
if isinstance(self.layer, nn.Conv2d):
Expand All @@ -765,6 +773,17 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
H[dead, dead] = 1
W[:, dead] = 0 # such channel makes no contribution to quantization computation

# enable static_groups
# calculate the quantization parameters for original group in advance.
if static_groups:
import copy

groups = []
for i in range(0, self.columns, groupsize):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i : (i + groupsize)], weight=True)
groups.append(quantizer)

# rearrange considering the diag's value
if act_order:
perm = torch.argsort(torch.diag(H), descending=True)
Expand Down Expand Up @@ -801,10 +820,16 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
d = Hinv1[i, i]

if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + groupsize)], weight=True)
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
if not static_groups:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + groupsize)], weight=True)
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
else:
idx = i1 + i
if act_order:
idx = perm[idx]
self.quantizer = groups[idx // groupsize]

q = quantize(w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq).flatten()
Q1[:, i] = q
Expand Down
1 change: 1 addition & 0 deletions test/quantization/test_weight_only_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __iter__(self):
"sym": False,
"percdamp": 0.01,
"act_order": True,
"static_groups": True,
},
"transformer.h.2.attn.k_proj": {
"wbits": 3,
Expand Down

0 comments on commit 1c426a0

Please sign in to comment.