Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tom/onnx2py #161

Merged
merged 4 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions onnxconverter_common/onnx2py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###########################################################################

"""
Converts onnx model into model.py file for easy editing. Resulting model.py file uses onnx.helper library to
recreate the original onnx model. Constant tensors with more than 10 elements are saved into .npy
files in location model/const#_tensor_name.npy

Example usage:
python -m onnxconverter_common.onnx2py my_model.onnx my_model.py
"""

import sys
import onnx
import collections
from onnx import helper, numpy_helper, TensorProto
import numpy as np
import os

from .pytracing import TracingObject

needed_types = set()
const_dir = None
const_counter = None

np_traced = TracingObject("np")
helper_traced = TracingObject("helper")
numpy_helper_traced = TracingObject("numpy_helper")
TensorProtoTraced = TracingObject("TensorProto")


def convert_tensor_type(i):
return getattr(TensorProtoTraced, TensorProto.DataType.Name(i))


def convert_field(field):
global needed_types
if isinstance(field, (int, str, float, bytes)):
return field
elif isinstance(field, onnx.GraphProto):
return convert_graph(field)
elif isinstance(field, onnx.ModelProto):
return convert_model(field)
elif isinstance(field, onnx.NodeProto):
return convert_node(field)
elif isinstance(field, onnx.TensorProto):
return convert_tensor(field)
elif isinstance(field, onnx.ValueInfoProto):
return convert_value_info(field)
elif isinstance(field, onnx.OperatorSetIdProto):
return convert_operatorsetid(field)
elif isinstance(field, collections.abc.Iterable):
return list(convert_field(x) for x in field)
else:
# Missing handler needs to be added
t = str(type(field))
needed_types.add(t)
return field


def convert_value_info(val_info):
name = val_info.name
elem_type = convert_tensor_type(val_info.type.tensor_type.elem_type)
kwargs = collections.OrderedDict()

def convert_shape_dim(d):
if d.HasField("dim_value"):
return d.dim_value
if d.HasField("dim_param"):
return d.dim_param
return None

def convert_shape_denotation(d):
if d.HasField("denotation"):
return d.denotation
return None

kwargs["shape"] = [convert_shape_dim(d) for d in val_info.type.tensor_type.shape.dim]
if any(d.HasField("denotation") for d in val_info.type.tensor_type.shape.dim):
kwargs["shape_denotation"] = [convert_shape_denotation(d) for d in val_info.type.tensor_type.shape.dim]

if val_info.HasField("doc_string"):
kwargs["doc_string"].doc_string

return helper_traced.make_tensor_value_info(name, elem_type, **kwargs)


def convert_operatorsetid(opsetid):
domain = opsetid.domain
version = opsetid.version
return helper_traced.make_operatorsetid(domain, version)


def convert_tensor(tensor):
global const_dir, const_counter
np_data = numpy_helper.to_array(tensor)
if np.product(np_data.shape) <= 10:
return numpy_helper_traced.from_array(np_data, name=tensor.name)
os.makedirs(const_dir, exist_ok=True)
name = "const" + str(const_counter)
if tensor.name:
name = name + "_" + tensor.name
for c in '~"#%&*:<>?/\\{|}':
name = name.replace(c, '_')
const_path = "%s/%s.npy" % (const_dir, name)
np.save(const_path, np_data)
const_counter += 1
return numpy_helper_traced.from_array(np_traced.load(const_path), name=tensor.name)


def convert_node(node):
fields = {f[0].name: f[1] for f in node.ListFields()}
attributes = fields.pop("attribute", [])
attrs = {a.name: convert_field(helper.get_attribute_value(a)) for a in attributes}
fields = {f: convert_field(v) for f, v in fields.items()}
op_type = fields.pop("op_type")
if op_type == "Cast" and "to" in attrs:
attrs["to"] = convert_tensor_type(attrs["to"])
inputs = fields.pop("input", [])
outputs = fields.pop("output", [])
return helper_traced.make_node(op_type, inputs=inputs, outputs=outputs, **fields, **attrs)


def convert_graph(graph):
fields = {f[0].name: convert_field(f[1]) for f in graph.ListFields()}
nodes = fields.pop("node", [])
name = fields.pop("name")
inputs = fields.pop("input", [])
outputs = fields.pop("output", [])
return helper_traced.make_graph(nodes, name=name, inputs=inputs, outputs=outputs, **fields)


def convert_model(model):
fields = {f[0].name: convert_field(f[1]) for f in model.ListFields()}
graph = fields.pop("graph")
opset_imports = fields.pop("opset_import", [])
return helper_traced.make_model(graph, opset_imports=opset_imports, **fields)


def clear_directory(path):
for f in os.listdir(path):
if f.endswith(".npy"):
os.remove(os.path.join(path, f))
try:
# Delete if empty
os.rmdir(path)
except OSError:
pass


class MissingHandlerException(Exception):
pass


def convert(model, out_path):
global needed_types, const_dir, const_counter
needed_types = set()
if out_path.endswith(".py"):
out_path = out_path[:-3]
if os.path.exists(out_path):
clear_directory(out_path)
const_dir = out_path
const_counter = 0

