Skip to content

Commit

Permalink
Refactor ox_utils (#1322)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuwenzho authored Oct 11, 2022
1 parent d5b1716 commit 288340b
Show file tree
Hide file tree
Showing 27 changed files with 763 additions and 963 deletions.
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _dump_model_op_stats(self, model):
field_names=["Op Type", "Total", "INT8", "BF16", "FP16", "FP32"]).print_stat()

def _get_quantize_params(self, model, data_loader, quantize_config, iterations):
from neural_compressor.adaptor.ox_utils.onnxrt_mid import ONNXRTAugment
from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment
from neural_compressor.model.onnx_model import ONNXModel
if not isinstance(model, ONNXModel):
model = ONNXModel(model)
Expand All @@ -346,7 +346,7 @@ def inspect_tensor(self, model, dataloader, op_list=[],
quantization_cfg=None):
'''The function is used by tune strategy class for dumping tensor info.
'''
from neural_compressor.adaptor.ox_utils.onnxrt_mid import ONNXRTAugment
from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment
from neural_compressor.model.onnx_model import ONNXModel
from neural_compressor.utils.utility import dump_data_to_local
if not isinstance(model, ONNXModel):
Expand Down
11 changes: 11 additions & 0 deletions neural_compressor/adaptor/ox_utils/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,14 @@
# limitations under the License.
#

from os.path import dirname, basename, isfile, join
import glob
from .ops import OPERATORS

modules = glob.glob(join(dirname(__file__), "*.py"))

for f in modules:
if isfile(f) and not f.startswith('__') and not f.endswith('__init__.py'):
__import__(basename(f)[:-3], globals(), locals(), level=1)

__all__ = ["OPERATORS"]
78 changes: 41 additions & 37 deletions neural_compressor/adaptor/ox_utils/operators/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,39 @@
#

import onnx
from .base_operator import QuantOperatorBase
from .qdq_base_operator import QDQOperatorBase
from neural_compressor.adaptor.ox_utils.util import QuantizedValueType, \
attribute_to_kwarg, ms_domain
from onnx import onnx_pb as onnx_proto
from neural_compressor.adaptor.ox_utils.util import QuantizedValue
from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain


class QLinearActivation(QuantOperatorBase):
@op_registry(op_types="LeakyRelu, Sigmoid")
class ActivationOperator(Operator):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
super(ActivationOperator, self).__init__(onnx_quantizer, onnx_node)

def quantize_check(self):
node = self.node
data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0])
if not data_found:
return False
return True

def quantize(self):
node = self.node
super().quantize()
node.name = node.name + "_quant"

def convert(self):
def convert_check(self, convert_format):
node = self.node
if node.op_type in ['Relu', 'Clip']:
return

if len(self.quantizer.model.get_children(node)) == 0 or \
not node.name.endswith('_quant'):
return
# No assert on op_type as it is controlled by registry
# only try to quantize when given quantization parameters for it
assert convert_format in ['static'], \
"convert format for {} should be in ['static']".format(node.op_type)

children = self.quantizer.model.get_children(node)
if len(children) == 0 or not node.name.endswith('_quant'):
return False
return True

def convert(self, convert_format):
node = self.node

parent = self.quantizer.model.get_parents(node)[0]
child = self.quantizer.model.get_children(node)[0]

Expand All @@ -48,7 +59,7 @@ def convert(self):

qlinear_activation_output = child.output[0]
kwargs = {}
for attribute in node.attribute:
for attribute in node.attribute: # pragma: no cover
kwargs.update(attribute_to_kwarg(attribute))
kwargs["domain"] = ms_domain

Expand All @@ -59,28 +70,21 @@ def convert(self):
self.quantizer.new_nodes.append(qlinear_activation_node)
self.quantizer.remove_nodes.extend([parent, child, node])

class QDQRemovableActivation(QDQOperatorBase):
@op_registry(op_types="Relu, Clip")
class RemovableActivationOperator(Operator):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
super(RemovableActivationOperator, self).__init__(onnx_quantizer, onnx_node)

def quantize(self):
def quantize_check(self):
node = self.node
if node.input[0] not in self.quantizer.quantized_value_map:
return
elif node.output[0] in [i.name for i in self.quantizer.model.model.graph.output]:
return False
return True

def quantize(self):
node = self.node
if node.output[0] in [i.name for i in self.quantizer.model.model.graph.output]:
self.quantizer.dequantize_tensor(node, node.input[0])
else:
self.quantizer.model.replace_input_of_all_nodes(node.output[0], node.input[0])
self.quantizer.remove_nodes.append(node)

class QDQActivation(QDQOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)

def quantize(self):
node = self.node
data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0])
if not data_found:
return
super().quantize()
node.name = node.name + "_quant"
self.quantizer.remove_nodes.append(node)
18 changes: 13 additions & 5 deletions neural_compressor/adaptor/ox_utils/operators/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@
# limitations under the License.
#

from .base_operator import QuantOperatorBase

class QArgMax(QuantOperatorBase):
from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator

@op_registry(op_types="ArgMax")
class ArgMaxOperator(Operator):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
super(ArgMaxOperator, self).__init__(onnx_quantizer, onnx_node)

