diff --git a/onnxconverter_common/onnx2py.py b/onnxconverter_common/onnx2py.py index beacf1a..b4630ca 100644 --- a/onnxconverter_common/onnx2py.py +++ b/onnxconverter_common/onnx2py.py @@ -15,8 +15,9 @@ import sys import onnx import collections +import inspect from collections import OrderedDict -from onnx import helper, numpy_helper, TensorProto +from onnx import helper, numpy_helper, TensorProto, external_data_helper import numpy as np import os @@ -33,6 +34,30 @@ os_traced = TracingObject("os") +# These can be inlined into the output script # + +def clear_field(proto, field): + proto.ClearField(field) + return proto + + +def make_external_tensor(name, data_type, dims, raw_data=None, **kwargs): + tensor = TensorProto() + tensor.data_type = data_type + tensor.name = name + tensor.dims.extend(dims) + if raw_data is not None: + tensor.raw_data = raw_data + external_data_helper.set_external_data(tensor, **kwargs) + return tensor + +# # + + +clear_field_traced = TracingObject("clear_field") +make_external_tensor_traced = TracingObject("make_external_tensor") + + def convert_tensor_type(i): return getattr(TensorProtoTraced, TensorProto.DataType.Name(i)) @@ -90,13 +115,28 @@ def convert_shape_denotation(d): def convert_operatorsetid(opsetid): - domain = opsetid.domain version = opsetid.version - return helper_traced.make_operatorsetid(domain, version) + if opsetid.HasField("domain"): + domain = opsetid.domain + return helper_traced.make_operatorsetid(domain, version) + else: + return clear_field_traced(helper_traced.make_operatorsetid('', version), 'domain') + + +def convert_external_tensor(tensor): + kwargs = OrderedDict() + if tensor.HasField("raw_data"): + kwargs["raw_data"] = tensor.raw_data + if tensor.external_data: + for d in tensor.external_data: + kwargs[d.key] = d.value + return make_external_tensor_traced(tensor.name, tensor.data_type, tensor.dims, **kwargs) def convert_tensor(tensor): global const_dir, const_counter + if tensor.data_location == TensorProto.EXTERNAL: + return convert_external_tensor(tensor) np_data = numpy_helper.to_array(tensor) if np.product(np_data.shape) <= 10: return numpy_helper_traced.from_array(np_data, name=tensor.name) @@ -169,15 +209,24 @@ def convert(model, out_path): const_dir = out_path const_dir_name = os.path.basename(out_path) const_counter = 0 + TracingObject.reset_cnt(clear_field_traced) + TracingObject.reset_cnt(make_external_tensor_traced) model_trace = convert_model(model) - code = "from onnx import helper, numpy_helper, TensorProto\n" + code = "from onnx import helper, numpy_helper, TensorProto" + if TracingObject.get_cnt(make_external_tensor_traced): + code += ", external_data_helper" + code += "\n" code += "import onnx\n" code += "import numpy as np\n" code += "import sys\n" if os.path.exists(const_dir): code += "import os\n" code += "\nDATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), %r)\n" % const_dir_name + if TracingObject.get_cnt(clear_field_traced): + code += "\n" + inspect.getsource(clear_field) + if TracingObject.get_cnt(make_external_tensor_traced): + code += "\n" + inspect.getsource(make_external_tensor) code += "\n" + "model = " + repr(model_trace) + "\n" code += "\nif __name__ == '__main__' and len(sys.argv) == 2:\n" code += " _, out_path = sys.argv\n" @@ -193,7 +242,7 @@ def main(): if not out_path.endswith(".py"): out_path = out_path + ".py" - model = onnx.load(in_path) + model = onnx.load(in_path, load_external_data=False) try: convert(model, out_path) except MissingHandlerException as e: diff --git a/onnxconverter_common/pytracing.py b/onnxconverter_common/pytracing.py index c0a7d16..e5901ba 100644 --- a/onnxconverter_common/pytracing.py +++ b/onnxconverter_common/pytracing.py @@ -20,6 +20,15 @@ class TracingObject: """ def __init__(self, trace): self._trace = trace + self._cnt = 0 + + @staticmethod + def reset_cnt(o): + o._cnt = 0 + + @staticmethod + def get_cnt(o): + return o._cnt @staticmethod def from_repr(o): @@ -38,9 +47,11 @@ def get_repr(x): return "[\n" + "".join(indent(s) + ",\n" for s in ls) + "]" def __getattr__(self, attr): + self._cnt += 1 return TracingObject(self._trace + "." + attr) def __call__(self, *args, **kwargs): + self._cnt += 1 arg_s = [TracingObject.get_repr(o) for o in args] arg_s += [k + "=" + TracingObject.get_repr(o) for k, o in kwargs.items()] trace = self._trace + "(" + ", ".join(arg_s) + ")" diff --git a/tests/test_onnx2py.py b/tests/test_onnx2py.py index c1e47a4..6b56dce 100644 --- a/tests/test_onnx2py.py +++ b/tests/test_onnx2py.py @@ -5,14 +5,13 @@ import onnxruntime as _ort import sys -from onnxconverter_common.onnx2py import convert, clear_directory - working_path = os.path.abspath(os.path.dirname(__file__)) tmp_path = os.path.join(working_path, 'temp') data_path = os.path.join(working_path, 'data') class Oonnx2PyTests(unittest.TestCase): def tearDown(self): + from onnxconverter_common.onnx2py import clear_directory for f in os.listdir(tmp_path): if f.endswith(".py"): os.remove(os.path.join(tmp_path, f)) @@ -22,7 +21,7 @@ def tearDown(self): @unittest.skipIf(sys.version_info < (3, 6), "Requires onnx > 1.3.0") def test_onnx2py(self): - global model + from onnxconverter_common.onnx2py import convert model_name = 'test_model_1_no_opt' onnx_model = onnx.load(os.path.join(data_path, model_name + '.onnx')) sess1 = _ort.InferenceSession(onnx_model.SerializeToString())