diff --git a/onnxconverter_common/optimizer.py b/onnxconverter_common/optimizer.py index e30468a..b939531 100644 --- a/onnxconverter_common/optimizer.py +++ b/onnxconverter_common/optimizer.py @@ -789,6 +789,63 @@ def optimize_onnx(onnx_nodes, nchw_inputs=None, inputs=None, outputs=None): return _build_onnx_model(node_list) +def _remove_unused_initializers(nodes, initializers): + adjusted_initializers = [] + nodes_input_set = set() + for n_ in nodes: + for input_name_ in n_.input: + nodes_input_set.add(input_name_) + + for initializers_ in initializers: + if initializers_.name in nodes_input_set: + adjusted_initializers.append(initializers_) + + return adjusted_initializers + + +def optimize_onnx_graph(onnx_nodes, nchw_inputs=None, inputs=None, outputs=None, initializers=None, + model_value_info=None, model_name='', target_opset=None): + """ + Optimize onnx model by several approaches. + :param onnx_nodes: the onnx node list in onnx model. + :param nchw_inputs: the name list of the inputs needed to be transposed as NCHW + :param inputs: the model input + :param outputs: the model output + :param initializers: the model initializers + :param model_value_info: the model value_info + :return: the optimized onnx graph + """ + if target_opset < 9: + raise Exception("target_opset = {}, Use optimize_onnx_graph for opset >= 9".format(target_opset)) + + # When calling ModelComponentContainer's add_initializer(...), nothing is added into the input list. + # However, In ONNX, for target opset < 9, initializers should also be model's (GraphProto) inputs. + # Thus, we create ValueInfoProto objects from initializers (type: TensorProto) directly and then add them into model's input list. + extra_inputs = [] # ValueInfoProto list of the initializers + for tensor in initializers: + # Sometimes (especially when creating optional input values such as RNN's initial hidden state), an initializer + # is also one of the original model's input, so it has been added into the container's input list. If this is + # the case, we need to skip one iteration to avoid duplicated inputs. + if tensor.name in [value_info.name for value_info in inputs]: + continue + + # Initializers are always tensors so we can just call make_tensor_value_info(...) + value_info = helper.make_tensor_value_info(tensor.name, tensor.data_type, tensor.dims) + extra_inputs.append(value_info) + + nodes = optimize_onnx(onnx_nodes, nchw_inputs=nchw_inputs, inputs=inputs + extra_inputs, outputs=outputs) + + # Create a graph from its main components + adjusted_initializers = _remove_unused_initializers(nodes, initializers) + graph = helper.make_graph(nodes, model_name, inputs, + outputs, adjusted_initializers) + + # Add extra information related to the graph + graph.value_info.extend(model_value_info) + + return graph + + def optimize_onnx_model(origin_model, nchw_inputs=None): """ the origin model will be updated after the optimization.