From 753783011d2f0fc71abe331e589aff94f63aeac3 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Thu, 12 Oct 2023 17:02:57 +0800 Subject: [PATCH] Support MatMulFpQ4 for onnxruntime 1.16.0 (#1293) Signed-off-by: Mengni Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- neural_compressor/adaptor/onnxrt.py | 2 +- .../adaptor/ox_utils/weight_only.py | 115 +++++++----------- 2 files changed, 43 insertions(+), 74 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 583b744dab8..59fb352d4d5 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -1694,7 +1694,7 @@ def _dump_model_op_stats(self, model, tune_cfg): dtype_set = set() for node in model.nodes(): - if node.op_type == "MatMulWithQuantWeight": + if node.op_type == "MatMulFpQ4": optype = "MatMul" else: optype = node.op_type diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index 5f7683341a7..a3d0f05e940 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -20,12 +20,14 @@ import logging import math import os +import struct import sys import numpy as np import onnx from onnx import helper, numpy_helper from onnx import onnx_pb as onnx_proto +from packaging.version import Version from neural_compressor.model.model import BaseModel from neural_compressor.model.onnx_model import ONNXModel @@ -33,50 +35,13 @@ ort = LazyImport("onnxruntime") logger = logging.getLogger("neural_compressor") - - -WEIGHT_ONLY_OP_SUPPORTED = False - - -def check_op_support_status(): - """Check whether weight-only op is supported.""" - input_tensor = helper.make_tensor_value_info("input", 1, [1, 32]) - output_tensor = helper.make_tensor_value_info("output", 1, [1, 64]) - initializers = [] - # weight shape (32, 64) - packed_weight = np.random.randint(0, high=16, size=(64, 1, 16), dtype="uint8") - initializers.append(onnx.helper.make_tensor("weight", 2, packed_weight.shape, packed_weight.flatten().tolist())) - scale = np.random.random((64, 1)).astype("float32") - initializers.append(onnx.helper.make_tensor("scale", 1, scale.shape, scale.flatten().tolist())) - - kwargs = {} - kwargs["K"] = 32 - kwargs["N"] = 64 - kwargs["bits"] = 4 - kwargs["block_size"] = 32 - node = onnx.helper.make_node( - "MatMulWithQuantWeight", - inputs=["input", "weight", "scale"], - outputs=["output"], - name="test", - domain="com.microsoft", - **kwargs, - ) - - global WEIGHT_ONLY_OP_SUPPORTED - graph = helper.make_graph([node], "test", [input_tensor], [output_tensor], initializer=initializers) - model = helper.make_model(graph) - try: - ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) - WEIGHT_ONLY_OP_SUPPORTED = True - except: - WEIGHT_ONLY_OP_SUPPORTED = False +ONNXRT116_VERSION = Version("1.16.0") def make_matmul_weight_only_node( node, weight_shape, num_bits, group_size, k_blocks, q_weight, scale, zero_point ): # pragma: no cover - """Build MatMulWithQuantWeight node. + """Build MatMulFpQ4 node. Args: node: original matmul node @@ -89,18 +54,32 @@ def make_matmul_weight_only_node( zero_point (array): zero point Returns: - matmul_weight_only_node: MatMulWithQuantWeight node - new_inits: initializers of the MatMulWithQuantWeight node + matmul_weight_only_node: MatMulFpQ4 node + new_inits: initializers of the MatMulFpQ4 node """ - blob_size = group_size // 2 + if zero_point is not None: + blob_size = group_size // 2 + 4 + 1 + offset = 5 + else: + blob_size = group_size // 2 + 4 + offset = 4 + packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8") for i in range(q_weight.shape[0]): - for k in range(0, group_size, 2): - packed[i][k // 2] = q_weight[i][k] | q_weight[i][k + 1] << 4 + bf = struct.pack("f", scale[i]) + packed[i][0] = bf[0] + packed[i][1] = bf[1] + packed[i][2] = bf[2] + packed[i][3] = bf[3] - packed = np.reshape(packed, (-1, k_blocks, blob_size)) - scale = np.reshape(scale, (-1, k_blocks)).astype("float32") + if zero_point is not None: + packed[i][4] = zero_point[i] + + packed[i][offset:] = np.bitwise_or( + q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits) + ) + packed = packed.reshape(-1) q_weight_tensor = onnx.helper.make_tensor( name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), data_type=2, @@ -108,27 +87,16 @@ def make_matmul_weight_only_node( vals=packed.tobytes(), raw=True, ) - scale_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_scale", data_type=1, dims=scale.shape, vals=scale.tobytes(), raw=True + shape_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64") ) - input_names = [node.input[0], q_weight_tensor.name, scale_tensor.name] - new_inits = [q_weight_tensor, scale_tensor] - - if zero_point is not None: - zero_point = np.reshape(zero_point, (-1, k_blocks)).astype("uint8") - zp_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_zp", data_type=2, dims=zero_point.shape, vals=zero_point.tobytes(), raw=True - ) - input_names.append(zp_tensor.name) - new_inits.append(zp_tensor) + input_names = [node.input[0], q_weight_tensor.name, shape_tensor.name] + new_inits = [q_weight_tensor, shape_tensor] kwargs = {} - kwargs["K"] = weight_shape[0] - kwargs["N"] = weight_shape[1] - kwargs["bits"] = num_bits - kwargs["block_size"] = group_size + kwargs["blk_quant_type"] = 1 if zero_point is not None else 0 matmul_weight_only_node = onnx.helper.make_node( - "MatMulWithQuantWeight", + "MatMulFpQ4", inputs=input_names, outputs=node.output, name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits), @@ -260,7 +228,6 @@ def rtn_quantize( Returns: model: fake quantized ONNXModel """ - check_op_support_status() model = model if isinstance(model, BaseModel) else ONNXModel(model) new_nodes = [] remove_nodes = [] @@ -290,8 +257,8 @@ def rtn_quantize( weight = pad_tensor(weight, group_size, k_blocks) - if WEIGHT_ONLY_OP_SUPPORTED and num_bits == 4 and group_size == 32: # pragma: no cover - # currently MatMulWithQuantWeights only support 4 bits and 32 group_size + if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover + # currently MatMulFpQ4 only support 4 bits and 32 group_size q_weight, scale, zp = quant_tensor( weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) ) @@ -394,7 +361,9 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, weight = weight.T * scales weight = pad_tensor(weight, group_size, (org_w_shape[0] + group_size - 1) // group_size).T - if WEIGHT_ONLY_OP_SUPPORTED and num_bits == 4 and group_size == 32: # pragma: no cover + if ( + Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32 + ): # pragma: no cover q_weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint") / np.expand_dims( scales, axis=-1 ) @@ -535,8 +504,10 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g for i_s in range(10): ratio = 1 - i_s / 100 weight = copy.deepcopy(org_weight) - if WEIGHT_ONLY_OP_SUPPORTED and num_bits == 4 and group_size == 32: # pragma: no cover - # currently MatMulWithQuantWeights only support 4 bits and 32 group_size + if ( + Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32 + ): # pragma: no cover + # currently MatMulFpQ4 only support 4 bits and 32 group_size weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)) else: weight = qdq_tensor(weight, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1)) @@ -644,7 +615,6 @@ def awq_quantize( Returns: model: fake quantized ONNXModel """ - check_op_support_status() model = model if isinstance(model, BaseModel) else ONNXModel(model) output_dicts = {} full_ratio = {} @@ -918,7 +888,6 @@ def gptq_quantize( Returns: model: fake quantized ONNXModel """ - check_op_support_status() model = model if isinstance(model, BaseModel) else ONNXModel(model) output_dicts = {} @@ -1013,8 +982,8 @@ def gptq_quantize( weight_tensor = model.get_initializer(node.input[1]) init_share_num = model.get_initializer_share_num(node.input[1]) - if WEIGHT_ONLY_OP_SUPPORTED and num_bits == 4 and group_size == 32: # pragma: no cover - # currently MatMulWithQuantWeights only support 4 bits and 32 group_size + if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover + # currently MatMulFpQ4 only support 4 bits and 32 group_size org_shape = weight.shape k_blocks = (org_shape[0] + group_size - 1) // group_size q_weight = pad_tensor(q_weight, group_size, k_blocks)