Skip to content

Commit

Permalink
Enable QDQ-adapted _MklFusedInstanceNorm+[Relu/LeakyRelu] fusion and …
Browse files Browse the repository at this point in the history
…quantization (#1311)

* Add _MklFusedInstanceNorm related configuration

* Enable QDQ-adapted _MklFusedInstanceNorm+[Relu/LeakyRelu] fusion and quantization

* Add attribute 'reduction_axes' to qin

* Add dummy mean/variance nodes for _MklFusedInstanceNorm

* Add QIN freeze_value

* Fix QIN freeze_value

* Add performance_only mode for _MklFusedInstanceNorm
  • Loading branch information
ChendaLi-Intel authored Oct 11, 2022
1 parent 1adb453 commit d5b1716
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 23 deletions.
14 changes: 12 additions & 2 deletions neural_compressor/adaptor/inteltensorflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

ops: &common_ops
int8: ['Conv2D', 'Conv3D', 'DepthwiseConv2dNative', 'FusedBatchNorm', 'FusedBatchNormV2','FusedBatchNormV3',
'MatMul', 'BatchMatMul', 'BatchMatMulV2', 'ConcatV2', 'MaxPool', 'MaxPool3D', 'AvgPool']
'MatMul', 'BatchMatMul', 'BatchMatMulV2', 'ConcatV2', 'MaxPool', 'MaxPool3D', 'AvgPool', '_MklFusedInstanceNorm']
uint8: ['Conv2D', 'Conv3D', 'DepthwiseConv2dNative', 'MatMul', 'BatchMatMul', 'BatchMatMulV2', 'ConcatV2', 'MaxPool', 'MaxPool3D', 'AvgPool']
bf16: ["Conv2D", "Conv2DBackpropFilter", "Conv2DBackpropInput", "Conv3D", "Conv3DBackpropFilterV2", "Conv3DBackpropInputV2",
"DepthwiseConv2dNative", "DepthwiseConv2dNativeBackpropFilter", "DepthwiseConv2dNativeBackpropInput", "GRUBlockCell",
Expand Down Expand Up @@ -86,6 +86,14 @@
'algorithm': ['minmax', 'kl']
}
},
'_MklFusedInstanceNorm': {
'activation': {
'dtype': ['int8', 'fp32'],
'scheme': ['sym'],
'granularity': ['per_tensor'],
'algorithm': ['minmax', 'kl']
}
},
'MatMul': {
'weight': {
'dtype': ['int8'],
Expand Down Expand Up @@ -283,7 +291,9 @@
'Dequantize + DepthwiseConv2dNative + Relu + QuantizeV2',
'Dequantize + DepthwiseConv2dNative + Add + Relu6 + QuantizeV2',
'Dequantize + DepthwiseConv2dNative + BiasAdd + QuantizeV2',
'Dequantize + FusedBatchNormV3 + Relu + QuantizeV2'
'Dequantize + FusedBatchNormV3 + Relu + QuantizeV2',
'Dequantize + _MklFusedInstanceNorm + Relu + QuantizeV2',
'Dequantize + _MklFusedInstanceNorm + LeakyRelu + QuantizeV2'
]
uint8: [
'Dequantize + Conv2D + BiasAdd + AddN + Relu + QuantizeV2',
Expand Down
22 changes: 14 additions & 8 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TensorFlowAdaptor(Adaptor):
"Conv3D": "conv3d",
"DepthwiseConv2dNative": "conv2d",
"FusedBatchNormV3": "batchnorm",
"_MklFusedInstanceNorm": "instancenorm",
"MaxPool": "pooling",
"MaxPool3D": "pooling",
"AvgPool": "pooling",
Expand Down Expand Up @@ -268,7 +269,7 @@ def evaluate(self, model, dataloader, postprocess=None,
output_postfix = "_fp32.output"
inspect_node_types = ["Conv2D", "DepthwiseConv2dNative", "MaxPool", "AvgPool",
"ConcatV2", "MatMul", "FusedBatchNormV3", "FusedBatchNorm", "BiasAdd",
"Relu", "Relu6", "Dequantize"]
"_MklFusedInstanceNorm", "Relu", "Relu6", "Dequantize"]
fp32_inspect_node_name = []
int8_inspect_node_name = []
q_node_scale = {}
Expand Down Expand Up @@ -316,7 +317,7 @@ def evaluate(self, model, dataloader, postprocess=None,
# Inspect weights, bias. Need further optimize
if node.op == "Const" and graph_info[graph_info[node.name].outputs[0]].node.op \
in ["Conv2D", "DepthwiseConv2dNative", "MatMul",
"FusedBatchNormV3", "BiasAdd"]:
"FusedBatchNormV3", "_MklFusedInstanceNorm", "BiasAdd"]:
const_value = tensor_util.MakeNdarray(node.attr.get(
'value').tensor).astype(np.float32)
self.log_histogram(writer, node.name, const_value)
Expand Down Expand Up @@ -610,7 +611,7 @@ def _dump_model_op_stats(self, model_graphdef):
'QuantizedMaxPool', 'QuantizedAvgPool',
'QuantizedConcatV2', 'QuantizedMatMul',
'_QuantizedFusedBatchNorm', '_QuantizedMatMul',
'_QuantizedBatchMatMul']
'_QuantizedBatchMatMul', '_QuantizedFusedInstanceNorm']
from tensorflow.python.framework import dtypes

