Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed support for external data format #164

Merged
merged 1 commit into from
Dec 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions onnxconverter_common/onnx2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import sys
import onnx
import collections
import inspect
from collections import OrderedDict
from onnx import helper, numpy_helper, TensorProto
from onnx import helper, numpy_helper, TensorProto, external_data_helper
import numpy as np
import os

Expand All @@ -33,6 +34,30 @@
os_traced = TracingObject("os")


# <Helpers> These can be inlined into the output script #

def clear_field(proto, field):
proto.ClearField(field)
return proto


def make_external_tensor(name, data_type, dims, raw_data=None, **kwargs):
tensor = TensorProto()
tensor.data_type = data_type
tensor.name = name
tensor.dims.extend(dims)
if raw_data is not None:
tensor.raw_data = raw_data
external_data_helper.set_external_data(tensor, **kwargs)
return tensor

# </Helpers> #


clear_field_traced = TracingObject("clear_field")
make_external_tensor_traced = TracingObject("make_external_tensor")


def convert_tensor_type(i):
return getattr(TensorProtoTraced, TensorProto.DataType.Name(i))

Expand Down Expand Up @@ -90,13 +115,28 @@ def convert_shape_denotation(d):


def convert_operatorsetid(opsetid):
domain = opsetid.domain
version = opsetid.version
return helper_traced.make_operatorsetid(domain, version)
if opsetid.HasField("domain"):
domain = opsetid.domain
return helper_traced.make_operatorsetid(domain, version)
else:
return clear_field_traced(helper_traced.make_operatorsetid('', version), 'domain')


def convert_external_tensor(tensor):
kwargs = OrderedDict()
if tensor.HasField("raw_data"):
kwargs["raw_data"] = tensor.raw_data
if tensor.external_data:
for d in tensor.external_data:
kwargs[d.key] = d.value
return make_external_tensor_traced(tensor.name, tensor.data_type, tensor.dims, **kwargs)


def convert_tensor(tensor):
global const_dir, const_counter
if tensor.data_location == TensorProto.EXTERNAL:
return convert_external_tensor(tensor)
np_data = numpy_helper.to_array(tensor)
if np.product(np_data.shape) <= 10:
return numpy_helper_traced.from_array(np_data, name=tensor.name)
Expand Down Expand Up @@ -169,15 +209,24 @@ def convert(model, out_path):
const_dir = out_path
const_dir_name = os.path.basename(out_path)
const_counter = 0
TracingObject.reset_cnt(clear_field_traced)
TracingObject.reset_cnt(make_external_tensor_traced)

model_trace = convert_model(model)
code = "from onnx import helper, numpy_helper, TensorProto\n"
code = "from onnx import helper, numpy_helper, TensorProto"
if TracingObject.get_cnt(make_external_tensor_traced):
code += ", external_data_helper"
code += "\n"
code += "import onnx\n"
code += "import numpy as np\n"
code += "import sys\n"
if os.path.exists(const_dir):
code += "import os\n"
code += "\nDATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), %r)\n" % const_dir_name
if TracingObject.get_cnt(clear_field_traced):
code += "\n" + inspect.getsource(clear_field)
if TracingObject.get_cnt(make_external_tensor_traced):
code += "\n" + inspect.getsource(make_external_tensor)
code += "\n" + "model = " + repr(model_trace) + "\n"
code += "\nif __name__ == '__main__' and len(sys.argv) == 2:\n"
code += " _, out_path = sys.argv\n"
Expand All @@ -193,7 +242,7 @@ def main():
if not out_path.endswith(".py"):
out_path = out_path + ".py"

model = onnx.load(in_path)
model = onnx.load(in_path, load_external_data=False)
try:
convert(model, out_path)
except MissingHandlerException as e:
Expand Down
11 changes: 11 additions & 0 deletions onnxconverter_common/pytracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ class TracingObject:
"""
def __init__(self, trace):
self._trace = trace
self._cnt = 0

@staticmethod
def reset_cnt(o):
o._cnt = 0

@staticmethod
def get_cnt(o):
return o._cnt

@staticmethod
def from_repr(o):
Expand All @@ -38,9 +47,11 @@ def get_repr(x):
return "[\n" + "".join(indent(s) + ",\n" for s in ls) + "]"

def __getattr__(self, attr):
self._cnt += 1
return TracingObject(self._trace + "." + attr)

def __call__(self, *args, **kwargs):
self._cnt += 1
arg_s = [TracingObject.get_repr(o) for o in args]
arg_s += [k + "=" + TracingObject.get_repr(o) for k, o in kwargs.items()]
trace = self._trace + "(" + ", ".join(arg_s) + ")"
Expand Down
5 changes: 2 additions & 3 deletions tests/test_onnx2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import onnxruntime as _ort
import sys

from onnxconverter_common.onnx2py import convert, clear_directory

working_path = os.path.abspath(os.path.dirname(__file__))
tmp_path = os.path.join(working_path, 'temp')
data_path = os.path.join(working_path, 'data')

class Oonnx2PyTests(unittest.TestCase):
def tearDown(self):
from onnxconverter_common.onnx2py import clear_directory
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasn't sure if this is a good practice in general, any special reason of moving it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onnx2py requires onnx version > 1.3. The new from onnx import external_data_helper will make it crash when imported in lower versions. The python 5 unit tests use onnx 1.3 but are skipped for onnx2py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally maybe onnx2py would be its own repo since its pip requirements are technically higher than onnxconverter-common's install requirements.

for f in os.listdir(tmp_path):
if f.endswith(".py"):
os.remove(os.path.join(tmp_path, f))
Expand All @@ -22,7 +21,7 @@ def tearDown(self):

@unittest.skipIf(sys.version_info < (3, 6), "Requires onnx > 1.3.0")
def test_onnx2py(self):
global model
from onnxconverter_common.onnx2py import convert
model_name = 'test_model_1_no_opt'
onnx_model = onnx.load(os.path.join(data_path, model_name + '.onnx'))
sess1 = _ort.InferenceSession(onnx_model.SerializeToString())
Expand Down