Skip to content

Commit

Permalink
Generate graph after optimizing onnx model (#34)
Browse files Browse the repository at this point in the history
* Generate graph after optimizing onnx
  • Loading branch information
jiafatom authored Nov 14, 2019
1 parent c0c226e commit 511a071
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
2 changes: 1 addition & 1 deletion onnxconverter_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
This framework performs optimization for ONNX models and
includes common utilities for ONNX converters.
"""
__version__ = "1.6.0"
__version__ = "1.6.1"
__author__ = "Microsoft"
__producer__ = "OnnxMLTools"
__producer_version__ = __version__
Expand Down
57 changes: 57 additions & 0 deletions onnxconverter_common/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,63 @@ def optimize_onnx(onnx_nodes, nchw_inputs=None, inputs=None, outputs=None):
return _build_onnx_model(node_list)


def _remove_unused_initializers(nodes, initializers):
adjusted_initializers = []
nodes_input_set = set()
for n_ in nodes:
for input_name_ in n_.input:
nodes_input_set.add(input_name_)

for initializers_ in initializers:
if initializers_.name in nodes_input_set:
adjusted_initializers.append(initializers_)

return adjusted_initializers


def optimize_onnx_graph(onnx_nodes, nchw_inputs=None, inputs=None, outputs=None, initializers=None,
model_value_info=None, model_name='', target_opset=None):
"""
Optimize onnx model by several approaches.
:param onnx_nodes: the onnx node list in onnx model.
:param nchw_inputs: the name list of the inputs needed to be transposed as NCHW
:param inputs: the model input
:param outputs: the model output
:param initializers: the model initializers
:param model_value_info: the model value_info
:return: the optimized onnx graph
"""
if target_opset < 9:
raise Exception("target_opset = {}, Use optimize_onnx_graph for opset >= 9".format(target_opset))

# When calling ModelComponentContainer's add_initializer(...), nothing is added into the input list.
# However, In ONNX, for target opset < 9, initializers should also be model's (GraphProto) inputs.
# Thus, we create ValueInfoProto objects from initializers (type: TensorProto) directly and then add them into model's input list.
extra_inputs = [] # ValueInfoProto list of the initializers
for tensor in initializers:
# Sometimes (especially when creating optional input values such as RNN's initial hidden state), an initializer
# is also one of the original model's input, so it has been added into the container's input list. If this is
# the case, we need to skip one iteration to avoid duplicated inputs.
if tensor.name in [value_info.name for value_info in inputs]:
continue

# Initializers are always tensors so we can just call make_tensor_value_info(...)
value_info = helper.make_tensor_value_info(tensor.name, tensor.data_type, tensor.dims)
extra_inputs.append(value_info)

nodes = optimize_onnx(onnx_nodes, nchw_inputs=nchw_inputs, inputs=inputs + extra_inputs, outputs=outputs)

# Create a graph from its main components
adjusted_initializers = _remove_unused_initializers(nodes, initializers)
graph = helper.make_graph(nodes, model_name, inputs,
outputs, adjusted_initializers)

# Add extra information related to the graph
graph.value_info.extend(model_value_info)

return graph


def optimize_onnx_model(origin_model, nchw_inputs=None):
"""
the origin model will be updated after the optimization.
Expand Down

0 comments on commit 511a071

Please sign in to comment.