Skip to content

Commit

Permalink
Skip converting resize nodes.
Browse files Browse the repository at this point in the history
  • Loading branch information
toothache committed Mar 8, 2022
1 parent 03e47d8 commit 596f845
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions onnxconverter_common/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4,
value_info_list.append(new_value_info)
io_casts.add(node_name)

for node in model.graph.node:
if node.op_type == "Resize":
graph_io_to_skip = graph_io_to_skip.union([n for n in node.input[1:] if n])

while queue:
next_level = []
for q in queue:
Expand Down Expand Up @@ -222,8 +226,9 @@ def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4,
if isinstance(q, onnx_proto.GraphProto):
for n in q.initializer: # TensorProto type
if n.data_type == onnx_proto.TensorProto.FLOAT:
n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val)
value_info_list.append(make_value_info_from_tensor(n))
if n.name not in graph_io_to_skip:
n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val)
value_info_list.append(make_value_info_from_tensor(n))
# for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to
# tensor(float16) except map and seq(map). And save them in value_info_list for further processing
for n in itertools.chain(q.input, q.output, q.value_info):
Expand Down

0 comments on commit 596f845

Please sign in to comment.