Skip to content

Commit

Permalink
fix regression of TFOOB models (#1150)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhiwei35 authored Aug 16, 2022
1 parent da6cb87 commit a6f7476
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def do_transformation(self):
block_value = [i for i in tensor_util.MakeNdarray(
block_shape_node.attr['value'].tensor).flat]
new_dilation = [1, block_value[0], block_value[1], 1]

# if padding input of SpaceToBatchND can't be directly fetched, we continue
if stob_padding_node.op != 'Const':
continue
padding_value = [i for i in tensor_util.MakeNdarray(
stob_padding_node.attr['value'].tensor).flat]
crops_value = [i for i in tensor_util.MakeNdarray(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def get_optimized_model(self, itex_mode=False):

self._tmp_graph_def = ConvertNanToRandom(
self._tmp_graph_def).do_transformation()

self._tmp_graph_def = DilatedContraction(
self._tmp_graph_def).do_transformation()
if self.new_api:
self._tmp_graph_def = DilatedContraction(
self._tmp_graph_def).do_transformation()
self._excluded_node_names.extend(excluded_node_names)
self._tmp_graph_def.library.CopyFrom(self.model.graph_def.library)

Expand Down

0 comments on commit a6f7476

Please sign in to comment.