Skip to content

Commit

Permalink
Enable _FusedBatchNormEx folding (intel#1202)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChendaLi-Intel authored Sep 1, 2022
1 parent 592758a commit cbbaf98
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ class FoldBatchNormNodesOptimizer(GraphRewriterBase):
["conv_op", "mean_op", "var_op", "beta_op", "gamma_op"],
# Order of inputs for FusedBatchNorm.
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"],
"FusedBatchNormV3": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
"FusedBatchNormV3": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"],
"_FusedBatchNormEx": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
}
# Name of the attribute epsilon value is stored in.
EPSILON_ATTR = {
"BatchNormWithGlobalNormalization": "variance_epsilon",
"FusedBatchNorm": "epsilon",
"FusedBatchNormV3": "epsilon"
"FusedBatchNormV3": "epsilon",
"_FusedBatchNormEx": "epsilon"
}

def scale_after_normalization(self, node):
Expand Down Expand Up @@ -70,8 +72,8 @@ def do_transformation(self):
scaling into the convolution weights. This function identifies the typical
pattern of batch normalization subgraphs, and performs the transformation to
fold the computations down into a simpler form. It currently only spots batch
normalization that's performed by the BatchNormWithGlobalNormalization and
FusedBatchNorm ops, and will need to be extended in the future to handle the
normalization that's performed by the BatchNormWithGlobalNormalization, FusedBatchNorm,
FusedBatchNormV3 and _FusedBatchNormEx ops, and will need to be extended in the future to handle the
newer style.
Returns:
Expand All @@ -86,7 +88,7 @@ def do_transformation(self):
graph_info = cur_graph.parse_graph()
target_nodes = cur_graph.query_fusion_pattern_nodes(
[["Conv2D", "DepthwiseConv2dNative"], ("BiasAdd", "Add", "AddV2"),
["BatchNormWithGlobalNormalization", "FusedBatchNorm", "FusedBatchNormV3"]])
["BatchNormWithGlobalNormalization", "FusedBatchNorm", "FusedBatchNormV3", "_FusedBatchNormEx"]])
for node_combination in target_nodes:
matched_node = node_combination[:-1]
has_add_op = True if len(node_combination[-1]) == 3 else False
Expand All @@ -96,6 +98,15 @@ def do_transformation(self):
weights_node = graph_info[Helper.node_name_from_input(weights_node_name)].node
bn_node = graph_info[Helper.node_name_from_input(matched_node[-1])].node

# oneDNN enabled _FusedBatchNormEx only supports num_side_inputs == 0
# and Relu/Identity activations.
if bn_node.op == "_FusedBatchNormEx":
if bn_node.attr["num_side_inputs"].i != 0:
continue
if not (bn_node.attr["activation_mode"].s == b"Identity" or
bn_node.attr["activation_mode"].s == b"Relu"):
continue

if weights_node.op != "Const":
self.logger.warning("Didn't find expected conv Constant input to '%s', "
"found %s instead. Maybe freeze_graph wasn't "
Expand Down Expand Up @@ -219,10 +230,31 @@ def do_transformation(self):
bias_add_node.attr["T"].CopyFrom(conv_node.attr["T"])
bias_add_node.attr["data_format"].CopyFrom(conv_node.attr["data_format"])
bias_add_node.input.extend([conv_node.name, offset_node.name])
if bn_node.op == "_FusedBatchNormEx" and bn_node.attr["activation_mode"].s == b"Relu":
# Create Relu op which takes Bias-Add as input.
# Conv2D/Depthwise-Conv2D Conv2D/Depthwise-Conv2D
# | |
# Bias-Add (originally _FusedBatchNormEx) <----> Bias-Add
# | | \
# <some-node> <some-node> Relu
relu_node = node_def_pb2.NodeDef()
relu_node.op = "Relu"
relu_node.name = bn_node.name + "_bn_relu"
relu_node.attr["T"].CopyFrom(conv_node.attr["T"])
relu_node.input.extend([bias_add_node.name])

cur_graph.add_node(offset_node, [], [bias_add_node.name])
cur_graph.add_node(bias_add_node, conv_node.name,
graph_info[Helper.node_name_from_input(matched_node[-1])].outputs)
if bn_node.op == "_FusedBatchNormEx" and bn_node.attr["activation_mode"].s == b"Relu":
matchd_node_outputs = graph_info[Helper.node_name_from_input(matched_node[-1])].outputs
cur_graph.add_node(offset_node, [], [bias_add_node.name])
cur_graph.add_node(bias_add_node, conv_node.name, [relu_node.name])
cur_graph.add_node(relu_node, bias_add_node.name, matchd_node_outputs)
else:
cur_graph.add_node(offset_node, [], [bias_add_node.name])
cur_graph.add_node(bias_add_node, conv_node.name,
graph_info[Helper.node_name_from_input(matched_node[-1])].outputs)
cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name)

