From 288340b80824153c0539c526286ec6efbd7b92a8 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 11 Oct 2022 13:42:47 +0800 Subject: [PATCH] Refactor ox_utils (#1322) --- neural_compressor/adaptor/onnxrt.py | 4 +- .../{onnxrt_mid.py => calibration.py} | 0 .../adaptor/ox_utils/operators/__init__.py | 11 + .../adaptor/ox_utils/operators/activation.py | 78 +++--- .../adaptor/ox_utils/operators/argmax.py | 18 +- .../adaptor/ox_utils/operators/attention.py | 52 ++-- .../ox_utils/operators/base_operator_cast.py | 26 -- .../adaptor/ox_utils/operators/binary_op.py | 68 ++--- .../adaptor/ox_utils/operators/concat.py | 47 ++-- .../adaptor/ox_utils/operators/conv.py | 240 +++++++++--------- .../adaptor/ox_utils/operators/direct_q8.py | 62 +++-- .../ox_utils/operators/embed_layernorm.py | 78 ++---- .../adaptor/ox_utils/operators/gather.py | 68 ++--- .../adaptor/ox_utils/operators/gavgpool.py | 29 ++- .../adaptor/ox_utils/operators/gemm.py | 95 ++++--- .../adaptor/ox_utils/operators/lstm.py | 65 +++-- .../adaptor/ox_utils/operators/matmul.py | 199 +++++++-------- .../adaptor/ox_utils/operators/maxpool.py | 59 ++--- .../operators/{base_operator.py => ops.py} | 142 ++++++----- .../adaptor/ox_utils/operators/pad.py | 52 ++-- .../adaptor/ox_utils/operators/pooling.py | 42 +-- .../ox_utils/operators/qdq_base_operator.py | 33 --- .../adaptor/ox_utils/operators/resize.py | 58 +++-- .../adaptor/ox_utils/operators/split.py | 45 ++-- .../adaptor/ox_utils/quantizer.py | 17 +- .../adaptor/ox_utils/registry.py | 136 ---------- .../onnxrt_adaptor/test_onnxrt_augment.py | 2 +- 27 files changed, 763 insertions(+), 963 deletions(-) rename neural_compressor/adaptor/ox_utils/{onnxrt_mid.py => calibration.py} (100%) delete mode 100644 neural_compressor/adaptor/ox_utils/operators/base_operator_cast.py rename neural_compressor/adaptor/ox_utils/operators/{base_operator.py => ops.py} (61%) delete mode 100644 neural_compressor/adaptor/ox_utils/operators/qdq_base_operator.py delete mode 100644 neural_compressor/adaptor/ox_utils/registry.py diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 6c1d5debbe4..9ca132de7e7 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -324,7 +324,7 @@ def _dump_model_op_stats(self, model): field_names=["Op Type", "Total", "INT8", "BF16", "FP16", "FP32"]).print_stat() def _get_quantize_params(self, model, data_loader, quantize_config, iterations): - from neural_compressor.adaptor.ox_utils.onnxrt_mid import ONNXRTAugment + from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment from neural_compressor.model.onnx_model import ONNXModel if not isinstance(model, ONNXModel): model = ONNXModel(model) @@ -346,7 +346,7 @@ def inspect_tensor(self, model, dataloader, op_list=[], quantization_cfg=None): '''The function is used by tune strategy class for dumping tensor info. ''' - from neural_compressor.adaptor.ox_utils.onnxrt_mid import ONNXRTAugment + from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment from neural_compressor.model.onnx_model import ONNXModel from neural_compressor.utils.utility import dump_data_to_local if not isinstance(model, ONNXModel): diff --git a/neural_compressor/adaptor/ox_utils/onnxrt_mid.py b/neural_compressor/adaptor/ox_utils/calibration.py similarity index 100% rename from neural_compressor/adaptor/ox_utils/onnxrt_mid.py rename to neural_compressor/adaptor/ox_utils/calibration.py diff --git a/neural_compressor/adaptor/ox_utils/operators/__init__.py b/neural_compressor/adaptor/ox_utils/operators/__init__.py index 8a37d55342a..da48d428ac4 100644 --- a/neural_compressor/adaptor/ox_utils/operators/__init__.py +++ b/neural_compressor/adaptor/ox_utils/operators/__init__.py @@ -16,3 +16,14 @@ # limitations under the License. # +from os.path import dirname, basename, isfile, join +import glob +from .ops import OPERATORS + +modules = glob.glob(join(dirname(__file__), "*.py")) + +for f in modules: + if isfile(f) and not f.startswith('__') and not f.endswith('__init__.py'): + __import__(basename(f)[:-3], globals(), locals(), level=1) + +__all__ = ["OPERATORS"] \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/activation.py b/neural_compressor/adaptor/ox_utils/operators/activation.py index 44720ae9cad..5339f6834ad 100644 --- a/neural_compressor/adaptor/ox_utils/operators/activation.py +++ b/neural_compressor/adaptor/ox_utils/operators/activation.py @@ -17,28 +17,39 @@ # import onnx -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase -from neural_compressor.adaptor.ox_utils.util import QuantizedValueType, \ - attribute_to_kwarg, ms_domain -from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValue +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain - -class QLinearActivation(QuantOperatorBase): +@op_registry(op_types="LeakyRelu, Sigmoid") +class ActivationOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(ActivationOperator, self).__init__(onnx_quantizer, onnx_node) + + def quantize_check(self): + node = self.node + data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0]) + if not data_found: + return False + return True + + def quantize(self): + node = self.node + super().quantize() + node.name = node.name + "_quant" - def convert(self): + def convert_check(self, convert_format): node = self.node - if node.op_type in ['Relu', 'Clip']: - return - - if len(self.quantizer.model.get_children(node)) == 0 or \ - not node.name.endswith('_quant'): - return - # No assert on op_type as it is controlled by registry - # only try to quantize when given quantization parameters for it + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + + children = self.quantizer.model.get_children(node) + if len(children) == 0 or not node.name.endswith('_quant'): + return False + return True + + def convert(self, convert_format): + node = self.node + parent = self.quantizer.model.get_parents(node)[0] child = self.quantizer.model.get_children(node)[0] @@ -48,7 +59,7 @@ def convert(self): qlinear_activation_output = child.output[0] kwargs = {} - for attribute in node.attribute: + for attribute in node.attribute: # pragma: no cover kwargs.update(attribute_to_kwarg(attribute)) kwargs["domain"] = ms_domain @@ -59,28 +70,21 @@ def convert(self): self.quantizer.new_nodes.append(qlinear_activation_node) self.quantizer.remove_nodes.extend([parent, child, node]) -class QDQRemovableActivation(QDQOperatorBase): +@op_registry(op_types="Relu, Clip") +class RemovableActivationOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(RemovableActivationOperator, self).__init__(onnx_quantizer, onnx_node) - def quantize(self): + def quantize_check(self): node = self.node if node.input[0] not in self.quantizer.quantized_value_map: - return - elif node.output[0] in [i.name for i in self.quantizer.model.model.graph.output]: + return False + return True + + def quantize(self): + node = self.node + if node.output[0] in [i.name for i in self.quantizer.model.model.graph.output]: self.quantizer.dequantize_tensor(node, node.input[0]) else: self.quantizer.model.replace_input_of_all_nodes(node.output[0], node.input[0]) - self.quantizer.remove_nodes.append(node) - -class QDQActivation(QDQOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0]) - if not data_found: - return - super().quantize() - node.name = node.name + "_quant" + self.quantizer.remove_nodes.append(node) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/argmax.py b/neural_compressor/adaptor/ox_utils/operators/argmax.py index 1bf1b1bcbed..9344498698e 100644 --- a/neural_compressor/adaptor/ox_utils/operators/argmax.py +++ b/neural_compressor/adaptor/ox_utils/operators/argmax.py @@ -16,16 +16,24 @@ # limitations under the License. # -from .base_operator import QuantOperatorBase -class QArgMax(QuantOperatorBase): +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator + +@op_registry(op_types="ArgMax") +class ArgMaxOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(ArgMaxOperator, self).__init__(onnx_quantizer, onnx_node) + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + return True - def convert(self): + def convert(self, convert_format): node = self.node origin_name = node.input[0].split('_argmax_node')[0] if origin_name in self.quantizer.quantized_value_map: node.input[0] = self.quantizer.quantized_value_map[origin_name].q_name - node.name = node.name + '_quant' + node.name = node.name + '_quant' \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/attention.py b/neural_compressor/adaptor/ox_utils/operators/attention.py index 808e923b5e1..687ee9a01c4 100644 --- a/neural_compressor/adaptor/ox_utils/operators/attention.py +++ b/neural_compressor/adaptor/ox_utils/operators/attention.py @@ -17,30 +17,30 @@ # import onnx -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain -from onnx import onnx_pb as onnx_proto -''' - Quantize Attention -''' - -class AttentionQuant(QuantOperatorBase): +@op_registry(op_types="Attention") +class AttentionOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(AttentionOperator, self).__init__(onnx_quantizer, onnx_node) - def convert(self): - ''' - parameter node: Attention node. - parameter new_nodes_list: List of new nodes created before processing this node. - return: a list of nodes in topological order that represents quantized Attention node - ''' + def quantize(self): node = self.node - assert (node.op_type == "Attention") + self.quantizer.quantize_inputs(node) + node.name = node.name + "_quant" + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['dynamic', 'static'], \ + "convert format for {} should be in ['dynamic', 'static']".format(node.op_type) + if not node.name.endswith('_quant'): - return + return False + return True + def convert(self, convert_format): + node = self.node parents = self.quantizer.model.get_parents(node) quantized_name = [] scale = [] @@ -65,25 +65,11 @@ def convert(self): inputs.extend([node.input[4] if len(node.input) > 4 else ""]) kwargs = {} - for attribute in node.attribute: + for attribute in node.attribute: # pragma: no cover kwargs.update(attribute_to_kwarg(attribute)) kwargs["domain"] = ms_domain qattention_node = onnx.helper.make_node("QAttention", inputs, node.output, node.name, **kwargs) self.quantizer.new_nodes.append(qattention_node) - self.quantizer.remove_nodes.append(node) - -class QDQAttention(QDQOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - assert (node.op_type == "Attention") - - if self.quantizer.static: - super().quantize() - else: - self.quantizer.quantize_inputs(node, [0, 1]) - node.name = node.name + "_quant" + self.quantizer.remove_nodes.append(node) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/base_operator_cast.py b/neural_compressor/adaptor/ox_utils/operators/base_operator_cast.py deleted file mode 100644 index c602703db35..00000000000 --- a/neural_compressor/adaptor/ox_utils/operators/base_operator_cast.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -class CastOperatorBase: - def __init__(self, onnx_quantizer, onnx_node): - self.quantizer = onnx_quantizer - self.node = onnx_node - self.dtype = self.quantizer.config[self.node.name] - - def cast(self): - self.quantizer.dtype_cast(self.node, self.dtype) diff --git a/neural_compressor/adaptor/ox_utils/operators/binary_op.py b/neural_compressor/adaptor/ox_utils/operators/binary_op.py index 80d9bc3b964..3848cd6ee9b 100644 --- a/neural_compressor/adaptor/ox_utils/operators/binary_op.py +++ b/neural_compressor/adaptor/ox_utils/operators/binary_op.py @@ -17,28 +17,49 @@ # import onnx -from .base_operator import QuantOperatorBase -from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain, \ - QuantizedValueType -from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValue +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain -class QLinearBinaryOp(QuantOperatorBase): +@op_registry(op_types="Add, Mul") +class BinaryOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(BinaryOperator, self).__init__(onnx_quantizer, onnx_node) - def convert(self): + def quantize_check(self): + node = self.node + data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0]) + if not data_found: + return False + if not all([self.quantizer.is_valid_quantize_weight(i) for i in node.input]): + return False + return True + + def quantize(self): + node = self.node + self.quantizer.quantize_inputs(node, initializer_use_weight_qType=False) + if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': + self.quantizer.quantize_outputs(node) + node.name = node.name + "_quant" + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + + children = self.quantizer.model.get_children(node) + if len(children) == 0 or not node.name.endswith('_quant'): + return False + return True + + def convert(self, convert_format): node = self.node - if len(self.quantizer.model.get_children(node)) == 0 or \ - not node.name.endswith('_quant'): - return parents = self.quantizer.model.get_parents(node) child = self.quantizer.model.get_children(node)[0] qlinear_binary_math_output = child.output[0] kwargs = {} - for attribute in node.attribute: + for attribute in node.attribute: # pragma: no cover kwargs.update(attribute_to_kwarg(attribute)) kwargs["domain"] = ms_domain @@ -56,25 +77,4 @@ def convert(self): self.quantizer.new_nodes += [qlinear_binary_math_node] self.quantizer.remove_nodes.extend(parents) self.quantizer.remove_nodes.append(child) - self.quantizer.remove_nodes.append(node) - -class QDQBinaryOp(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0]) - if not data_found: - return - - if not all([self.quantizer.is_valid_quantize_weight(i) for i in node.input]): - return - - if not all([self.quantizer.is_valid_quantize_weight(i) for i in node.input]): - return - - self.quantizer.quantize_inputs(node, initializer_use_weight_qType=False) - if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': - self.quantizer.quantize_outputs(node) - node.name = node.name + "_quant" + self.quantizer.remove_nodes.append(node) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/concat.py b/neural_compressor/adaptor/ox_utils/operators/concat.py index 01a80c30e60..b633bf22040 100644 --- a/neural_compressor/adaptor/ox_utils/operators/concat.py +++ b/neural_compressor/adaptor/ox_utils/operators/concat.py @@ -17,25 +17,27 @@ # import onnx -from .base_operator import QuantOperatorBase -from neural_compressor.adaptor.ox_utils.util import QuantizedValueType, \ - attribute_to_kwarg, ms_domain -from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValue -from .qdq_base_operator import QDQOperatorBase +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain -class QDQConcat(QDQOperatorBase): +@op_registry(op_types="Concat") +class ConcatOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(ConcatOperator, self).__init__(onnx_quantizer, onnx_node) - def quantize(self): + def quantize_check(self): node = self.node inits = [i.name for i in self.quantizer.model.initializer()] if all([inp not in self.quantizer.quantized_value_map and inp not in inits \ for inp in node.input]) or \ not all([inp in self.quantizer.quantized_value_map or inp in inits \ for inp in node.input]): - return + return False + return True + + def quantize(self): + node = self.node + inits = [i.name for i in self.quantizer.model.initializer()] for idx, inp in enumerate(node.input): initializer_use_weight_qType = inp not in inits self.quantizer.quantize_inputs(node, [idx], initializer_use_weight_qType) @@ -43,18 +45,23 @@ def quantize(self): self.quantizer.quantize_outputs(node) node.name = node.name + "_quant" -class QLinearConcat(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): + def convert_check(self, convert_format): node = self.node - + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + parents = self.quantizer.model.get_parents(node) children = self.quantizer.model.get_children(node) if len(children) == 0 or len(parents) == 0 or not node.name.endswith('_quant'): - return + return False + return True + def convert(self, convert_format): + node = self.node + + parents = self.quantizer.model.get_parents(node) + children = self.quantizer.model.get_children(node) + if all([i.op_type == 'DequantizeLinear' for i in parents]) and \ any([i.op_type == 'QuantizeLinear' for i in children]): inputs = [] @@ -81,3 +88,9 @@ def convert(self): self.quantizer.new_nodes += [qlconcat_node] self.quantizer.remove_nodes.append(node) + + def cast(self): # pragma: no cover + node = self.node + if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]: + return + self.quantizer.dtype_cast(self.node, self.dtype) diff --git a/neural_compressor/adaptor/ox_utils/operators/conv.py b/neural_compressor/adaptor/ox_utils/operators/conv.py index 9c825b79905..90b849bd9e6 100644 --- a/neural_compressor/adaptor/ox_utils/operators/conv.py +++ b/neural_compressor/adaptor/ox_utils/operators/conv.py @@ -16,137 +16,19 @@ # limitations under the License. # + import onnx -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase -from neural_compressor.adaptor.ox_utils.util import find_by_name, \ - QuantizedValueType, attribute_to_kwarg from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValue - -class ConvInteger(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): - node = self.node - assert node.op_type in ["Conv", "FusedConv"] - - inputs = [] - parents = self.quantizer.model.get_parents(node) - if parents[0].op_type == 'QuantizeLinear': - inputs.append(parents[0].output[0]) - inputs.append(parents[1].input[0]) - inputs.append(parents[0].input[2]) - inputs.append(parents[1].input[2]) - scale_0 = parents[0].input[1] - else: - inputs.append(parents[0].output[0]) - inputs.append(parents[1].input[0]) - inputs.append(parents[0].output[2]) - inputs.append(parents[1].input[2]) - scale_0 = parents[0].output[1] - scale_1 = parents[1].input[1] - # quantize bias if exist - quantized_bias_name = "" - bias_present = False - if len(node.input) == 3: - quantized_bias_name = node.input[2] + "_quantized" - bias_present = True - - conv_integer_output = node.output[0] + "_output_quantized" - - kwargs = {} - for attribute in node.attribute: - if attribute.name == 'activation' and attribute.s in [b'Relu', b'Clip']: - continue - if attribute.name == 'activation_params': - continue - kwargs.update(attribute_to_kwarg(attribute)) - conv_integer_node = onnx.helper.make_node("ConvInteger", - inputs, - [conv_integer_output], - node.name, **kwargs) - self.quantizer.new_nodes.append(conv_integer_node) - - # Add bias add nodes - if bias_present: - conv_integer_output = self.quantizer.get_bias_add_nodes(node, - parents[1].input[0], - conv_integer_output, - quantized_bias_name) - - # Add cast operation to cast convInteger output to float. - cast_op_output = conv_integer_output + "_cast_output" - cast_node = onnx.helper.make_node("Cast", [conv_integer_output], [cast_op_output], - conv_integer_output + "_cast", - to=onnx_proto.TensorProto.FLOAT) - self.quantizer.new_nodes.append(cast_node) - - # Add mul operation to multiply scales of two inputs. - scales_mul_op = node.name + "_scales_mul" - - scales_mul_node = find_by_name(scales_mul_op, self.quantizer.new_nodes) - if scales_mul_node is None: - scales_mul_node = onnx.helper.make_node("Mul", [scale_0, scale_1], - [scales_mul_op + ":0"], scales_mul_op) - self.quantizer.new_nodes.append(scales_mul_node) - - scales_mul_op_output = scales_mul_node.output[0] - - # Add mul operation to multiply mul_scales_op result with output of ConvInteger - # and make the output of this node the same as output of original conv node. - output_scale_mul_op = node.name + "_output_scale_mul" - self.quantizer.new_nodes.append(onnx.helper.make_node("Mul", - [cast_op_output, scales_mul_op_output], [node.output[0]], output_scale_mul_op)) - self.quantizer.remove_nodes.extend(parents[1:]) - self.quantizer.remove_nodes.append(node) - - -class QLinearConv(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import find_by_name, attribute_to_kwarg - def convert(self): - node = self.node - assert (node.op_type in ["Conv", "FusedConv"]) - - if len(self.quantizer.model.get_children(node)) == 0 or \ - not node.name.endswith('_quant'): - return - parents = self.quantizer.model.get_parents(node) - child = self.quantizer.model.get_children(node)[0] - qlinear_conv_inputs = [] - for parent in parents[0:2]: - qlinear_conv_inputs.extend(parent.input) - qlinear_conv_inputs.extend(child.input[1:]) - if len(parents) == 3: - qlinear_conv_inputs.append(parents[-1].input[0]) - - qlinear_conv_output = child.output[0] - - kwargs = {} - for attribute in node.attribute: - if attribute.name == 'activation' and attribute.s in [b'Relu', b'Clip']: - continue - if attribute.name == 'activation_params': - continue - kwargs.update(attribute_to_kwarg(attribute)) - qlinear_conv_node = onnx.helper.make_node("QLinearConv", qlinear_conv_inputs, - [qlinear_conv_output], - node.name, **kwargs) - self.quantizer.new_nodes.append(qlinear_conv_node) - self.quantizer.remove_nodes.extend(parents) - self.quantizer.remove_nodes.append(child) - self.quantizer.remove_nodes.append(node) - -class QDQConv(QDQOperatorBase): +@op_registry(op_types="Conv, FusedConv") +class ConvOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(ConvOperator, self).__init__(onnx_quantizer, onnx_node) def quantize(self): node = self.node - assert (node.op_type in ["Conv", "FusedConv"]) if node.op_type == "FusedConv": kwargs = {} for attribute in node.attribute: @@ -173,3 +55,113 @@ def quantize(self): self.quantizer.quantize_bias_tensor(node) node.name = node.name + "_quant" + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['dynamic', 'static'], \ + 'convert format for {} should be in [dynamic, static]'.format(node.op_type) + return True + + def convert(self, convert_format): + node = self.node + if convert_format == 'dynamic': + inputs = [] + parents = self.quantizer.model.get_parents(node) + if parents[0].op_type == 'QuantizeLinear': + inputs.append(parents[0].output[0]) + inputs.append(parents[1].input[0]) + inputs.append(parents[0].input[2]) + inputs.append(parents[1].input[2]) + scale_0 = parents[0].input[1] + else: + inputs.append(parents[0].output[0]) + inputs.append(parents[1].input[0]) + inputs.append(parents[0].output[2]) + inputs.append(parents[1].input[2]) + scale_0 = parents[0].output[1] + scale_1 = parents[1].input[1] + # quantize bias if exist + quantized_bias_name = "" + bias_present = False + if len(node.input) == 3: + quantized_bias_name = node.input[2] + "_quantized" + bias_present = True + + conv_integer_output = node.output[0] + "_output_quantized" + + kwargs = {} + for attribute in node.attribute: + if attribute.name == 'activation' and attribute.s in [b'Relu', b'Clip']: # pragma: no cover + continue + if attribute.name == 'activation_params': # pragma: no cover + continue + kwargs.update(attribute_to_kwarg(attribute)) + conv_integer_node = onnx.helper.make_node("ConvInteger", + inputs, + [conv_integer_output], + node.name, **kwargs) + self.quantizer.new_nodes.append(conv_integer_node) + + # Add bias add nodes + if bias_present: + conv_integer_output = self.quantizer.get_bias_add_nodes(node, + parents[1].input[0], + conv_integer_output, + quantized_bias_name) + + # Add cast operation to cast convInteger output to float. + cast_op_output = conv_integer_output + "_cast_output" + cast_node = onnx.helper.make_node("Cast", [conv_integer_output], [cast_op_output], + conv_integer_output + "_cast", + to=onnx_proto.TensorProto.FLOAT) + self.quantizer.new_nodes.append(cast_node) + + # Add mul operation to multiply scales of two inputs. + scales_mul_op = node.name + "_scales_mul" + + scales_mul_node = find_by_name(scales_mul_op, self.quantizer.new_nodes) + if scales_mul_node is None: + scales_mul_node = onnx.helper.make_node("Mul", [scale_0, scale_1], + [scales_mul_op + ":0"], scales_mul_op) + self.quantizer.new_nodes.append(scales_mul_node) + + scales_mul_op_output = scales_mul_node.output[0] + + # Add mul operation to multiply mul_scales_op result with output of ConvInteger + # and make the output of this node the same as output of original conv node. + output_scale_mul_op = node.name + "_output_scale_mul" + self.quantizer.new_nodes.append(onnx.helper.make_node("Mul", + [cast_op_output, scales_mul_op_output], [node.output[0]], output_scale_mul_op)) + self.quantizer.remove_nodes.extend(parents[1:]) + self.quantizer.remove_nodes.append(node) + elif convert_format == 'static': + if len(self.quantizer.model.get_children(node)) == 0 or \ + not node.name.endswith('_quant'): # pragma: no cover + return + parents = self.quantizer.model.get_parents(node) + child = self.quantizer.model.get_children(node)[0] + qlinear_conv_inputs = [] + for parent in parents[0:2]: + qlinear_conv_inputs.extend(parent.input) + qlinear_conv_inputs.extend(child.input[1:]) + if len(parents) == 3: + qlinear_conv_inputs.append(parents[-1].input[0]) + + qlinear_conv_output = child.output[0] + + kwargs = {} + for attribute in node.attribute: + if attribute.name == 'activation' and attribute.s in [b'Relu', b'Clip']: # pragma: no cover + continue + if attribute.name == 'activation_params': # pragma: no cover + continue + kwargs.update(attribute_to_kwarg(attribute)) + qlinear_conv_node = onnx.helper.make_node("QLinearConv", qlinear_conv_inputs, + [qlinear_conv_output], + node.name, **kwargs) + self.quantizer.new_nodes.append(qlinear_conv_node) + self.quantizer.remove_nodes.extend(parents) + self.quantizer.remove_nodes.append(child) + self.quantizer.remove_nodes.append(node) + + diff --git a/neural_compressor/adaptor/ox_utils/operators/direct_q8.py b/neural_compressor/adaptor/ox_utils/operators/direct_q8.py index fc81c203de0..bcddb43aa3a 100644 --- a/neural_compressor/adaptor/ox_utils/operators/direct_q8.py +++ b/neural_compressor/adaptor/ox_utils/operators/direct_q8.py @@ -16,26 +16,43 @@ # limitations under the License. # -from .qdq_base_operator import QDQOperatorBase -from .base_operator import QuantOperatorBase -from .base_operator_cast import CastOperatorBase -from neural_compressor.adaptor.ox_utils.util import QuantizedValue +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator -# For operators that support 8bits operations directly, and output could -# reuse input[0]'s type, zeropoint, scale; For example,Transpose, Reshape, etc. - -class Direct8BitOp(QuantOperatorBase): +@op_registry(op_types="Reshape, Transpose, Squeeze, Unsqueeze") +class Direct8BitOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(Direct8BitOperator, self).__init__(onnx_quantizer, onnx_node) + + def quantize_check(self): + node = self.node + if not self.quantizer.is_valid_quantize_weight(node.input[0]): + return False + return True + + def quantize(self): + node = self.node + self.quantizer.quantize_inputs(self.node, direct_int8=True) + if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': + self.quantizer.quantize_outputs(self.node, direct_int8=True) + node.name = node.name + "_quant" - def convert(self): + def convert_check(self, convert_format): node = self.node + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + parents = self.quantizer.model.get_parents(node) children = self.quantizer.model.get_children(node) if (len(children) == 0 and len(parents) == 0) or \ not node.name.endswith('_quant'): - return + return False + return True + def convert(self, convert_format): + node = self.node + + parents = self.quantizer.model.get_parents(node) + children = self.quantizer.model.get_children(node) if any([i.op_type == 'DequantizeLinear' for i in parents]) and \ any([i.op_type == 'QuantizeLinear' for i in children]): for parent in parents: @@ -49,27 +66,20 @@ def convert(self): self.quantizer.model.replace_input_of_all_nodes( child.output[0], node.output[0] + '_quantized') node.output[0] = node.output[0] + '_quantized' - -class QDQDirect8BitOp(QDQOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): + + def cast(self): # pragma: no cover node = self.node - if not self.quantizer.is_valid_quantize_weight(node.input[0]): + if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]: return - self.quantizer.quantize_inputs(self.node, direct_int8=True) - if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': - self.quantizer.quantize_outputs(self.node, direct_int8=True) - node.name = node.name + "_quant" + self.quantizer.dtype_cast(self.node, self.dtype) -class DirectCast(CastOperatorBase): +@op_registry(op_types="Shape, Loop, Slice") +class DirectCastOperator(Operator): # pragma: no cover def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(DirectCastOperator, self).__init__(onnx_quantizer, onnx_node) def cast(self): node = self.node if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]: return - super().cast() - + self.quantizer.dtype_cast(self.node, self.dtype) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py b/neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py index 98a06c1314b..8c99929a9a1 100644 --- a/neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py +++ b/neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py @@ -17,37 +17,30 @@ # import onnx -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase -from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValueType, \ - attribute_to_kwarg, ms_domain -''' -Quantize EmbedLayerNormalization -''' +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain - -class EmbedLayerNormalizationQuant(QuantOperatorBase): # pragma: no cover +@op_registry(op_types="EmbedLayerNormalization") +class EmbedLayerNormalizationOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(EmbedLayerNormalizationOperator, self).__init__(onnx_quantizer, onnx_node) - def convert(self): + def quantize(self): node = self.node - assert (node.op_type == "EmbedLayerNormalization") + self.quantizer.quantize_inputs(node, [2, 3, 4, 5, 6]) + node.name = node.name + "_quant" + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['dynamic', 'static'], \ + "convert format for {} should be in ['dynamic', 'static']".format(node.op_type) + if not node.name.endswith('_quant'): - return + return False + return True - ''' - Pre-quantization EmbedLayerNorm inputs: - [0] input_ids (int32) - [1] segment_ids (int32) - [2] word_embedding (float32) - [3] position_embedding (float32) - [4] segment_embedding (float32) - [5] gamma (float32) - [6] beta (float32) - [7] mask (int32) (optional) - ''' + def convert(self, convert_format): + node = self.node parents = [i for i in self.quantizer.model.get_parents(node) \ if i.op_type == 'DequantizeLinear'] @@ -66,29 +59,8 @@ def convert(self): for parent in parents: inputs.append(parent.input[2]) - ''' - Quantized Input Tensor List - [0] input_ids (int32) - [1] segment_ids (int32) - [2] word_embedding (uint8) - [3] position_embedding (uint8) - [4] segment_embedding (uint8) - [5] gamma (uint8) - [6] beta (uint8) - [7] mask (int32) (optional) - [8] word_embedding_scale (float) - [9] position_embedding_scale (float) - [10] segment_embedding_scale (float) - [11] gamma_scale (float) - [12] beta_scale (float) - [13] word_embedding_zero_point (uint8) - [14] position_embedding_zero_point (uint8) - [15] segment_embedding_zero_point (uint8) - [16] gamma_zero_point (uint8) - [17] beta_zero_point (uint8) - ''' kwargs = {} - for attribute in node.attribute: + for attribute in node.attribute: # pragma: no cover kwargs.update(attribute_to_kwarg(attribute)) kwargs["domain"] = ms_domain @@ -96,14 +68,4 @@ def convert(self): inputs, node.output, node.name, **kwargs) self.quantizer.new_nodes.append(qembed_layer_norm_node) - self.quantizer.remove_nodes.extend(parents) - -class QDQEmbedLayerNormalization(QDQOperatorBase): # pragma: no cover - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - assert (node.op_type == "EmbedLayerNormalization") - self.quantizer.quantize_inputs(node, [2, 3, 4, 5, 6]) - node.name = node.name + "_quant" + self.quantizer.remove_nodes.extend(parents) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/gather.py b/neural_compressor/adaptor/ox_utils/operators/gather.py index 24132d8d7f9..93f98823047 100644 --- a/neural_compressor/adaptor/ox_utils/operators/gather.py +++ b/neural_compressor/adaptor/ox_utils/operators/gather.py @@ -17,30 +17,45 @@ # import onnx -from .base_operator import QuantOperatorBase -from neural_compressor.adaptor.ox_utils.util import QuantizedValueType, attribute_to_kwarg -from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValue -''' - Quantize Gather -''' +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg -class GatherConverter(QuantOperatorBase): +@op_registry(op_types="Gather") +class GatherOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(GatherOperator, self).__init__(onnx_quantizer, onnx_node) - def convert(self): + def quantize_check(self): node = self.node - assert node.op_type in ["Gather"] + if not self.quantizer.is_valid_quantize_weight(node.input[0]): + return False + return True + + def quantize(self): + node = self.node + self.quantizer.quantize_inputs(node, [0]) + if not self.disable_qdq_for_node_output or self.quantizer != 'qdq': + self.quantizer.quantize_outputs(node) + node.name = node.name + "_quant" + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['dynamic', 'static'], \ + "convert format for {} should be in ['dynamic', 'static']".format(node.op_type) - if len(self.quantizer.model.get_children(node)) == 0 or \ - not node.name.endswith('_quant'): - return + parents = self.quantizer.model.get_parents(node) + children = self.quantizer.model.get_children(node) + if len(children) == 0 or len(parents) == 0 or not node.name.endswith('_quant'): + return False + + return True + + def convert(self, convert_format): + node = self.node + parents = self.quantizer.model.get_parents(node) children = self.quantizer.model.get_children(node) - if len(parents) == 0: - return if any([i.op_type == 'DequantizeLinear' for i in parents]): inputs = [] inputs.append(parents[0].input[0]) @@ -49,7 +64,7 @@ def convert(self): gather_new_output = node.output[0] + "_quantized" kwargs = {} - for attribute in node.attribute: + for attribute in node.attribute: # pragma: no cover kwargs.update(attribute_to_kwarg(attribute)) gather_node = onnx.helper.make_node("Gather", @@ -58,7 +73,7 @@ def convert(self): node.name, **kwargs) self.quantizer.new_nodes.append(gather_node) - if any([i.op_type != 'QuantizeLinear' for i in children]): + if any([i.op_type != 'QuantizeLinear' for i in children]): # pragma: no cover dq_inputs = [] dq_inputs.append(gather_new_output) dq_inputs.extend(parents[0].input[1:]) @@ -74,19 +89,4 @@ def convert(self): for n in self.quantizer.model.get_children(child): self.quantizer.model.replace_node_input(n, child.output[0], gather_new_output) - self.quantizer.remove_nodes.extend([node, parents[0]]) - -class GatherQuant(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - assert (node.op_type == "Gather") - - if not self.quantizer.is_valid_quantize_weight(node.input[0]): - return - self.quantizer.quantize_inputs(node, [0]) - if not self.disable_qdq_for_node_output or self.quantizer != 'qdq': - self.quantizer.quantize_outputs(node) - node.name = node.name + "_quant" + self.quantizer.remove_nodes.extend([node, parents[0]]) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/gavgpool.py b/neural_compressor/adaptor/ox_utils/operators/gavgpool.py index 16c95c03582..b4bafcafeae 100644 --- a/neural_compressor/adaptor/ox_utils/operators/gavgpool.py +++ b/neural_compressor/adaptor/ox_utils/operators/gavgpool.py @@ -17,20 +17,27 @@ # import onnx -from .base_operator import QuantOperatorBase -from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain, \ - QuantizedValueType -from neural_compressor.adaptor.ox_utils.util import QuantizedValue -class QGlobalAveragePool(QuantOperatorBase): +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain + +@op_registry(op_types="GlobalAveragePool") +class GlobalAveragePoolOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(GlobalAveragePoolOperator, self).__init__(onnx_quantizer, onnx_node) - def convert(self): + def convert_check(self, convert_format): node = self.node - assert (node.op_type == "GlobalAveragePool") + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + + children = self.quantizer.model.get_children(node) + if len(children) == 0: # pragma: no cover + return False + return True - if len(self.quantizer.model.get_children(node)) == 0: - return + def convert(self, convert_format): + node = self.node + parent = self.quantizer.model.get_parents(node)[0] child = self.quantizer.model.get_children(node)[0] @@ -51,4 +58,4 @@ def convert(self): self.quantizer.new_nodes += [qnode] self.quantizer.remove_nodes.append(child) self.quantizer.remove_nodes.append(parent) - self.quantizer.remove_nodes.append(node) + self.quantizer.remove_nodes.append(node) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/gemm.py b/neural_compressor/adaptor/ox_utils/operators/gemm.py index 37e0a9150ac..65aca2e8a7d 100644 --- a/neural_compressor/adaptor/ox_utils/operators/gemm.py +++ b/neural_compressor/adaptor/ox_utils/operators/gemm.py @@ -17,67 +17,28 @@ # import onnx -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator from neural_compressor.adaptor.ox_utils.util import find_by_name, ms_domain, \ attribute_to_kwarg, is_B_transposed - -''' - Used when quantize mode is QuantizationMode.QLinearOps -''' - -class QLinearGemm(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): - node = self.node - assert (node.op_type == "Gemm") - - if len(self.quantizer.model.get_children(node)) == 0 or \ - not node.name.endswith('_quant'): - return - - parents = self.quantizer.model.get_parents(node) - child = self.quantizer.model.get_children(node)[0] - qgemm_output = child.output[0] - qgemm_inputs = [] - for parent in parents[:-1]: - qgemm_inputs.extend(parent.input) - qgemm_inputs.append(parents[-1].input[0]) - qgemm_inputs.extend(child.input[1:]) - - kwargs = {} - for attribute in node.attribute: - if attribute.name != "beta": - kwargs.update(attribute_to_kwarg(attribute)) - kwargs["domain"] = ms_domain - - qgemm_node = onnx.helper.make_node("QGemm", - qgemm_inputs, [qgemm_output], node.name, **kwargs) - - self.quantizer.new_nodes.append(qgemm_node) - self.quantizer.remove_nodes.extend(parents) - self.quantizer.remove_nodes.append(child) - self.quantizer.remove_nodes.append(node) - -class QDQGemm(QDQOperatorBase): +@op_registry(op_types="Gemm") +class GemmOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): + super(GemmOperator, self).__init__(onnx_quantizer, onnx_node) + + def quantize_check(self): node = self.node - assert (node.op_type == "Gemm") - if len(node.input) == 3 and \ not find_by_name(node.input[2], self.quantizer.model.initializer()): from neural_compressor.utils import logger logger.warning("Bias of Gemm node '{}' is not constant. " \ "Exclude this node can get better performance.".format(node.name)) if self.quantizer.mode != 'qdq': - return + return False + return True + def quantize(self): + node = self.node self.quantizer.quantize_inputs(node, [0]) if self.per_channel and find_by_name(node.input[1], self.quantizer.model.initializer()): self.quantizer.quantize_weights_per_channel(node, [1], @@ -95,3 +56,39 @@ def quantize(self): if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': self.quantizer.quantize_outputs(node) node.name = node.name + "_quant" + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + + children = self.quantizer.model.get_children(node) + if len(children) == 0 or not node.name.endswith('_quant'): + return False + return True + + def convert(self, convert_format): + node = self.node + + parents = self.quantizer.model.get_parents(node) + child = self.quantizer.model.get_children(node)[0] + qgemm_output = child.output[0] + qgemm_inputs = [] + for parent in parents[:-1]: + qgemm_inputs.extend(parent.input) + qgemm_inputs.append(parents[-1].input[0]) + qgemm_inputs.extend(child.input[1:]) + + kwargs = {} + for attribute in node.attribute: + if attribute.name != "beta": + kwargs.update(attribute_to_kwarg(attribute)) + kwargs["domain"] = ms_domain + + qgemm_node = onnx.helper.make_node("QGemm", + qgemm_inputs, [qgemm_output], node.name, **kwargs) + + self.quantizer.new_nodes.append(qgemm_node) + self.quantizer.remove_nodes.extend(parents) + self.quantizer.remove_nodes.append(child) + self.quantizer.remove_nodes.append(node) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/lstm.py b/neural_compressor/adaptor/ox_utils/operators/lstm.py index 18a78dd5010..1d7fa45e6e0 100644 --- a/neural_compressor/adaptor/ox_utils/operators/lstm.py +++ b/neural_compressor/adaptor/ox_utils/operators/lstm.py @@ -17,46 +17,47 @@ # import onnx +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import ms_domain, attribute_to_kwarg import numpy -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase -from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain, QuantType -from onnx import onnx_pb as onnx_proto -''' - Quantize LSTM -''' - -class LSTMQuant(QuantOperatorBase): # pragma: no cover +@op_registry(op_types="LSTM") +class LSTMOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): - ''' - parameter node: LSTM node. - parameter new_nodes_list: List of new nodes created before processing this node. - return: a list of nodes in topological order that represents quantized Attention node. - ''' - node = self.node - assert (node.op_type == "LSTM") + super(LSTMOperator, self).__init__(onnx_quantizer, onnx_node) + def quantize(self): + return + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['dynamic'], \ + "convert format for {} should be in ['dynamic']".format(node.op_type) + if (not self.quantizer.is_valid_quantize_weight(node.input[1]) or - not self.quantizer.is_valid_quantize_weight(node.input[2])): - super().convert() - return + not self.quantizer.is_valid_quantize_weight(node.input[2])): # pragma: no cover + return False model = self.quantizer.model W = model.get_initializer(node.input[1]) R = model.get_initializer(node.input[2]) - if (len(W.dims) != 3 or len(R.dims) != 3): - super().convert() - return + if (len(W.dims) != 3 or len(R.dims) != 3): # pragma: no cover + return False + + return True + + def convert(self, convert_format): + node = self.node + model = self.quantizer.model + W = model.get_initializer(self.node.input[1]) + R = model.get_initializer(self.node.input[2]) + [W_num_dir, W_4_hidden_size, W_input_size] = W.dims [R_num_dir, R_4_hidden_size, R_hidden_size] = R.dims - if self.per_channel: + if self.per_channel: # pragma: no cover del W.dims[0] del R.dims[0] W.dims[0] = W_num_dir * W_4_hidden_size @@ -93,7 +94,7 @@ def convert(self): W_quant_scale = model.get_initializer(quant_input_weight_tuple[2]) R_quant_scale = model.get_initializer(quant_recurrent_weight_tuple[2]) - if self.per_channel: + if self.per_channel: # pragma: no cover W_quant_zp.dims[:] = [W_num_dir, W_4_hidden_size] R_quant_zp.dims[:] = [R_num_dir, R_4_hidden_size] W_quant_scale.dims[:] = [W_num_dir, W_4_hidden_size] @@ -122,12 +123,4 @@ def convert(self): quant_lstm_node = onnx.helper.make_node("DynamicQuantizeLSTM", inputs, node.output, quant_lstm_name, **kwargs) self.quantizer.remove_nodes.append(node) - self.quantizer.new_nodes.append(quant_lstm_node) - -class QDQLSTM(QDQOperatorBase): # pragma: no cover - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - return - + self.quantizer.new_nodes.append(quant_lstm_node) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/matmul.py b/neural_compressor/adaptor/ox_utils/operators/matmul.py index ad58eaf97a5..988e157e323 100644 --- a/neural_compressor/adaptor/ox_utils/operators/matmul.py +++ b/neural_compressor/adaptor/ox_utils/operators/matmul.py @@ -17,121 +17,17 @@ # import onnx -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase -from neural_compressor.adaptor.ox_utils.util import find_by_name, \ - QuantizedValueType +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import find_by_name from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValue -''' - Used when quantize mode is QuantizationMode.IntegerOps. -''' - -class MatMulInteger(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): - node = self.node - assert (node.op_type == "MatMul") - - parents = self.quantizer.model.get_parents(node) - - inputs = [] - quantized_name = [] - scale = [] - zp = [] - for parent in parents: - if parent.op_type == 'DequantizeLinear': - quantized_name.append(parent.input[0]) - else: - quantized_name.append(parent.output[0]) - if parent.op_type == 'DynamicQuantizeLinear': - scale.append(parent.output[1]) - zp.append(parent.output[2]) - else: - scale.append(parent.input[1]) - zp.append(parent.input[2]) - inputs.extend(quantized_name) - inputs.extend(zp) - matmul_integer_output = node.output[0] + "_output_quantized" - matmul_integer_node = onnx.helper.make_node("MatMulInteger", - inputs, - [matmul_integer_output], node.name) - self.quantizer.new_nodes.append(matmul_integer_node) - - # Add cast operation to cast matmulInteger output to float. - cast_op_output = matmul_integer_output + "_cast_output" - cast_node = onnx.helper.make_node("Cast", [matmul_integer_output], [cast_op_output], - matmul_integer_output + "_cast", - to=onnx_proto.TensorProto.FLOAT) - self.quantizer.new_nodes.append(cast_node) - - # Add mul operation to multiply scales of two inputs. - scales_mul_op = node.name + "_scales_mul" - - scales_mul_node = find_by_name(scales_mul_op, self.quantizer.new_nodes) - if scales_mul_node is None: - scales_mul_node = onnx.helper.make_node("Mul", [scale[0], scale[1]], - [scales_mul_op + ":0"], scales_mul_op) - self.quantizer.new_nodes.append(scales_mul_node) - - scales_mul_op_output = scales_mul_node.output[0] - - # Add mul operation to multiply mul_scales_op result with output of MatMulInteger - # and make the output of this node the same as output of original matmul node. - output_scale_mul_op = node.name + "_output_scale_mul" - self.quantizer.new_nodes.append( - onnx.helper.make_node("Mul", [cast_op_output, scales_mul_op_output], - [node.output[0]], output_scale_mul_op)) - if parents[1].op_type == 'DequantizeLinear': - self.quantizer.remove_nodes.append(parents[1]) - self.quantizer.remove_nodes.append(node) - -''' - Used when quantize mode is QuantizationMode.QLinearOps -''' - - -class QLinearMatMul(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): - node = self.node - assert (node.op_type == "MatMul") - - parents = self.quantizer.model.get_parents(node) - if len(self.quantizer.model.get_children(node)) == 0 or \ - not node.name.endswith('_quant'): - return - child = self.quantizer.model.get_children(node)[0] - - qlinear_matmul_output = child.output[0] - - qlinear_matmul_inputs = [] - for parent in parents: - qlinear_matmul_inputs.extend(parent.input) - qlinear_matmul_inputs.extend(child.input[1:]) - - qlinear_matmul_node = onnx.helper.make_node("QLinearMatMul", - qlinear_matmul_inputs, - [qlinear_matmul_output], - node.name) - self.quantizer.new_nodes.append(qlinear_matmul_node) - self.quantizer.remove_nodes.extend(parents) - self.quantizer.remove_nodes.append(child) - self.quantizer.remove_nodes.append(node) - -class QDQMatMul(QDQOperatorBase): +@op_registry(op_types="MatMul") +class MatMulOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(MatMulOperator, self).__init__(onnx_quantizer, onnx_node) def quantize(self): node = self.node - assert (node.op_type == "MatMul") - self.quantizer.quantize_inputs(node, [0]) if self.per_channel and find_by_name(node.input[1], self.quantizer.model.initializer()): self.quantizer.quantize_weights_per_channel(node, [1], @@ -142,3 +38,88 @@ def quantize(self): if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': self.quantizer.quantize_outputs(node) node.name = node.name + "_quant" + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['dynamic', 'static'], \ + "convert format for {} should be in ['dynamic', 'static']".format(node.op_type) + return True + + def convert(self, convert_format): + node = self.node + + if convert_format == 'dynamic': + parents = self.quantizer.model.get_parents(node) + + inputs = [] + quantized_name = [] + scale = [] + zp = [] + for parent in parents: + if parent.op_type == 'DequantizeLinear': + quantized_name.append(parent.input[0]) + else: + quantized_name.append(parent.output[0]) + if parent.op_type == 'DynamicQuantizeLinear': + scale.append(parent.output[1]) + zp.append(parent.output[2]) + else: + scale.append(parent.input[1]) + zp.append(parent.input[2]) + inputs.extend(quantized_name) + inputs.extend(zp) + matmul_integer_output = node.output[0] + "_output_quantized" + matmul_integer_node = onnx.helper.make_node("MatMulInteger", + inputs, + [matmul_integer_output], node.name) + self.quantizer.new_nodes.append(matmul_integer_node) + + # Add cast operation to cast matmulInteger output to float. + cast_op_output = matmul_integer_output + "_cast_output" + cast_node = onnx.helper.make_node("Cast", [matmul_integer_output], [cast_op_output], + matmul_integer_output + "_cast", + to=onnx_proto.TensorProto.FLOAT) + self.quantizer.new_nodes.append(cast_node) + + # Add mul operation to multiply scales of two inputs. + scales_mul_op = node.name + "_scales_mul" + + scales_mul_node = find_by_name(scales_mul_op, self.quantizer.new_nodes) + if scales_mul_node is None: + scales_mul_node = onnx.helper.make_node("Mul", [scale[0], scale[1]], + [scales_mul_op + ":0"], scales_mul_op) + self.quantizer.new_nodes.append(scales_mul_node) + + scales_mul_op_output = scales_mul_node.output[0] + + # Add mul operation to multiply mul_scales_op result with output of MatMulInteger + # and make the output of this node the same as output of original matmul node. + output_scale_mul_op = node.name + "_output_scale_mul" + self.quantizer.new_nodes.append( + onnx.helper.make_node("Mul", [cast_op_output, scales_mul_op_output], + [node.output[0]], output_scale_mul_op)) + if parents[1].op_type == 'DequantizeLinear': + self.quantizer.remove_nodes.append(parents[1]) + self.quantizer.remove_nodes.append(node) + elif convert_format == 'static': + parents = self.quantizer.model.get_parents(node) + if len(self.quantizer.model.get_children(node)) == 0 or \ + not node.name.endswith('_quant'): # pragma: no cover + return + child = self.quantizer.model.get_children(node)[0] + + qlinear_matmul_output = child.output[0] + + qlinear_matmul_inputs = [] + for parent in parents: + qlinear_matmul_inputs.extend(parent.input) + qlinear_matmul_inputs.extend(child.input[1:]) + + qlinear_matmul_node = onnx.helper.make_node("QLinearMatMul", + qlinear_matmul_inputs, + [qlinear_matmul_output], + node.name) + self.quantizer.new_nodes.append(qlinear_matmul_node) + self.quantizer.remove_nodes.extend(parents) + self.quantizer.remove_nodes.append(child) + self.quantizer.remove_nodes.append(node) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/maxpool.py b/neural_compressor/adaptor/ox_utils/operators/maxpool.py index 4f9ee545b7d..393c0210300 100644 --- a/neural_compressor/adaptor/ox_utils/operators/maxpool.py +++ b/neural_compressor/adaptor/ox_utils/operators/maxpool.py @@ -16,32 +16,40 @@ # limitations under the License. # -import onnx -from .base_operator import QuantOperatorBase -from .direct_q8 import QDQDirect8BitOp -from neural_compressor.adaptor.ox_utils.util import QuantizedValueType -from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValue +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator - -class QMaxPool(QuantOperatorBase): +@op_registry(op_types="MaxPool") +class MaxPoolOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(MaxPoolOperator, self).__init__(onnx_quantizer, onnx_node) + + def quantize_check(self): + # if opset version is less than 12, just no change + if self.quantizer.opset_version < 12: # pragma: no cover + return False + return True - def convert(self): + def quantize(self): node = self.node - assert (node.op_type == "MaxPool") + super().quantize() + node.name = node.name + '_quant' - if self.quantizer.opset_version < 12: # pragma: no cover - return + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) - if len(self.quantizer.model.get_children(node)) == 0 or \ - not node.name.endswith('_quant'): - return + children = self.quantizer.model.get_children(node) + if len(children) == 0 or not node.name.endswith('_quant'): # pragma: no cover + return False + return True + + def convert(self, convert_format): + node = self.node parent = self.quantizer.model.get_parents(node)[0] children = self.quantizer.model.get_children(node) if parent.op_type != 'DequantizeLinear' or \ - all([i.op_type != 'QuantizeLinear' for i in children]): + all([i.op_type != 'QuantizeLinear' for i in children]): # pragma: no cover return node.input[0] = parent.input[0] node.output[0] = node.output[0] + '_quantized' @@ -52,19 +60,4 @@ def convert(self): self.quantizer.model.replace_node_input(n, child.output[0], node.output[0]) - self.quantizer.remove_nodes.append(parent) - -class QDQMaxPool(QDQDirect8BitOp): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - assert (node.op_type == "MaxPool") - - # if version is less than 12, just no change - if self.quantizer.opset_version < 12: - return - - # Direct 8bits op - super().quantize() + self.quantizer.remove_nodes.append(parent) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/base_operator.py b/neural_compressor/adaptor/ox_utils/operators/ops.py similarity index 61% rename from neural_compressor/adaptor/ox_utils/operators/base_operator.py rename to neural_compressor/adaptor/ox_utils/operators/ops.py index 9db63e99b31..33d4ecf7c5d 100644 --- a/neural_compressor/adaptor/ox_utils/operators/base_operator.py +++ b/neural_compressor/adaptor/ox_utils/operators/ops.py @@ -1,58 +1,84 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -class QuantOperatorBase: - def __init__(self, onnx_quantizer, onnx_node): - self.quantizer = onnx_quantizer - self.node = onnx_node - self.disable_qdq_for_node_output = True if onnx_node.op_type in \ - onnx_quantizer.op_types_to_exclude_output_quantization else False - self.per_channel = False - self.algorithm = 'minmax' - self.weight_scheme = 'sym' - self.weight_dtype = None - self.activation_dtype = None - self.activation_scheme = 'asym' - if self.node.name in self.quantizer.config: - if self.quantizer.config[self.node.name] != 'fp32': - if 'weight' in self.quantizer.config[self.node.name].keys(): - self.per_channel = self.quantizer.config[self.node.name]\ - ['weight']['granularity'] == 'per_channel' - self.algorithm = self.quantizer.config[self.node.name]\ - ['weight']['algorithm'] - self.weight_scheme = self.quantizer.config[self.node.name]\ - ['weight']['scheme'] - self.weight_dtype = self.quantizer.config[self.node.name]\ - ['weight']['dtype'] - if 'activation' in self.quantizer.config[self.node.name].keys(): - self.activation_dtype = self.quantizer.config[self.node.name]\ - ['activation']['dtype'] - self.activation_scheme = self.quantizer.config[self.node.name]\ - ['activation']['scheme'] - - def convert(self): - ''' - Given a node which does not support quantization(Conv, Matmul, Gather), this method - checks whether the input to this node is quantized and adds a DequantizeLinear node - to dequantize this input back to FP32 - parameter node: Current node - parameter new_nodes_list: List of new nodes created before processing current node - return: List of new nodes created - ''' - return - +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +OPERATORS = {} + +def op_registry(op_types): + '''The class decorator used to register all Operator subclasses. + + Args: + cls (class): The class of register. + ''' + def decorator_op(cls): + assert cls.__name__.endswith( + 'Operator'), "The name of subclass of Operator should end with \'Operator\' substring." + if cls.__name__[:-len('Operator')] in OPERATORS: # pragma: no cover + raise ValueError('Cannot have two operators with the same name.') + for single_op_type in [op_type.strip() for op_type in op_types.split(',')]: + OPERATORS[single_op_type] = cls + return cls + return decorator_op + +class Operator(object): + def __init__(self, onnx_quantizer, onnx_node): + self.quantizer = onnx_quantizer + self.node = onnx_node + if self.node.name in self.quantizer.config: + self.dtype = self.quantizer.config[self.node.name] + self.disable_qdq_for_node_output = True if onnx_node.op_type in \ + onnx_quantizer.op_types_to_exclude_output_quantization else False + self.per_channel = False + self.algorithm = 'minmax' + self.weight_scheme = 'sym' + self.weight_dtype = None + self.activation_dtype = None + self.activation_scheme = 'asym' + if self.node.name in self.quantizer.config: + if self.quantizer.config[self.node.name] != 'fp32': + if 'weight' in self.quantizer.config[self.node.name].keys(): + self.per_channel = self.quantizer.config[self.node.name]\ + ['weight']['granularity'] == 'per_channel' + self.algorithm = self.quantizer.config[self.node.name]\ + ['weight']['algorithm'] + self.weight_scheme = self.quantizer.config[self.node.name]\ + ['weight']['scheme'] + self.weight_dtype = self.quantizer.config[self.node.name]\ + ['weight']['dtype'] + if 'activation' in self.quantizer.config[self.node.name].keys(): + self.activation_dtype = self.quantizer.config[self.node.name]\ + ['activation']['dtype'] + self.activation_scheme = self.quantizer.config[self.node.name]\ + ['activation']['scheme'] + + def quantize_check(self): + return True + + def quantize(self): + node = self.node + self.quantizer.quantize_inputs(node) + if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': + self.quantizer.quantize_outputs(node) + + def convert_check(self, convert_format): + return True + + def convert(self, convert_format): + return + + def cast(self): # pragma: no cover + self.quantizer.dtype_cast(self.node, self.dtype) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/pad.py b/neural_compressor/adaptor/ox_utils/operators/pad.py index c6fcf730d79..0f0acfcbec7 100644 --- a/neural_compressor/adaptor/ox_utils/operators/pad.py +++ b/neural_compressor/adaptor/ox_utils/operators/pad.py @@ -16,48 +16,41 @@ # limitations under the License. # -import numpy import onnx -from neural_compressor.adaptor.ox_utils.util import QuantizedValueType, \ - attribute_to_kwarg -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase -from neural_compressor.adaptor.ox_utils.util import QuantizedValue, quantize_nparray +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, quantize_nparray -class QDQPad(QDQOperatorBase): +@op_registry(op_types="Pad") +class PadOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(PadOperator, self).__init__(onnx_quantizer, onnx_node) + + def quantize_check(self): + # if opset version is less than 11, just no change + if self.quantizer.opset_version < 11: # pragma: no cover + return False + return True def quantize(self): node = self.node - assert (node.op_type == "Pad") - - # Only after version 11, it has the optional constant_value - # If input[0] is not quantized, do not quanitize this node - if self.quantizer.opset_version < 11: - return - self.quantizer.quantize_inputs(node, [0]) if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': self.quantizer.quantize_outputs(node) node.name = node.name + "_quant" -class QPad(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): + def convert_check(self, convert_format): node = self.node - assert (node.op_type == "Pad") + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) - # Only after version 11, it has the optional constant_value - # If input[0] is not quantized, do not quanitize this node - if self.quantizer.opset_version < 11: - return + children = self.quantizer.model.get_children(node) + if len(children) == 0 or not node.name.endswith('_quant'): # pragma: no cover + return False + return True - if len(self.quantizer.model.get_children(node)) == 0 or \ - not node.name.endswith('_quant'): - return + def convert(self, convert_format): + node = self.node + parent = self.quantizer.model.get_parents(node)[0] child = self.quantizer.model.get_children(node)[0] @@ -98,7 +91,6 @@ def convert(self): node.input.extend([parent.input[2]]) # Create an entry for output quantized value - node.input[0] = parent.input[0] node.output[0] = child.output[0] - self.quantizer.remove_nodes.extend([parent, child]) + self.quantizer.remove_nodes.extend([parent, child]) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/pooling.py b/neural_compressor/adaptor/ox_utils/operators/pooling.py index cf553e86e87..bba746129e6 100644 --- a/neural_compressor/adaptor/ox_utils/operators/pooling.py +++ b/neural_compressor/adaptor/ox_utils/operators/pooling.py @@ -17,38 +17,42 @@ # import onnx -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase -from neural_compressor.adaptor.ox_utils.util import QuantizedValueType, \ - attribute_to_kwarg, ms_domain -from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.util import QuantizedValue +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain -class QDQPool(QDQOperatorBase): +@op_registry(op_types="AveragePool") +class PoolOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(PoolOperator, self).__init__(onnx_quantizer, onnx_node) - def quantize(self): + def quantize_check(self): node = self.node if not self.quantizer.is_valid_quantize_weight(node.input[0]): - return + return False + return True - self.quantizer.quantize_inputs(self.node) - if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': - self.quantizer.quantize_outputs(self.node) + def quantize(self): + node = self.node + super().quantize() node.name = node.name + "_quant" -class QLinearPool(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): + def convert_check(self, convert_format): node = self.node + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + parents = self.quantizer.model.get_parents(node) children = self.quantizer.model.get_children(node) if len(children) == 0 or len(parents) == 0 or not node.name.endswith('_quant'): - return + return False + return True + + def convert(self, convert_format): + node = self.node + + parents = self.quantizer.model.get_parents(node) + children = self.quantizer.model.get_children(node) if all([i.op_type == 'DequantizeLinear' for i in parents]) and \ any([i.op_type == 'QuantizeLinear' for i in children]): diff --git a/neural_compressor/adaptor/ox_utils/operators/qdq_base_operator.py b/neural_compressor/adaptor/ox_utils/operators/qdq_base_operator.py deleted file mode 100644 index 070d7bcbc45..00000000000 --- a/neural_compressor/adaptor/ox_utils/operators/qdq_base_operator.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import itertools -from .base_operator import QuantOperatorBase - - -class QDQOperatorBase(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - - self.quantizer.quantize_inputs(node) - if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': - self.quantizer.quantize_outputs(node) diff --git a/neural_compressor/adaptor/ox_utils/operators/resize.py b/neural_compressor/adaptor/ox_utils/operators/resize.py index 5585f301350..d5f906f8372 100644 --- a/neural_compressor/adaptor/ox_utils/operators/resize.py +++ b/neural_compressor/adaptor/ox_utils/operators/resize.py @@ -16,21 +16,45 @@ # limitations under the License. # -from .direct_q8 import QDQDirect8BitOp, Direct8BitOp +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator -class QResize(Direct8BitOp): +@op_registry(op_types="Resize") +class ResizeOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(ResizeOperator, self).__init__(onnx_quantizer, onnx_node) - def convert(self): + def quantize_check(self): + node = self.node + # if version is less than 11, just keep this node + if self.quantizer.opset_version < 11: + return False + if not self.quantizer.is_valid_quantize_weight(node.input[0]): + return False + return True + + def quantize(self): + node = self.node + self.quantizer.quantize_inputs(node, [0], direct_int8=True) + if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': + self.quantizer.quantize_outputs(self.node, direct_int8=True) + node.name = node.name + "_quant" + + def convert_check(self, convert_format): + node = self.node + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + + parents = self.quantizer.model.get_parents(node) + children = self.quantizer.model.get_children(node) + if (len(children) == 0 and len(parents) == 0) or not node.name.endswith('_quant'): + return False + return True + + def convert(self, convert_format): node = self.node - assert node.op_type == "Resize" parents = self.quantizer.model.get_parents(node) children = self.quantizer.model.get_children(node) - if (len(children) == 0 and len(parents) == 0) or \ - not node.name.endswith('_quant'): - return if any([i.op_type == 'DequantizeLinear' for i in parents]) and \ any([i.op_type == 'QuantizeLinear' for i in children]): @@ -46,21 +70,3 @@ def convert(self): child.output[0], node.output[0] + '_quantized') node.output[0] = node.output[0] + '_quantized' -class QDQResize(QDQDirect8BitOp): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - assert node.op_type == "Resize" - - # if version is less than 11, just keep this node - if self.quantizer.opset_version < 11: - return - - if not self.quantizer.is_valid_quantize_weight(node.input[0]): - return - self.quantizer.quantize_inputs(node, [0], direct_int8=True) - if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': - self.quantizer.quantize_outputs(self.node, direct_int8=True) - node.name = node.name + "_quant" diff --git a/neural_compressor/adaptor/ox_utils/operators/split.py b/neural_compressor/adaptor/ox_utils/operators/split.py index 0dee3296311..3c6d7d04cef 100644 --- a/neural_compressor/adaptor/ox_utils/operators/split.py +++ b/neural_compressor/adaptor/ox_utils/operators/split.py @@ -17,15 +17,13 @@ # import onnx -from neural_compressor.adaptor.ox_utils.util import QuantizedValueType, \ - attribute_to_kwarg -from .base_operator import QuantOperatorBase -from neural_compressor.adaptor.ox_utils.util import QuantizedValue -from .qdq_base_operator import QDQOperatorBase +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg -class QDQSplit(QDQOperatorBase): +@op_registry(op_types="Split") +class SplitOperator(Operator): def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) + super(SplitOperator, self).__init__(onnx_quantizer, onnx_node) def quantize(self): node = self.node @@ -34,24 +32,29 @@ def quantize(self): self.quantizer.quantize_outputs(self.node, direct_int8=True) node.name = node.name + "_quant" -class QSplit(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def convert(self): + def convert_check(self, convert_format): node = self.node + assert convert_format in ['static'], \ + "convert format for {} should be in ['static']".format(node.op_type) + parent = self.quantizer.model.get_parents(node)[0] children = self.quantizer.model.get_children(node) if parent.op_type != 'DequantizeLinear' or len(children) == 0 or \ - not node.name.endswith('_quant'): - return + not node.name.endswith('_quant'): # pragma: no cover + return False + return True + + def convert(self, convert_format): + node = self.node + + parent = self.quantizer.model.get_parents(node)[0] kwargs = {} - for attribute in node.attribute: + for attribute in node.attribute: # pragma: no cover kwargs.update(attribute_to_kwarg(attribute)) quantized_input_names = [] quantized_input_names.append(parent.input[0]) - if len(node.input) > 1: + if len(node.input) > 1: # pragma: no cover quantized_input_names.extend(node.input[1:]) outputs = [] for output in node.output: @@ -60,9 +63,9 @@ def convert(self): if child.op_type == 'QuantizeLinear': self.quantizer.remove_nodes.append(child) outputs.append(child.output[0]) - else: + else: # pragma: no cover outputs.append(output) - else: + else: # pragma: no cover outputs.append(output + '_quatized') quantized_node = onnx.helper.make_node(node.op_type, @@ -71,3 +74,9 @@ def convert(self): node.name, **kwargs) self.quantizer.new_nodes.append(quantized_node) self.quantizer.remove_nodes.extend([parent, node]) + + def cast(self): # pragma: no cover + node = self.node + if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]: + return + self.quantizer.dtype_cast(self.node, self.dtype) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/quantizer.py b/neural_compressor/adaptor/ox_utils/quantizer.py index 7e30393f90a..829dff40045 100644 --- a/neural_compressor/adaptor/ox_utils/quantizer.py +++ b/neural_compressor/adaptor/ox_utils/quantizer.py @@ -29,8 +29,6 @@ from onnx import shape_inference from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel -from neural_compressor.adaptor.ox_utils.registry import CreateQDQQuantizer, \ - CreateOpConverter, CreateCaster from neural_compressor.adaptor.ox_utils.util import QuantizedValue, QuantizedInitializer, \ _get_qrange_for_qType, cast_tensor, make_quant_node, make_dquant_node from neural_compressor.adaptor.ox_utils.util import QuantizedValueType @@ -40,6 +38,7 @@ from neural_compressor import options from neural_compressor.utils.utility import CpuInfo from neural_compressor.model.onnx_model import ONNXModel +from neural_compressor.adaptor.ox_utils.operators import OPERATORS logger = logging.getLogger() @@ -209,10 +208,11 @@ def should_cast(self, node): def insert_qdq(self): for node in self.model.nodes(): if self.should_quantize(node): - op_quantizer = CreateQDQQuantizer(self, node) - op_quantizer.quantize() + op_quantizer = OPERATORS[node.op_type](self, node) + if op_quantizer.quantize_check(): + op_quantizer.quantize() elif self.should_cast(node): # pragma: no cover - op_caster = CreateCaster(self, node) + op_caster = OPERATORS[node.op_type](self, node) op_caster.cast() self.model.graph().node.extend(self.new_nodes) self.model.remove_nodes(self.remove_nodes) @@ -236,9 +236,10 @@ def convert_qdq_to_operator_oriented(self): for node in self.model.nodes(): if node.op_type not in ['QuantizeLinear', 'DequantizeLinear'] and \ self.should_convert(node): - op_converter = CreateOpConverter(self, node, - self.config[node.name.split('_quant')[0]]['activation']['quant_mode']) - op_converter.convert() + op_converter = OPERATORS[node.op_type](self, node) + mode = self.config[node.name.split('_quant')[0]]['activation']['quant_mode'] + if op_converter.convert_check(mode): + op_converter.convert(mode) self.model.graph().node.extend(self.new_nodes) self.model.remove_nodes(self.remove_nodes) for node, old_input_name, new_input_name in self.replace_input: diff --git a/neural_compressor/adaptor/ox_utils/registry.py b/neural_compressor/adaptor/ox_utils/registry.py deleted file mode 100644 index 6d49e472f69..00000000000 --- a/neural_compressor/adaptor/ox_utils/registry.py +++ /dev/null @@ -1,136 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from .util import QuantizationMode -from .operators.base_operator import QuantOperatorBase -from .operators.qdq_base_operator import QDQOperatorBase -from .operators.matmul import MatMulInteger, QLinearMatMul, QDQMatMul -from .operators.attention import AttentionQuant, QDQAttention -from .operators.embed_layernorm import EmbedLayerNormalizationQuant, QDQEmbedLayerNormalization -from .operators.gather import GatherConverter, GatherQuant -from .operators.conv import QLinearConv, ConvInteger, QDQConv -from .operators.activation import QLinearActivation, QDQRemovableActivation, QDQActivation -from .operators.binary_op import QLinearBinaryOp, QDQBinaryOp -from .operators.maxpool import QMaxPool, QDQMaxPool -from .operators.gavgpool import QGlobalAveragePool -from .operators.lstm import LSTMQuant, QDQLSTM -from .operators.split import QSplit, QDQSplit -from .operators.concat import QLinearConcat, QDQConcat -from .operators.pad import QPad, QDQPad -from .operators.pooling import QLinearPool, QDQPool -from .operators.direct_q8 import QDQDirect8BitOp, Direct8BitOp, DirectCast -from .operators.base_operator_cast import CastOperatorBase -from .operators.argmax import QArgMax -from .operators.gemm import QLinearGemm, QDQGemm -from .operators.resize import QResize, QDQResize - -CommonOpsRegistry = {"Gather": GatherConverter, \ - "EmbedLayerNormalization": EmbedLayerNormalizationQuant} - -IntegerOpsRegistry = { - "Conv": ConvInteger, - "FusedConv": ConvInteger, - "MatMul": MatMulInteger, - "Attention": AttentionQuant, - "LSTM": LSTMQuant, -} -IntegerOpsRegistry.update(CommonOpsRegistry) - -QLinearOpsRegistry = { - "Conv": QLinearConv, - "Concat": QLinearConcat, - "Attention": AttentionQuant, - "FusedConv": QLinearConv, - "MatMul": QLinearMatMul, - "Add": QLinearBinaryOp, - "Mul": QLinearBinaryOp, - "Relu": QLinearActivation, - "Clip": QLinearActivation, - "LeakyRelu" : QLinearActivation, - "Sigmoid" : QLinearActivation, - "MaxPool": QMaxPool, - "GlobalAveragePool": QGlobalAveragePool, - "Split": QSplit, - "Pad": QPad, - "AveragePool" : QLinearPool, - "Reshape": Direct8BitOp, - "Transpose" : Direct8BitOp, - "Squeeze" : Direct8BitOp, - "Unsqueeze" : Direct8BitOp, - "Resize": QResize, - "ArgMax": QArgMax, - "Gemm": QLinearGemm, -} -QLinearOpsRegistry.update(CommonOpsRegistry) - -QDQRegistry = { - "FusedConv": QDQConv, - "Conv": QDQConv, - "Clip": QDQRemovableActivation, - "Relu": QDQRemovableActivation, - "LeakyRelu": QDQActivation, - "Sigmoid": QDQActivation, - "MaxPool": QDQMaxPool, - "MatMul": QDQMatMul, - "Add": QDQBinaryOp, - "Mul": QDQBinaryOp, - "Gather": GatherQuant, - "Attention": QDQAttention, - "LSTM": QDQLSTM, - "Pad": QDQPad, - "Reshape": QDQDirect8BitOp, - "Transpose" : QDQDirect8BitOp, - "Squeeze" : QDQDirect8BitOp, - "AveragePool": QDQPool, - "Unsqueeze" : QDQDirect8BitOp, - "Concat": QDQConcat, - "Split": QDQSplit, - "EmbedLayerNormalization": QDQEmbedLayerNormalization, - "Gemm": QDQGemm, - "Resize": QDQResize, -} - -CastRegistry = { - "Shape": DirectCast, - "Squeeze": DirectCast, - "Unsqueeze": DirectCast, - "Reshape": DirectCast, - "Unsqueeze": DirectCast, - "Transpose": DirectCast, - "Loop": DirectCast, - "Slice": DirectCast, - "Split": DirectCast, - "Concat": DirectCast, - -} - - -def CreateOpConverter(onnx_quantizer, node, quant_mode): - registry = IntegerOpsRegistry if quant_mode == 'dynamic' else QLinearOpsRegistry - if node.op_type in registry.keys(): - return registry[node.op_type](onnx_quantizer, node) - return QuantOperatorBase(onnx_quantizer, node) - -def CreateQDQQuantizer(onnx_quantizer, node): - if node.op_type in QDQRegistry.keys(): - return QDQRegistry[node.op_type](onnx_quantizer, node) - return QDQOperatorBase(onnx_quantizer, node) - -def CreateCaster(onnx_quantizer, node): - if node.op_type in CastRegistry.keys(): - return CastRegistry[node.op_type](onnx_quantizer, node) - return CastOperatorBase(onnx_quantizer, node) diff --git a/test/adaptor/onnxrt_adaptor/test_onnxrt_augment.py b/test/adaptor/onnxrt_adaptor/test_onnxrt_augment.py index 4cbddac9a13..769fd7353a1 100644 --- a/test/adaptor/onnxrt_adaptor/test_onnxrt_augment.py +++ b/test/adaptor/onnxrt_adaptor/test_onnxrt_augment.py @@ -9,7 +9,7 @@ sys.path.append('..') from neural_compressor.experimental.data.datasets.dataset import Dataset -from neural_compressor.adaptor.ox_utils.onnxrt_mid import ONNXRTAugment +from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment from neural_compressor.model.onnx_model import ONNXModel from neural_compressor.data import DATASETS, DATALOADERS