From aa2795994aa6f5405ca04355fff5d0a3634d5a19 Mon Sep 17 00:00:00 2001 From: Leo Dong Date: Mon, 23 Dec 2024 13:40:56 -0800 Subject: [PATCH 1/3] [FIRST] Use UUID to avoid inserted cast node name collision. --- onnxconverter_common/float16.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxconverter_common/float16.py b/onnxconverter_common/float16.py index 2731695..36bcbdc 100644 --- a/onnxconverter_common/float16.py +++ b/onnxconverter_common/float16.py @@ -4,6 +4,8 @@ ########################################################################### import itertools +import uuid +import warnings import numpy as np import onnx import packaging.version as pv @@ -239,15 +241,15 @@ def process_node_in_block_list(graph: onnx_proto.GraphProto, global_input_name_d # Todo: global_input_name_dict still not fill value def insert_cast32_before_node(graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict): - for i in range(len(node.input)): + for i, input_name in enumerate(node.input): input_name = node.input[i] for value_info in itertools.chain(graph.value_info, graph.input): if input_name == value_info.name: if value_info.type.tensor_type.elem_type != onnx_proto.TensorProto.FLOAT16: break - cast_output_name = node.name + "_input_cast_" + str(i) + cast_node_name = f"onnxconverter_inserted_cast_{str(uuid.uuid4())}" + cast_output_name = f"{cast_node_name}_output" add_new_value_info(graph, value_info, cast_output_name, onnx_proto.TensorProto.FLOAT) - cast_node_name = node.name + "_input_cast" + str(i) add_cast_node(graph, [input_name], [cast_output_name], cast_node_name, onnx_proto.TensorProto.FLOAT) node.input[i] = cast_output_name break @@ -255,16 +257,16 @@ def insert_cast32_before_node(graph: onnx_proto.GraphProto, node: onnx_proto.Nod # Todo: global_input_name_dict still not fill value def insert_cast16_after_node(graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict): - for i in range(len(node.output)): + for i, output_name in enumerate(node.output): output_name = node.output[i] for value_info in itertools.chain(graph.value_info, graph.output): if output_name == value_info.name: if value_info.type.tensor_type.elem_type != onnx_proto.TensorProto.FLOAT: break - cast_input_name = node.name + "_output_cast_" + str(i) + cast_node_name = f"onnxconverter_inserted_cast_{str(uuid.uuid4())}" + cast_input_name = f"{cast_node_name}_input" add_new_value_info(graph, value_info, cast_input_name, onnx_proto.TensorProto.FLOAT) value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 - cast_node_name = node.name + "_output_cast" + str(i) add_cast_node(graph, [cast_input_name], [output_name], cast_node_name, onnx_proto.TensorProto.FLOAT16) node.output[i] = cast_input_name break From efcfd73f48c2c6518a0c7dc9c2123869f80e428a Mon Sep 17 00:00:00 2001 From: Leo Dong Date: Mon, 23 Dec 2024 13:41:24 -0800 Subject: [PATCH 2/3] Add Cast output to value info block list. --- onnxconverter_common/float16.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxconverter_common/float16.py b/onnxconverter_common/float16.py index 36bcbdc..6513fb3 100644 --- a/onnxconverter_common/float16.py +++ b/onnxconverter_common/float16.py @@ -276,7 +276,8 @@ def insert_cast16_after_node(graph: onnx_proto.GraphProto, node: onnx_proto.Node def process_tensor_in_node(graph: onnx_proto.GraphProto, op_block_list: list, node_block_list: list, min_positive_val, max_finite_val): value_info_block_list = set() # This is for later use, not in this step for node in graph.node: - if (node.op_type in op_block_list) or (node.name in node_block_list): + # NOTE: "Cast" operation cannot change its output type because it is strongly typed. + if (node.op_type in op_block_list) or (node.name in node_block_list) or (node.op_type == "Cast"): # Only need to block the output value_info changing for output_name in node.output: value_info_block_list.add(output_name) From e12afaf13a899899cf841b9bdcf181325fae51ab Mon Sep 17 00:00:00 2001 From: Leo Dong Date: Mon, 23 Dec 2024 13:53:06 -0800 Subject: [PATCH 3/3] Only remove cast pairs with fp32 input types. --- onnxconverter_common/float16.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/onnxconverter_common/float16.py b/onnxconverter_common/float16.py index 6513fb3..3389872 100644 --- a/onnxconverter_common/float16.py +++ b/onnxconverter_common/float16.py @@ -6,14 +6,14 @@ import itertools import uuid import warnings +from typing import Optional + import numpy as np import onnx import packaging.version as pv -import warnings from onnx import helper, numpy_helper from onnx import onnx_pb as onnx_proto - FLOAT32 = 1 FLOAT16 = 10 @@ -522,10 +522,31 @@ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto): if upstream_node.op_type == 'Constant': cast_node_list.remove(cast_node) - # 4. find the cast(to16) node which downstream is Cast(to32) + # 4. find (cast_to_fp16, cast_to_fp32) pairs where --fp32--> cast_to_fp16 --fp16--> cast_to_fp32. remove_candidate = [] + + name_to_value_info = { + value_info.name: value_info for value_info in itertools.chain(graph_proto.value_info, graph_proto.input) + } + + def get_type(name: str) -> Optional[int]: + if name in name_to_value_info: + return name_to_value_info[name].type + else: + # `name` has no value info. + return None + for cast_node_name, downstream_node in cast_node_downstream_dict.items(): cast_node = name_to_node_dict[cast_node_name] + if len(cast_node.input) != 1: + raise RuntimeError( + f"Cast node {cast_node_name} should have only one input, but has {len(cast_node.input)}." + ) + + input_type = get_type(cast_node.input[0]) + if input_type != onnx_proto.TensorProto.FLOAT: + continue + if isinstance(downstream_node, list): for dn in downstream_node: if dn.op_type == 'Cast' and \ @@ -542,7 +563,8 @@ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto): cast_node in cast_node_list: remove_candidate.append((cast_node, downstream_node)) - # 5. change the connection of "upstream->cast16->cast32->downstream" to "upstream->downstream" + # 5. change "upstream --fp32--> cast_to_fp16 --fp16--> cast_to_fp32 --fp32--> downstream" to + # "upstream --fp32--> downstream". for cast_node_pair in remove_candidate: first_cast_node = cast_node_pair[0] second_cast_node = cast_node_pair[1]