Skip to content

Commit

Permalink
Support more ONNX direct INT8 ops (#1151)
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho authored Aug 16, 2023
1 parent 641d42b commit b9ce61a
Show file tree
Hide file tree
Showing 8 changed files with 676 additions and 126 deletions.
10 changes: 9 additions & 1 deletion neural_compressor/adaptor/onnxrt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,15 @@
'Transpose': *default_static_qlinear_qdq_minmax,
'ArgMax': *default_static_qlinear,
'Resize': *default_static_qlinear_qdq_minmax,

'Abs': *default_static_qlinear_qdq_minmax,
'Shrink': *default_static_qlinear_qdq_minmax,
'Sign': *default_static_qlinear_qdq_minmax,
'Flatten': *default_static_qlinear_qdq_minmax,
'Expand': *default_static_qlinear_qdq_minmax,
'Slice': *default_static_qlinear_qdq_minmax,
'Mod': *default_static_qlinear_qdq_minmax,
'ReduceMax': *default_static_qlinear_qdq_minmax,
'ReduceMin': *default_static_qlinear_qdq_minmax,
},
'dynamic': *ref_1_9_dynamic
}
Expand Down
10 changes: 9 additions & 1 deletion neural_compressor/adaptor/onnxrt_cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,15 @@
'Transpose': *default_static_qlinear_qdq,
'ArgMax': *default_static_qlinear,
'Resize': *default_static_qlinear_qdq,

'Abs': *default_static_qlinear_qdq,
'Shrink': *default_static_qlinear_qdq,
'Sign': *default_static_qlinear_qdq,
'Flatten': *default_static_qlinear_qdq,
'Expand': *default_static_qlinear_qdq,
'Slice': *default_static_qlinear_qdq,
'Mod': *default_static_qlinear_qdq,
'ReduceMax': *default_static_qlinear_qdq,
'ReduceMin': *default_static_qlinear_qdq,
},
'dynamic': *ref_1_9_dynamic
}
Expand Down
56 changes: 56 additions & 0 deletions neural_compressor/adaptor/ox_utils/operators/binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,62 @@ def convert(self, convert_format):
self.quantizer.remove_nodes.append(child)
self.quantizer.remove_nodes.append(node)

@op_registry(op_types="Mod")
class BinaryDirect8BitOperator(Operator):
"""Binary operator."""

def __init__(self, onnx_quantizer, onnx_node):
"""Initialization."""
super(BinaryDirect8BitOperator, self).__init__(onnx_quantizer, onnx_node)

def quantize_check(self):
"""Check if quantizaion can be done."""
node = self.node
data_found, _, _, _, _ = self.quantizer._get_quantization_params(node.output[0])
if not data_found:
return False
if not all([self.quantizer.is_valid_quantize_weight(i) for i in node.input]):
return False

return True

def quantize(self):
"""Do quantizaion."""
node = self.node
self.quantizer.quantize_inputs(node, initializer_use_weight_qType=False)
if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq':
self.quantizer.quantize_outputs(node)
node.name = node.name + "_quant"

def convert_check(self, convert_format):
"""Check if conversion can be done."""
node = self.node
assert convert_format in ['static'], \
"convert format for {} should be in ['static']".format(node.op_type)

children = self.quantizer.model.get_children(node)
if len(children) == 0 or not node.name.endswith('_quant'):
return False
return True

def convert(self, convert_format):
"""Convert to QOperator format."""
node = self.node
parents = self.quantizer.model.get_parents(node)
children = self.quantizer.model.get_children(node)
if any([i.op_type == 'DequantizeLinear' for i in parents]) and \
any([i.op_type == 'QuantizeLinear' for i in children]):
for idx, parent in enumerate(parents):
if parent.op_type == 'DequantizeLinear':
self.node.input[idx] = parent.input[0]
self.quantizer.remove_nodes.append(parent)
for child in children:
if child.op_type == 'QuantizeLinear':
self.quantizer.remove_nodes.append(child)
self.quantizer.model.replace_input_of_all_nodes(
child.output[0], node.output[0] + '_quantized')
node.output[0] = node.output[0] + '_quantized'

@qop_registry(op_types="QLinearAdd, QLinearMul")
class QBinaryOperator(QOperator):
"""QBinary Operator."""
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/ox_utils/operators/direct_q8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, qop_registry, QOperator

@op_registry(op_types="Reshape, Transpose, Squeeze, Unsqueeze")
@op_registry(op_types="Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Expand, Slice")
class Direct8BitOperator(Operator):
"""Direct8Bit Operator."""

Expand Down
58 changes: 57 additions & 1 deletion neural_compressor/adaptor/ox_utils/operators/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,67 @@

from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator

