diff --git a/onnxconverter_common/onnx2py.py b/onnxconverter_common/onnx2py.py new file mode 100644 index 0000000..1deefeb --- /dev/null +++ b/onnxconverter_common/onnx2py.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +########################################################################### + +""" +Converts onnx model into model.py file for easy editing. Resulting model.py file uses onnx.helper library to +recreate the original onnx model. Constant tensors with more than 10 elements are saved into .npy +files in location model/const#_tensor_name.npy + +Example usage: +python -m onnxconverter_common.onnx2py my_model.onnx my_model.py +""" + +import sys +import onnx +import collections +from onnx import helper, numpy_helper, TensorProto +import numpy as np +import os + +from .pytracing import TracingObject + +needed_types = set() +const_dir = None +const_counter = None + +np_traced = TracingObject("np") +helper_traced = TracingObject("helper") +numpy_helper_traced = TracingObject("numpy_helper") +TensorProtoTraced = TracingObject("TensorProto") + + +def convert_tensor_type(i): + return getattr(TensorProtoTraced, TensorProto.DataType.Name(i)) + + +def convert_field(field): + global needed_types + if isinstance(field, (int, str, float, bytes)): + return field + elif isinstance(field, onnx.GraphProto): + return convert_graph(field) + elif isinstance(field, onnx.ModelProto): + return convert_model(field) + elif isinstance(field, onnx.NodeProto): + return convert_node(field) + elif isinstance(field, onnx.TensorProto): + return convert_tensor(field) + elif isinstance(field, onnx.ValueInfoProto): + return convert_value_info(field) + elif isinstance(field, onnx.OperatorSetIdProto): + return convert_operatorsetid(field) + elif isinstance(field, collections.abc.Iterable): + return list(convert_field(x) for x in field) + else: + # Missing handler needs to be added + t = str(type(field)) + needed_types.add(t) + return field + + +def convert_value_info(val_info): + name = val_info.name + elem_type = convert_tensor_type(val_info.type.tensor_type.elem_type) + kwargs = collections.OrderedDict() + + def convert_shape_dim(d): + if d.HasField("dim_value"): + return d.dim_value + if d.HasField("dim_param"): + return d.dim_param + return None + + def convert_shape_denotation(d): + if d.HasField("denotation"): + return d.denotation + return None + + kwargs["shape"] = [convert_shape_dim(d) for d in val_info.type.tensor_type.shape.dim] + if any(d.HasField("denotation") for d in val_info.type.tensor_type.shape.dim): + kwargs["shape_denotation"] = [convert_shape_denotation(d) for d in val_info.type.tensor_type.shape.dim] + + if val_info.HasField("doc_string"): + kwargs["doc_string"].doc_string + + return helper_traced.make_tensor_value_info(name, elem_type, **kwargs) + + +def convert_operatorsetid(opsetid): + domain = opsetid.domain + version = opsetid.version + return helper_traced.make_operatorsetid(domain, version) + + +def convert_tensor(tensor): + global const_dir, const_counter + np_data = numpy_helper.to_array(tensor) + if np.product(np_data.shape) <= 10: + return numpy_helper_traced.from_array(np_data, name=tensor.name) + os.makedirs(const_dir, exist_ok=True) + name = "const" + str(const_counter) + if tensor.name: + name = name + "_" + tensor.name + for c in '~"#%&*:<>?/\\{|}': + name = name.replace(c, '_') + const_path = "%s/%s.npy" % (const_dir, name) + np.save(const_path, np_data) + const_counter += 1 + return numpy_helper_traced.from_array(np_traced.load(const_path), name=tensor.name) + + +def convert_node(node): + fields = {f[0].name: f[1] for f in node.ListFields()} + attributes = fields.pop("attribute", []) + attrs = {a.name: convert_field(helper.get_attribute_value(a)) for a in attributes} + fields = {f: convert_field(v) for f, v in fields.items()} + op_type = fields.pop("op_type") + if op_type == "Cast" and "to" in attrs: + attrs["to"] = convert_tensor_type(attrs["to"]) + inputs = fields.pop("input", []) + outputs = fields.pop("output", []) + return helper_traced.make_node(op_type, inputs=inputs, outputs=outputs, **fields, **attrs) + + +def convert_graph(graph): + fields = {f[0].name: convert_field(f[1]) for f in graph.ListFields()} + nodes = fields.pop("node", []) + name = fields.pop("name") + inputs = fields.pop("input", []) + outputs = fields.pop("output", []) + return helper_traced.make_graph(nodes, name=name, inputs=inputs, outputs=outputs, **fields) + + +def convert_model(model): + fields = {f[0].name: convert_field(f[1]) for f in model.ListFields()} + graph = fields.pop("graph") + opset_imports = fields.pop("opset_import", []) + return helper_traced.make_model(graph, opset_imports=opset_imports, **fields) + + +def clear_directory(path): + for f in os.listdir(path): + if f.endswith(".npy"): + os.remove(os.path.join(path, f)) + try: + # Delete if empty + os.rmdir(path) + except OSError: + pass + + +class MissingHandlerException(Exception): + pass + + +def convert(model, out_path): + global needed_types, const_dir, const_counter + needed_types = set() + if out_path.endswith(".py"): + out_path = out_path[:-3] + if os.path.exists(out_path): + clear_directory(out_path) + const_dir = out_path + const_counter = 0 + + model_trace = convert_model(model) + code = "from onnx import helper, numpy_helper, TensorProto\n" + code += "import numpy as np\n" + code += "\n" + "model = " + repr(model_trace) + "\n" + with open(out_path + ".py", "wt") as file: + file.write(code) + if needed_types: + raise MissingHandlerException("Missing handler for types: %s" % list(needed_types)) + + +def main(): + _, in_path, out_path = sys.argv + if not out_path.endswith(".py"): + out_path = out_path + ".py" + + model = onnx.load(in_path) + try: + convert(model, out_path) + except MissingHandlerException as e: + print("ERROR:", e) + + print("Model saved to", out_path) + + +if __name__ == '__main__': + main() diff --git a/onnxconverter_common/pytracing.py b/onnxconverter_common/pytracing.py new file mode 100644 index 0000000..c0a7d16 --- /dev/null +++ b/onnxconverter_common/pytracing.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +########################################################################### + +import numpy as np + + +def indent(s): + return "\n".join(" " + line for line in s.split("\n")) + + +class TracingObject: + """ + Used by onnx2py to mock a module like numpy or onnx.helper and record calls on that module + Ex: + np = TracingObject("np") + x = np.array(np.product([1, 2, 3]), np.int32) + assert repr(x) == "np.array(np.product([1, 2, 3]), np.int32)" + """ + def __init__(self, trace): + self._trace = trace + + @staticmethod + def from_repr(o): + return TracingObject(TracingObject.get_repr(o)) + + @staticmethod + def get_repr(x): + if isinstance(x, np.ndarray): + return "np.array(%r, dtype=np.%s)" % (x.tolist(), x.dtype) + if not isinstance(x, list): + return repr(x) + ls = [TracingObject.get_repr(o) for o in x] + code = "[" + ", ".join(ls) + "]" + if len(code) <= 200: + return code + return "[\n" + "".join(indent(s) + ",\n" for s in ls) + "]" + + def __getattr__(self, attr): + return TracingObject(self._trace + "." + attr) + + def __call__(self, *args, **kwargs): + arg_s = [TracingObject.get_repr(o) for o in args] + arg_s += [k + "=" + TracingObject.get_repr(o) for k, o in kwargs.items()] + trace = self._trace + "(" + ", ".join(arg_s) + ")" + if len(trace) <= 200: + return TracingObject(trace) + return TracingObject(self._trace + "(\n" + "".join(indent(s) + ",\n" for s in arg_s) + ")") + + def __repr__(self): + return self._trace diff --git a/tests/test_onnx2py.py b/tests/test_onnx2py.py new file mode 100644 index 0000000..d247cb8 --- /dev/null +++ b/tests/test_onnx2py.py @@ -0,0 +1,48 @@ +import unittest +import numpy as np +import os +import onnx +import onnxruntime as _ort +import sys + +from onnxconverter_common.onnx2py import convert, clear_directory + +working_path = os.path.abspath(os.path.dirname(__file__)) +tmp_path = os.path.join(working_path, 'temp') +data_path = os.path.join(working_path, 'data') + +class Oonnx2PyTests(unittest.TestCase): + def tearDown(self): + for f in os.listdir(tmp_path): + if f.endswith(".py"): + os.remove(os.path.join(tmp_path, f)) + folder_path = os.path.join(tmp_path, f[:-3]) + if os.path.exists(folder_path): + clear_directory(folder_path) + + @unittest.skipIf(sys.version_info < (3, 6), "Requires onnx > 1.3.0") + def test_onnx2py(self): + global model + model_name = 'test_model_1_no_opt' + onnx_model = onnx.load(os.path.join(data_path, model_name + '.onnx')) + sess1 = _ort.InferenceSession(onnx_model.SerializeToString()) + np.random.seed(42) + data = np.random.random_sample(size=(1, 1, 512)).astype(np.float32) + expected = sess1.run(["conv1d_1"], {"input_1": data}) + + os.makedirs(tmp_path, exist_ok=True) + out_path = os.path.join(tmp_path, model_name + '.py') + convert(onnx_model, out_path) + self.assertTrue(os.path.exists(out_path)) + local_map = {} + with open(out_path, "rt") as f: + # Creates model called 'model' + exec(f.read(), None, local_map) + model = local_map["model"] + sess2 = _ort.InferenceSession(model.SerializeToString()) + actual = sess2.run(["conv1d_1"], {"input_1": data}) + + np.testing.assert_allclose(expected[0], actual[0]) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pytracing.py b/tests/test_pytracing.py new file mode 100644 index 0000000..9dab0be --- /dev/null +++ b/tests/test_pytracing.py @@ -0,0 +1,21 @@ +import unittest +import numpy as np + +from onnxconverter_common.pytracing import TracingObject + +class TracingTests(unittest.TestCase): + def test_tracing_simple(self): + tracer = TracingObject("x") + actual = repr(tracer.func_call([1, 'A', 3, tracer], tracer)) + expected = "x.func_call([1, 'A', 3, x], x)" + self.assertEqual(actual, expected) + + def test_tracing_numpy(self): + tracer = TracingObject("helper") + x = np.array([1, 2, 3], dtype=np.int32) + actual = repr(tracer.from_numpy(x)) + expected = "helper.from_numpy(np.array([1, 2, 3], dtype=np.int32))" + self.assertEqual(actual, expected) + +if __name__ == '__main__': + unittest.main()