res = {}
Expand All @@ -629,6 +630,8 @@ def _dump_model_op_stats(self, model_graphdef):
origin_op_type = possible_int8_res[0].split('Quantized')[-1]
if origin_op_type == 'FusedBatchNorm':
origin_op_type = 'FusedBatchNormV3'
if origin_op_type == 'FusedInstanceNorm':
origin_op_type = '_MklFusedInstanceNorm'
if origin_op_type == 'Depthwise':
origin_op_type = 'DepthwiseConv2dNative'
if origin_op_type == 'BatchMatMul':
Expand Down Expand Up @@ -1597,16 +1600,19 @@ def _one_shot_query(self):
try:
self.cur_config = self._get_specified_version_cfg(content)
if not self.performance_only:
remove_int8_ops = ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3']
remove_int8_ops = ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
'_MklFusedInstanceNorm']
for op in remove_int8_ops:
while op in self.cur_config['ops']['int8']:
self.cur_config['ops']['int8'].remove(op)
if self.cur_config.get('capabilities'):
self.cur_config['capabilities']['int8'].pop(op, None)
pattern = f'Dequantize + {op} + Relu + QuantizeV2'
if self.cur_config.get('patterns'):
while pattern in self.cur_config['patterns']['int8']:
self.cur_config['patterns']['int8'].remove(pattern)
patterns = [f'Dequantize + {op} + Relu + QuantizeV2',
f'Dequantize + {op} + LeakyRelu + QuantizeV2']
for pattern in patterns:
if self.cur_config.get('patterns'):
while pattern in self.cur_config['patterns']['int8']:
self.cur_config['patterns']['int8'].remove(pattern)

except Exception as e:
logger.info("Fail to parse {} due to {}.".format(self.cfg, str(e)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import tf_logging

from neural_compressor.adaptor.tf_utils.quantize_graph_common import QuantizeGraphHelper as helper
from neural_compressor.utils.utility import dump_elapsed_time

class FuseDecomposedINOptimizer():
Expand Down Expand Up @@ -235,10 +236,18 @@ def do_transformation(self):

# Mean and variance values will be computed at runtime for fp32 & bf16 input.
# Pass a "dummy" node for mean and variance.
mean_variance_dim = tensor_util.MakeNdarray(gamma_op.attr["value"].tensor).shape[-1]
dummy_mean_node = \
helper.create_constant_node(node.name + '_dummy_mean',
[0.]*mean_variance_dim, dtypes.float32)
dummy_variance_node = \
helper.create_constant_node(node.name + '_dummy_variance',
[1.]*mean_variance_dim, dtypes.float32)
new_fused_instancenorm_op.input.extend([input_data_op.name, gamma_op.name,
beta_op.name, gamma_op.name,
gamma_op.name])

beta_op.name, dummy_mean_node.name,
dummy_variance_node.name])
new_ops.append(dummy_mean_node)
new_ops.append(dummy_variance_node)
new_ops.append(new_fused_instancenorm_op)