@op_registry(op_types="ReduceMean, ReduceLogSum, ReduceLogSumExp, ReduceMax, " \
@op_registry(op_types="ReduceMean, ReduceLogSum, ReduceLogSumExp, " \
"ReduceL1, ReduceL2, ReduceProd, ReduceSum, ReduceSumSquare")
class ReduceOperator(Operator):
"""Reduce Operator."""

def __init__(self, onnx_quantizer, onnx_node):
"""Initialization."""
super(ReduceOperator, self).__init__(onnx_quantizer, onnx_node)

@op_registry(op_types="ReduceMax, ReduceMin")
class ReduceMinMaxOperator(Operator):
"""ReduceMin and ReduceMax Operator."""

def __init__(self, onnx_quantizer, onnx_node):
"""Initialization."""
super(ReduceMinMaxOperator, self).__init__(onnx_quantizer, onnx_node)

def quantize_check(self):
"""Check if quantizaion can be done."""
node = self.node
if not self.quantizer.is_valid_quantize_weight(node.input[0]):
return False
return True

def quantize(self):
"""Do quantizaion."""
node = self.node
self.quantizer.quantize_inputs(self.node, [0], direct_int8=True)
if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq':
self.quantizer.quantize_outputs(self.node, direct_int8=True)
node.name = node.name + "_quant"

def convert_check(self, convert_format):
"""Check if conversion can be done."""
node = self.node
assert convert_format in ['static'], \
"convert format for {} should be in ['static']".format(node.op_type)

parents = self.quantizer.model.get_parents(node)
children = self.quantizer.model.get_children(node)
if (len(children) == 0 and len(parents) == 0) or \
not node.name.endswith('_quant'):
return False
return True

def convert(self, convert_format):
"""Convert to QOperator format."""
node = self.node

parents = self.quantizer.model.get_parents(node)
children = self.quantizer.model.get_children(node)
if any([i.op_type == 'DequantizeLinear' for i in parents]) and \
any([i.op_type == 'QuantizeLinear' for i in children]):
for parent in parents:
if parent.op_type == 'DequantizeLinear':
self.node.input[0] = parent.input[0]
self.quantizer.remove_nodes.append(parents[0])
break
for child in children:
if child.op_type == 'QuantizeLinear':
self.quantizer.remove_nodes.append(child)
self.quantizer.model.replace_input_of_all_nodes(
child.output[0], node.output[0] + '_quantized')
node.output[0] = node.output[0] + '_quantized'
59 changes: 58 additions & 1 deletion neural_compressor/adaptor/ox_utils/operators/unary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,67 @@

from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator

@op_registry(op_types="Abs, Exp, Log, Round, Sqrt")
@op_registry(op_types="Exp, Log, Round, Sqrt")
class UnaryOperator(Operator):
"""Unary operator."""

def __init__(self, onnx_quantizer, onnx_node):
"""Initialization."""
super(UnaryOperator, self).__init__(onnx_quantizer, onnx_node)


@op_registry(op_types="Abs, Shrink, Sign")
class UnaryDirect8BitOperator(Operator):
"""Unary operator."""

def __init__(self, onnx_quantizer, onnx_node):
"""Initialization."""
super(UnaryDirect8BitOperator, self).__init__(onnx_quantizer, onnx_node)

def quantize_check(self):
"""Check if quantizaion can be done."""
node = self.node
if not self.quantizer.is_valid_quantize_weight(node.input[0]):
return False
return True

def quantize(self):
"""Do quantizaion."""
node = self.node
self.quantizer.quantize_inputs(self.node, [0], direct_int8=True)
if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq':
self.quantizer.quantize_outputs(self.node, direct_int8=True)
node.name = node.name + "_quant"

def convert_check(self, convert_format):
"""Check if conversion can be done."""
node = self.node
assert convert_format in ['static'], \
"convert format for {} should be in ['static']".format(node.op_type)

parents = self.quantizer.model.get_parents(node)
children = self.quantizer.model.get_children(node)
if (len(children) == 0 and len(parents) == 0) or \
not node.name.endswith('_quant'):
return False
return True

def convert(self, convert_format):
"""Convert to QOperator format."""
node = self.node

parents = self.quantizer.model.get_parents(node)
children = self.quantizer.model.get_children(node)
if any([i.op_type == 'DequantizeLinear' for i in parents]) and \
any([i.op_type == 'QuantizeLinear' for i in children]):
for parent in parents:
if parent.op_type == 'DequantizeLinear':
self.node.input[0] = parent.input[0]
self.quantizer.remove_nodes.append(parents[0])
break
for child in children:
if child.op_type == 'QuantizeLinear':
self.quantizer.remove_nodes.append(child)
self.quantizer.model.replace_input_of_all_nodes(
child.output[0], node.output[0] + '_quantized')
node.output[0] = node.output[0] + '_quantized'
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/ox_utils/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def quantize_inputs(self, node, indices=None,
find_by_name(zeropoint_name, self.model.initializer()))
qlinear_node = onnx.helper.make_node("DynamicQuantizeLinear",
[tensor_name],
[tensor_name + "_quantized", scale_name, zeropoint_name],
[tensor_name + "_dynamic_quantized", scale_name, zeropoint_name],
tensor_name + "_QuantizeLinear")
else:
scale_name, zp_name, _, _ = \
Expand Down
Loading

0 comments on commit b9ce61a

Please sign in to comment.