diff --git a/onnxconverter_common/__init__.py b/onnxconverter_common/__init__.py index 26976ee..1d8da6e 100644 --- a/onnxconverter_common/__init__.py +++ b/onnxconverter_common/__init__.py @@ -1,8 +1,7 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### """ The entry point to onnxconverter-common. diff --git a/onnxconverter_common/case_insensitive_dict.py b/onnxconverter_common/case_insensitive_dict.py index f1706c9..41972d4 100644 --- a/onnxconverter_common/case_insensitive_dict.py +++ b/onnxconverter_common/case_insensitive_dict.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +############################################################################### + try: from collections.abc import Mapping, MutableMapping except ImportError: diff --git a/onnxconverter_common/container.py b/onnxconverter_common/container.py index 9036895..9d51675 100644 --- a/onnxconverter_common/container.py +++ b/onnxconverter_common/container.py @@ -1,8 +1,7 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### import six from onnx import helper diff --git a/onnxconverter_common/data_types.py b/onnxconverter_common/data_types.py index b41efd5..91fbe20 100644 --- a/onnxconverter_common/data_types.py +++ b/onnxconverter_common/data_types.py @@ -1,8 +1,7 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### import numbers import onnx diff --git a/onnxconverter_common/decast.py b/onnxconverter_common/decast.py index 2cff08e..e6ba720 100644 --- a/onnxconverter_common/decast.py +++ b/onnxconverter_common/decast.py @@ -1,10 +1,13 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +############################################################################### import sys import onnx from .optimizer import LinkedNode, Solution def remove_cast(lnodes, op_set): - while True: sln = [] for n_ in lnodes: diff --git a/onnxconverter_common/float16.py b/onnxconverter_common/float16.py index e7b68a5..e09a779 100644 --- a/onnxconverter_common/float16.py +++ b/onnxconverter_common/float16.py @@ -1,4 +1,3 @@ -########################################################################### # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. diff --git a/onnxconverter_common/interface.py b/onnxconverter_common/interface.py index 8ac076f..306e2fe 100644 --- a/onnxconverter_common/interface.py +++ b/onnxconverter_common/interface.py @@ -1,8 +1,7 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### # This file defines the interface of the converter internal object for callback, # So the usage of the methods and properties list here will not be affected among the different versions. diff --git a/onnxconverter_common/metadata_props.py b/onnxconverter_common/metadata_props.py index 816b2c8..9d57168 100644 --- a/onnxconverter_common/metadata_props.py +++ b/onnxconverter_common/metadata_props.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +############################################################################### + import warnings from .case_insensitive_dict import CaseInsensitiveDict from onnx import onnx_pb as onnx_proto diff --git a/onnxconverter_common/onnx_ops.py b/onnxconverter_common/onnx_ops.py index bab150b..df8437e 100644 --- a/onnxconverter_common/onnx_ops.py +++ b/onnxconverter_common/onnx_ops.py @@ -1,8 +1,8 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### + # This file contains some high-level APIs for applying operations on variables specified by names. We should try our # best to use those functions because they can produce ONNX operators according to the ONNX version specified in the # `container` argument. Notice that those function behaviors are defined in a way very similar to ONNX-1.2. diff --git a/onnxconverter_common/oopb.py b/onnxconverter_common/oopb.py new file mode 100644 index 0000000..33d0c36 --- /dev/null +++ b/onnxconverter_common/oopb.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +############################################################################### + +import functools +import numpy as np +from onnx import onnx_pb as onnx_proto +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE +from . import onnx_ops + + +class _OperatorNameContext: + def __init__(self, oopb, basename): + self.basename = basename + self.oopb = oopb + + def __enter__(self): + assert self.oopb.basename is None, "The previous context doesn't quit" + self.oopb.basename = self.basename + return self.oopb + + def __exit__(self, type, value, traceback): + self.oopb.basename = None + + +class OnnxOperatorBuilder: + def __init__(self, container, scope): + self._container = container + self._scope = scope + self.basename = None + self.int32 = onnx_proto.TensorProto.INT32 + self.int64 = onnx_proto.TensorProto.INT64 + self.float = onnx_proto.TensorProto.FLOAT + self.float16 = onnx_proto.TensorProto.FLOAT16 + self.double = onnx_proto.TensorProto.DOUBLE + self.bool = onnx_proto.TensorProto.BOOL + + apply_operations = onnx_ops.__dict__ + for k_, m_ in apply_operations.items(): + if k_.startswith("apply_") and callable(m_): + setattr(self, k_, functools.partial(self.apply_op, m_)) + + def as_default(self, basename): + return _OperatorNameContext(self, basename) + + def _process_inputs(self, inputs, name): + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + ox_inputs = [] + for i_ in inputs: + ox_n = i_ + if isinstance(i_, np.ndarray): + ox_n = self._scope.get_unique_variable_name(name + '_i') + self._container.add_initializer( + ox_n, + NP_TYPE_TO_TENSOR_TYPE[i_.dtype], + i_.shape, + i_.flatten() + ) + elif isinstance(i_, (tuple, list)): + ox_n = self._scope.get_unique_variable_name(name + i_[0]) + self._container.add_initializer( + ox_n, + i_[1], + i_[2].shape, + i_[2].flatten() + ) + elif isinstance(ox_n, str): + pass + else: + raise ValueError('Unknown type for ONNX initializer: {}'.format(type(ox_n))) + ox_inputs.append(ox_n) + + return ox_inputs + + def _generate_name(self, type_or_func, name): + base_name = (self.basename if self.basename else '') + '_' + if name is not None: + long_name = base_name + name + else: + if isinstance(type_or_func, str): + suffix = type_or_func.lower() + else: + suffix = type_or_func.__name__[len('apply_'):] + long_name = base_name + suffix + return long_name + + def add_node(self, op_type, inputs, name=None, outputs=None, op_domain='', op_version=None, **attrs): + if op_version is None: + op_version = self._container.target_opset + name = self._generate_name(op_type, name) + if outputs is None: + ox_output = 1 + else: + ox_output = outputs + if isinstance(ox_output, int): + ox_output = [self._scope.get_unique_variable_name(name + str(i_)) for i_ in range(ox_output)] + elif isinstance(ox_output, (list, tuple)): + pass + else: + raise ValueError('Unknown type for outputs: {}'.format(type(ox_output))) + ox_inputs = self._process_inputs(inputs, name) + self._container.add_node(op_type, ox_inputs, ox_output, op_domain, op_version, + name=self._scope.get_unique_operator_name(name), **attrs) + return ox_output[0] if outputs is None else ox_output + + def apply_op(self, apply_func, inputs, name=None, outputs=None, **attrs): + name = self._generate_name(apply_func, name) + if outputs is None: + ox_output = 1 + else: + ox_output = outputs + if isinstance(ox_output, int): + ox_output = [self._scope.get_unique_variable_name(name + str(i_)) for i_ in range(ox_output)] + elif isinstance(ox_output, (list, tuple)): + pass + else: + raise ValueError('Unknown type for outputs: {}'.format(type(ox_output))) + ox_inputs = self._process_inputs(inputs, name) + apply_func(self._scope, ox_inputs, ox_output, self._container, + operator_name=self._scope.get_unique_operator_name(name), **attrs) + return ox_output[0] if outputs is None else ox_output + + def apply_op_name(self, apply_func_name, inputs, name=None, outputs=None, **attrs): + apply_operations = onnx_ops.__dict__ + apply_func = apply_operations[apply_func_name] + assert apply_func is not None, "{} not implemented in onnx_ops.py.".format(apply_func_name) + return self.apply_op(apply_func, inputs, name, outputs) diff --git a/onnxconverter_common/optimizer.py b/onnxconverter_common/optimizer.py index b939531..bb0433d 100644 --- a/onnxconverter_common/optimizer.py +++ b/onnxconverter_common/optimizer.py @@ -1,8 +1,8 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### + import six import numpy as np import onnx diff --git a/onnxconverter_common/registration.py b/onnxconverter_common/registration.py index 2b14831..8d44698 100644 --- a/onnxconverter_common/registration.py +++ b/onnxconverter_common/registration.py @@ -1,8 +1,7 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### # This dictionary defines the converters which can be invoked in the conversion framework defined in topology.py. A key # in this dictionary is an operator's unique ID (e.g., string and type) while the associated value is the callable diff --git a/onnxconverter_common/shape_calculator.py b/onnxconverter_common/shape_calculator.py index 195f7ec..0932817 100644 --- a/onnxconverter_common/shape_calculator.py +++ b/onnxconverter_common/shape_calculator.py @@ -1,11 +1,12 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### + """ Common functions to convert any learner based on trees. """ + import numpy as np import numbers import six diff --git a/onnxconverter_common/topology.py b/onnxconverter_common/topology.py index ecb03d8..c03975d 100644 --- a/onnxconverter_common/topology.py +++ b/onnxconverter_common/topology.py @@ -1,15 +1,13 @@ -# ------------------------------------------------------------------------- +# coding=utf-8 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### import re import warnings from logging import getLogger from distutils.version import StrictVersion -import onnx -from onnx import onnx_pb as onnx_proto from onnx import helper from .metadata_props import add_metadata_props from . import registration @@ -17,7 +15,7 @@ from .data_types import * from .container import ModelComponentContainer from .optimizer import optimize_onnx -from .interface import OperatorBase +from .interface import OperatorBase, ScopeBase class Variable: @@ -110,7 +108,7 @@ def infer_types(self): registration.get_shape_calculator(self.type)(self) -class Scope: +class Scope(ScopeBase): def __init__(self, name, parent_scopes=None, variable_name_set=None, operator_name_set=None, target_opset=None): ''' @@ -786,7 +784,7 @@ def convert_topology(topology, model_name, doc_string, target_opset, targeted_on i += 1 if container.target_opset < op_version: raise RuntimeError(('The specified opset %d is too low to convert this model, ' + - 'which requires at least opset %d.') % (container.target_opset, op_version)) + 'which requires at least opset %d.') % (container.target_opset, op_version)) elif container.target_opset > op_version: getLogger('onnxmltools').warning('The maximum opset needed by this model is only %d.' % op_version) diff --git a/onnxconverter_common/tree_ensemble.py b/onnxconverter_common/tree_ensemble.py index d5487f4..e185e29 100644 --- a/onnxconverter_common/tree_ensemble.py +++ b/onnxconverter_common/tree_ensemble.py @@ -1,11 +1,12 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### + """ Common functions to convert any learner based on trees. """ + from .registration import register_converter diff --git a/onnxconverter_common/utils.py b/onnxconverter_common/utils.py index 62f6d7b..c2b7f88 100644 --- a/onnxconverter_common/utils.py +++ b/onnxconverter_common/utils.py @@ -1,13 +1,12 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +############################################################################### import numbers, six import numpy as np import warnings -from distutils.version import LooseVersion, StrictVersion +from distutils.version import LooseVersion def sparkml_installed(): @@ -20,6 +19,7 @@ def sparkml_installed(): except ImportError: return False + def sklearn_installed(): """ Checks that *scikit-learn* is available. @@ -30,6 +30,7 @@ def sklearn_installed(): except ImportError: return False + def skl2onnx_installed(): """ Checks that *skl2onnx* converter is available. @@ -40,6 +41,7 @@ def skl2onnx_installed(): except ImportError: return False + def coreml_installed(): """ Checks that *coremltools* is available. @@ -140,6 +142,7 @@ def xgboost_installed(): warnings.warn('The converter works for xgboost >= 0.7. Earlier versions might not.') return True + def h2o_installed(): """ Checks that *h2o* is available. @@ -150,6 +153,7 @@ def h2o_installed(): return False return True + def get_producer(): """ Internal helper function to return the producer @@ -203,7 +207,7 @@ def is_string_type(item): if isinstance(item, np.ndarray): return np.issubdtype(item.dtype, np.str_) return isinstance(item, types) - + def cast_list(type, items): return [type(item) for item in items] diff --git a/tests/test_oopb.py b/tests/test_oopb.py new file mode 100644 index 0000000..94904a2 --- /dev/null +++ b/tests/test_oopb.py @@ -0,0 +1,45 @@ +import unittest +import numpy as np +from onnxconverter_common.oopb import OnnxOperatorBuilder +from onnxconverter_common.container import ModelComponentContainer, RawModelContainer +from onnxconverter_common.topology import Topology, convert_topology + + +class _SimpleRawModelContainer(RawModelContainer): + def __init__(self): + super(_SimpleRawModelContainer, self).__init__(None) + + @property + def input_names(self): + return ['input_0', 'input_1'] + + @property + def output_names(self): + return ['m_output'] + + +class OnnxOpTestCase(unittest.TestCase): + def setUp(self): + self.raw_model = _SimpleRawModelContainer() + + def test_apply_op(self): + topo = Topology(self.raw_model) + scope = topo.declare_scope('__ROOT__') + container = ModelComponentContainer(target_opset=7) + + with OnnxOperatorBuilder(container, scope).as_default('node_bn') as oopb: + mul_node = oopb.apply_mul(self.raw_model.input_names) + sub_node = oopb.apply_sub([mul_node] + [np.array([1.0, 2.0])]) + output = oopb.add_node('Add', + [sub_node, ('add_1', oopb.float, np.array([3.0, 4.0]))], + outputs=self.raw_model.output_names) + + self.assertIsInstance(output, list) + self.assertEqual(len(container.nodes), 3) + self.assertEqual(len(container.initializers), 2) + self.assertTrue(container.nodes[0].name.startswith('node_bn')) + + # a fake conversion to check the scope data correctness. + oxml = convert_topology(topo, 'test', "doc_string", 7, None) + self.assertIsNotNone(oxml) + self.assertEqual(len(oxml.graph.node), 0)