From 46a72b90d25bb92231c82c48a4c7828bbca167fe Mon Sep 17 00:00:00 2001 From: TomWildenhain-Microsoft <67606533+TomWildenhain-Microsoft@users.noreply.github.com> Date: Thu, 22 Jul 2021 14:54:16 -0700 Subject: [PATCH] Fix onnx2py for seq types (#194) --- onnxconverter_common/onnx2py.py | 27 +++++++++++++++++---------- onnxconverter_common/pytracing.py | 5 ++++- tests/test_pytracing.py | 2 +- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/onnxconverter_common/onnx2py.py b/onnxconverter_common/onnx2py.py index dd7e77c..684fa01 100644 --- a/onnxconverter_common/onnx2py.py +++ b/onnxconverter_common/onnx2py.py @@ -117,7 +117,12 @@ def convert_field(field): def convert_value_info(val_info): name = val_info.name - elem_type = convert_tensor_type(val_info.type.tensor_type.elem_type) + is_sequence_type = val_info.type.HasField('sequence_type') + if is_sequence_type: + tensor_type = val_info.type.sequence_type.elem_type.tensor_type + else: + tensor_type = val_info.type.tensor_type + elem_type = convert_tensor_type(tensor_type.elem_type) kwargs = OrderedDict() def convert_shape_dim(d): @@ -132,18 +137,20 @@ def convert_shape_denotation(d): return d.denotation return None - if val_info.type.tensor_type.HasField("shape"): - kwargs["shape"] = [convert_shape_dim(d) for d in val_info.type.tensor_type.shape.dim] + if tensor_type.HasField("shape"): + kwargs["shape"] = [convert_shape_dim(d) for d in tensor_type.shape.dim] else: kwargs["shape"] = None - if any(d.HasField("denotation") for d in val_info.type.tensor_type.shape.dim): - kwargs["shape_denotation"] = [convert_shape_denotation(d) for d in val_info.type.tensor_type.shape.dim] + if any(d.HasField("denotation") for d in tensor_type.shape.dim): + kwargs["shape_denotation"] = [convert_shape_denotation(d) for d in tensor_type.shape.dim] if val_info.HasField("doc_string"): kwargs["doc_string"].doc_string - helper.make_tensor_value_info - return helper_traced.make_tensor_value_info(name, elem_type, **kwargs) + if is_sequence_type: + return helper_traced.make_sequence_value_info(name, elem_type, **kwargs) + else: + return helper_traced.make_tensor_value_info(name, elem_type, **kwargs) def convert_operatorsetid(opsetid): @@ -173,8 +180,8 @@ def convert_tensor(tensor): if np.product(np_data.shape) <= 10: return numpy_helper_traced.from_array(np_data, name=tensor.name) dtype = np_data.dtype - if dtype == np.object: - np_data = np_data.astype(np.str) + if dtype == object: + np_data = np_data.astype(str) os.makedirs(const_dir, exist_ok=True) name = "const" + str(const_counter) if tensor.name and len(tensor.name) < 100: @@ -186,7 +193,7 @@ def convert_tensor(tensor): np.save(const_path, np_data) data_path = os_traced.path.join(DATA_DIR_TRACED, name + '.npy') const_counter += 1 - np_dtype = getattr(np_traced, str(dtype)) + np_dtype = str(dtype) np_shape = list(np_data.shape) np_array = np_traced.load(data_path).astype(np_dtype).reshape(np_shape) return numpy_helper_traced.from_array(np_array, name=tensor.name) diff --git a/onnxconverter_common/pytracing.py b/onnxconverter_common/pytracing.py index f7727f4..8134ca6 100644 --- a/onnxconverter_common/pytracing.py +++ b/onnxconverter_common/pytracing.py @@ -4,6 +4,7 @@ ########################################################################### from collections import OrderedDict +import math import numpy as np @@ -44,7 +45,9 @@ def from_repr(o): @staticmethod def get_repr(x): if isinstance(x, np.ndarray): - return "np.array(%r, dtype=np.%s)" % (x.tolist(), x.dtype) + return "np.array(%s, dtype='%s')" % (TracingObject.get_repr(x.tolist()), x.dtype) + if isinstance(x, float) and not math.isfinite(x): + return "float('%r')" % x # handle nan/inf/-inf if not isinstance(x, list): return repr(x) ls = [TracingObject.get_repr(o) for o in x] diff --git a/tests/test_pytracing.py b/tests/test_pytracing.py index 9dab0be..e1325bb 100644 --- a/tests/test_pytracing.py +++ b/tests/test_pytracing.py @@ -14,7 +14,7 @@ def test_tracing_numpy(self): tracer = TracingObject("helper") x = np.array([1, 2, 3], dtype=np.int32) actual = repr(tracer.from_numpy(x)) - expected = "helper.from_numpy(np.array([1, 2, 3], dtype=np.int32))" + expected = "helper.from_numpy(np.array([1, 2, 3], dtype='int32'))" self.assertEqual(actual, expected) if __name__ == '__main__':