Skip to content

Commit

Permalink
add attr to MatMulNBits (#1378)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 authored Nov 17, 2023
1 parent f9663d0 commit 7057e3b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
11 changes: 10 additions & 1 deletion neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
actorder = self.recipes.get("gptq_args", {}).get("actorder", False)
mse = self.recipes.get("gptq_args", {}).get("mse", False)
perchannel = self.recipes.get("gptq_args", {}).get("perchannel", True)
accuracy_level = self.recipes.get("gptq_args", {}).get("accuracy_level", 0)
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
tmp_model = gptq_quantize(
tmp_model,
Expand All @@ -1677,13 +1678,15 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
actorder=actorder,
mse=mse,
perchannel=perchannel,
accuracy_level=accuracy_level,
)
if "AWQ" in algos:
from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize

assert data_loader is not None, "AWQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
enable_auto_scale = self.recipes.get("awq_args", {}).get("enable_auto_scale", True)
enable_mse_search = self.recipes.get("awq_args", {}).get("enable_mse_search", True)
accuracy_level = self.recipes.get("awq_args", {}).get("accuracy_level", 0)
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
tmp_model = awq_quantize(
tmp_model,
Expand All @@ -1692,11 +1695,17 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
n_samples=calib_sampling_size,
enable_auto_scale=enable_auto_scale,
enable_mse_search=enable_mse_search,
accuracy_level=accuracy_level,
)
elif "RTN" in algos:
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize

tmp_model = rtn_quantize(tmp_model, quant_config)
accuracy_level = self.recipes.get("rtn_args", {}).get("accuracy_level", 0)
tmp_model = rtn_quantize(
tmp_model,
quant_config,
accuracy_level=accuracy_level,
)
tmp_model.q_config = copy.deepcopy(quant_config)
self._dump_model_op_stats(tmp_model, tune_cfg)
tmp_model.topological_sort()
Expand Down
32 changes: 30 additions & 2 deletions neural_compressor/adaptor/ox_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ def get_blob_size(group_size, has_zp): # pragma: no cover


def make_matmul_weight_only_node(
node, weight_shape, num_bits, group_size, k_blocks, q_weight, scale, zero_point
node,
weight_shape,
num_bits,
group_size,
k_blocks,
q_weight,
scale,
zero_point,
accuracy_level=0,
): # pragma: no cover
"""Build MatMulFpQ4 node.
Expand All @@ -69,6 +77,9 @@ def make_matmul_weight_only_node(
q_weight (array): quantized weight
scale (array): scale
zero_point (array): zero point
accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel),
2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel),
4 (int8 compute type of jblas kernel)
Returns:
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
Expand Down Expand Up @@ -125,6 +136,9 @@ def make_matmul_weight_only_node(
kwargs["N"] = weight_shape[1]
kwargs["bits"] = num_bits
kwargs["block_size"] = group_size
if accuracy_level > 0:
# require onnxruntime > 1.16.2
kwargs["accuracy_level"] = accuracy_level

else:
offset = 5 if zero_point is not None else 4
Expand Down Expand Up @@ -274,6 +288,7 @@ def rtn_quantize(
group_size=32,
scheme="asym",
ratios={},
accuracy_level=0,
):
"""Quant the model with round to nearst method.
Expand All @@ -294,6 +309,9 @@ def rtn_quantize(
group_size (int, optional): how many elements share one scale/zp. Default is 32.
scheme (str, optional): sym or asym. Defaults to "asym".
ratios (dict, optional): percentile of clip. Defaults to {}.
accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel),
2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel),
4 (int8 compute type of jblas kernel)
Returns:
model: fake quantized ONNXModel
Expand Down Expand Up @@ -344,6 +362,7 @@ def rtn_quantize(
q_weight=q_weight.astype("uint8"),
scale=scale,
zero_point=zp if scheme == "asym" else None,
accuracy_level=accuracy_level,
)

model.add_initializers(new_inits)
Expand Down Expand Up @@ -664,6 +683,7 @@ def awq_quantize(
n_samples=128,
enable_auto_scale=True,
enable_mse_search=True,
accuracy_level=0,
):
"""Quant the model with Activation-aware Weight quantization(AWQ) method.
Expand All @@ -687,6 +707,9 @@ def awq_quantize(
n_samples (int, optional): calibration sample number.
enable_auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True.
enable_mse_search (bool, optional): whether enable clip for weight by checking mse. Defaults to True.
accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel),
2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel),
4 (int8 compute type of jblas kernel)
Returns:
model: fake quantized ONNXModel
Expand Down Expand Up @@ -773,7 +796,7 @@ def awq_quantize(

model.remove_tensors_from_outputs(output_names)
model.model.graph.output.MergeFrom(org_output)
model = rtn_quantize(model, weight_config, num_bits, group_size, scheme, full_ratio)
model = rtn_quantize(model, weight_config, num_bits, group_size, scheme, full_ratio, accuracy_level)
return model


Expand Down Expand Up @@ -934,6 +957,7 @@ def gptq_quantize(
actorder=False,
mse=False,
perchannel=True,
accuracy_level=0,
):
"""Quant the model with GPTQ method.
Expand All @@ -960,6 +984,9 @@ def gptq_quantize(
actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
mse (bool, optional): whether get scale and zero point with mse error.
perchannel (bool, optional): whether quantize weight per-channel.
accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel),
2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel),
4 (int8 compute type of jblas kernel)
Returns:
model: fake quantized ONNXModel
Expand Down Expand Up @@ -1076,6 +1103,7 @@ def gptq_quantize(
q_weight=q_weight.astype("uint8"),
scale=scale,
zero_point=zp if scheme == "asym" else None,
accuracy_level=accuracy_level,
)

model.add_initializers(new_inits)
Expand Down

0 comments on commit 7057e3b

Please sign in to comment.