From d7d08f7b2c8ac7e81cbb02648cbda6b1fafe906b Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 28 May 2024 19:25:50 -0700 Subject: [PATCH 01/36] adding python interface --- .../quantization/matmul_4bits_quantizer.py | 79 ++++++++++++++++--- 1 file changed, 69 insertions(+), 10 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 11a830dc6d7f5..011ba29f61c5e 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -22,27 +22,32 @@ from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel -from .quant_utils import attribute_to_kwarg +from .quant_utils import QuantFormat, attribute_to_kwarg logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) class WeightOnlyQuantConfig: - def __init__(self, algorithm): + def __init__(self, algorithm, quant_format): """This is the Base class for Weight Only Quant Configuration. Args: algorithm: weight only quantize algorithm name. + quant_format: QuantFormat{QOperator, QDQ}. + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. """ self.algorithm = algorithm + self.quant_format = quant_format class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, ratios=None, + quant_format=QuantFormat.QOperator, ): """ This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. @@ -51,11 +56,16 @@ def __init__( Args: ratios: percentile of clip. Defaults to {}. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ if ratios is None: ratios = {} super().__init__( algorithm="RTN", + quant_format=quant_format, ) self.ratios = ratios @@ -69,6 +79,7 @@ def __init__( actorder=False, mse=False, perchannel=True, + quant_format=QuantFormat.QOperator, ): """ This is a class for GPTQ algorithm Weight Only Quant Configuration. @@ -87,9 +98,14 @@ def __init__( whether get scale and zero point with mse error. perchannel (bool, optional): whether quantize weight per-channel. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ super().__init__( algorithm="GPTQ", + quant_format=quant_format, ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp @@ -105,6 +121,7 @@ def __init__( block_size=128, bits=4, axis=1, + quant_format=QuantFormat.QOperator, ): """ This is a class for HQQ algorithm Weight Only Quant Configuration. @@ -112,14 +129,19 @@ def __init__( Args: block_size (int, optional): - channel number in one block to execute a GPTQ quantization iteration. + channel number in one block to execute a HQQ quantization iteration. bits (int, optional): how many bits to represent weight. axis (int, optional): 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ super().__init__( algorithm="HQQ", + quant_format=quant_format, ) self.block_size = block_size self.bits = bits @@ -132,8 +154,26 @@ def __init__( block_size: int = 128, is_symmetric: bool = False, accuracy_level: int | None = None, + quant_format=QuantFormat.QOperator, ): - super().__init__(algorithm="DEFAULT") + """ + This is a class for weight only affine quantization configuration. + + Args: + block_size (int, optional): + channel number in one block to execute an affine quantization iteration. + is_symmetric (bool, optional): + whether quantize weight symmetrically. + accuracy_level (int, optional): + Accuracy level of the 4-bit quantized MatMul computation. + Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details. + (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits) + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. + """ + super().__init__(algorithm="DEFAULT", quant_format=quant_format) self.block_size = block_size self.is_symmetric = is_symmetric self.bits = 4 @@ -288,7 +328,7 @@ def quantize_internal( return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node.""" if node.op_type != "MatMul": return node # only care about MatMul for now import torch @@ -466,7 +506,13 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: class MatMul4BitsQuantizer: - """Perform 4b quantization of constant MatMul weights""" + """ + Perform 4b quantization of constant MatMul weights. + If algo_config.quant_format is QOperator, the quantized weight is stored in a MatMulNBits node, which relaces the + MatMul node. + If algo_config.quant_format is QDQ, the quantized weight is stored in a DeQuantizeLinear node. The MatMul node is + replaced by the DequantizeLinear + MatMul nodes. + """ def __init__( self, @@ -688,6 +734,15 @@ def parse_args(): default=[], help="Specify the nodes to be excluded from quantization with node names", ) + parser.add_argument( + "--quant_format", + default="QOperator", + type=QuantFormat, + choices=list(QuantFormat), + help="QuantFormat {QOperator, QDQ}" + "QOperator format quantizes the model with quantized operators directly." + "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.", + ) return parser.parse_args() @@ -699,6 +754,7 @@ def parse_args(): input_model_path = args.input_model output_model_path = args.output_model + quant_format = args.quant_format if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") @@ -710,15 +766,18 @@ def parse_args(): model = onnx.load(input_model_path) if args.quant_method == "hqq": - quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) + quant_config = HQQWeightOnlyQuantConfig( + block_size=args.block_size, bits=args.bits, quant_format=quant_format + ) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( - block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, + quant_format=quant_format ) elif args.quant_method == "rtn": - quant_config = RTNWeightOnlyQuantConfig() + quant_config = RTNWeightOnlyQuantConfig(quant_format=quant_format) elif args.quant_method == "gptq": - quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size) + quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, quant_format=quant_format) else: raise ValueError(f"Unsupported quantization method: {args.quant_method}") From 1c687d1989308dd8d0707b8e0ee696c5a89b1d64 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 17 Jun 2024 13:48:18 -0700 Subject: [PATCH 02/36] restarting --- .../python/tools/quantization/matmul_4bits_quantizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 011ba29f61c5e..3f5f3981fd964 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -328,7 +328,10 @@ def quantize_internal( return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node.""" + """ + If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. + If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. + """ if node.op_type != "MatMul": return node # only care about MatMul for now import torch From 7612401cedf32bb484d572342085cab896bfe125 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 17 Jun 2024 17:15:45 -0700 Subject: [PATCH 03/36] finished default quantizer --- .../quantization/matmul_4bits_quantizer.py | 161 +++++++++++------- 1 file changed, 103 insertions(+), 58 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 3f5f3981fd964..17785aa1dec65 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -18,7 +18,7 @@ from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from packaging import version -from onnxruntime.capi._pybind_state import quantize_matmul_4bits +from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_qdq_matmul_4bits from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel @@ -327,13 +327,13 @@ def quantize_internal( return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) - def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: """ If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. """ if node.op_type != "MatMul": - return node # only care about MatMul for now + return [node] # only care about MatMul for now import torch logger.info(f"start to quantize {node.name} ...") @@ -341,12 +341,12 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): b_pb, bs_graph = get_initializer(inputB, graph_stack) if b_pb is None: logger.info("MatMul doesn't have const weight. Skip to quantize") - return node # only care about constant weight + return [node] # only care about constant weight b_array = onnx.numpy_helper.to_array(b_pb) if len(b_array.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") - return node # can only process 2-D matrix + return [node] # can only process 2-D matrix b_array_torch = torch.from_numpy(b_array) if torch.cuda.is_available(): b_array_torch = b_array_torch.cuda() @@ -409,7 +409,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): logger.info(f"complete quantization of {node.name} ...") - return matmul_q4_node + return [matmul_q4_node] def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: @@ -425,7 +425,7 @@ class DefaultWeightOnlyQuantizer: def __init__(self, config: DefaultWeightOnlyQuantConfig): self.config = config - def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: + def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """4b quantize fp32 weight to a blob""" if len(fp32weight.shape) != 2: @@ -433,42 +433,58 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: rows, cols = fp32weight.shape block_size = self.config.block_size - blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size - padded_rows = k_blocks * block_size - pad_len = padded_rows - rows - if pad_len > 0: - fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") - - # block wise quantization, each block comes from a single column - packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) - zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") - quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric) + + if self.config.quant_format == QuantFormat.QOperator: + blob_size = block_size // 2 + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") + + # block wise quantization, each block comes from a single column + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") + quantize_matmul_4bits( + packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + ) + else: + packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") + zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + quantize_qdq_matmul_4bits( + packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + ) return (packed, scales, zero_point) - def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": - return node # only care about MatMul for now + return [node] # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") inputB = node.input[1] # noqa: N806 B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 if B is None: logger.info("MatMul doesn't have const weight. Skip to quantize") - return node # only care about constant weight + return [node] # only care about constant weight B_array = onnx.numpy_helper.to_array(B) # noqa: N806 if len(B_array.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") - return node # can only process 2-D matrix + return [node] # can only process 2-D matrix packed, scales, zero_points = self.int4_block_quant(B_array) - B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 - B_quant.name = B.name + "_Q4" + + if self.config.quant_format == QuantFormat.QOperator: + B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 + B_quant.name = B.name + "_Q4" + else: + # QDQ default UINT4 + B_quant = onnx.helper.make_tensor(B.name + "_DQ_Q4", TensorProto.UINT4, B_array.shape, packed, True) + for input in Bs_graph.input: if input.name == inputB: Bs_graph.input.remove(input) @@ -478,34 +494,63 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: scales_tensor.name = B.name + "_scales" Bs_graph.initializer.extend([B_quant, scales_tensor]) - input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.config.is_symmetric: - zp_tensor = onnx.numpy_helper.from_array(zero_points) - zp_tensor.name = B.name + "_zero_points" - Bs_graph.initializer.extend([zp_tensor]) - input_names.append(zp_tensor.name) - - kwargs = {} - rows, cols = B_array.shape - kwargs["K"] = rows - kwargs["N"] = cols - kwargs["bits"] = 4 - kwargs["block_size"] = self.config.block_size - if self.config.accuracy_level is not None: - kwargs["accuracy_level"] = self.config.accuracy_level + output_nodes = [] + + if self.config.quant_format == QuantFormat.QOperator: + input_names = [node.input[0], B_quant.name, scales_tensor.name] + if not self.config.is_symmetric: + zp_tensor = onnx.numpy_helper.from_array(zero_points) + zp_tensor.name = B.name + "_zero_points" + input_names.append(zp_tensor.name) + Bs_graph.initializer.extend([zp_tensor]) + kwargs = {} + rows, cols = B_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = 4 + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: + kwargs["accuracy_level"] = self.config.accuracy_level + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) - matmul_q4_node = onnx.helper.make_node( - "MatMulNBits", - inputs=input_names, - outputs=[node.output[0]], - name=node.name + "_Q4" if node.name else "", - domain="com.microsoft", - **kwargs, - ) + output_nodes.append(matmul_q4_node) + else: + dq_input_names = [B_quant.name, scales_tensor.name] + dq_output_names = [B_quant.name + "_output"] + matmul_input_names = [node.input[0], dq_output_names[0]] + matmul_output_names = [node.output[0]] + if not self.config.is_symmetric: + zp_tensor = onnx.helper.make_tensor( + B.name + "_DQ_zero_points", TensorProto.UINT4, scales.shape, zero_points, True + ) + dq_input_names.append(zp_tensor.name) + Bs_graph.initializer.extend([zp_tensor]) + dq_kwargs = {"axis": 0, "block_size": self.config.block_size} + dq_node = onnx.helper.make_node( + "DequantizeLinear", + inputs=dq_input_names, + outputs=dq_output_names, + name=node.name + "_DQ_Q4" if node.name else "", + **dq_kwargs, + ) + matmul_node = onnx.helper.make_node( + "MatMul", + inputs=matmul_input_names, + outputs=matmul_output_names, + name=node.name + "_matmul_Q4" if node.name else "", + ) + output_nodes.extend([dq_node, matmul_node]) logger.info(f"complete quantization of {node.name} ...") - - return matmul_q4_node + return output_nodes class MatMul4BitsQuantizer: @@ -575,15 +620,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) - out_node = None + out_nodes = [] if node.name in self.nodes_to_exclude: logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") - out_node = node + out_nodes = [node] elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": - out_node = self.node_quantizer.quantize(node, graph_stack) + out_nodes = self.node_quantizer.quantize(node, graph_stack) else: - out_node = self.node_quantizer.quantize(node, graph_stack) - new_nodes.append(out_node) + out_nodes = self.node_quantizer.quantize(node, graph_stack) + new_nodes.extend(out_nodes) graph.ClearField("node") graph.node.extend(new_nodes) @@ -769,13 +814,13 @@ def parse_args(): model = onnx.load(input_model_path) if args.quant_method == "hqq": - quant_config = HQQWeightOnlyQuantConfig( - block_size=args.block_size, bits=args.bits, quant_format=quant_format - ) + quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits, quant_format=quant_format) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( - block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, - quant_format=quant_format + block_size=args.block_size, + is_symmetric=args.symmetric, + accuracy_level=args.accuracy_level, + quant_format=quant_format, ) elif args.quant_method == "rtn": quant_config = RTNWeightOnlyQuantConfig(quant_format=quant_format) From c5b6175625d27984421ff4f823f575263094cb86 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 18 Jun 2024 11:04:34 -0700 Subject: [PATCH 04/36] only enable quant_format in default quantizer --- .../quantization/matmul_4bits_quantizer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 17785aa1dec65..78144158f5fa6 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -61,6 +61,8 @@ def __init__( QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format" + if ratios is None: ratios = {} super().__init__( @@ -103,6 +105,8 @@ def __init__( QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format" + super().__init__( algorithm="GPTQ", quant_format=quant_format, @@ -139,6 +143,8 @@ def __init__( QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format" + super().__init__( algorithm="HQQ", quant_format=quant_format, @@ -459,7 +465,10 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.nd return (packed, scales, zero_point) def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + """ + If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. + If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. + """ if node.op_type != "MatMul": return [node] # only care about MatMul for now @@ -814,7 +823,7 @@ def parse_args(): model = onnx.load(input_model_path) if args.quant_method == "hqq": - quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits, quant_format=quant_format) + quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( block_size=args.block_size, @@ -823,9 +832,9 @@ def parse_args(): quant_format=quant_format, ) elif args.quant_method == "rtn": - quant_config = RTNWeightOnlyQuantConfig(quant_format=quant_format) + quant_config = RTNWeightOnlyQuantConfig() elif args.quant_method == "gptq": - quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, quant_format=quant_format) + quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size) else: raise ValueError(f"Unsupported quantization method: {args.quant_method}") From 60c3cf8b7467e338ccde0da89da612032703d8c7 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 18 Jun 2024 17:20:33 -0700 Subject: [PATCH 05/36] added UT and fixed qtype in DQ --- .../quantization/matmul_4bits_quantizer.py | 8 ++-- .../quantization/test_op_matmul_4bits.py | 42 ++++++++++++++++--- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 78144158f5fa6..572cf82152db8 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -474,6 +474,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP return [node] # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") + qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4 inputB = node.input[1] # noqa: N806 B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 if B is None: @@ -491,8 +492,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 B_quant.name = B.name + "_Q4" else: - # QDQ default UINT4 - B_quant = onnx.helper.make_tensor(B.name + "_DQ_Q4", TensorProto.UINT4, B_array.shape, packed, True) + B_quant = onnx.helper.make_tensor(B.name + "_DQ_Q4", qtype, B_array.shape, packed, True) for input in Bs_graph.input: if input.name == inputB: @@ -537,9 +537,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP matmul_input_names = [node.input[0], dq_output_names[0]] matmul_output_names = [node.output[0]] if not self.config.is_symmetric: - zp_tensor = onnx.helper.make_tensor( - B.name + "_DQ_zero_points", TensorProto.UINT4, scales.shape, zero_points, True - ) + zp_tensor = onnx.helper.make_tensor(B.name + "_DQ_zero_points", qtype, scales.shape, zero_points, True) dq_input_names.append(zp_tensor.name) Bs_graph.initializer.extend([zp_tensor]) dq_kwargs = {"axis": 0, "block_size": self.config.block_size} diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 88e5052db4e2e..1adbff348dcc0 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -14,11 +14,10 @@ import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type from onnxruntime.quantization import quant_utils - class TestOpMatMul4Bits(unittest.TestCase): @classmethod def setUpClass(cls): @@ -116,9 +115,12 @@ def quant_test( data_reader: TestDataFeeds, block_size: int, is_symmetric: bool, + quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator ): + use_qdq = quant_format == quant_utils.QuantFormat.QDQ + name_prefix = "DQ_MatMul" if use_qdq else "MatMulNBits" model_int4_path = str( - Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() + Path(self._tmp_model_dir.name).joinpath(f"{name_prefix}_{block_size}_{is_symmetric}.onnx").absolute() ) # Quantize fp32 model to int4 model @@ -126,15 +128,25 @@ def quant_test( model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric + block_size=block_size, is_symmetric=is_symmetric, quant_format=quant_format ) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) - quant_nodes = {"MatMulNBits": 1} + quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1} check_op_type_count(self, model_int4_path, **quant_nodes) + if use_qdq: + dq_qtype = TensorProto.INT4 if is_symmetric else TensorProto.UINT4 + dqnode_io_qtypes = { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ["i", 2, dq_qtype], + ] + } + check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes) + data_reader.rewind() try: @@ -211,6 +223,26 @@ def test_quantize_matmul_int4_offsets(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test(model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_symmetric_qdq(self): + np.random.seed(13) + + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=True) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_offsets_qdq(self): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ) + @unittest.skipIf( find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) From 77b7dd9aa6f3cdae092ee9b08b700f5345f0fe83 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Thu, 20 Jun 2024 17:43:30 -0700 Subject: [PATCH 06/36] added dq matmul selectors --- .../selectors_actions/qdq_selectors.cc | 68 +++++++++++++++++++ .../selectors_actions/qdq_selectors.h | 15 ++++ 2 files changed, 83 insertions(+) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 09705f61c82ce..6944de876c84c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace QDQ { namespace { +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + constexpr bool Is16BitIntType(int32_t data_type) { return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16) || (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16); @@ -25,6 +31,19 @@ constexpr bool Is4BitIntType(int32_t data_type) { (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4); } +FORCEINLINE bool IsPowerOfTwo(int64_t val) { + bool seen_one = val & 1; + val >>= 1; + + for (; val; seen_one = val & 1, val >>= 1) { + if (seen_one) { + return false; + } + } + + return true; +} + // adjust for an optional input/output that has an entry but does not exist int NumActualValues(const Node& node, bool input) { const auto& defs = input ? node.InputDefs() : node.OutputDefs(); @@ -414,6 +433,55 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } +bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + // MatMul has 1 DQ input and is the second input + if (dq_nodes.size() != 1) { + return false; + } + + auto iter = node.InputNodesBegin(); + ++iter; + if (iter == node.InputNodesEnd() || iter->Index() != dq_nodes[0]->Index()) { + return false; + } + + // DQ weight/zero points types are int4/uint4, scales/output types are float or float16 + int32_t dt_weight = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_scales = dq_nodes[0]->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && + dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { + return false; + } + + if (!Is4BitIntType(dt_weight)) { + return false; + } + + // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 + const auto dq_attrs = dq_nodes[0]->GetAttributes(); + + if (const auto a_iter = dq_attrs.find("axis"); + a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return false; + } + + const auto a_iter = dq_attrs.find("block_size"); + if (a_iter == dq_attrs.end()) { + return false; + } + + auto block_size = a_iter->second.i(); + if (block_size < 16 || !IsPowerOfTwo(block_size)) { + return false; + } + + return true; +} + bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 1a2a620acb480..a8a12d5ff1ac7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -204,6 +204,14 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { bool allow_4bit_; }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulNodeGroupSelector : public NodeGroupSelector { + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { @@ -358,6 +366,13 @@ class MatMulSelector : public BaseSelector { allow_16bit, allow_4bit)) {} }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulSelector : public BaseSelector { + public: + DQMatMulSelector(gsl::span compatible_providers = {}) + : BaseSelector(std::make_unique(), compatible_providers) {} +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { From 1685070b4d934c3fd9f2e2557a75f83846bcc2d1 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Thu, 20 Jun 2024 19:23:04 -0700 Subject: [PATCH 07/36] added selector checks, and init action --- .../selectors_actions/qdq_actions.cc | 6 ++ .../selectors_actions/qdq_actions.h | 4 ++ .../qdq_selector_action_transformer.cc | 22 +++++++ .../selectors_actions/qdq_selectors.cc | 60 +++++++++++++++---- 4 files changed, 81 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 3d2a81ce7f8cd..8da74017b3a48 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -273,6 +273,12 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } +Status DQMatMulReplaceWithMatMulNBits::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { + // create new node, move existing node args + // transpose constant args, and insert to the new node + // remove selected nodes +} + static std::vector GetGemmMoveInfo(bool does_q_node_exist) { NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0}; NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1}; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 8179a030508a5..64871f1ac2fef 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -76,6 +76,10 @@ struct MatMulReplaceWithQLinear : public Action { BinaryReplaceWithQLinear qlinear_matmul_replacer_; }; +struct DQMatMulReplaceWithMatMulNBits : public Action { + Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; +}; + struct GemmReplaceWithQuant : public Action { GemmReplaceWithQuant(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 80ead8f8c68d6..1d3e9fbca16d6 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -228,6 +228,28 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #endif } +void DQMatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { + // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. + // DQ's weight is int4/uint4. DQ's scale is float/float16. + // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. + const std::string action_name{"DQMatMul"}; + + std::unique_ptr action = std::make_unique(); + +#if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit. + std::unique_ptr selector = std::make_unique(); + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"MatMul", {}}}, + std::move(selector), + std::move(action)); + +#else + ORT_UNUSED_PARAMETER(is_int8_allowed); + qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); +#endif +} + void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 to 5 nodes. 0=DQ A, 1=DQ B, 2=DQ C(optional), 3=Gemm, 4=Q Y(optional) // Replace with QGemm diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 6944de876c84c..419a4b07b4420 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -436,22 +436,27 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, - const std::vector& q_nodes) const { - // MatMul has 1 DQ input and is the second input - if (dq_nodes.size() != 1) { + const std::vector& q_nodes) const { + const auto& graph = graph_viewer.GetGraph(); + + // MatMul has only 1 DQ input and the DQ must has 1 output edge which is not graph output + if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { return false; } - - auto iter = node.InputNodesBegin(); - ++iter; - if (iter == node.InputNodesEnd() || iter->Index() != dq_nodes[0]->Index()) { + + // DQ must be MatMul's the second input + auto input_node_iter = node.InputNodesBegin(); + if (++input_node_iter; + input_node_iter == node.InputNodesEnd() || input_node_iter->Index() != dq_nodes[0]->Index()) { return false; } // DQ weight/zero points types are int4/uint4, scales/output types are float or float16 - int32_t dt_weight = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_scales = dq_nodes[0]->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - + const auto weight_arg = dq_nodes[0]->InputDefs()[0]; + const auto scale_arg = dq_nodes[0]->InputDefs()[1]; + const auto zero_point_arg = dq_nodes[0]->InputDefs().size() == 3 ? dq_nodes[0]->InputDefs()[2] : nullptr; + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { return false; @@ -463,7 +468,6 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 const auto dq_attrs = dq_nodes[0]->GetAttributes(); - if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { return false; @@ -479,6 +483,40 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + // weight, scale and zero points (if exists) must be constants + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + + if (!graph_utils::NodeArgIsConstant(graph, *weight_arg) || + !graph_utils::NodeArgIsConstant(graph, *scale_arg) || + !graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto) || + !graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto)) { + return false; + } + + if (zero_point_arg && + (!graph_utils::NodeArgIsConstant(graph, *zero_point_arg) || + !graph.GetInitializedTensor(zero_point_arg->Name(), zp_tensor_proto))) { + return false; + } + + // weight, scale and zero points (if exists) must have the rank 2 + if (weight_tensor_proto->dims_size() != 2 || + scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return false; + } + + // check weight, scale and zero points (if exists) shapes + if ((weight_tensor_proto->dims()[0] + block_size) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && + (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return false; + } + return true; } From 82419a3eaa767ca9c79753d6379f5271c8ae2de9 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 21 Jun 2024 12:14:11 -0700 Subject: [PATCH 08/36] finished attribute insertion --- .../selectors_actions/qdq_actions.cc | 41 +++++++++++++++++++ .../selectors_actions/qdq_actions.h | 12 ++++++ .../selectors_actions/qdq_selectors.cc | 2 + .../optimizer/selectors_actions/actions.h | 1 - 4 files changed, 55 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 8da74017b3a48..8e26fc41ac4a3 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -273,10 +273,51 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } +DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level) + : accuracy_level_{accuracy_level}, + domain_{kMSDomain}, + op_type_{"MatMulNBits"}, + value_moves_{[]() { + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; + return std::vector{ + MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), + MoveAll(target, ArgType::kOutput)}; + }()} { +} + +NodeAttributes DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const Graph&, const NodesToOptimize& selected_nodes) const { + NodeAttributes extra_attributes; + + const auto* dq_node = selected_nodes.Input(0); + auto attrs = dq_node->GetAttributes(); + const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); + + ORT_ENFORCE(weight_shape->dim(0).has_dim_value() && weight_shape->dim(1).has_dim_value(), + "Input x of DQ node must have rank 2 shape dimensions"); + + utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); + if (accuracy_level_ > -1) { + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); + } + // currently only 4bits is supported. In the future, derive bits from DQ's weight type. + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs["block_size"].i()), extra_attributes); + + return extra_attributes; +} + Status DQMatMulReplaceWithMatMulNBits::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { // create new node, move existing node args // transpose constant args, and insert to the new node // remove selected nodes + ORT_RETURN_IF_ERROR(CreateReplacementNode(graph, selected_nodes, + OpType(runtime_state), + Domain(runtime_state), + ExtraAttributes(runtime_state), + ValueMoves(runtime_state), + /* only_update_dest_definitions */ false, nullptr)); + return node_remover_.Run(graph, selected_nodes); } static std::vector GetGemmMoveInfo(bool does_q_node_exist) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 64871f1ac2fef..be65075318127 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -76,8 +76,20 @@ struct MatMulReplaceWithQLinear : public Action { BinaryReplaceWithQLinear qlinear_matmul_replacer_; }; +// used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulReplaceWithMatMulNBits : public Action { + DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level = -1); Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; + + private: + NodeAttributes ExtraAttributes(const Graph&, const NodesToOptimize& selected_nodes) const; + + // -1 means not set + const int64_t accuracy_level_; + const std::string domain_; + const std::string op_type_; + const std::vector value_moves_; + RemoveNodes node_remover_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 419a4b07b4420..7348084d98fae 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -32,6 +32,8 @@ constexpr bool Is4BitIntType(int32_t data_type) { } FORCEINLINE bool IsPowerOfTwo(int64_t val) { + if (val < 0) return false; + bool seen_one = val & 1; val >>= 1; diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index 9384bfa7027cd..9d800ffd80636 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -187,5 +187,4 @@ struct ReplaceWithNewFixed : public ReplaceWithNew { const NodeAttributes extra_attrs_; const std::vector value_moves_; }; - } // namespace onnxruntime From d5d9e618470df5f245f3f3806cb11e70ada1ff39 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 21 Jun 2024 16:46:56 -0700 Subject: [PATCH 09/36] finished initializer transpose and append to replacement node --- .../selectors_actions/qdq_actions.cc | 132 ++++++++++++++++-- .../selectors_actions/qdq_actions.h | 3 + .../selectors_actions/qdq_selectors.cc | 4 +- .../quantization/matmul_4bits_quantizer.py | 13 +- 4 files changed, 131 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 8e26fc41ac4a3..06ea495136640 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" - #include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/optimizer/initializer.h" #include "core/graph/node_attr_utils.h" +#include "core/mlas/inc/mlas_q4.h" + namespace onnxruntime { namespace QDQ { @@ -289,7 +291,7 @@ NodeAttributes DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const Graph&, con NodeAttributes extra_attributes; const auto* dq_node = selected_nodes.Input(0); - auto attrs = dq_node->GetAttributes(); + auto& attrs = dq_node->GetAttributes(); const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); ORT_ENFORCE(weight_shape->dim(0).has_dim_value() && weight_shape->dim(1).has_dim_value(), @@ -302,21 +304,127 @@ NodeAttributes DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const Graph&, con } // currently only 4bits is supported. In the future, derive bits from DQ's weight type. utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); - utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs["block_size"].i()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); return extra_attributes; } +void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, + const NodesToOptimize& selected_nodes, + Node& replacement_node) const { + const auto* dq_node = selected_nodes.Input(0); + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + const auto& attrs = dq_node->GetAttributes(); + + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto); + graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto); + if (zp_arg) { + graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); + } + + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = attrs.at("block_size").i(); + auto quant_num = (K + block_size - 1) / block_size; + auto blob_bytes = (block_size + 1) / 2; + + // Unfortunately iterating the source data is complicated, the data maybe in + // external file, a raw buffer, or a repeated field depending on the data + // type. UnpackTensor() already contains some of these logic and is closest + // to what we need. But it does not handle external data. + Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); + std::unique_ptr zp_src_ptr = nullptr; + Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + weight_arg->Name() + "_T", + std::vector{N, quant_num, blob_bytes}); + Initializer scale_dst(static_cast(scale_src.data_type()), + scale_arg->Name() + "_T", + std::vector{N * quant_num}); + std::unique_ptr zp_dst_ptr = nullptr; + + if (zp_tensor_proto) { + zp_src_ptr = std::make_unique(*zp_tensor_proto, graph.ModelPath()); + zp_dst_ptr = std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + zp_arg->Name() + "_T", + std::vector{N * ((quant_num + 1) / 2)}); + } + + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + MlasQDQTransposeBlockwiseQuantized(weight_src.data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + K, N, block_size, + tp.get()); + } else { + MlasQDQTransposeBlockwiseQuantized(weight_src.data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + K, N, block_size, + tp.get()); + + } + + ONNX_NAMESPACE::TensorProto weight_T_tp; + ONNX_NAMESPACE::TensorProto scale_T_tp; + std::unique_ptr zp_T_tp_ptr = nullptr; + + weight_dst.ToProto(weight_T_tp); + scale_dst.ToProto(scale_T_tp); + if (zp_dst_ptr) { + zp_T_tp_ptr = std::make_unique(); + zp_dst_ptr->ToProto(*zp_T_tp_ptr); + } + + auto& input_defs = replacement_node.MutableInputDefs(); + input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + + if (zp_T_tp_ptr) { + input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr)); + replacement_node.MutableInputArgsCount().push_back(1); + } +} + Status DQMatMulReplaceWithMatMulNBits::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { - // create new node, move existing node args - // transpose constant args, and insert to the new node - // remove selected nodes - ORT_RETURN_IF_ERROR(CreateReplacementNode(graph, selected_nodes, - OpType(runtime_state), - Domain(runtime_state), - ExtraAttributes(runtime_state), - ValueMoves(runtime_state), - /* only_update_dest_definitions */ false, nullptr)); + const auto attributes = ExtraAttributes(graph, selected_nodes); + const auto& target = selected_nodes.Target(); + + // create node. we'll populate the input and output defs via moves + auto& replacement = graph.AddNode(target.Name(), + op_type_, + target.Description(), + {}, // input defs + {}, // output defs + &attributes, + domain_); + + const auto& target_provider = target.GetExecutionProviderType(); + replacement.SetExecutionProviderType(target_provider.empty() ? kCpuExecutionProvider : target_provider); + + ORT_RETURN_IF_ERROR(MoveInputOutput(graph, selected_nodes, replacement, value_moves_, false)); + + AddTransposedInitializers(graph, selected_nodes, replacement); + return node_remover_.Run(graph, selected_nodes); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index be65075318127..85090509db0c3 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -84,6 +84,9 @@ struct DQMatMulReplaceWithMatMulNBits : public Action { private: NodeAttributes ExtraAttributes(const Graph&, const NodesToOptimize& selected_nodes) const; + // transpose initializers, and add to the MatMulNBits inputs + void AddTransposedInitializers(Graph&, const NodesToOptimize& selected_nodes, Node& replacement_node) const; + // -1 means not set const int64_t accuracy_level_; const std::string domain_; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 7348084d98fae..575241af1ad2f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -469,7 +469,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 - const auto dq_attrs = dq_nodes[0]->GetAttributes(); + const auto& dq_attrs = dq_nodes[0]->GetAttributes(); if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { return false; @@ -511,7 +511,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } // check weight, scale and zero points (if exists) shapes - if ((weight_tensor_proto->dims()[0] + block_size) / block_size != scale_tensor_proto->dims()[0] || + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 572cf82152db8..40f7b3c272253 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -440,7 +440,6 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.nd block_size = self.config.block_size k_blocks = (rows + block_size - 1) // block_size - scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) if self.config.quant_format == QuantFormat.QOperator: blob_size = block_size // 2 @@ -452,12 +451,14 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.nd # block wise quantization, each block comes from a single column packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) quantize_matmul_4bits( packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric ) else: packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype) quantize_qdq_matmul_4bits( packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric ) @@ -489,18 +490,17 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP packed, scales, zero_points = self.int4_block_quant(B_array) if self.config.quant_format == QuantFormat.QOperator: - B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 - B_quant.name = B.name + "_Q4" + B_quant = onnx.numpy_helper.from_array(packed, B.name + "_Q4") # noqa: N806 + scales_tensor = onnx.numpy_helper.from_array(scales, B.name + "_scales") else: B_quant = onnx.helper.make_tensor(B.name + "_DQ_Q4", qtype, B_array.shape, packed, True) + scales_tensor = onnx.numpy_helper.from_array(scales, B.name + "_DQ_scales") for input in Bs_graph.input: if input.name == inputB: Bs_graph.input.remove(input) break - scales_tensor = onnx.numpy_helper.from_array(scales) - scales_tensor.name = B.name + "_scales" Bs_graph.initializer.extend([B_quant, scales_tensor]) output_nodes = [] @@ -508,8 +508,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP if self.config.quant_format == QuantFormat.QOperator: input_names = [node.input[0], B_quant.name, scales_tensor.name] if not self.config.is_symmetric: - zp_tensor = onnx.numpy_helper.from_array(zero_points) - zp_tensor.name = B.name + "_zero_points" + zp_tensor = onnx.numpy_helper.from_array(zero_points, B.name + "_zero_points") input_names.append(zp_tensor.name) Bs_graph.initializer.extend([zp_tensor]) kwargs = {} From b2548da9204dfc59530fc48e6416c6f24fff4f5e Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 21 Jun 2024 16:49:48 -0700 Subject: [PATCH 10/36] change target name generation --- .../qdq_transformer/selectors_actions/qdq_actions.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 06ea495136640..c0bef19443d69 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -341,17 +341,17 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); std::unique_ptr zp_src_ptr = nullptr; Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - weight_arg->Name() + "_T", + graph.GenerateNodeArgName(weight_arg->Name() + "_T"), std::vector{N, quant_num, blob_bytes}); Initializer scale_dst(static_cast(scale_src.data_type()), - scale_arg->Name() + "_T", + graph.GenerateNodeArgName(scale_arg->Name() + "_T"), std::vector{N * quant_num}); std::unique_ptr zp_dst_ptr = nullptr; if (zp_tensor_proto) { zp_src_ptr = std::make_unique(*zp_tensor_proto, graph.ModelPath()); zp_dst_ptr = std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - zp_arg->Name() + "_T", + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), std::vector{N * ((quant_num + 1) / 2)}); } From 3aa704cdf365ecef7a329fb5fa7170d56048ce9c Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 21 Jun 2024 17:04:54 -0700 Subject: [PATCH 11/36] added selector and action to qdq selector transformer --- .../selectors_actions/qdq_selector_action_transformer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 1d3e9fbca16d6..192d006965c1b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -234,7 +234,7 @@ void DQMatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. const std::string action_name{"DQMatMul"}; - std::unique_ptr action = std::make_unique(); + std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) // TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit. @@ -305,6 +305,7 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed); GemmQDQRules(qdq_selector_action_registry); WhereQDQRules(qdq_selector_action_registry); + DQMatMulQDQRules(qdq_selector_action_registry); return qdq_selector_action_registry; } From 2a10834208770603c5d39fce8ae39bc1ceef2181 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 21 Jun 2024 17:36:35 -0700 Subject: [PATCH 12/36] fixed linting --- .../qdq_transformer/selectors_actions/qdq_actions.cc | 5 ++--- .../qdq_transformer/selectors_actions/qdq_selectors.cc | 6 +++--- .../test/python/quantization/test_op_matmul_4bits.py | 3 ++- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index c0bef19443d69..d127a8c25d0ec 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -296,7 +296,7 @@ NodeAttributes DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const Graph&, con ORT_ENFORCE(weight_shape->dim(0).has_dim_value() && weight_shape->dim(1).has_dim_value(), "Input x of DQ node must have rank 2 shape dimensions"); - + utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); if (accuracy_level_ > -1) { @@ -368,7 +368,7 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, zp_dst_ptr ? zp_dst_ptr->data() : nullptr, true, K, N, block_size, - tp.get()); + tp.get()); } else { MlasQDQTransposeBlockwiseQuantized(weight_src.data(), scale_src.data(), @@ -379,7 +379,6 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, true, K, N, block_size, tp.get()); - } ONNX_NAMESPACE::TensorProto weight_T_tp; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 575241af1ad2f..e07349da05a6c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -436,8 +436,8 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, - const Node& node, - const std::vector& dq_nodes, + const Node& node, + const std::vector& dq_nodes, const std::vector& q_nodes) const { const auto& graph = graph_viewer.GetGraph(); @@ -445,7 +445,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { return false; } - + // DQ must be MatMul's the second input auto input_node_iter = node.InputNodesBegin(); if (++input_node_iter; diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 1adbff348dcc0..fa386eafd5190 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -18,6 +18,7 @@ from onnxruntime.quantization import quant_utils + class TestOpMatMul4Bits(unittest.TestCase): @classmethod def setUpClass(cls): @@ -115,7 +116,7 @@ def quant_test( data_reader: TestDataFeeds, block_size: int, is_symmetric: bool, - quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator + quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, ): use_qdq = quant_format == quant_utils.QuantFormat.QDQ name_prefix = "DQ_MatMul" if use_qdq else "MatMulNBits" From b5164873eebd8c16672f85aec09a7640e5f18721 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 24 Jun 2024 16:40:59 -0700 Subject: [PATCH 13/36] fixed building UT --- .../selectors_actions/qdq_actions.cc | 8 +- .../selectors_actions/qdq_selectors.cc | 1 + onnxruntime/test/common/random_generator.h | 16 + .../optimizer/graph_transform_test_builder.h | 16 - .../qdq_matmulnbits_transformer_test.cc | 313 ++++++++++++++++++ onnxruntime/test/optimizer/qdq_test_utils.h | 2 +- 6 files changed, 337 insertions(+), 19 deletions(-) create mode 100644 onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index d127a8c25d0ec..b27ef92a5e475 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -367,7 +367,9 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, scale_dst.data(), zp_dst_ptr ? zp_dst_ptr->data() : nullptr, true, - K, N, block_size, + static_cast(K), + static_cast(N), + static_cast(block_size), tp.get()); } else { MlasQDQTransposeBlockwiseQuantized(weight_src.data(), @@ -377,7 +379,9 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, scale_dst.data(), zp_dst_ptr ? zp_dst_ptr->data() : nullptr, true, - K, N, block_size, + static_cast(K), + static_cast(N), + static_cast(block_size), tp.get()); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index e07349da05a6c..e3f70ec7366ac 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -439,6 +439,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const { + ONNX_UNUSED_PARAMETER(q_nodes); const auto& graph = graph_viewer.GetGraph(); // MatMul has only 1 DQ input and the DQ must has 1 output edge which is not graph output diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 9ab4a82463d51..fcce91a45227f 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -108,6 +108,22 @@ class RandomValueGenerator { return val; } + template + typename std::enable_if< + std::is_same_v || std::is_same_v, + std::vector>::type + Uniform(gsl::span dims, TInt4 min, TInt4 max) { + using UnpackedType = typename TInt4::UnpackedType; + std::vector data_int8 = Uniform(dims, min.GetElem(0), max.GetElem(0)); + std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); + for (size_t i = 0; i < data_int8.size(); i++) { + size_t r = i >> 1; + size_t c = i & 0x1; + data[r].SetElem(c, data_int8[i]); + } + return data; + } + // Gaussian distribution for float template typename std::enable_if< diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 0282d09f340b2..33c56e6583e6b 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -116,22 +116,6 @@ class ModelTestBuilder { return MakeInput(shape, data); } - template - typename std::enable_if< - std::is_same_v || std::is_same_v, - NodeArg*>::type - MakeInputInt4(const std::vector& shape, typename TInt4::UnpackedType min, typename TInt4::UnpackedType max) { - using UnpackedType = typename TInt4::UnpackedType; - std::vector data_int8 = rand_gen_.Uniform(shape, min, max); - std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); - for (size_t i = 0; i < data_int8.size(); i++) { - size_t r = i >> 1; - size_t c = i & 0x1; - data[r].SetElem(c, data_int8[i]); - } - return MakeInput(shape, data); - } - template NodeArg* MakeInput(const std::optional>& shape, std::optional input_name = std::nullopt) { diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc new file mode 100644 index 0000000000000..3198700e3a418 --- /dev/null +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -0,0 +1,313 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/span_utils.h" +#include "core/framework/int4.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" + +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/framework/test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gtest/gtest.h" +#include "graph_transform_test_builder.h" + +#include "qdq_test_utils.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4127) +#endif // #if defined(_MSC_VER) + +struct QDQOpKeys { + const char* quantize_linear; + const char* dequantize_linear; +}; + +constexpr QDQOpKeys GetQDQOpKeys(bool use_contrib_qdq) { + if (use_contrib_qdq) { + return {"com.microsoft.QuantizeLinear", "com.microsoft.DequantizeLinear"}; + } + return {"QuantizeLinear", "DequantizeLinear"}; +} + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +// Input1 Input2 +// | | +// \ DQ +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* input2_arg = builder.MakeInput(input2_shape, T(T::min_val, 0), T(T::min_val, 0)); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{input2_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto scales = builder.rand_gen_.Uniform(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto zero_points = builder.rand_gen_.Uniform(scale_shape, T(0, 0), T(2, 0)); + builder.AddDequantizeLinearNode(input2_arg, scales, zero_points, dq_output, &attrs, use_contrib_qdq); + } else { + builder.AddDequantizeLinearNode(input2_arg, scales, dq_output, &attrs, use_contrib_qdq); + } + + builder.AddNode("MatMul", {input1_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, true); +} + +// Input1 +// | +// \ DQ +// \ / +// MatMul +// | +// output +template +void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + NodeArg* weight_arg = nullptr; + + // add DQ + if constexpr (std::is_same_v || std::is_same_v) { + weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + } else { + weight_arg = builder.MakeInitializer(weight_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto scales = builder.rand_gen_.Uniform(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + std::vector zero_points; + if constexpr (std::is_same_v || std::is_same_v) { + zero_points = builder.rand_gen_.Uniform(scale_shape, T(0, 0), T(2, 0)); + } else { + zero_points = builder.rand_gen_.Uniform(scale_shape, static_cast(0), static_cast(2)); + } + + builder.AddDequantizeLinearNode(weight_arg, scales, zero_points, dq_output, &attrs, use_contrib_qdq); + } else { + builder.AddDequantizeLinearNode(weight_arg, scales, dq_output, &attrs, use_contrib_qdq); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { + // block size too small + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, true); + // block size not 2's power + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, true); + // not axis 0 + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, true); + // not rank 2 + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, true); +} + +// Input1 +// | DQ +// \ / +// MatMul +// | DQ +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + auto* weight1_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* weight2_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + auto* matmul1_output = builder.MakeIntermediate(); + + auto scales1 = builder.rand_gen_.Uniform(scale_shape, 8.0f, 12.0f); + auto scales2 = builder.rand_gen_.Uniform(scale_shape, 8.0f, 12.0f); + Node* dp1_node = nullptr; + if constexpr (use_zp) { + auto zero_points1 = builder.rand_gen_.Uniform(scale_shape, T(0, 0), T(2, 0)); + auto zero_points2 = builder.rand_gen_.Uniform(scale_shape, T(0, 0), T(2, 0)); + builder.AddDequantizeLinearNode(weight1_arg, scales1, zero_points1, dq1_output, &attrs, use_contrib_qdq); + builder.AddDequantizeLinearNode(weight2_arg, scales2, zero_points2, dq2_output, &attrs, use_contrib_qdq); + } else { + builder.AddDequantizeLinearNode(weight1_arg, scales1, dq1_output, &attrs, use_contrib_qdq); + builder.AddDequantizeLinearNode(weight2_arg, scales2, dq2_output, &attrs, use_contrib_qdq); + } + + builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); + builder.AddNode("MatMul", {matmul1_output, dq2_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/); +} + +TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 862408f31f004..52ac2a2541a79 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -517,7 +517,7 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, NodeArg* input_arg = nullptr; if constexpr (std::is_same_v || std::is_same_v) { - input_arg = builder.MakeInputInt4(input_shape, InputType::min_val, InputType::max_val); + input_arg = builder.MakeInput(input_shape, InputType(InputType::min_val, 0), InputType(InputType::max_val, 0)); dq_zp = InputType(static_cast(InputType::max_val / 2)); q_zp = OutputType(static_cast(OutputType::max_val / 2)); } else { From 3fe19cbfe3cc2818faaebcd09008f51e066c40b2 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 24 Jun 2024 23:24:47 -0700 Subject: [PATCH 14/36] fixed non-convert ut --- .../selectors_actions/qdq_selectors.cc | 4 +- .../qdq_matmulnbits_transformer_test.cc | 182 +++++++++++------- 2 files changed, 117 insertions(+), 69 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index e3f70ec7366ac..424340aea848f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -448,9 +448,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } // DQ must be MatMul's the second input - auto input_node_iter = node.InputNodesBegin(); - if (++input_node_iter; - input_node_iter == node.InputNodesEnd() || input_node_iter->Index() != dq_nodes[0]->Index()) { + if (node.InputDefs()[1]->Name() != dq_nodes[0]->OutputDefs()[0]->Name()) { return false; } diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index 3198700e3a418..1aaa0ae6f441b 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -58,9 +58,11 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, bool use_contrib_qdq) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* input1_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); - auto* input2_arg = builder.MakeInput(input2_shape, T(T::min_val, 0), T(T::min_val, 0)); + auto* input2_arg = builder.MakeInput(input2_shape, T(T::min_val, 0), T(T::max_val, 0)); auto* output_arg = builder.MakeOutput(); + std::string domain = use_contrib_qdq ? kMSDomain : ""; + // add DQ auto* dq_output = builder.MakeIntermediate(); NodeAttributes attrs; @@ -69,12 +71,12 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, auto scale_shape = std::vector{input2_shape}; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; - auto scales = builder.rand_gen_.Uniform(scale_shape, 8.0f, 12.0f); + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { - auto zero_points = builder.rand_gen_.Uniform(scale_shape, T(0, 0), T(2, 0)); - builder.AddDequantizeLinearNode(input2_arg, scales, zero_points, dq_output, &attrs, use_contrib_qdq); + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg, zp_arg}, {dq_output}, domain, &attrs); } else { - builder.AddDequantizeLinearNode(input2_arg, scales, dq_output, &attrs, use_contrib_qdq); + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg}, {dq_output}, domain, &attrs); } builder.AddNode("MatMul", {input1_arg, dq_output}, {output_arg}); @@ -92,20 +94,82 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, - 19 /*opset_version*/, + 21 /*opset_version*/, 1e-5 /*per_sample_tolerance*/, 1e-5 /*relative_per_sample_tolerance*/); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { + // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, true); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, true); +} + +// Input2 +// | +// DQ / +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* input2_arg = builder.MakeInput(input2_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + std::string domain = use_contrib_qdq ? kMSDomain : ""; + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, domain, &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, domain, &attrs); + } + + builder.AddNode("MatMul", {dq_output, input2_arg}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, false); } // Input1 @@ -125,6 +189,7 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); auto* output_arg = builder.MakeOutput(); NodeArg* weight_arg = nullptr; + std::string domain = use_contrib_qdq ? kMSDomain : ""; // add DQ if constexpr (std::is_same_v || std::is_same_v) { @@ -142,18 +207,18 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input auto scale_shape = std::vector{weight_shape}; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; - auto scales = builder.rand_gen_.Uniform(scale_shape, 8.0f, 12.0f); + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { - std::vector zero_points; + NodeArg* zp_arg; if constexpr (std::is_same_v || std::is_same_v) { - zero_points = builder.rand_gen_.Uniform(scale_shape, T(0, 0), T(2, 0)); + zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); } else { - zero_points = builder.rand_gen_.Uniform(scale_shape, static_cast(0), static_cast(2)); + zp_arg = builder.MakeInitializer(scale_shape, 0, 2); } - builder.AddDequantizeLinearNode(weight_arg, scales, zero_points, dq_output, &attrs, use_contrib_qdq); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, domain, &attrs); } else { - builder.AddDequantizeLinearNode(weight_arg, scales, dq_output, &attrs, use_contrib_qdq); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, domain, &attrs); } builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); @@ -171,61 +236,46 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input check_graph, TransformerLevel::Level1, TransformerLevel::Level2, - 19 /*opset_version*/, + 21 /*opset_version*/, 1e-5 /*per_sample_tolerance*/, 1e-5 /*relative_per_sample_tolerance*/); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { + // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { + // DQ contrib op schema is not updated to support blocked quantization // block size too small RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, true); // block size not 2's power RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, true); // not axis 0 RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, true); // not rank 2 - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12, 2}, 0, 16, true); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, false); } // Input1 @@ -240,38 +290,41 @@ TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { template typename std::enable_if || std::is_same_v, void>::type RunDQMatMulConverted(const std::vector& input1_shape, - const std::vector& weight_shape, + const std::vector& weight1_shape, + const std::vector& weight2_shape, const int64_t axis, const int64_t block_size, bool use_contrib_qdq) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); auto* output_arg = builder.MakeOutput(); + std::string domain = use_contrib_qdq ? kMSDomain : ""; // add DQ NodeAttributes attrs; utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); - auto scale_shape = std::vector{weight_shape}; - scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto scale1_shape = std::vector{weight1_shape}; + auto scale2_shape = std::vector{weight2_shape}; + scale1_shape[axis] = (scale1_shape[axis] + block_size - 1) / block_size; + scale2_shape[axis] = (scale2_shape[axis] + block_size - 1) / block_size; - auto* weight1_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); - auto* weight2_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* weight1_arg = builder.MakeInitializer(weight1_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* weight2_arg = builder.MakeInitializer(weight2_shape, T(T::min_val, 0), T(T::max_val, 0)); auto* dq1_output = builder.MakeIntermediate(); auto* dq2_output = builder.MakeIntermediate(); auto* matmul1_output = builder.MakeIntermediate(); - auto scales1 = builder.rand_gen_.Uniform(scale_shape, 8.0f, 12.0f); - auto scales2 = builder.rand_gen_.Uniform(scale_shape, 8.0f, 12.0f); - Node* dp1_node = nullptr; + auto* scales1_arg = builder.MakeInitializer(scale1_shape, 8.0f, 12.0f); + auto* scales2_arg = builder.MakeInitializer(scale2_shape, 8.0f, 12.0f); if constexpr (use_zp) { - auto zero_points1 = builder.rand_gen_.Uniform(scale_shape, T(0, 0), T(2, 0)); - auto zero_points2 = builder.rand_gen_.Uniform(scale_shape, T(0, 0), T(2, 0)); - builder.AddDequantizeLinearNode(weight1_arg, scales1, zero_points1, dq1_output, &attrs, use_contrib_qdq); - builder.AddDequantizeLinearNode(weight2_arg, scales2, zero_points2, dq2_output, &attrs, use_contrib_qdq); + auto* zp1_arg = builder.MakeInitializer(scale1_shape, T(0, 0), T(2, 0)); + auto* zp2_arg = builder.MakeInitializer(scale2_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, domain, &attrs); } else { - builder.AddDequantizeLinearNode(weight1_arg, scales1, dq1_output, &attrs, use_contrib_qdq); - builder.AddDequantizeLinearNode(weight2_arg, scales2, dq2_output, &attrs, use_contrib_qdq); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, domain, &attrs); } builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); @@ -291,20 +344,17 @@ RunDQMatMulConverted(const std::vector& input1_shape, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, - 19 /*opset_version*/, + 21 /*opset_version*/, 1e-5 /*per_sample_tolerance*/, 1e-5 /*relative_per_sample_tolerance*/); } TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, true); + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, false); } #endif // !defined(DISABLE_CONTRIB_OPS) From 67542a5f3114d5376ed42da5d7cab52dc97f6e61 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 25 Jun 2024 17:07:55 -0700 Subject: [PATCH 15/36] fixed action calling transpose --- .../qdq_transformer/selectors_actions/qdq_actions.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index b27ef92a5e475..20d6b8553957f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -360,9 +360,9 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, concurrency::ThreadPoolType::INTRA_OP); if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - MlasQDQTransposeBlockwiseQuantized(weight_src.data(), + MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), zp_dst_ptr ? zp_dst_ptr->data() : nullptr, @@ -372,9 +372,9 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, static_cast(block_size), tp.get()); } else { - MlasQDQTransposeBlockwiseQuantized(weight_src.data(), + MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), zp_dst_ptr ? zp_dst_ptr->data() : nullptr, From 95de135265bd71e5151f0ef5713f25657cf5545d Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 25 Jun 2024 19:08:56 -0700 Subject: [PATCH 16/36] finished changing quantize --- onnxruntime/core/mlas/inc/mlas_q4.h | 26 ++- onnxruntime/core/mlas/lib/q4_dq.cpp | 188 ++++++++++++------ .../python/onnxruntime_pybind_quant.cc | 4 +- .../qdq_matmulnbits_transformer_test.cc | 8 +- 4 files changed, 146 insertions(+), 80 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 898fb23cf3e4f..aec14070ffd55 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -360,12 +360,12 @@ MlasDequantizeBlockwise( ); /** - * @brief Blockwise 2 bits or 4 bits quantization. After quantization, the weights and zero points - * are packed row-wise. In terms of the qbits type, dst and src have the same shape, and - * scales and zero_points have the same shape. - * columns must be multiple of 8 / qbits. + * @brief Blockwise 4 bits quantization. After quantization, the weights and zero points + * are packed row-wise. If zero_points is null, quantized type is int4 with default + * zero point 0, to align with DQ schema. Otherwise, quantized type is uint4. + * In int4/uint4, dst have the same shape as src, and zero_points have the same shape as scales. * @tparam Tin - * @tparam qbits number of bits used for quantization, 2 or 4 + * @tparam qbits number of bits used for quantization, only 4 is supported * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] * @param scales points to the scales matrix, row major * @param zero_points points to the zero_points matrix, row major @@ -376,9 +376,10 @@ MlasDequantizeBlockwise( * @param columns * @param quant_block_size number of elements in a quantize block * @param thread_pool + * @return the quantized type is signed. */ template -void +bool MlasQDQQuantizeBlockwise( const Tin* src, Tin* scales, @@ -395,8 +396,17 @@ MlasQDQQuantizeBlockwise( * @brief Transpose blockwise quantized tensors. The src tensors are row major. src weights and zero * points are packed row-wise. The dst tensors are column major. dst weights and zero points * are packed column-wise. + * dst_weights and dst_zero_points are in uint4. + * If src_weights is int4 and has src_zero_points, src_weights and src_zero_points are + * converted to uint4 by adding 8. + * If src_weights is int4 and no src_zero_points, src_weights is converted to uint4 by adding 8. + * src_zero_points is 0 and dst_zero_points is 8. + * If src_weights is uint4 and has src_zero_points, just transpose. + * If src_weights is uint4 and no src_zero_points, caller must allocate dst_zero_points with + * 0 values. Otherwise exception is thrown. * @tparam Tin - * @tparam qbits number of bits used for quantization, 2 or 4 + * @tparam qbits number of bits used for quantization, only 4 is supported + * @tparam signed_quant true when quantized type is signed, false when quantized type is unsigned * @param src_weights points to the quantized matrix, row major, shape [rows, columns] in qbits type. * In uint8_t type, shape is [rows, columns * qbits / 8]. * @param src_scales points to the scales matrix, row major @@ -410,7 +420,7 @@ MlasQDQQuantizeBlockwise( * @param quant_block_size number of elements in a quantize block * @param thread_pool */ -template +template void MlasQDQTransposeBlockwiseQuantized( const uint8_t* src_weights, diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 62fe58ca333de..ec7c479236ccd 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -314,14 +314,18 @@ struct Shape2D { }; -template +template struct BitsTraits { static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); static constexpr int kBits = qbits; - static constexpr int kMax = (1 << qbits) - 1; - static constexpr int kMid = 1 << (qbits - 1); + static constexpr int kMax = signed_quant ? (1 << (qbits -1)) - 1 : (1 << qbits) - 1; + static constexpr int kMid = signed_quant ? 0 : 1 << (qbits - 1); + static constexpr int kMin = signed_quant ? -(1 << (qbits - 1)) : 0; static constexpr float kMaxFp = static_cast(kMax); + static constexpr float kMinFp = static_cast(kMin); + static constexpr float fullRange = kMaxFp - kMinFp; + static constexpr float halfRange = static_cast(kMid - kMin); // number of qbit elements to pack into whole bytes static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; @@ -331,53 +335,51 @@ struct BitsTraits { /** * @brief Rectify min/max from a set of weights, and convert to scale and zero point - * for quantization - * @tparam ScaleT type of scale, usually floating point of various bits - * @tparam qbits number of int bits used for zero point value + * for quantization. + * @tparam ScaleT type of scale, usually floating point of various bits + * @tparam qbits number of int bits used for zero point value + * @tparam signed_quant output quantized type is signed * @param[in] min * @param[in] max * @param[out] scale * @param[out] zp */ -template +template MLAS_FORCEINLINE void range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) { - constexpr int zp_max = BitsTraits::kMax; - constexpr float zp_max_fp = BitsTraits::kMaxFp; - min = std::min(min, 0.0f); max = std::max(max, 0.0f); - float scale_f = (max - min) / zp_max; + float scale_f = (max - min) / BitsTraits::fullRange; float zero_point_fp = min; if (scale_f != 0.0f) { - zero_point_fp = 0.f - min / scale_f; + zero_point_fp = BitsTraits::kMinFp - min / scale_f; } - if (zero_point_fp < 0.0f) { - zp = 0; - } else if (zero_point_fp > zp_max_fp) { - zp = zp_max; + if (zero_point_fp < BitsTraits::kMinFp) { + zp = BitsTraits::kMin; + } else if (zero_point_fp > BitsTraits::kMaxFp) { + zp = BitsTraits::kMax; } else { zp = (uint8_t)roundf(zero_point_fp); } scale = ScaleT(scale_f); } -template +/** + * @brief Rectify min/max from a set of symmetric weights, and convert + * to scale for quantization. + */ +template MLAS_FORCEINLINE void range2scale(float min, float max, ScaleT& scale) { - constexpr int mid_v = BitsTraits::kMid; - constexpr float mid_fp = static_cast(-mid_v); - max = fabsf(max) > fabsf(min) ? max : min; - - scale = ScaleT(max / mid_fp); + scale = ScaleT(max / BitsTraits::halfRange); }; @@ -400,7 +402,7 @@ struct BlockwiseQuantizer { static_assert(qbits == 4, "Only 4b block quantization is supported!"); using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; - using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; static MLAS_FORCEINLINE @@ -474,8 +476,8 @@ struct BlockwiseQuantizer { MlasTryBatchParallel( thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { - uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + uint8_t zp_bytes[BitsTraits::kPackSize]; + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -490,7 +492,7 @@ struct BlockwiseQuantizer { const int meta_col = c / QuantBlk::kColumn; // compute scale and zero point - for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { // scan a single block to extract range [min, max] float min = std::numeric_limits::max(); @@ -509,9 +511,9 @@ struct BlockwiseQuantizer { if (row_start < row_end) { const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; if (zero_points == nullptr) { - range2scale(min, max, scales[meta_idx]); + range2scale(min, max, scales[meta_idx]); } else { - range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); } } } @@ -533,7 +535,7 @@ struct BlockwiseQuantizer { const float v0 = static_cast(src[i * leadingDimension + j]); const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), - 0.0f, BitsTraits::kMaxFp); + 0.0f, BitsTraits::kMaxFp); uint8_t vi1 = (uint8_t)zp; if (i + 1 < r_end) { @@ -545,7 +547,7 @@ struct BlockwiseQuantizer { } const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, - BitsTraits::kMaxFp); + BitsTraits::kMaxFp); } // !! 4b specific code @@ -644,14 +646,19 @@ struct BlockwiseQuantizer { * in memory are packed together, which means the packing is along the row. Quantized data * are stored in row major, so the output tensor reserves same shape, in terms of qbits type, * as the input tensor. - * @tparam Tin source data type, e.g. fp32/fp16 - * @tparam qbits number of bits in each quantized element + * If has zero points, quantized type is unsigned. Otherwise, quantized type is signed and the + * zero point is 0. + * The transposed outputs are used by MatMulNBits, so quant type becomes uint4 with default + * zp at 8. + * @tparam Tin source data type, e.g. fp32/fp16 + * @tparam qbits number of bits in each quantized element + * @tparam signed_quant quantized type is signed */ -template +template struct BlockwiseQDQQuantizer; -template -struct BlockwiseQDQQuantizer { +template +struct BlockwiseQDQQuantizer { static MLAS_FORCEINLINE uint8_t GetElem(uint8_t val, int32_t idx) { return (val >> (idx << 2)) & 0xF; @@ -693,8 +700,8 @@ struct BlockwiseQDQQuantizer { static_cast( std::roundf(static_cast(src) * reciprocal_scale) ) + static_cast(zero_point), - 0, - BitsTraits<4>::kMax + BitsTraits<4, signed_quant>::kMin, + BitsTraits<4, signed_quant>::kMax ) ); } @@ -769,6 +776,7 @@ struct BlockwiseQDQQuantizer { MLAS_THREADPOOL* thread_pool ) { + ORT_ENFORCE(zero_points || signed_quant, "Unsigned quant with no zero points is not supported."); // Must avoid multiple thread write to a single byte, which means the starting index // of a thread block must be even. To achieve that, we need to customize the thread // block size based on the parity of columns. @@ -896,15 +904,15 @@ struct BlockwiseQDQQuantizer { // calculate scale and zero point, and store for (int32_t i = 0; i < col_size; i += 2) { - v0_tt = v1_tt = BitsTraits<4>::kMid; + v0_tt = v1_tt = BitsTraits<4, signed_quant>::kMid; if (zero_points) { - range2scalezp(vmin_t[i], vmax_t[i], scale0_tt, v0_tt); - range2scalezp(vmin_t[i + 1], vmax_t[i + 1], scale1_tt, v1_tt); + range2scalezp(vmin_t[i], vmax_t[i], scale0_tt, v0_tt); + range2scalezp(vmin_t[i + 1], vmax_t[i + 1], scale1_tt, v1_tt); zero_points[(scale_idx + i) >> 1] = Pack(v0_tt, v1_tt); } else { - range2scale(vmin_t[i], vmax_t[i], scale0_tt); - range2scale(vmin_t[i + 1], vmax_t[i + 1], scale1_tt); + range2scale(vmin_t[i], vmax_t[i], scale0_tt); + range2scale(vmin_t[i + 1], vmax_t[i + 1], scale1_tt); } scales[scale_idx + i] = scale0_tt; @@ -993,14 +1001,14 @@ struct BlockwiseQDQQuantizer { int32_t col_idx = 0; // leading unailgned zero points if (scale_buffer_idx & 1) { - v0_tt = BitsTraits<4>::kMid; + v0_tt = BitsTraits<4, signed_quant>::kMid; if (zero_points) { - range2scalezp(vmin_t[0], vmax_t[0], scale0_tt, v0_tt); + range2scalezp(vmin_t[0], vmax_t[0], scale0_tt, v0_tt); zero_points[scale_buffer_idx >> 1] = SetElem( v0_tt, 1, zero_points[scale_buffer_idx >> 1] ); } else { - range2scale(vmin_t[0], vmax_t[0], scale0_tt); + range2scale(vmin_t[0], vmax_t[0], scale0_tt); } scales[scale_buffer_idx] = scale0_tt; @@ -1014,14 +1022,16 @@ struct BlockwiseQDQQuantizer { } // aligned zero points for (; scale_buffer_idx < scale_buffer_idx_end - 1; col_idx += 2, scale_buffer_idx += 2) { - v0_tt = v1_tt = BitsTraits<4>::kMid; + v0_tt = v1_tt = BitsTraits<4, signed_quant>::kMid; if (zero_points) { - range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); - range2scalezp(vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt, v1_tt); + range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); + range2scalezp( + vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt, v1_tt + ); zero_points[scale_buffer_idx >> 1] = Pack(v0_tt, v1_tt); } else { - range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); - range2scale(vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt); + range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); + range2scale(vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt); } scales[scale_buffer_idx] = scale0_tt; @@ -1037,14 +1047,14 @@ struct BlockwiseQDQQuantizer { } // tailing unaligned elements if (scale_buffer_idx < scale_buffer_idx_end) { - v0_tt = BitsTraits<4>::kMid; + v0_tt = BitsTraits<4, signed_quant>::kMid; if (zero_points) { - range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); + range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); zero_points[scale_buffer_idx >> 1] = SetElem( v0_tt, 0, zero_points[scale_buffer_idx >> 1] ); } else { - range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); + range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); } scales[scale_buffer_idx] = scale0_tt; @@ -1745,7 +1755,7 @@ MlasDequantizeBlockwise( ); template -void +bool MlasQDQQuantizeBlockwise( const Tin* src, Tin* scales, @@ -1759,17 +1769,33 @@ MlasQDQQuantizeBlockwise( ) { if (columnwise) { - BlockwiseQDQQuantizer::QuantizeColumnWise( - src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool - ); + if (zero_points) { + BlockwiseQDQQuantizer::QuantizeColumnWise( + src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool + ); + return false; + } else { + BlockwiseQDQQuantizer::QuantizeColumnWise( + src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool + ); + return true; + } } else { - BlockwiseQDQQuantizer::QuantizeRowWise( - src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool - ); + if (zero_points) { + BlockwiseQDQQuantizer::QuantizeRowWise( + src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool + ); + return false; + } else { + BlockwiseQDQQuantizer::QuantizeRowWise( + src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool + ); + return true; + } } } -template void +template bool MlasQDQQuantizeBlockwise( const float* src, float* scales, @@ -1782,7 +1808,7 @@ MlasQDQQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); -template void +template bool MlasQDQQuantizeBlockwise( const MLAS_FP16* src, MLAS_FP16* scales, @@ -1795,7 +1821,7 @@ MlasQDQQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); -template +template void MlasQDQTransposeBlockwiseQuantized( const uint8_t* src_weights, @@ -1812,7 +1838,7 @@ MlasQDQTransposeBlockwiseQuantized( ) { if (columnwise) { - BlockwiseQDQQuantizer::TransposeColumnWiseQuantized( + BlockwiseQDQQuantizer::TransposeColumnWiseQuantized( src_weights, src_scales, src_zero_points, dst_weights, dst_scales, dst_zero_points, rows, columns, quant_block_size, thread_pool ); @@ -1822,7 +1848,7 @@ MlasQDQTransposeBlockwiseQuantized( } template void -MlasQDQTransposeBlockwiseQuantized( +MlasQDQTransposeBlockwiseQuantized( const uint8_t* src_weights, const float* src_scales, const uint8_t* src_zero_points, @@ -1837,7 +1863,37 @@ MlasQDQTransposeBlockwiseQuantized( ); template void -MlasQDQTransposeBlockwiseQuantized( +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( const uint8_t* src_weights, const MLAS_FP16* src_scales, const uint8_t* src_zero_points, diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 5e8e5c1a2a2fc..51a52af1b151e 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -67,7 +67,7 @@ void QuantizeMatMul4BitsBlockwise( } template -void QuantizeQDQMatMul4BitsBlockwise( +bool QuantizeQDQMatMul4BitsBlockwise( py::array_t dst, // shape: [K, N / 2] py::array_t src, // shape: [K, N] py::array_t scale, // shape: [block_per_K, N] @@ -85,7 +85,7 @@ void QuantizeQDQMatMul4BitsBlockwise( py::buffer_info scale_buf = scale.request(); py::buffer_info zp_buf = zero_points.request(); - MlasQDQQuantizeBlockwise( + return MlasQDQQuantizeBlockwise( reinterpret_cast(src_buf.ptr), reinterpret_cast(scale_buf.ptr), is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index 1aaa0ae6f441b..c64413d72da5f 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -117,10 +117,10 @@ TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { template typename std::enable_if || std::is_same_v, void>::type RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, - const std::vector& input2_shape, - const int64_t axis, - const int64_t block_size, - bool use_contrib_qdq) { + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + bool use_contrib_qdq) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); auto* input2_arg = builder.MakeInput(input2_shape, -100.0f, 100.0f); From 0ad7fe4d12a6130e3dca78289edc6ff9885d4a48 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 25 Jun 2024 21:01:14 -0700 Subject: [PATCH 17/36] finished modifying transpose --- onnxruntime/core/mlas/lib/q4_dq.cpp | 45 +++++++++++++++++++---------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index ec7c479236ccd..fb7079daf9858 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -670,9 +670,14 @@ struct BlockwiseQDQQuantizer { return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); } + template static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1) { - return (v0 & 0xF) | ((v1 & 0xF) << 4); + if constexpr (add8) { + return (v0 & 0xF ^ 8) | ((v1 & 0xF ^ 8) << 4); + } else { + return (v0 & 0xF) | ((v1 & 0xF) << 4); + } } // If src is row major, then dst is column major. Transpose: @@ -687,10 +692,16 @@ struct BlockwiseQDQQuantizer { // --> // | dst0: low 4 bit | dst0: high 4 bit | // | dst1: low 4 bit | dst1: high 4 bit | + template static MLAS_FORCEINLINE void Transpose(uint8_t src0, uint8_t src1, uint8_t& dst0, uint8_t& dst1) { - dst0 = (src0 & 0xF) | ((src1 & 0xF) << 4); - dst1 = ((src0 & 0xF0) >> 4) | (src1 & 0xF0); + if constexpr (add8) { + dst0 = ((src0 & 0xF) ^ 8) | (((src1 & 0xF) ^ 8) << 4); + dst1 = (((src0 & 0xF0) ^ 0x80) >> 4) | ((src1 & 0xF0) ^ 0x80); + } else { + dst0 = (src0 & 0xF) | ((src1 & 0xF) << 4); + dst1 = ((src0 & 0xF0) >> 4) | (src1 & 0xF0); + } } static MLAS_FORCEINLINE uint8_t QuantizeV(Tin src, float reciprocal_scale, uint8_t zero_point) @@ -823,6 +834,10 @@ struct BlockwiseQDQQuantizer { MLAS_THREADPOOL* thread_pool ) { + ORT_ENFORCE( + src_zero_points || signed_quant || dst_zero_points, + "Unsigned quant types without zero points must allocate zero points with value 0." + ); // Must avoid multiple thread write to a single byte, which means the starting index // of a thread block must be even. To achieve that, we need to customize the thread // block size based on the parity of columns. @@ -909,7 +924,7 @@ struct BlockwiseQDQQuantizer { if (zero_points) { range2scalezp(vmin_t[i], vmax_t[i], scale0_tt, v0_tt); range2scalezp(vmin_t[i + 1], vmax_t[i + 1], scale1_tt, v1_tt); - zero_points[(scale_idx + i) >> 1] = Pack(v0_tt, v1_tt); + zero_points[(scale_idx + i) >> 1] = Pack(v0_tt, v1_tt); } else { range2scale(vmin_t[i], vmax_t[i], scale0_tt); range2scale(vmin_t[i + 1], vmax_t[i + 1], scale1_tt); @@ -933,7 +948,7 @@ struct BlockwiseQDQQuantizer { for (int32_t i = 0; i < col_size; i += 2) { v0_tt = QuantizeV(src[input_idx_t + i], reciprocal_scale_t[i], zp_t[i]); v1_tt = QuantizeV(src[input_idx_t + i + 1], reciprocal_scale_t[i + 1], zp_t[i + 1]); - dst[(input_idx_t + i) >> 1] = Pack(v0_tt, v1_tt); + dst[(input_idx_t + i) >> 1] = Pack(v0_tt, v1_tt); } } } @@ -1028,7 +1043,7 @@ struct BlockwiseQDQQuantizer { range2scalezp( vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt, v1_tt ); - zero_points[scale_buffer_idx >> 1] = Pack(v0_tt, v1_tt); + zero_points[scale_buffer_idx >> 1] = Pack(v0_tt, v1_tt); } else { range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); range2scale(vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt); @@ -1088,7 +1103,7 @@ struct BlockwiseQDQQuantizer { src[input_idx_t_start + 1], reciprocal_scale_t[col_idx + 1], zp_t[col_idx + 1] ); - dst[input_idx_t_start >> 1] = Pack(v0_tt, v1_tt); + dst[input_idx_t_start >> 1] = Pack(v0_tt, v1_tt); } // tailing unaligned output if (input_idx_t_start < input_idx_t_end) { @@ -1154,7 +1169,7 @@ struct BlockwiseQDQQuantizer { src0_t = src_weights[src_idx]; src1_t = src_weights[src_idx + packed_col_size]; src_idx += packed_col_size + packed_col_size; - Transpose(src0_t, src1_t, dst0_t, dst1_t); + Transpose(src0_t, src1_t, dst0_t, dst1_t); dst_weights[dst_idx] = dst0_t; dst_weights[dst_idx + dstT_num_row] = dst1_t; } @@ -1162,7 +1177,7 @@ struct BlockwiseQDQQuantizer { if (src_idx < src_end_idx) { src0_t = src_weights[src_idx]; src1_t = 0; - Transpose(src0_t, src1_t, dst0_t, dst1_t); + Transpose(src0_t, src1_t, dst0_t, dst1_t); dst_weights[dst_idx] = dst0_t; dst_weights[dst_idx + dstT_num_row] = dst1_t; } @@ -1200,7 +1215,7 @@ struct BlockwiseQDQQuantizer { for (; src_idx < src_end_idx - packed_col_size; ++dst_idx) { src0_t = src_zero_points[src_idx]; src1_t = src_zero_points[src_idx + packed_col_size]; - Transpose(src0_t, src1_t, dst0_t, dst1_t); + Transpose(src0_t, src1_t, dst0_t, dst1_t); dst_zero_points[dst_idx] = dst0_t; dst_zero_points[dst_idx + dst_zp_row_num] = dst1_t; src_idx += packed_col_size + packed_col_size; @@ -1209,7 +1224,7 @@ struct BlockwiseQDQQuantizer { if (src_idx < src_end_idx) { src0_t = src_zero_points[src_idx]; src1_t = 0; - Transpose(src0_t, src1_t, dst0_t, dst1_t); + Transpose(src0_t, src1_t, dst0_t, dst1_t); dst_zero_points[dst_idx] = dst0_t; dst_zero_points[dst_idx + dst_zp_row_num] = dst1_t; } @@ -1257,13 +1272,13 @@ struct BlockwiseQDQQuantizer { for (; src_idx < src_end_idx - columns; ++dst_idx) { src0_t = GetElem(src_weights[src_idx >> 1], src_idx & 1); src1_t = GetElem(src_weights[(src_idx + columns) >> 1], (src_idx + columns) & 1); - dst_weights[dst_idx] = (src0_t & 0xf) | ((src1_t & 0xf) << 4); + dst_weights[dst_idx] = Pack(src0_t, src1_t); src_idx += columns + columns; } if (src_idx < src_end_idx) { src0_t = GetElem(src_weights[src_idx >> 1], src_idx & 1); - dst_weights[dst_idx] = src0_t & 0xf; + dst_weights[dst_idx] = Pack(src0_t, 0); } } ); @@ -1298,13 +1313,13 @@ struct BlockwiseQDQQuantizer { for (; src_idx < src_end_idx - columns; ++dst_idx) { src0_t = GetElem(src_zero_points[src_idx >> 1], src_idx & 1); src1_t = GetElem(src_zero_points[(src_idx + columns) >> 1], (src_idx + columns) & 1); - dst_zero_points[dst_idx] = (src0_t & 0xf) | ((src1_t & 0xf) << 4); + dst_zero_points[dst_idx] = Pack(src0_t, src1_t); src_idx += columns + columns; } if (src_idx < src_end_idx) { src0_t = GetElem(src_zero_points[src_idx >> 1], src_idx & 1); - dst_zero_points[dst_idx] = src0_t & 0xf; + dst_zero_points[dst_idx] = Pack(src0_t, 0); } } ); From f81d7bbceddd23a3585868b1b6e1b242d8d22645 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 25 Jun 2024 21:16:19 -0700 Subject: [PATCH 18/36] updated mlas kernel calling --- onnxruntime/test/mlas/bench/bench_q4dq.cpp | 27 +++++++++++++------ .../test/mlas/unittest/test_blockq4.cpp | 17 +++++++++--- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp index 00234ecfd2ce2..9d15c9a6bf994 100644 --- a/onnxruntime/test/mlas/bench/bench_q4dq.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4dq.cpp @@ -69,6 +69,7 @@ static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state int N = state.range(1); int quant_block_size = state.range(2); int threads = state.range(3); + bool add8 = state.range(4) != 0; int quant_num_M = (M + quant_block_size - 1) / quant_block_size; int blob_size = (quant_block_size + 1) / 2; size_t scale_size = quant_num_M * N; @@ -87,12 +88,22 @@ static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - for (auto _ : state) { - benchmark::DoNotOptimize(dst.data()); - MlasQDQTransposeBlockwiseQuantized( - dst.data(), scales.data(), zero_points.data(), dst_T.data(), scales_T.data(), zero_points_T.data(), - true, M, N, quant_block_size, tp.get()); - benchmark::ClobberMemory(); + if (add8) { + for (auto _ : state) { + benchmark::DoNotOptimize(dst.data()); + MlasQDQTransposeBlockwiseQuantized( + dst.data(), scales.data(), zero_points.data(), dst_T.data(), scales_T.data(), zero_points_T.data(), + true, M, N, quant_block_size, tp.get()); + benchmark::ClobberMemory(); + } + } else { + for (auto _ : state) { + benchmark::DoNotOptimize(dst.data()); + MlasQDQTransposeBlockwiseQuantized( + dst.data(), scales.data(), zero_points.data(), dst_T.data(), scales_T.data(), zero_points_T.data(), + true, M, N, quant_block_size, tp.get()); + benchmark::ClobberMemory(); + } } } @@ -113,6 +124,6 @@ BENCHMARK(BM_MlasQuantizeBlockwise) BENCHMARK(BM_QDQBlockwiseQuantizer_TransposeColumnwise) ->UseRealTime() ->Apply([](benchmark::internal::Benchmark* b) { - b->ArgNames({"M", "N", "quant_block_size", "threads"}); - b->ArgsProduct({{1024, 4096}, {4096, 4095}, {64, 128}, {2, 8, 16}}); + b->ArgNames({"M", "N", "quant_block_size", "threads", "add8"}); + b->ArgsProduct({{1024, 4096}, {4096, 4095}, {64, 128}, {2, 8, 16}, {0, 1}}); }); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index b466e883059f4..f75002f715154 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -127,13 +127,22 @@ class MlasBlockwiseQdqTest : public MlasTestBase { columnwise, rows, columns, columns, threadpool_ptr); if (columnwise) { - MlasQDQQuantizeBlockwise( + bool signed_quant = MlasQDQQuantizeBlockwise( transposed, qdq_scales, qdq_zp, qdq_weights, true, rows, columns, block_size, threadpool_ptr); - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); + ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; + + if (symmetric) { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + + } else { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + } } for (int c = 0; c < columns; c++) { From 8e8b3144435a6059d379ab1f88e41978b5a6a458 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 25 Jun 2024 23:26:43 -0700 Subject: [PATCH 19/36] fixed mlas scale calc bug --- onnxruntime/core/mlas/lib/q4_dq.cpp | 11 ++- .../selectors_actions/qdq_actions.cc | 77 +++++++++++++------ 2 files changed, 62 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index fb7079daf9858..48d801db468f3 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -320,7 +320,7 @@ struct BitsTraits { static constexpr int kBits = qbits; static constexpr int kMax = signed_quant ? (1 << (qbits -1)) - 1 : (1 << qbits) - 1; - static constexpr int kMid = signed_quant ? 0 : 1 << (qbits - 1); + static constexpr int kMid = signed_quant ? 0 : (1 << (qbits - 1)); static constexpr int kMin = signed_quant ? -(1 << (qbits - 1)) : 0; static constexpr float kMaxFp = static_cast(kMax); static constexpr float kMinFp = static_cast(kMin); @@ -360,9 +360,9 @@ range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) } if (zero_point_fp < BitsTraits::kMinFp) { - zp = BitsTraits::kMin; + zp = static_cast(BitsTraits::kMin); } else if (zero_point_fp > BitsTraits::kMaxFp) { - zp = BitsTraits::kMax; + zp = static_cast(BitsTraits::kMax); } else { zp = (uint8_t)roundf(zero_point_fp); } @@ -379,7 +379,10 @@ void range2scale(float min, float max, ScaleT& scale) { max = fabsf(max) > fabsf(min) ? max : min; - scale = ScaleT(max / BitsTraits::halfRange); + // !!Note: in the quantized space, abs of min -8 > abs of max 7. + // Therefore map the larger half FP space to [-8, 0]. + // Minus sign achieves this purpose. + scale = ScaleT(-max / BitsTraits::halfRange); }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 20d6b8553957f..2f32dd4b55415 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -353,6 +353,10 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, zp_dst_ptr = std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, graph.GenerateNodeArgName(zp_arg->Name() + "_T"), std::vector{N * ((quant_num + 1) / 2)}); + } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + zp_dst_ptr = std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)}); } OrtThreadPoolParams to; @@ -360,29 +364,58 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, concurrency::ThreadPoolType::INTRA_OP); if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - tp.get()); + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + tp.get()); + } else { + MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + tp.get()); + } } else { - MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - tp.get()); + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + tp.get()); + + } else { + MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + tp.get()); + } } ONNX_NAMESPACE::TensorProto weight_T_tp; From 572976203a859ea96ec0665b6e1bd056cea1e1a6 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 25 Jun 2024 23:40:57 -0700 Subject: [PATCH 20/36] passed UT --- .../optimizer/qdq_transformer/selectors_actions/qdq_actions.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 2f32dd4b55415..dd76b57af8a34 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -355,7 +355,7 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, std::vector{N * ((quant_num + 1) / 2)}); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { zp_dst_ptr = std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), std::vector{N * ((quant_num + 1) / 2)}); } From c4805c8b643e038a259bb528be72248be8343159 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 26 Jun 2024 12:39:19 -0700 Subject: [PATCH 21/36] fixed python build --- .../quantization/matmul_4bits_quantizer.py | 52 ++++++++++--------- .../quantization/test_op_matmul_4bits.py | 9 +++- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 40f7b3c272253..23e4481f47d5d 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -343,8 +343,8 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP import torch logger.info(f"start to quantize {node.name} ...") - inputB = node.input[1] # noqa: N806 - b_pb, bs_graph = get_initializer(inputB, graph_stack) + input_b = node.input[1] + b_pb, bs_graph = get_initializer(input_b, graph_stack) if b_pb is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return [node] # only care about constant weight @@ -383,7 +383,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) b_quant.name = b_pb.name + "_Q4" for input in bs_graph.input: - if input.name == inputB: + if input.name == input_b: bs_graph.input.remove(input) break @@ -476,43 +476,43 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP logger.info(f"start to quantize {node.name} ...") qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4 - inputB = node.input[1] # noqa: N806 - B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 - if B is None: + input_b = node.input[1] + b_tensor, b_graph = get_initializer(input_b, graph_stack) + if b_tensor is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return [node] # only care about constant weight - B_array = onnx.numpy_helper.to_array(B) # noqa: N806 - if len(B_array.shape) != 2: + b_ndarray = onnx.numpy_helper.to_array(b_tensor) + if len(b_ndarray.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") return [node] # can only process 2-D matrix - packed, scales, zero_points = self.int4_block_quant(B_array) + packed, scales, zero_points = self.int4_block_quant(b_ndarray) if self.config.quant_format == QuantFormat.QOperator: - B_quant = onnx.numpy_helper.from_array(packed, B.name + "_Q4") # noqa: N806 - scales_tensor = onnx.numpy_helper.from_array(scales, B.name + "_scales") + b_quant = onnx.numpy_helper.from_array(packed, b_tensor.name + "_Q4") + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_scales") else: - B_quant = onnx.helper.make_tensor(B.name + "_DQ_Q4", qtype, B_array.shape, packed, True) - scales_tensor = onnx.numpy_helper.from_array(scales, B.name + "_DQ_scales") + b_quant = onnx.helper.make_tensor(b_tensor.name + "_DQ_Q4", qtype, b_ndarray.shape, packed.tobytes(), True) + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") - for input in Bs_graph.input: - if input.name == inputB: - Bs_graph.input.remove(input) + for input in b_graph.input: + if input.name == input_b: + b_graph.input.remove(input) break - Bs_graph.initializer.extend([B_quant, scales_tensor]) + b_graph.initializer.extend([b_quant, scales_tensor]) output_nodes = [] if self.config.quant_format == QuantFormat.QOperator: - input_names = [node.input[0], B_quant.name, scales_tensor.name] + input_names = [node.input[0], b_quant.name, scales_tensor.name] if not self.config.is_symmetric: - zp_tensor = onnx.numpy_helper.from_array(zero_points, B.name + "_zero_points") + zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points") input_names.append(zp_tensor.name) - Bs_graph.initializer.extend([zp_tensor]) + b_graph.initializer.extend([zp_tensor]) kwargs = {} - rows, cols = B_array.shape + rows, cols = b_ndarray.shape kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = 4 @@ -531,14 +531,16 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP output_nodes.append(matmul_q4_node) else: - dq_input_names = [B_quant.name, scales_tensor.name] - dq_output_names = [B_quant.name + "_output"] + dq_input_names = [b_quant.name, scales_tensor.name] + dq_output_names = [b_quant.name + "_output"] matmul_input_names = [node.input[0], dq_output_names[0]] matmul_output_names = [node.output[0]] if not self.config.is_symmetric: - zp_tensor = onnx.helper.make_tensor(B.name + "_DQ_zero_points", qtype, scales.shape, zero_points, True) + zp_tensor = onnx.helper.make_tensor( + b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True + ) dq_input_names.append(zp_tensor.name) - Bs_graph.initializer.extend([zp_tensor]) + b_graph.initializer.extend([zp_tensor]) dq_kwargs = {"axis": 0, "block_size": self.config.block_size} dq_node = onnx.helper.make_node( "DequantizeLinear", diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index fa386eafd5190..e43bf656b51cc 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -105,8 +105,9 @@ def make_matmul( [output_tensor], initializer=initializers, ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + # blocked quantization requires DQ op set >= 21 + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 # use stable onnx ir version onnx.save(model, output_model_path) @@ -141,6 +142,10 @@ def quant_test( if use_qdq: dq_qtype = TensorProto.INT4 if is_symmetric else TensorProto.UINT4 dqnode_io_qtypes = { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ] + } if is_symmetric else { "DequantizeLinear": [ ["i", 0, dq_qtype], ["i", 2, dq_qtype], From 53635fa26d0a40ee0f80c6c586bea9e62fe0c29c Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 26 Jun 2024 14:24:16 -0700 Subject: [PATCH 22/36] fixing ci --- onnxruntime/core/mlas/lib/q4_dq.cpp | 2 +- .../selectors_actions/qdq_actions.cc | 95 ++++++++++--------- .../selectors_actions/qdq_actions.h | 6 +- .../qdq_selector_action_transformer.cc | 1 - .../selectors_actions/qdq_selectors.h | 2 +- .../qdq_matmulnbits_transformer_test.cc | 5 +- .../quantization/test_op_matmul_4bits.py | 24 +++-- 7 files changed, 73 insertions(+), 62 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 48d801db468f3..4a3b3c2502bf6 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -677,7 +677,7 @@ struct BlockwiseQDQQuantizer { static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1) { if constexpr (add8) { - return (v0 & 0xF ^ 8) | ((v1 & 0xF ^ 8) << 4); + return ((v0 & 0xF) ^ 8) | (((v1 & 0xF) ^ 8) << 4); } else { return (v0 & 0xF) | ((v1 & 0xF) << 4); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index dd76b57af8a34..193455a3e3693 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -287,7 +287,8 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_ }()} { } -NodeAttributes DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const Graph&, const NodesToOptimize& selected_nodes) const { +NodeAttributes +DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const Graph&, const NodesToOptimize& selected_nodes) const { NodeAttributes extra_attributes; const auto* dq_node = selected_nodes.Input(0); @@ -365,56 +366,60 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { - MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - tp.get()); + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + tp.get()); } else { - MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - tp.get()); + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + tp.get()); } } else { if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { - MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - tp.get()); + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + tp.get()); } else { - MlasQDQTransposeBlockwiseQuantized(weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - tp.get()); + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + tp.get()); } } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 85090509db0c3..243f43874cf83 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + #include "core/optimizer/selectors_actions/actions.h" namespace onnxruntime { @@ -78,7 +82,7 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulReplaceWithMatMulNBits : public Action { - DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level = -1); + explicit DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level = -1); Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; private: diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 192d006965c1b..f4a946c477287 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -237,7 +237,6 @@ void DQMatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"MatMul", {}}}, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index a8a12d5ff1ac7..4a337280f454a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -369,7 +369,7 @@ class MatMulSelector : public BaseSelector { // Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" class DQMatMulSelector : public BaseSelector { public: - DQMatMulSelector(gsl::span compatible_providers = {}) + DQMatMulSelector(gsl::span compatible_providers = {}) explicit : BaseSelector(std::make_unique(), compatible_providers) {} }; diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index c64413d72da5f..fb71c8b6744ef 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -13,13 +13,12 @@ #include "test/compare_ortvalue.h" #include "test/test_environment.h" #include "test/framework/test_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" #include "gtest/gtest.h" -#include "graph_transform_test_builder.h" - -#include "qdq_test_utils.h" #if defined(_MSC_VER) #pragma warning(disable : 4127) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index e43bf656b51cc..4cc8a0c151d14 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -141,16 +141,20 @@ def quant_test( if use_qdq: dq_qtype = TensorProto.INT4 if is_symmetric else TensorProto.UINT4 - dqnode_io_qtypes = { - "DequantizeLinear": [ - ["i", 0, dq_qtype], - ] - } if is_symmetric else { - "DequantizeLinear": [ - ["i", 0, dq_qtype], - ["i", 2, dq_qtype], - ] - } + dqnode_io_qtypes = ( + { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ] + } + if is_symmetric + else { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ["i", 2, dq_qtype], + ] + } + ) check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes) data_reader.rewind() From 425f61b3d5844b9db8afeb43f8bde7d3e8223ce3 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 26 Jun 2024 14:42:59 -0700 Subject: [PATCH 23/36] fixing ci --- .../optimizer/qdq_transformer/selectors_actions/qdq_selectors.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 4a337280f454a..bf5fa1ef5bd81 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -369,7 +369,7 @@ class MatMulSelector : public BaseSelector { // Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" class DQMatMulSelector : public BaseSelector { public: - DQMatMulSelector(gsl::span compatible_providers = {}) explicit + explicit DQMatMulSelector(gsl::span compatible_providers = {}) : BaseSelector(std::make_unique(), compatible_providers) {} }; From b188f559b8b1ef051a88733759a55d81498905cb Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 26 Jun 2024 15:50:52 -0700 Subject: [PATCH 24/36] fixing minimal build --- .../qdq_selector_action_transformer.cc | 1 - .../selectors_actions/qdq_selectors.cc | 15 +++++---------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index f4a946c477287..6b2cb5f67610a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -244,7 +244,6 @@ void DQMatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::move(action)); #else - ORT_UNUSED_PARAMETER(is_int8_allowed); qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); #endif } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 424340aea848f..79d1a9cb517a9 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -485,20 +485,15 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } // weight, scale and zero points (if exists) must be constants - const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; - const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; - const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); + const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); + const auto* zp_tensor_proto = zero_point_arg ? graph.GetConstantInitializer(zero_point_arg->Name(), true) : nullptr; - if (!graph_utils::NodeArgIsConstant(graph, *weight_arg) || - !graph_utils::NodeArgIsConstant(graph, *scale_arg) || - !graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto) || - !graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto)) { + if (!weight_tensor_proto || !scale_tensor_proto) { return false; } - if (zero_point_arg && - (!graph_utils::NodeArgIsConstant(graph, *zero_point_arg) || - !graph.GetInitializedTensor(zero_point_arg->Name(), zp_tensor_proto))) { + if (zero_point_arg && !zp_tensor_proto) { return false; } From 1d34c27cf2e43d30ff7617bfa0eca2a063cfb523 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 26 Jun 2024 16:45:55 -0700 Subject: [PATCH 25/36] fixing ci --- onnxruntime/core/mlas/lib/q4_dq.cpp | 54 +---------------------------- 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 4a3b3c2502bf6..015d69de68766 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -720,48 +720,6 @@ struct BlockwiseQDQQuantizer { ); } - /** - * @brief Quantize a matrix shape [rows, columns] row-wise. Scales and zero points are calculated. - * Quantized data are packed row-wise based on qbits. Quantized data are stored in row - * major, so the output tensor reserves the shape, in terms output type. - * Thread block is [1, quant_block_size * 2]. - * @param src the source matrix, row major: [rows * columns] - * @param scales the scales of quantized blocks, row major layout with shape: - * [rows * ceil(columns / quant_block_size)] - * @param zero_points the zero points of quantized blocks, packed. Same shape as scales - * in terms of output type. In terms of uint8_t, the shape is: - * [ceil(rows * ceil(columns / quant_block_size) * qbits / 8)] - * @param dst the quantized weights, row major: [rows * columns] in terms of - * output type. In terms of uint8_t, the shape is: [ceil(rows * columns * qbits / 8] - * @param rows number of rows in the source matrix - * @param columns number of columns in the source matrix, must satisfy - * ceil(columns / quant_block_size) % 2 == 0, so in each thread block, - * zero points are packed into one byte. - * @param quant_block_size number of elements quantized together. - * @param thread_pool thread pool for parallel processing - */ - static void QuantizeRowWise( - const Tin* src, - Tin* scales, - uint8_t* zero_points, - uint8_t* dst, - int32_t rows, - int32_t columns, - int32_t quant_block_size, - MLAS_THREADPOOL* thread_pool - ) - { - MLAS_UNREFERENCED_PARAMETER(src); - MLAS_UNREFERENCED_PARAMETER(scales); - MLAS_UNREFERENCED_PARAMETER(zero_points); - MLAS_UNREFERENCED_PARAMETER(dst); - MLAS_UNREFERENCED_PARAMETER(rows); - MLAS_UNREFERENCED_PARAMETER(columns); - MLAS_UNREFERENCED_PARAMETER(quant_block_size); - MLAS_UNREFERENCED_PARAMETER(thread_pool); - ORT_THROW("BlockwiseQDQQuantizer::BlockwiseQDQQuantizer is not implemented"); - } - /** * @brief Quantize a matrix shape [rows, columns] column-wise. Scales and zero points are calculated. * Quantized data are packed row-wise based on qbits. Quantized data are stored in row major @@ -1799,17 +1757,7 @@ MlasQDQQuantizeBlockwise( return true; } } else { - if (zero_points) { - BlockwiseQDQQuantizer::QuantizeRowWise( - src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool - ); - return false; - } else { - BlockwiseQDQQuantizer::QuantizeRowWise( - src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool - ); - return true; - } + ORT_THROW("Row-wise MlasQDQQuantizeBlockwise is not implemented"); } } From 75bc4c5a4538cc8eee4e8aa32f5dfce33b05ca6e Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 28 Jun 2024 10:39:12 -0700 Subject: [PATCH 26/36] change dq matmul tool chain interface for genAI --- .../python/tools/quantization/matmul_4bits_quantizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 23e4481f47d5d..7fe4d86153ba4 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -577,7 +577,8 @@ def __init__( is_symmetric: bool = False, accuracy_level: int | None = None, nodes_to_exclude=None, - algo_config: WeightOnlyQuantConfig = None, + quant_format=QuantFormat.QOperator, + algo_config: WeightOnlyQuantConfig | None = None, ): if nodes_to_exclude is None: nodes_to_exclude = [] @@ -590,7 +591,8 @@ def __init__( self.node_quantizer = None if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level, + quant_format=quant_format ) self.algo_config = algo_config if algo_config.algorithm == "HQQ": From f8773d915f13e62ace548301b7fffc6b002c85cb Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 28 Jun 2024 10:59:06 -0700 Subject: [PATCH 27/36] pass accuracy from session.config --- .../onnxruntime_session_options_config_keys.h | 5 +++++ .../core/optimizer/graph_transformer_utils.cc | 8 ++++++-- .../selectors_actions/qdq_actions.cc | 1 + .../selectors_actions/qdq_actions.h | 3 +-- .../qdq_selector_action_transformer.cc | 15 +++++++++------ .../qdq_selector_action_transformer.h | 4 +++- 6 files changed, 25 insertions(+), 11 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index c32e2a77e8453..17ae649e6f174 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -270,3 +270,8 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // - "0": Gemm FastMath mode is not enabled. [DEFAULT] // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; + +// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. +// Refer to MatMulNBits op schema for more details. +// If not provided, default is 4. +static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4298551aec412..1c498648a90ba 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -423,12 +423,16 @@ InlinedVector> GenerateTransformersForMinimalB const bool qdq_is_int8_allowed = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, QDQIsInt8Allowed() ? "1" : "0") == "1"; - + const int64_t qdq_matmulnbits_accuracy_level = + std::stoi(session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; if (!disable_quant_qdq) { - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + apply_context, + qdq_matmulnbits_accuracy_level)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 193455a3e3693..4305f94ca3eb2 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -285,6 +285,7 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_ MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()} { + ORT_ENFORCE(accuracy_level >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } NodeAttributes diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 243f43874cf83..cece2bbcd727c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -82,7 +82,7 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulReplaceWithMatMulNBits : public Action { - explicit DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level = -1); + explicit DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level); Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; private: @@ -91,7 +91,6 @@ struct DQMatMulReplaceWithMatMulNBits : public Action { // transpose initializers, and add to the MatMulNBits inputs void AddTransposedInitializers(Graph&, const NodesToOptimize& selected_nodes, Node& replacement_node) const; - // -1 means not set const int64_t accuracy_level_; const std::string domain_; const std::string op_type_; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 6b2cb5f67610a..9e890df46a24b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -228,13 +228,15 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #endif } -void DQMatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { +void DQMatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, + int64_t qdq_matmulnbits_accuracy_level) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is int4/uint4. DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. const std::string action_name{"DQMatMul"}; - std::unique_ptr action = std::make_unique(); + std::unique_ptr action = + std::make_unique(qdq_matmulnbits_accuracy_level); #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr selector = std::make_unique(); @@ -291,7 +293,8 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { +SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -303,7 +306,7 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed); GemmQDQRules(qdq_selector_action_registry); WhereQDQRules(qdq_selector_action_registry); - DQMatMulQDQRules(qdq_selector_action_registry); + DQMatMulQDQRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level); return qdq_selector_action_registry; } @@ -311,10 +314,10 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { } // namespace QDQSelectorActionTransformer::QDQSelectorActionTransformer( - bool is_int8_allowed, const SatApplyContextVariant& apply_context) + bool is_int8_allowed, const SatApplyContextVariant& apply_context, int64_t qdq_matmulnbits_accuracy_level) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 1780923f3f273..eb44972e4d065 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -21,7 +21,9 @@ Transformer that fuses QDQ and fp32 ops into quantized ops. */ class QDQSelectorActionTransformer : public SelectorActionTransformer { public: - QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}); + QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context = {}, + int64_t qdq_matmulnbits_accuracy_level = 4); }; } // namespace onnxruntime From d5b032ebb9d45bdf20e859c9effbb775796b40e5 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 28 Jun 2024 12:05:58 -0700 Subject: [PATCH 28/36] added UT for accuracy level in session config options --- .../selectors_actions/qdq_actions.cc | 2 +- .../quantization/matmul_4bits_quantizer.py | 6 +- .../qdq_matmulnbits_transformer_test.cc | 201 ++++++++++++------ 3 files changed, 141 insertions(+), 68 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 4305f94ca3eb2..2cbb165acf69a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -285,7 +285,7 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_ MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()} { - ORT_ENFORCE(accuracy_level >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); + ORT_ENFORCE(accuracy_level >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } NodeAttributes diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 7fe4d86153ba4..40a4a4d26dc1c 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -591,8 +591,10 @@ def __init__( self.node_quantizer = None if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level, - quant_format=quant_format + block_size=block_size, + is_symmetric=is_symmetric, + accuracy_level=accuracy_level, + quant_format=quant_format, ) self.algo_config = algo_config if algo_config.algorithm == "HQQ": diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index fb71c8b6744ef..a668eee550c98 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -9,6 +9,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/compare_ortvalue.h" #include "test/test_environment.h" @@ -54,14 +55,12 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, const std::vector& input2_shape, const int64_t axis, const int64_t block_size, - bool use_contrib_qdq) { + int64_t accuracy_level) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* input1_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); auto* input2_arg = builder.MakeInput(input2_shape, T(T::min_val, 0), T(T::max_val, 0)); auto* output_arg = builder.MakeOutput(); - std::string domain = use_contrib_qdq ? kMSDomain : ""; - // add DQ auto* dq_output = builder.MakeIntermediate(); NodeAttributes attrs; @@ -73,9 +72,9 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); - builder.AddNode("DequantizeLinear", {input2_arg, scale_arg, zp_arg}, {dq_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); } else { - builder.AddNode("DequantizeLinear", {input2_arg, scale_arg}, {dq_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg}, {dq_output}, "", &attrs); } builder.AddNode("MatMul", {input1_arg, dq_output}, {output_arg}); @@ -83,27 +82,49 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); EXPECT_EQ(op_to_count["MatMul"], 1); EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); }; + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21 /*opset_version*/, 1e-5 /*per_sample_tolerance*/, - 1e-5 /*relative_per_sample_tolerance*/); + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { // DQ contrib op schema is not updated to support blocked quantization - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); } // Input2 @@ -119,14 +140,12 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, const std::vector& input2_shape, const int64_t axis, const int64_t block_size, - bool use_contrib_qdq) { + int64_t accuracy_level) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); auto* input2_arg = builder.MakeInput(input2_shape, -100.0f, 100.0f); auto* output_arg = builder.MakeOutput(); - std::string domain = use_contrib_qdq ? kMSDomain : ""; - // add DQ auto* dq_output = builder.MakeIntermediate(); NodeAttributes attrs; @@ -138,9 +157,9 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); - builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); } else { - builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); } builder.AddNode("MatMul", {dq_output, input2_arg}, {output_arg}); @@ -148,27 +167,49 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); EXPECT_EQ(op_to_count["MatMul"], 1); EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); }; + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21 /*opset_version*/, 1e-5 /*per_sample_tolerance*/, - 1e-5 /*relative_per_sample_tolerance*/); + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput) { // DQ contrib op schema is not updated to support blocked quantization - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); } // Input1 @@ -183,12 +224,11 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input const std::vector& weight_shape, const int64_t axis, const int64_t block_size, - bool use_contrib_qdq) { + int64_t accuracy_level) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); auto* output_arg = builder.MakeOutput(); NodeArg* weight_arg = nullptr; - std::string domain = use_contrib_qdq ? kMSDomain : ""; // add DQ if constexpr (std::is_same_v || std::is_same_v) { @@ -215,9 +255,9 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input zp_arg = builder.MakeInitializer(scale_shape, 0, 2); } - builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); } else { - builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); } builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); @@ -225,56 +265,66 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); EXPECT_EQ(op_to_count["MatMul"], 1); EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); }; + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21 /*opset_version*/, 1e-5 /*per_sample_tolerance*/, - 1e-5 /*relative_per_sample_tolerance*/); + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { // DQ contrib op schema is not updated to support blocked quantization - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { // DQ contrib op schema is not updated to support blocked quantization // block size too small - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); // block size not 2's power - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); // not axis 0 - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); // not rank 2 - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, false); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, false); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); } // Input1 @@ -293,11 +343,10 @@ RunDQMatMulConverted(const std::vector& input1_shape, const std::vector& weight2_shape, const int64_t axis, const int64_t block_size, - bool use_contrib_qdq) { + int64_t accuracy_level) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); auto* output_arg = builder.MakeOutput(); - std::string domain = use_contrib_qdq ? kMSDomain : ""; // add DQ NodeAttributes attrs; @@ -319,11 +368,11 @@ RunDQMatMulConverted(const std::vector& input1_shape, if constexpr (use_zp) { auto* zp1_arg = builder.MakeInitializer(scale1_shape, T(0, 0), T(2, 0)); auto* zp2_arg = builder.MakeInitializer(scale2_shape, T(0, 0), T(2, 0)); - builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, domain, &attrs); - builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, "", &attrs); } else { - builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, domain, &attrs); - builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, domain, &attrs); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, "", &attrs); } builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); @@ -332,28 +381,50 @@ RunDQMatMulConverted(const std::vector& input1_shape, auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); EXPECT_EQ(op_to_count["MatMul"], 0); EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); }; + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21 /*opset_version*/, 1e-5 /*per_sample_tolerance*/, - 1e-5 /*relative_per_sample_tolerance*/); + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); } TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { // DQ contrib op schema is not updated to support blocked quantization - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, false); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, false); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, -1); } #endif // !defined(DISABLE_CONTRIB_OPS) From 2d01a385f9dde38022d76e1095fc1d078e3ef0ed Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 28 Jun 2024 12:19:41 -0700 Subject: [PATCH 29/36] fixed missing transformer path --- onnxruntime/core/optimizer/graph_transformer_utils.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 1c498648a90ba..4f145b848a24a 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -287,6 +287,9 @@ InlinedVector> GenerateTransformers( onnxruntime::kJsExecutionProvider}; const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const int64_t qdq_matmulnbits_accuracy_level = + std::stoi(session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -300,7 +303,9 @@ InlinedVector> GenerateTransformers( if (!qdq_is_int8_allowed) { transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); } - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + SatApplyContextVariant{}, + qdq_matmulnbits_accuracy_level)); } transformers.emplace_back(std::make_unique(cpu_ep)); From c17a9cf6bff866129aaa0933aa05b457f0bf72d5 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 5 Jul 2024 19:36:57 -0700 Subject: [PATCH 30/36] resolved comments --- .../core/optimizer/graph_transformer_utils.h | 6 +- .../core/optimizer/graph_transformer_utils.cc | 22 ++-- .../selectors_actions/qdq_actions.cc | 103 +++++++----------- .../selectors_actions/qdq_actions.h | 18 ++- .../qdq_selector_action_transformer.cc | 25 +++-- .../qdq_selector_action_transformer.h | 3 +- .../selectors_actions/qdq_selectors.cc | 29 +---- .../selectors_actions/qdq_selectors.h | 4 +- .../optimizer/selectors_actions/actions.cc | 4 +- .../optimizer/selectors_actions/actions.h | 2 + onnxruntime/core/session/inference_session.cc | 14 ++- 11 files changed, 107 insertions(+), 123 deletions(-) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e609745b5e03f..bd12d710f6422 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -49,7 +49,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -78,7 +79,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4f145b848a24a..b536705b627b1 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -187,7 +187,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -288,8 +289,9 @@ InlinedVector> GenerateTransformers( const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; const int64_t qdq_matmulnbits_accuracy_level = - std::stoi(session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, - "4")); + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -305,7 +307,8 @@ InlinedVector> GenerateTransformers( } transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, - qdq_matmulnbits_accuracy_level)); + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -414,7 +417,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -429,15 +433,17 @@ InlinedVector> GenerateTransformersForMinimalB session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, QDQIsInt8Allowed() ? "1" : "0") == "1"; const int64_t qdq_matmulnbits_accuracy_level = - std::stoi(session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, - "4")); + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; if (!disable_quant_qdq) { transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, - qdq_matmulnbits_accuracy_level)); + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 2cbb165acf69a..b2f9c79b455f3 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -275,7 +275,8 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } -DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level) +DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -284,26 +285,23 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_ return std::vector{ MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; - }()} { - ORT_ENFORCE(accuracy_level >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); + }()}, + intra_op_thread_pool_{intra_op_thread_pool} { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); + ORT_ENFORCE(intra_op_thread_pool_, "Intra op thread pool cannot be null"); } NodeAttributes -DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const Graph&, const NodesToOptimize& selected_nodes) const { +DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_state) const { NodeAttributes extra_attributes; - const auto* dq_node = selected_nodes.Input(0); + const auto* dq_node = runtime_state.selected_nodes.Input(0); auto& attrs = dq_node->GetAttributes(); const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); - ORT_ENFORCE(weight_shape->dim(0).has_dim_value() && weight_shape->dim(1).has_dim_value(), - "Input x of DQ node must have rank 2 shape dimensions"); - utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); - if (accuracy_level_ > -1) { - utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); - } + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); // currently only 4bits is supported. In the future, derive bits from DQ's weight type. utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); @@ -311,9 +309,9 @@ DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const Graph&, const NodesToOptim return extra_attributes; } -void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, - const NodesToOptimize& selected_nodes, - Node& replacement_node) const { +Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, + const NodesToOptimize& selected_nodes, + Node& replacement_node) const { const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -341,86 +339,82 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, // to what we need. But it does not handle external data. Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); - std::unique_ptr zp_src_ptr = nullptr; + std::optional zp_src; Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, graph.GenerateNodeArgName(weight_arg->Name() + "_T"), std::vector{N, quant_num, blob_bytes}); Initializer scale_dst(static_cast(scale_src.data_type()), graph.GenerateNodeArgName(scale_arg->Name() + "_T"), std::vector{N * quant_num}); - std::unique_ptr zp_dst_ptr = nullptr; + std::optional zp_dst; if (zp_tensor_proto) { - zp_src_ptr = std::make_unique(*zp_tensor_proto, graph.ModelPath()); - zp_dst_ptr = std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_src.emplace(Initializer(*zp_tensor_proto, graph.ModelPath())); + zp_dst.emplace(Initializer(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)})); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst_ptr = std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_dst.emplace(Initializer(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), + std::vector{N * ((quant_num + 1) / 2)})); } - OrtThreadPoolParams to; - auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, - concurrency::ThreadPoolType::INTRA_OP); - if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + zp_dst ? zp_dst->data() : nullptr, true, static_cast(K), static_cast(N), static_cast(block_size), - tp.get()); + intra_op_thread_pool_); } else { MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + zp_dst ? zp_dst->data() : nullptr, true, static_cast(K), static_cast(N), static_cast(block_size), - tp.get()); + intra_op_thread_pool_); } } else { if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + zp_dst ? zp_dst->data() : nullptr, true, static_cast(K), static_cast(N), static_cast(block_size), - tp.get()); + intra_op_thread_pool_); } else { MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr->data() : nullptr, + zp_dst ? zp_dst->data() : nullptr, true, static_cast(K), static_cast(N), static_cast(block_size), - tp.get()); + intra_op_thread_pool_); } } @@ -428,11 +422,13 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, ONNX_NAMESPACE::TensorProto scale_T_tp; std::unique_ptr zp_T_tp_ptr = nullptr; + // TODO(fajin): external_data to memory location to avoid arena allocation + // https://github.com/microsoft/onnxruntime/pull/12465 weight_dst.ToProto(weight_T_tp); scale_dst.ToProto(scale_T_tp); - if (zp_dst_ptr) { + if (zp_dst) { zp_T_tp_ptr = std::make_unique(); - zp_dst_ptr->ToProto(*zp_T_tp_ptr); + zp_dst->ToProto(*zp_T_tp_ptr); } auto& input_defs = replacement_node.MutableInputDefs(); @@ -445,29 +441,8 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph, input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr)); replacement_node.MutableInputArgsCount().push_back(1); } -} - -Status DQMatMulReplaceWithMatMulNBits::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { - const auto attributes = ExtraAttributes(graph, selected_nodes); - const auto& target = selected_nodes.Target(); - - // create node. we'll populate the input and output defs via moves - auto& replacement = graph.AddNode(target.Name(), - op_type_, - target.Description(), - {}, // input defs - {}, // output defs - &attributes, - domain_); - - const auto& target_provider = target.GetExecutionProviderType(); - replacement.SetExecutionProviderType(target_provider.empty() ? kCpuExecutionProvider : target_provider); - - ORT_RETURN_IF_ERROR(MoveInputOutput(graph, selected_nodes, replacement, value_moves_, false)); - - AddTransposedInitializers(graph, selected_nodes, replacement); - - return node_remover_.Run(graph, selected_nodes); + + return Status::OK(); } static std::vector GetGemmMoveInfo(bool does_q_node_exist) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index cece2bbcd727c..c73be519871cf 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -81,21 +81,27 @@ struct MatMulReplaceWithQLinear : public Action { }; // used together with DQMatMulNodeGroupSelector, which does the sanity check -struct DQMatMulReplaceWithMatMulNBits : public Action { - explicit DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level); - Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; +struct DQMatMulReplaceWithMatMulNBits : public ReplaceWithNew { + DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool); private: - NodeAttributes ExtraAttributes(const Graph&, const NodesToOptimize& selected_nodes) const; + std::string OpType(const RuntimeState&) const override { return op_type_; } + + std::string Domain(const RuntimeState&) const override { return domain_; } + + NodeAttributes ExtraAttributes(const RuntimeState&) const override; + + std::vector ValueMoves(const RuntimeState&) const override { return value_moves_; } // transpose initializers, and add to the MatMulNBits inputs - void AddTransposedInitializers(Graph&, const NodesToOptimize& selected_nodes, Node& replacement_node) const; + Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const override; const int64_t accuracy_level_; const std::string domain_; const std::string op_type_; const std::vector value_moves_; - RemoveNodes node_remover_; + concurrency::ThreadPool* intra_op_thread_pool_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 9e890df46a24b..0b10f092cc565 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -228,18 +228,20 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #endif } -void DQMatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, - int64_t qdq_matmulnbits_accuracy_level) { +void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is int4/uint4. DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. const std::string action_name{"DQMatMul"}; std::unique_ptr action = - std::make_unique(qdq_matmulnbits_accuracy_level); + std::make_unique(qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); + std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"MatMul", {}}}, std::move(selector), @@ -294,7 +296,8 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { } SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, - int64_t qdq_matmulnbits_accuracy_level) { + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -306,18 +309,22 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed); GemmQDQRules(qdq_selector_action_registry); WhereQDQRules(qdq_selector_action_registry); - DQMatMulQDQRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level); + DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer( - bool is_int8_allowed, const SatApplyContextVariant& apply_context, int64_t qdq_matmulnbits_accuracy_level) +QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index eb44972e4d065..f2aa9c8327eb4 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -23,7 +23,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { public: QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, - int64_t qdq_matmulnbits_accuracy_level = 4); + int64_t qdq_matmulnbits_accuracy_level = 4, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 79d1a9cb517a9..96dc10a326692 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -15,12 +15,6 @@ namespace onnxruntime { namespace QDQ { namespace { -#if defined(_MSC_VER) -#define FORCEINLINE __forceinline -#else -#define FORCEINLINE __attribute__((always_inline)) inline -#endif - constexpr bool Is16BitIntType(int32_t data_type) { return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16) || (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16); @@ -31,21 +25,6 @@ constexpr bool Is4BitIntType(int32_t data_type) { (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4); } -FORCEINLINE bool IsPowerOfTwo(int64_t val) { - if (val < 0) return false; - - bool seen_one = val & 1; - val >>= 1; - - for (; val; seen_one = val & 1, val >>= 1) { - if (seen_one) { - return false; - } - } - - return true; -} - // adjust for an optional input/output that has an entry but does not exist int NumActualValues(const Node& node, bool input) { const auto& defs = input ? node.InputDefs() : node.OutputDefs(); @@ -439,16 +418,16 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - ONNX_UNUSED_PARAMETER(q_nodes); + ORT_UNUSED_PARAMETER(q_nodes); const auto& graph = graph_viewer.GetGraph(); - // MatMul has only 1 DQ input and the DQ must has 1 output edge which is not graph output +// MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { return false; } // DQ must be MatMul's the second input - if (node.InputDefs()[1]->Name() != dq_nodes[0]->OutputDefs()[0]->Name()) { + if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { return false; } @@ -480,7 +459,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } auto block_size = a_iter->second.i(); - if (block_size < 16 || !IsPowerOfTwo(block_size)) { + if (block_size < 16 || ((block_size - 1) & block_size)) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index bf5fa1ef5bd81..491a15b62cb03 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -367,9 +367,9 @@ class MatMulSelector : public BaseSelector { }; // Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" -class DQMatMulSelector : public BaseSelector { +class DQMatMulToMatMulNBitsSelector : public BaseSelector { public: - explicit DQMatMulSelector(gsl::span compatible_providers = {}) + explicit DQMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) : BaseSelector(std::make_unique(), compatible_providers) {} }; diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.cc b/onnxruntime/core/optimizer/selectors_actions/actions.cc index c8d5acbf66b78..bb4033afedc49 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.cc +++ b/onnxruntime/core/optimizer/selectors_actions/actions.cc @@ -102,12 +102,14 @@ static Status CreateReplacementNode(Graph& graph, Status ReplaceWithNew::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { const RuntimeState runtime_state{graph, selected_nodes}; + Node* replacement{}; ORT_RETURN_IF_ERROR(CreateReplacementNode(graph, selected_nodes, OpType(runtime_state), Domain(runtime_state), ExtraAttributes(runtime_state), ValueMoves(runtime_state), - /* only_update_dest_definitions */ false, nullptr)); + /* only_update_dest_definitions */ false, &replacement)); + ORT_RETURN_IF_ERROR(ProcessNewNode(graph, selected_nodes, *replacement)); return node_remover_.Run(graph, selected_nodes); } diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index 9d800ffd80636..4d5b520cc47cb 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -158,6 +158,8 @@ struct ReplaceWithNew : public Action { // specifies how the inputs and outputs for the replaced nodes are moved to the new node virtual std::vector ValueMoves(const RuntimeState&) const = 0; + virtual Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const { return Status::OK(); } + RemoveNodes node_remover_; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 3ef6490a56ded..88146b517fb83 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1613,7 +1613,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, - const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep) { + const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, + concurrency::ThreadPool* intra_op_thread_pool) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1621,7 +1622,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable); + optimizers_to_disable, intra_op_thread_pool); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2009,7 +2010,8 @@ common::Status InferenceSession::Initialize() { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( - ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep)); + ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep, + GetIntraOpThreadPoolToUse())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3171,7 +3173,8 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, - optimizers_to_disable_); + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3180,7 +3183,8 @@ common::Status InferenceSession::AddPredefinedTransformers( record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, - optimizers_to_disable_); + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); } }(); From 8c2a1212f71ee16459d826585b4b45f9cebd90ef Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 8 Jul 2024 11:30:54 -0700 Subject: [PATCH 31/36] fix build --- .../selectors_actions/qdq_actions.cc | 44 +++++++++---------- .../selectors_actions/qdq_actions.h | 1 + .../selectors_actions/qdq_selectors.cc | 2 +- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index b2f9c79b455f3..c1216d4c73ae9 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -339,24 +339,24 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, // to what we need. But it does not handle external data. Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); - std::optional zp_src; + std::optional> zp_src_ptr; Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, graph.GenerateNodeArgName(weight_arg->Name() + "_T"), std::vector{N, quant_num, blob_bytes}); Initializer scale_dst(static_cast(scale_src.data_type()), graph.GenerateNodeArgName(scale_arg->Name() + "_T"), std::vector{N * quant_num}); - std::optional zp_dst; + std::optional> zp_dst_ptr; if (zp_tensor_proto) { - zp_src.emplace(Initializer(*zp_tensor_proto, graph.ModelPath())); - zp_dst.emplace(Initializer(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), - std::vector{N * ((quant_num + 1) / 2)})); + zp_src_ptr.emplace(std::make_unique(*zp_tensor_proto, graph.ModelPath())); + zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)})); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst.emplace(Initializer(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), - std::vector{N * ((quant_num + 1) / 2)})); + zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), + std::vector{N * ((quant_num + 1) / 2)})); } if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -364,10 +364,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, true, static_cast(K), static_cast(N), @@ -377,10 +377,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, true, static_cast(K), static_cast(N), @@ -392,10 +392,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, true, static_cast(K), static_cast(N), @@ -406,10 +406,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, true, static_cast(K), static_cast(N), @@ -420,15 +420,15 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, ONNX_NAMESPACE::TensorProto weight_T_tp; ONNX_NAMESPACE::TensorProto scale_T_tp; - std::unique_ptr zp_T_tp_ptr = nullptr; + std::optional> zp_T_tp_ptr; // TODO(fajin): external_data to memory location to avoid arena allocation // https://github.com/microsoft/onnxruntime/pull/12465 weight_dst.ToProto(weight_T_tp); scale_dst.ToProto(scale_T_tp); - if (zp_dst) { + if (zp_dst_ptr) { zp_T_tp_ptr = std::make_unique(); - zp_dst->ToProto(*zp_T_tp_ptr); + zp_dst_ptr.value()->ToProto(*zp_T_tp_ptr.value()); } auto& input_defs = replacement_node.MutableInputDefs(); @@ -438,10 +438,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, replacement_node.MutableInputArgsCount().push_back(1); if (zp_T_tp_ptr) { - input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr)); + input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr.value())); replacement_node.MutableInputArgsCount().push_back(1); } - + return Status::OK(); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index c73be519871cf..d80c3f9d183bf 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -8,6 +8,7 @@ #include #include "core/optimizer/selectors_actions/actions.h" +#include "core/platform/threadpool.h" namespace onnxruntime { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 96dc10a326692..692db4eb327b5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -421,7 +421,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, ORT_UNUSED_PARAMETER(q_nodes); const auto& graph = graph_viewer.GetGraph(); -// MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output + // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { return false; } From baa9389ff222cb1c36ae00e3063bd3652588812b Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 8 Jul 2024 15:25:48 -0700 Subject: [PATCH 32/36] fix ut --- .../optimizer/qdq_transformer/selectors_actions/qdq_actions.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index c1216d4c73ae9..657218b44dc91 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -288,7 +288,6 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_ }()}, intra_op_thread_pool_{intra_op_thread_pool} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); - ORT_ENFORCE(intra_op_thread_pool_, "Intra op thread pool cannot be null"); } NodeAttributes @@ -312,6 +311,8 @@ DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_stat Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { + ORT_ENFORCE(intra_op_thread_pool_, "Intra op thread pool cannot be null"); + const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; From f33dbf9af74ac7649737d6e2760cd0ed131775c6 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 8 Jul 2024 16:31:51 -0700 Subject: [PATCH 33/36] fixing arm ut --- .../selectors_actions/qdq_actions.cc | 16 ++++++++++------ .../selectors_actions/qdq_actions.h | 1 + 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 657218b44dc91..fe310d2ea453f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -288,6 +288,12 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_ }()}, intra_op_thread_pool_{intra_op_thread_pool} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); + + if (!intra_op_thread_pool) { + OrtThreadPoolParams to; + intra_op_thread_pool_optional_ = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + } } NodeAttributes @@ -311,8 +317,6 @@ DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_stat Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { - ORT_ENFORCE(intra_op_thread_pool_, "Intra op thread pool cannot be null"); - const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -373,7 +377,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, static_cast(K), static_cast(N), static_cast(block_size), - intra_op_thread_pool_); + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); } else { MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), @@ -386,7 +390,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, static_cast(K), static_cast(N), static_cast(block_size), - intra_op_thread_pool_); + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); } } else { if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { @@ -401,7 +405,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, static_cast(K), static_cast(N), static_cast(block_size), - intra_op_thread_pool_); + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); } else { MlasQDQTransposeBlockwiseQuantized( @@ -415,7 +419,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, static_cast(K), static_cast(N), static_cast(block_size), - intra_op_thread_pool_); + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); } } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index d80c3f9d183bf..52ae745186b53 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -103,6 +103,7 @@ struct DQMatMulReplaceWithMatMulNBits : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + std::optional> intra_op_thread_pool_optional_; }; struct GemmReplaceWithQuant : public Action { From 18e0000162c081c0196e8d18144a6a2a2b3d871a Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 8 Jul 2024 16:53:39 -0700 Subject: [PATCH 34/36] fix linting --- onnxruntime/core/session/inference_session.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 88146b517fb83..72cafa1034a4c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2010,8 +2010,8 @@ common::Status InferenceSession::Initialize() { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( - ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep, - GetIntraOpThreadPoolToUse())); + ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, + cpu_ep, GetIntraOpThreadPoolToUse())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } From 1ecf5c5c89bcab92841d1d501491a1e7e8c80d30 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 8 Jul 2024 20:30:10 -0700 Subject: [PATCH 35/36] corrected UT semantics --- .../test/optimizer/qdq_matmulnbits_transformer_test.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index a668eee550c98..3d117794104fa 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -417,14 +417,6 @@ TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, -1); } #endif // !defined(DISABLE_CONTRIB_OPS) From f95c3d6ce10744873948c6714d34bbd8c1932fb8 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 12 Jul 2024 14:24:44 -0700 Subject: [PATCH 36/36] try to fix web ci failure --- include/onnxruntime/core/optimizer/graph_transformer_utils.h | 1 + onnxruntime/core/optimizer/graph_transformer_utils.cc | 1 + .../selectors_actions/qdq_selector_action_transformer.h | 1 + onnxruntime/test/common/random_generator.h | 1 + 4 files changed, 4 insertions(+) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index bd12d710f6422..0bb5c7432f0a7 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -10,6 +10,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" #include "core/optimizer/graph_transformer.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/optimizer/rule_based_graph_transformer.h" diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index b536705b627b1..6e5be28f12745 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -13,6 +13,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index f2aa9c8327eb4..ba636f76d1900 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -5,6 +5,7 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" #include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" namespace onnxruntime { diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index fcce91a45227f..9bc50ce88ef16 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -12,6 +12,7 @@ #include "core/common/common.h" #include "core/common/optional.h" #include "core/common/type_utils.h" +#include "core/framework/int4.h" #include "test/util/include/test_random_seed.h" namespace onnxruntime {