Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix onnx2py for seq types #194

Merged
merged 1 commit into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions onnxconverter_common/onnx2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def convert_field(field):

def convert_value_info(val_info):
name = val_info.name
elem_type = convert_tensor_type(val_info.type.tensor_type.elem_type)
is_sequence_type = val_info.type.HasField('sequence_type')
if is_sequence_type:
tensor_type = val_info.type.sequence_type.elem_type.tensor_type
else:
tensor_type = val_info.type.tensor_type
elem_type = convert_tensor_type(tensor_type.elem_type)
kwargs = OrderedDict()

def convert_shape_dim(d):
Expand All @@ -132,18 +137,20 @@ def convert_shape_denotation(d):
return d.denotation
return None

if val_info.type.tensor_type.HasField("shape"):
kwargs["shape"] = [convert_shape_dim(d) for d in val_info.type.tensor_type.shape.dim]
if tensor_type.HasField("shape"):
kwargs["shape"] = [convert_shape_dim(d) for d in tensor_type.shape.dim]
else:
kwargs["shape"] = None
if any(d.HasField("denotation") for d in val_info.type.tensor_type.shape.dim):
kwargs["shape_denotation"] = [convert_shape_denotation(d) for d in val_info.type.tensor_type.shape.dim]
if any(d.HasField("denotation") for d in tensor_type.shape.dim):
kwargs["shape_denotation"] = [convert_shape_denotation(d) for d in tensor_type.shape.dim]

if val_info.HasField("doc_string"):
kwargs["doc_string"].doc_string

helper.make_tensor_value_info
return helper_traced.make_tensor_value_info(name, elem_type, **kwargs)
if is_sequence_type:
return helper_traced.make_sequence_value_info(name, elem_type, **kwargs)
else:
return helper_traced.make_tensor_value_info(name, elem_type, **kwargs)


def convert_operatorsetid(opsetid):
Expand Down Expand Up @@ -173,8 +180,8 @@ def convert_tensor(tensor):
if np.product(np_data.shape) <= 10:
return numpy_helper_traced.from_array(np_data, name=tensor.name)
dtype = np_data.dtype
if dtype == np.object:
np_data = np_data.astype(np.str)
if dtype == object:
np_data = np_data.astype(str)
os.makedirs(const_dir, exist_ok=True)
name = "const" + str(const_counter)
if tensor.name and len(tensor.name) < 100:
Expand All @@ -186,7 +193,7 @@ def convert_tensor(tensor):
np.save(const_path, np_data)
data_path = os_traced.path.join(DATA_DIR_TRACED, name + '.npy')
const_counter += 1
np_dtype = getattr(np_traced, str(dtype))
np_dtype = str(dtype)
np_shape = list(np_data.shape)
np_array = np_traced.load(data_path).astype(np_dtype).reshape(np_shape)
return numpy_helper_traced.from_array(np_array, name=tensor.name)
Expand Down
5 changes: 4 additions & 1 deletion onnxconverter_common/pytracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
###########################################################################

from collections import OrderedDict
import math
import numpy as np


Expand Down Expand Up @@ -44,7 +45,9 @@ def from_repr(o):
@staticmethod
def get_repr(x):
if isinstance(x, np.ndarray):
return "np.array(%r, dtype=np.%s)" % (x.tolist(), x.dtype)
return "np.array(%s, dtype='%s')" % (TracingObject.get_repr(x.tolist()), x.dtype)
if isinstance(x, float) and not math.isfinite(x):
return "float('%r')" % x # handle nan/inf/-inf
if not isinstance(x, list):
return repr(x)
ls = [TracingObject.get_repr(o) for o in x]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pytracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_tracing_numpy(self):
tracer = TracingObject("helper")
x = np.array([1, 2, 3], dtype=np.int32)
actual = repr(tracer.from_numpy(x))
expected = "helper.from_numpy(np.array([1, 2, 3], dtype=np.int32))"
expected = "helper.from_numpy(np.array([1, 2, 3], dtype='int32'))"
self.assertEqual(actual, expected)

if __name__ == '__main__':
Expand Down