Skip to content

Commit

Permalink
Merge pull request #30 from kuke/refine_unittest
Browse files Browse the repository at this point in the history
Enhance and add new ops & resolve name conflicts
  • Loading branch information
Yibing Liu authored Apr 26, 2018
2 parents 5441d5f + 02d9ed9 commit 486b871
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 84 deletions.
36 changes: 24 additions & 12 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import argparse

from fluid.utils import op_io_info
from onnx import helper, checker
import paddle.fluid as fluid

Expand All @@ -28,7 +29,11 @@ def parse_args():
parser.add_argument(
"--fluid_model", required=True, help="Input PaddlePaddle Fluid model.")
parser.add_argument(
"--onnx_model", required=False, help="The path to save ONNX model.")
"--onnx_model", required=True, help="The path to save ONNX model.")
parser.add_argument(
"--to_print_model",
action='store_true',
help="To print converted ONNX model.")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -68,23 +73,14 @@ def convert(args):
for v in feed_target_names
]

# Create outputs
fetch_target_names = [
fetch_target.name for fetch_target in fetch_targets
]
outputs = [
paddle_variable_to_onnx_tensor(v, global_block)
for v in fetch_target_names
]

# Create nodes
for block in inference_program.blocks:
for op in block.ops:
if op.type in ops.node_maker:
# TODO(kuke): deal with the corner case that vars in
# different blocks have the same name
node_proto = ops.node_maker[op.type](operator=op,
scope=inference_scope)
block=block)

if isinstance(node_proto, tuple):
onnx_nodes.extend(list(node_proto))
Expand All @@ -95,6 +91,21 @@ def convert(args):
raise NotImplementedError("OP[%s] is not supported in "
"the converter!" % op.type)

# Create outputs
fetch_target_names = [
fetch_target.name for fetch_target in fetch_targets
]
# Get the new names for outputs if they've renamed in nodes' making
renamed_outputs = op_io_info.get_all_renamed_outputs()
fetch_target_names = [
name if name not in renamed_outputs else renamed_outputs[name]
for name in fetch_target_names
]
outputs = [
paddle_variable_to_onnx_tensor(v, global_block)
for v in fetch_target_names
]

# Make graph
model_name = os.path.basename(args.fluid_model.strip('/')).split('.')[0]
onnx_graph = helper.make_graph(onnx_nodes, model_name, inputs, outputs)
Expand All @@ -106,7 +117,8 @@ def convert(args):
checker.check_model(onnx_model)

# Print model
print("The converted model is:\n{}".format(onnx_model))
if args.to_print_model:
print("The converted model is:\n{}".format(onnx_model))

# Save converted model
if args.onnx_model is not None:
Expand Down
68 changes: 61 additions & 7 deletions fluid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from compiler.ast import flatten

def get_op_io_info(op):
inputs = dict([(name, op.input(name)) for name in op.input_names])
attrs = dict(
[(name, op.attr(name))
for name in op.attr_names]) if op.attr_names is not None else None
outputs = dict([(name, op.output(name)) for name in op.output_names])

return inputs, attrs, outputs
class OpIOsInfo():
"""Return inputs/outputs information for an operator, and resolve potential
name conflicts in ONNX graph.
"""

def __init__(self):
self._all_renamed_outputs = {}
self._renamed_cnt = 0

def _get_new_name(self, arg):
"""Get the new name for an argument.
"""

self._renamed_cnt += 1
return arg + '@dup_' + str(self._renamed_cnt)

def _rename_input_args(self):
"""Rename input arguments if their previous output arugments have been
renamed.
"""

for in_name in self.inputs:
if self.inputs[in_name][0] in self._all_renamed_outputs:
self.inputs[in_name][0] = self._all_renamed_outputs[self.inputs[
in_name][0]]

def _rename_output_args(self):
"""Rename output arguments if they have same names with the input
arguments.
"""

input_args = flatten(self.inputs.values())
for out_name in self.outputs:
if self.outputs[out_name][0] in input_args:
new_name = self._get_new_name(self.outputs[out_name][0])
self._all_renamed_outputs[self.outputs[out_name][0]] = new_name
self.outputs[out_name][0] = new_name

def get_all_renamed_outputs(self):
"""Get all the renamed outputs in history.
"""

return self._all_renamed_outputs

def __call__(self, op):
self.inputs = dict([(name, op.input(name)) for name in op.input_names])
self.attrs = dict(
[(name, op.attr(name))
for name in op.attr_names]) if op.attr_names is not None else None
self.outputs = dict(
[(name, op.output(name)) for name in op.output_names])

self._rename_input_args()
self._rename_output_args()

return self.inputs, self.attrs, self.outputs


# Instantiate the class to a callable object
op_io_info = OpIOsInfo()
Loading

0 comments on commit 486b871

Please sign in to comment.