Skip to content

Commit

Permalink
A better operator builder for ONNXConverter-Common (#46)
Browse files Browse the repository at this point in the history
* A better operator builder for ONNXConverter-Common

* taking opset=7 for unit test.
  • Loading branch information
wenbingl committed May 14, 2020
1 parent f7f9ca0 commit e559e64
Show file tree
Hide file tree
Showing 17 changed files with 217 additions and 31 deletions.
3 changes: 1 addition & 2 deletions onnxconverter_common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

"""
The entry point to onnxconverter-common.
Expand Down
5 changes: 5 additions & 0 deletions onnxconverter_common/case_insensitive_dict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################

try:
from collections.abc import Mapping, MutableMapping
except ImportError:
Expand Down
3 changes: 1 addition & 2 deletions onnxconverter_common/container.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

import six
from onnx import helper
Expand Down
3 changes: 1 addition & 2 deletions onnxconverter_common/data_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

import numbers
import onnx
Expand Down
5 changes: 4 additions & 1 deletion onnxconverter_common/decast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################
import sys
import onnx
from .optimizer import LinkedNode, Solution


def remove_cast(lnodes, op_set):

while True:
sln = []
for n_ in lnodes:
Expand Down
1 change: 0 additions & 1 deletion onnxconverter_common/float16.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
###########################################################################
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
Expand Down
3 changes: 1 addition & 2 deletions onnxconverter_common/interface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################
# This file defines the interface of the converter internal object for callback,
# So the usage of the methods and properties list here will not be affected among the different versions.

Expand Down
5 changes: 5 additions & 0 deletions onnxconverter_common/metadata_props.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################

import warnings
from .case_insensitive_dict import CaseInsensitiveDict
from onnx import onnx_pb as onnx_proto
Expand Down
4 changes: 2 additions & 2 deletions onnxconverter_common/onnx_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

# This file contains some high-level APIs for applying operations on variables specified by names. We should try our
# best to use those functions because they can produce ONNX operators according to the ONNX version specified in the
# `container` argument. Notice that those function behaviors are defined in a way very similar to ONNX-1.2.
Expand Down
130 changes: 130 additions & 0 deletions onnxconverter_common/oopb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################

import functools
import numpy as np
from onnx import onnx_pb as onnx_proto
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
from . import onnx_ops


class _OperatorNameContext:
def __init__(self, oopb, basename):
self.basename = basename
self.oopb = oopb

def __enter__(self):
assert self.oopb.basename is None, "The previous context doesn't quit"
self.oopb.basename = self.basename
return self.oopb

def __exit__(self, type, value, traceback):
self.oopb.basename = None


class OnnxOperatorBuilder:
def __init__(self, container, scope):
self._container = container
self._scope = scope
self.basename = None
self.int32 = onnx_proto.TensorProto.INT32
self.int64 = onnx_proto.TensorProto.INT64
self.float = onnx_proto.TensorProto.FLOAT
self.float16 = onnx_proto.TensorProto.FLOAT16
self.double = onnx_proto.TensorProto.DOUBLE
self.bool = onnx_proto.TensorProto.BOOL

apply_operations = onnx_ops.__dict__
for k_, m_ in apply_operations.items():
if k_.startswith("apply_") and callable(m_):
setattr(self, k_, functools.partial(self.apply_op, m_))

def as_default(self, basename):
return _OperatorNameContext(self, basename)

def _process_inputs(self, inputs, name):
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
ox_inputs = []
for i_ in inputs:
ox_n = i_
if isinstance(i_, np.ndarray):
ox_n = self._scope.get_unique_variable_name(name + '_i')
self._container.add_initializer(
ox_n,
NP_TYPE_TO_TENSOR_TYPE[i_.dtype],
i_.shape,
i_.flatten()
)
elif isinstance(i_, (tuple, list)):
ox_n = self._scope.get_unique_variable_name(name + i_[0])
self._container.add_initializer(
ox_n,
i_[1],
i_[2].shape,
i_[2].flatten()
)
elif isinstance(ox_n, str):
pass
else:
raise ValueError('Unknown type for ONNX initializer: {}'.format(type(ox_n)))
ox_inputs.append(ox_n)

return ox_inputs

def _generate_name(self, type_or_func, name):
base_name = (self.basename if self.basename else '') + '_'
if name is not None:
long_name = base_name + name
else:
if isinstance(type_or_func, str):
suffix = type_or_func.lower()
else:
suffix = type_or_func.__name__[len('apply_'):]
long_name = base_name + suffix
return long_name

def add_node(self, op_type, inputs, name=None, outputs=None, op_domain='', op_version=None, **attrs):
if op_version is None:
op_version = self._container.target_opset
name = self._generate_name(op_type, name)
if outputs is None:
ox_output = 1
else:
ox_output = outputs
if isinstance(ox_output, int):
ox_output = [self._scope.get_unique_variable_name(name + str(i_)) for i_ in range(ox_output)]
elif isinstance(ox_output, (list, tuple)):
pass
else:
raise ValueError('Unknown type for outputs: {}'.format(type(ox_output)))
ox_inputs = self._process_inputs(inputs, name)
self._container.add_node(op_type, ox_inputs, ox_output, op_domain, op_version,
name=self._scope.get_unique_operator_name(name), **attrs)
return ox_output[0] if outputs is None else ox_output

def apply_op(self, apply_func, inputs, name=None, outputs=None, **attrs):
name = self._generate_name(apply_func, name)
if outputs is None:
ox_output = 1
else:
ox_output = outputs
if isinstance(ox_output, int):
ox_output = [self._scope.get_unique_variable_name(name + str(i_)) for i_ in range(ox_output)]
elif isinstance(ox_output, (list, tuple)):
pass
else:
raise ValueError('Unknown type for outputs: {}'.format(type(ox_output)))
ox_inputs = self._process_inputs(inputs, name)
apply_func(self._scope, ox_inputs, ox_output, self._container,
operator_name=self._scope.get_unique_operator_name(name), **attrs)
return ox_output[0] if outputs is None else ox_output

def apply_op_name(self, apply_func_name, inputs, name=None, outputs=None, **attrs):
apply_operations = onnx_ops.__dict__
apply_func = apply_operations[apply_func_name]
assert apply_func is not None, "{} not implemented in onnx_ops.py.".format(apply_func_name)
return self.apply_op(apply_func, inputs, name, outputs)
4 changes: 2 additions & 2 deletions onnxconverter_common/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

import six
import numpy as np
import onnx
Expand Down
3 changes: 1 addition & 2 deletions onnxconverter_common/registration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

# This dictionary defines the converters which can be invoked in the conversion framework defined in topology.py. A key
# in this dictionary is an operator's unique ID (e.g., string and type) while the associated value is the callable
Expand Down
5 changes: 3 additions & 2 deletions onnxconverter_common/shape_calculator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

"""
Common functions to convert any learner based on trees.
"""

import numpy as np
import numbers
import six
Expand Down
12 changes: 5 additions & 7 deletions onnxconverter_common/topology.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
# -------------------------------------------------------------------------
# coding=utf-8
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

import re
import warnings
from logging import getLogger
from distutils.version import StrictVersion
import onnx
from onnx import onnx_pb as onnx_proto
from onnx import helper
from .metadata_props import add_metadata_props
from . import registration
from . import utils
from .data_types import *
from .container import ModelComponentContainer
from .optimizer import optimize_onnx
from .interface import OperatorBase
from .interface import OperatorBase, ScopeBase

OPSET_TO_IR_VERSION = {
1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3,
Expand Down Expand Up @@ -117,7 +115,7 @@ def infer_types(self):
registration.get_shape_calculator(self.type)(self)


class Scope:
class Scope(ScopeBase):

def __init__(self, name, parent_scopes=None, variable_name_set=None, operator_name_set=None, target_opset=None):
'''
Expand Down Expand Up @@ -794,7 +792,7 @@ def convert_topology(topology, model_name, doc_string, target_opset, targeted_on
i += 1
if container.target_opset < op_version:
raise RuntimeError(('The specified opset %d is too low to convert this model, ' +
'which requires at least opset %d.') % (container.target_opset, op_version))
'which requires at least opset %d.') % (container.target_opset, op_version))
elif container.target_opset > op_version:
getLogger('onnxmltools').warning('The maximum opset needed by this model is only %d.' % op_version)

Expand Down
5 changes: 3 additions & 2 deletions onnxconverter_common/tree_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

"""
Common functions to convert any learner based on trees.
"""

from .registration import register_converter


Expand Down
12 changes: 8 additions & 4 deletions onnxconverter_common/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
###############################################################################

import numbers, six
import numpy as np
import warnings
from distutils.version import LooseVersion, StrictVersion
from distutils.version import LooseVersion


def sparkml_installed():
Expand All @@ -20,6 +19,7 @@ def sparkml_installed():
except ImportError:
return False


def sklearn_installed():
"""
Checks that *scikit-learn* is available.
Expand All @@ -30,6 +30,7 @@ def sklearn_installed():
except ImportError:
return False


def skl2onnx_installed():
"""
Checks that *skl2onnx* converter is available.
Expand All @@ -40,6 +41,7 @@ def skl2onnx_installed():
except ImportError:
return False


def coreml_installed():
"""
Checks that *coremltools* is available.
Expand Down Expand Up @@ -140,6 +142,7 @@ def xgboost_installed():
warnings.warn('The converter works for xgboost >= 0.7. Earlier versions might not.')
return True


def h2o_installed():
"""
Checks that *h2o* is available.
Expand All @@ -150,6 +153,7 @@ def h2o_installed():
return False
return True


def get_producer():
"""
Internal helper function to return the producer
Expand Down Expand Up @@ -203,7 +207,7 @@ def is_string_type(item):
if isinstance(item, np.ndarray):
return np.issubdtype(item.dtype, np.str_)
return isinstance(item, types)


def cast_list(type, items):
return [type(item) for item in items]
Expand Down
Loading

0 comments on commit e559e64

Please sign in to comment.