cur_graph.remove_node(weights_node_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def _find_relu_node(self, node):
if (node.op in ("Relu", "Relu6", "Elu") or \
(node.op.find("AndRelu") != -1 and \
('alpha' not in node.attr or ('alpha' in node.attr and node.attr['alpha'].f == 0)))) \
and self.node_name_mapping[node.input[0]].op.find("FusedBatchNorm") == -1:
and (node.op != "Relu" or \
self.node_name_mapping[Helper.node_name_from_input(node.input[0])].op.find("FusedBatchNorm") == -1):
return True
elif 'T' in node.attr and node.attr['T'].type in (dtypes.quint8, dtypes.uint8):
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,17 @@ def apply_the_transform(self):
if matched_node_name:
self.output_graph = graph_pb2.GraphDef()
fusion_name = ''.join(matched_rule)
if fusion_name in self.fusion_mapping:
bn_node = self.node_name_mapping[matched_node_name[0]].node
is_training = bn_node.attr['is_training'].b
if fusion_name in self.fusion_mapping and is_training == False:
self.fusion_mapping[fusion_name](matched_node_name)
else:
if self.new_api:
self.logger.info("Unknown fusion pattern {}.".format(fusion_name))
if is_training == True:
self.logger.info \
("Skip quantizing the BN node '{}' due to the attr 'is_training == true'." \
.format(bn_node.name))
elif self.new_api:
self.logger.info("Unknown fusion pattern {} .".format(fusion_name))
if self.remove_redundant_quant_flag:
self.input_graph = self.remove_redundant_quantization(self.input_graph)
return self.input_graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def _find_relu_node(self, node):
if (node.op in ("Relu", "Relu6") or \
(node.op.find("AndRelu") != -1 and \
('alpha' not in node.attr or ('alpha' in node.attr and node.attr['alpha'].f == 0)))) \
and self.node_name_mapping[node.input[0]].node.op.find("FusedBatchNorm") == -1:
and (node.op != "Relu" or \
self.node_name_mapping \
[helper.node_name_from_input(node.input[0])].node.op.find("FusedBatchNorm") == -1):
return True
elif 'T' in node.attr and node.attr['T'].type in (dtypes.quint8, dtypes.uint8):
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,17 @@ def apply_the_transform(self):
if matched_node_name:
self.output_graph = graph_pb2.GraphDef()
fusion_name = ''.join(matched_rule)
if fusion_name in self.fusion_mapping:
bn_node = self.node_name_mapping[matched_node_name[0]].node
is_training = bn_node.attr['is_training'].b
if fusion_name in self.fusion_mapping and is_training == False:
self.fusion_mapping[fusion_name](matched_node_name)
else:
if self.new_api:
self.logger.info("Unknown fusion pattern {}.".format(fusion_name))
if is_training == True:
self.logger.info \
("Skip quantizing the BN node '{}' due to the attr 'is_training == true'." \
.format(bn_node.name))
elif self.new_api:
self.logger.info("Unknown fusion pattern {} .".format(fusion_name))
if self.remove_redundant_quant_flag:
self.input_graph = self.remove_redundant_quantization(self.input_graph)
return self.input_graph, []
Expand Down
Loading

0 comments on commit cbbaf98

Please sign in to comment.