def convert_check(self, convert_format):
node = self.node
assert convert_format in ['static'], \
"convert format for {} should be in ['static']".format(node.op_type)
return True

def convert(self):
def convert(self, convert_format):
node = self.node
origin_name = node.input[0].split('_argmax_node')[0]

if origin_name in self.quantizer.quantized_value_map:
node.input[0] = self.quantizer.quantized_value_map[origin_name].q_name
node.name = node.name + '_quant'
node.name = node.name + '_quant'
52 changes: 19 additions & 33 deletions neural_compressor/adaptor/ox_utils/operators/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,30 @@
#

import onnx
from .base_operator import QuantOperatorBase
from .qdq_base_operator import QDQOperatorBase
from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain
from onnx import onnx_pb as onnx_proto
'''
Quantize Attention
'''


class AttentionQuant(QuantOperatorBase):
@op_registry(op_types="Attention")
class AttentionOperator(Operator):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
super(AttentionOperator, self).__init__(onnx_quantizer, onnx_node)

def convert(self):
'''
parameter node: Attention node.
parameter new_nodes_list: List of new nodes created before processing this node.
return: a list of nodes in topological order that represents quantized Attention node
'''
def quantize(self):
node = self.node
assert (node.op_type == "Attention")
self.quantizer.quantize_inputs(node)
node.name = node.name + "_quant"

def convert_check(self, convert_format):
node = self.node
assert convert_format in ['dynamic', 'static'], \
"convert format for {} should be in ['dynamic', 'static']".format(node.op_type)

if not node.name.endswith('_quant'):
return
return False
return True

def convert(self, convert_format):
node = self.node
parents = self.quantizer.model.get_parents(node)
quantized_name = []
scale = []
Expand All @@ -65,25 +65,11 @@ def convert(self):
inputs.extend([node.input[4] if len(node.input) > 4 else ""])

kwargs = {}
for attribute in node.attribute:
for attribute in node.attribute: # pragma: no cover
kwargs.update(attribute_to_kwarg(attribute))
kwargs["domain"] = ms_domain
qattention_node = onnx.helper.make_node("QAttention", inputs, node.output,
node.name, **kwargs)
self.quantizer.new_nodes.append(qattention_node)

self.quantizer.remove_nodes.append(node)

class QDQAttention(QDQOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)

def quantize(self):
node = self.node
assert (node.op_type == "Attention")

if self.quantizer.static:
super().quantize()
else:
self.quantizer.quantize_inputs(node, [0, 1])
node.name = node.name + "_quant"
self.quantizer.remove_nodes.append(node)
26 changes: 0 additions & 26 deletions neural_compressor/adaptor/ox_utils/operators/base_operator_cast.py

This file was deleted.

68 changes: 34 additions & 34 deletions neural_compressor/adaptor/ox_utils/operators/binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,49 @@
#

import onnx
from .base_operator import QuantOperatorBase
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain, \
QuantizedValueType
from onnx import onnx_pb as onnx_proto
from neural_compressor.adaptor.ox_utils.util import QuantizedValue
from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain

class QLinearBinaryOp(QuantOperatorBase):
@op_registry(op_types="Add, Mul")
class BinaryOperator(Operator):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
super(BinaryOperator, self).__init__(onnx_quantizer, onnx_node)

def convert(self):
def quantize_check(self):
node = self.node
data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0])
if not data_found:
return False
if not all([self.quantizer.is_valid_quantize_weight(i) for i in node.input]):
return False
return True

def quantize(self):
node = self.node
self.quantizer.quantize_inputs(node, initializer_use_weight_qType=False)
if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq':
self.quantizer.quantize_outputs(node)
node.name = node.name + "_quant"

def convert_check(self, convert_format):
node = self.node
assert convert_format in ['static'], \
"convert format for {} should be in ['static']".format(node.op_type)

children = self.quantizer.model.get_children(node)
if len(children) == 0 or not node.name.endswith('_quant'):
return False
return True

def convert(self, convert_format):
node = self.node
if len(self.quantizer.model.get_children(node)) == 0 or \
not node.name.endswith('_quant'):
return
parents = self.quantizer.model.get_parents(node)
child = self.quantizer.model.get_children(node)[0]

qlinear_binary_math_output = child.output[0]

kwargs = {}
for attribute in node.attribute:
for attribute in node.attribute: # pragma: no cover
kwargs.update(attribute_to_kwarg(attribute))
kwargs["domain"] = ms_domain

Expand All @@ -56,25 +77,4 @@ def convert(self):
self.quantizer.new_nodes += [qlinear_binary_math_node]
self.quantizer.remove_nodes.extend(parents)
self.quantizer.remove_nodes.append(child)
self.quantizer.remove_nodes.append(node)

class QDQBinaryOp(QuantOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)

def quantize(self):
node = self.node
data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0])
if not data_found:
return

if not all([self.quantizer.is_valid_quantize_weight(i) for i in node.input]):
return

if not all([self.quantizer.is_valid_quantize_weight(i) for i in node.input]):
return

self.quantizer.quantize_inputs(node, initializer_use_weight_qType=False)
if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq':
self.quantizer.quantize_outputs(node)
node.name = node.name + "_quant"
self.quantizer.remove_nodes.append(node)
Loading

0 comments on commit 288340b

Please sign in to comment.