model_trace = convert_model(model)
code = "from onnx import helper, numpy_helper, TensorProto\n"
code += "import numpy as np\n"
code += "\n" + "model = " + repr(model_trace) + "\n"
with open(out_path + ".py", "wt") as file:
file.write(code)
if needed_types:
raise MissingHandlerException("Missing handler for types: %s" % list(needed_types))


def main():
_, in_path, out_path = sys.argv
if not out_path.endswith(".py"):
out_path = out_path + ".py"

model = onnx.load(in_path)
try:
convert(model, out_path)
except MissingHandlerException as e:
print("ERROR:", e)

print("Model saved to", out_path)


if __name__ == '__main__':
main()
52 changes: 52 additions & 0 deletions onnxconverter_common/pytracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###########################################################################

import numpy as np


def indent(s):
return "\n".join(" " + line for line in s.split("\n"))


class TracingObject:
"""
Used by onnx2py to mock a module like numpy or onnx.helper and record calls on that module
Ex:
np = TracingObject("np")
x = np.array(np.product([1, 2, 3]), np.int32)
assert repr(x) == "np.array(np.product([1, 2, 3]), np.int32)"
"""
def __init__(self, trace):
self._trace = trace

@staticmethod
def from_repr(o):
return TracingObject(TracingObject.get_repr(o))

@staticmethod
def get_repr(x):
if isinstance(x, np.ndarray):
return "np.array(%r, dtype=np.%s)" % (x.tolist(), x.dtype)
if not isinstance(x, list):
return repr(x)
ls = [TracingObject.get_repr(o) for o in x]
code = "[" + ", ".join(ls) + "]"
if len(code) <= 200:
return code
return "[\n" + "".join(indent(s) + ",\n" for s in ls) + "]"

def __getattr__(self, attr):
return TracingObject(self._trace + "." + attr)

def __call__(self, *args, **kwargs):
arg_s = [TracingObject.get_repr(o) for o in args]
arg_s += [k + "=" + TracingObject.get_repr(o) for k, o in kwargs.items()]
trace = self._trace + "(" + ", ".join(arg_s) + ")"
if len(trace) <= 200:
return TracingObject(trace)
return TracingObject(self._trace + "(\n" + "".join(indent(s) + ",\n" for s in arg_s) + ")")

def __repr__(self):
return self._trace
48 changes: 48 additions & 0 deletions tests/test_onnx2py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import unittest
import numpy as np
import os
import onnx
import onnxruntime as _ort
import sys

from onnxconverter_common.onnx2py import convert, clear_directory

working_path = os.path.abspath(os.path.dirname(__file__))
tmp_path = os.path.join(working_path, 'temp')
data_path = os.path.join(working_path, 'data')

class Oonnx2PyTests(unittest.TestCase):
def tearDown(self):
for f in os.listdir(tmp_path):
if f.endswith(".py"):
os.remove(os.path.join(tmp_path, f))
folder_path = os.path.join(tmp_path, f[:-3])
if os.path.exists(folder_path):
clear_directory(folder_path)

@unittest.skipIf(sys.version_info < (3, 6), "Requires onnx > 1.3.0")
def test_onnx2py(self):
global model
model_name = 'test_model_1_no_opt'
onnx_model = onnx.load(os.path.join(data_path, model_name + '.onnx'))
sess1 = _ort.InferenceSession(onnx_model.SerializeToString())
np.random.seed(42)
data = np.random.random_sample(size=(1, 1, 512)).astype(np.float32)
expected = sess1.run(["conv1d_1"], {"input_1": data})

os.makedirs(tmp_path, exist_ok=True)
out_path = os.path.join(tmp_path, model_name + '.py')
convert(onnx_model, out_path)
self.assertTrue(os.path.exists(out_path))
local_map = {}
with open(out_path, "rt") as f:
# Creates model called 'model'
exec(f.read(), None, local_map)
model = local_map["model"]
sess2 = _ort.InferenceSession(model.SerializeToString())
actual = sess2.run(["conv1d_1"], {"input_1": data})

np.testing.assert_allclose(expected[0], actual[0])

if __name__ == '__main__':
unittest.main()
21 changes: 21 additions & 0 deletions tests/test_pytracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
import numpy as np

from onnxconverter_common.pytracing import TracingObject

class TracingTests(unittest.TestCase):
def test_tracing_simple(self):
tracer = TracingObject("x")
actual = repr(tracer.func_call([1, 'A', 3, tracer], tracer))
expected = "x.func_call([1, 'A', 3, x], x)"
self.assertEqual(actual, expected)

def test_tracing_numpy(self):
tracer = TracingObject("helper")
x = np.array([1, 2, 3], dtype=np.int32)
actual = repr(tracer.from_numpy(x))
expected = "helper.from_numpy(np.array([1, 2, 3], dtype=np.int32))"
self.assertEqual(actual, expected)

if __name__ == '__main__':
unittest.main()