Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added header to output and fixed empty docstrings #166

Merged
merged 1 commit into from
Jan 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions onnxconverter_common/onnx2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,20 @@ def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None,
order_repeated_field(node.attribute, 'name', kwargs.keys())
return node


def make_graph(*args, doc_string=None, **kwargs):
graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks a bug that should be fixed in ONNX repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably empty doc strings don't need to be included, but I'm trying to make the results byte-for-byte identical. I would appreciate it if the onnx helpers let me do this.

if doc_string == '':
graph.doc_string = ''
return graph

# </Helpers> #


clear_field_traced = TracingObject("clear_field", clear_field)
make_external_tensor_traced = TracingObject("make_external_tensor", make_external_tensor)
make_node_traced = TracingObject("make_node", make_node)
make_graph_traced = TracingObject("make_graph", make_graph)
DATA_DIR_TRACED = None


Expand Down Expand Up @@ -198,7 +206,7 @@ def convert_graph(graph):
name = fields.pop("name")
inputs = fields.pop("input", [])
outputs = fields.pop("output", [])
return helper_traced.make_graph(name=name, inputs=inputs, outputs=outputs, **fields, nodes=nodes)
return make_graph_traced(name=name, inputs=inputs, outputs=outputs, **fields, nodes=nodes)


def convert_model(model):
Expand All @@ -223,6 +231,13 @@ class MissingHandlerException(Exception):
pass


FILE_HEADER = '''"""
Run this script to recreate the original onnx model.
Example usage:
python %s.py out_model_path.onnx
"""'''


def convert(model, out_path):
global needed_types, const_dir, const_counter, DATA_DIR_TRACED
needed_types = set()
Expand All @@ -238,7 +253,9 @@ def convert(model, out_path):
DATA_DIR_TRACED = TracingObject("DATA_DIR", const_dir)

model_trace = convert_field(model)
code = "from onnx import helper, numpy_helper, TensorProto"

code = FILE_HEADER % os.path.basename(out_path) + "\n"
code += "\nfrom onnx import helper, numpy_helper, TensorProto\n"
if TracingObject.get_cnt(make_external_tensor_traced):
code += ", external_data_helper"
code += "\n"
Expand All @@ -254,6 +271,7 @@ def convert(model, out_path):
if TracingObject.get_cnt(make_external_tensor_traced):
code += "\n" + inspect.getsource(make_external_tensor)
code += "\n" + inspect.getsource(make_node)
code += "\n" + inspect.getsource(make_graph)
code += "\n" + "model = " + repr(model_trace) + "\n"
code += "\nif __name__ == '__main__' and len(sys.argv) == 2:\n"
code += " _, out_path = sys.argv\n"
Expand All @@ -266,6 +284,7 @@ def convert(model, out_path):
file.write(code)
if needed_types:
raise MissingHandlerException("Missing handler for types: %s" % list(needed_types))
return model_trace


def main():
Expand All @@ -275,7 +294,12 @@ def main():

model = onnx.load(in_path, load_external_data=False)
try:
convert(model, out_path)
model_trace = convert(model, out_path)
if TracingObject.get_py_obj(model_trace).SerializeToString() == model.SerializeToString():
print("\nConversion successful. Converted model is identical.\n")
else:
print("\nWARNING: Conversion succeeded but converted model is not identical. "
"Difference might be trivial.\n")
except MissingHandlerException as e:
print("ERROR:", e)

Expand Down