result_graph_def = graph_pb2.GraphDef()
Expand Down Expand Up @@ -339,4 +348,4 @@ def get_const_dim_count(node_def):
Number of dimensions for the Const node.
"""
const_value = values_from_const(node_def)
return const_value.ndim
return const_value.ndim
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,27 @@ def generate_output_graph_ranges(self, max_name_value):
"""
for node_name, value in max_name_value.items():
bn_node_name = node_name.replace('eightbit_requant_range', 'eightbit_quantized_bn')
in_node_name = node_name.replace('eightbit_requant_range', 'eightbit_quantized_in')
if not self.graph_info.get(bn_node_name) or \
not bn_node_name.endswith('_eightbit_quantized_bn'):
bn_node_name = None
if not self.graph_info.get(in_node_name) or \
not in_node_name.endswith('_eightbit_quantized_in'):
in_node_name = None
if node_name not in self.graph_info \
and bn_node_name not in self.graph_info:
and bn_node_name not in self.graph_info \
and in_node_name not in self.graph_info:
continue

min_node = node_def_pb2.NodeDef()
min_node.op = "Const"
min_node_postfix = "/frozen_min"
min_node.name = bn_node_name + "/frozen_bn_output_min" if bn_node_name \
else node_name + min_node_postfix
if bn_node_name:
min_node.name = bn_node_name + "/frozen_bn_output_min"
elif in_node_name:
min_node.name = in_node_name + "/frozen_in_output_min"
else:
min_node.name = node_name + min_node_postfix
min_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
min_node.attr["value"].CopyFrom(
Expand All @@ -250,8 +259,12 @@ def generate_output_graph_ranges(self, max_name_value):
max_node = node_def_pb2.NodeDef()
max_node.op = "Const"
max_node_postfix = "/frozen_max"
max_node.name = bn_node_name + "/frozen_bn_output_max" if bn_node_name \
else node_name + max_node_postfix
if bn_node_name:
max_node.name = bn_node_name + "/frozen_bn_output_max"
elif in_node_name:
max_node.name = in_node_name + "/frozen_in_output_max"
else:
max_node.name = node_name + max_node_postfix
max_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
max_node.attr["value"].CopyFrom(
Expand All @@ -270,6 +283,17 @@ def generate_output_graph_ranges(self, max_name_value):
[Helper.node_name_from_input(bn_node_name)],
bn_node_name + '_input8_output_max'
)
elif in_node_name:
self.cur_graph.replace_const_node(
min_node,
[Helper.node_name_from_input(in_node_name)],
in_node_name + '_input7_output_min'
)
self.cur_graph.replace_const_node(
max_node,
[Helper.node_name_from_input(in_node_name)],
in_node_name + '_input8_output_max'
)
elif not self.itex_mode and node_name in self.cur_graph.parent_frame_details and \
self.cur_graph.parent_frame_details[node_name]: # pragma: no cover
output_node_name = self.graph_info[node_name].outputs[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _check_op_list(self, node_type):
op_list = ("ConcatV2", "Conv2D", "Conv3D", "DepthwiseConv2D", "QuantizeV2", "DepthwiseConv2dNative",
"MaxPool", "MaxPool3D", "FusedBatchNormV3", "Requantize", "RequantizePerChannel", "AvgPool", "Pad",
"CropAndResize", "Dequantize", "Mean", "MatMul", "BatchMatMul",
"BatchMatMulV2", "FakeQuantWithMinMaxVars")
"BatchMatMulV2", "FakeQuantWithMinMaxVars", "_MklFusedInstanceNorm")
return any([node_type.find(i) != -1 for i in op_list])

def _find_relu_node(self, node):
Expand Down Expand Up @@ -573,6 +573,8 @@ def _ignore_insert_qdq_pattern(self, matched_node_name):
return True
if "FusedBatchNorm" in self.graph_info[matched_node_name].node.op:
return True
if "_MklFusedInstanceNorm" == self.graph_info[matched_node_name].node.op:
return True
return False


2 changes: 1 addition & 1 deletion neural_compressor/adaptor/tf_utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _has_positive_input(self, start_node):
return True
elif op_type in ("Conv3D", "Conv2D", "DepthwiseConv2D", "QuantizeV2", "DepthwiseConv2dNative",
"MaxPool", "MaxPool3D", "Requantize", "AvgPool", "Pad", "CropAndResize", "Dequantize",
"Mean", "MatMul", "FusedBatchNormV3"):
"Mean", "MatMul", "FusedBatchNormV3", "_MklFusedInstanceNorm"):
return self._has_positive_input(
self.node_name_details[GraphRewriterHelper.node_name_from_input(
start_node.input[0])].node)
Expand Down
Loading

0 comments on commit d5b1716

Please sign in to comment.