Skip to content

Commit

Permalink
fix TFOOB model regression (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhiwei35 authored Sep 19, 2022
1 parent 5f0adfe commit fc6f9a7
Showing 1 changed file with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def apply_matmul_biasadd_relu_fusion(self, match_node_name):
# workaround for RNN model like LTSM.
if not parent_node.op == 'Const':
self.logger.debug('The weight node of matched_node {} is not Const or Const + Enter, skipped')
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []
else:
Expand All @@ -130,7 +130,7 @@ def apply_matmul_biasadd_relu_fusion(self, match_node_name):
weights_content = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)

if np.any(np.isnan(weights_content)):
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []

Expand All @@ -141,7 +141,7 @@ def apply_matmul_biasadd_relu_fusion(self, match_node_name):

# If weight node non const, can't insert dummy biasadd to do matmul fusion.
if weight_node.op != 'Const' and len(match_node_name) == 3:
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []

Expand Down Expand Up @@ -313,7 +313,7 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
# workaround for RNN model like LTSM.
if not parent_node.op == 'Const':
self.logger.debug('The weight node of matched_node {} is not Const or Const + Enter, skipped')
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []
else:
Expand All @@ -334,15 +334,15 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
#TODO Remove below two lines once the TF enabled the QuantizedMatMul while
# transpose_a could be set to True.
if matched_node.node.attr["transpose_a"].b == True:
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []

if weight_node.op == 'Const':
weights_content = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)

if np.any(np.isnan(weights_content)):
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []

Expand All @@ -354,7 +354,7 @@ def apply_matmul_biasadd_fusion(self, match_node_name):

# If weight node non const, can't insert dummy biasadd to do matmul fusion.
if weight_node.op != 'Const' and len(match_node_name) == 3:
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []

Expand All @@ -373,7 +373,7 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
if len(match_node_name) == 3:
if is_shared_output:
self.output_graph = self.input_graph
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
return []
else:
second_node = self.node_name_mapping[match_node_name[2]].node
Expand Down Expand Up @@ -559,7 +559,7 @@ def apply_batchmatmulv2_fusion(self, match_node_name):
# workaround for RNN model like LTSM.
if not parent_node.op == 'Const':
self.logger.debug('The weight node of matched_node {} is not Const or Const + Enter, skipped')
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []
else:
Expand All @@ -580,14 +580,14 @@ def apply_batchmatmulv2_fusion(self, match_node_name):

if np.any(np.isnan(weights_content)):
self.output_graph = self.input_graph
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
return []

for i in self.node_name_mapping:
if weight_node.input and not weight_node.input[0].startswith('^') \
and weight_node.name in self.node_name_mapping[i].output:
self.output_graph = self.input_graph
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
return []

q_weights_name, q_weights_min_name, q_weights_max_name = \
Expand Down Expand Up @@ -692,7 +692,7 @@ def apply_batchmatmulv2_mul_add_fusion(self, match_node_name):
# workaround for RNN model like LTSM.
if not parent_node.op == 'Const':
self.logger.debug('The weight node of matched_node {} is not Const or Const + Enter, skipped')
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []
else:
Expand All @@ -712,15 +712,15 @@ def apply_batchmatmulv2_mul_add_fusion(self, match_node_name):
weights_content = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)

if np.any(np.isnan(weights_content)):
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []

for i in self.node_name_mapping:
if weight_node.input and not weight_node.input[0].startswith('^') \
and weight_node.name in self.node_name_mapping[i].output:
self.output_graph = self.input_graph
self.exclude_matmul_nodes.append(matched_node.name)
self.exclude_matmul_nodes.append(matched_node.node.name)
return []

q_weights_name, q_weights_min_name, q_weights_max_name = \
Expand Down

0 comments on commit fc6f9a7

Please sign in to comment.