Skip to content

Commit

Permalink
Fixes microsoft#17, update clip operator (microsoft#18)
Browse files Browse the repository at this point in the history
* Fixes microsoft#17, update clip operator (updated in ONNX 11)
* handle min or max = None
  • Loading branch information
xadupre authored and jiafatom committed Sep 5, 2019
1 parent 2deefe5 commit f99b024
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 31 deletions.
2 changes: 1 addition & 1 deletion onnxconverter_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
This framework performs optimization for ONNX models and
includes common utilities for ONNX converters.
"""
__version__ = "1.5.3"
__version__ = "1.5.4"
__author__ = "Microsoft"
__producer__ = "OnnxMLTools"
__producer_version__ = __version__
Expand Down
2 changes: 1 addition & 1 deletion onnxconverter_common/case_insensitive_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __eq__(self, other):
return dict(self.lower_key_iteritems()) == dict(other.lower_key_iteritems())

def copy(self):
return CaseInsensitiveDict(self._dict.values())
return CaseInsensitiveDict(self._dict.values())

def __repr__(self):
return str(dict(self.items()))
1 change: 0 additions & 1 deletion onnxconverter_common/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# --------------------------------------------------------------------------

import six
import onnx
from onnx import helper
from .interface import ModelContainer

Expand Down
6 changes: 2 additions & 4 deletions onnxconverter_common/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def to_onnx_type(self):
onnx_type.map_type.key_type = onnx_proto.TensorProto.STRING
onnx_type.map_type.value_type.CopyFrom(
self.value_type.to_onnx_type())
except AttributeError as e:
import onnx
except AttributeError:
msg = "ONNX was not compiled with flag ONNX-ML.\n{0}\n{1}"
msg = msg.format(str(self), str(self.value_type.to_onnx_type()))
info = [onnx.__version__, str(onnx_type)]
Expand All @@ -207,8 +206,7 @@ def to_onnx_type(self):
try:
onnx_type.sequence_type.elem_type.CopyFrom(
self.element_type.to_onnx_type())
except AttributeError as e:
import onnx
except AttributeError:
msg = "ONNX was not compiled with flag ONNX-ML.\n{0}\n{1}"
msg = msg.format(str(self), str(self.element_type.to_onnx_type()))
info = [onnx.__version__, str(onnx_type)]
Expand Down
1 change: 1 addition & 0 deletions onnxconverter_common/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from onnx import helper
from onnx import onnx_pb as onnx_proto


def _npfloat16_to_int(np_list):
'''
Convert numpy float16 to python int.
Expand Down
1 change: 0 additions & 1 deletion onnxconverter_common/metadata_props.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings
from .case_insensitive_dict import CaseInsensitiveDict
import onnx
from onnx import onnx_pb as onnx_proto

KNOWN_METADATA_PROPS = CaseInsensitiveDict({
Expand Down
85 changes: 68 additions & 17 deletions onnxconverter_common/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ def apply_batch_norm(scope, input_names, output_names, container, operator_name=
name = _create_name_or_use_existing_one(scope, 'BatchNormalization', operator_name)
attrs = {'name': name, 'epsilon': epsilon, 'momentum': momentum}

if container.target_opset < 9: attrs['spatial'] = spatial
if container.target_opset < 7: attrs['is_test'] = is_test
if container.target_opset < 9:
attrs['spatial'] = spatial
if container.target_opset < 7:
attrs['is_test'] = is_test

if container.target_opset < 6:
attrs['consumed_inputs'] = [0] * len(input_names)
Expand Down Expand Up @@ -169,20 +171,69 @@ def apply_cast(scope, input_name, output_name, container, operator_name=None, to

def apply_clip(scope, input_name, output_name, container, operator_name=None, max=None, min=None):
name = _create_name_or_use_existing_one(scope, 'Clip', operator_name)

attrs = {'name': name}
if max is not None:
attrs['max'] = float(max)
if min is not None:
attrs['min'] = float(min)

if container.target_opset < 6:
attrs['consumed_inputs'] = [0]
op_version = 1
if container.target_opset < 11:
if max is not None:
attrs['max'] = float(max)
if min is not None:
attrs['min'] = float(min)

if container.target_opset < 6:
attrs['consumed_inputs'] = [0]
op_version = 1
else:
op_version = 6

container.add_node('Clip', input_name, output_name, op_version=op_version, **attrs)
else:
op_version = 6
op_version = 11
if min is None and max is not None:
raise RuntimeError("Operator 'Clip': min must be specified if max is.")
inputs = [input_name]

if min is not None:
if isinstance(min, (np.ndarray, float, int)):
# add initializer
if isinstance(min, np.ndarray):
if min.shape != (1, ):
raise RuntimeError("min must an array of one element.")
else:
# container in sklearn-onnx stores the computation type in
# container.dtype.
min = np.array([min], dtype=getattr(
container, 'dtype', np.float32))
min_name = scope.get_unique_variable_name('clip_min')
container.add_initializer(min_name, getattr(container, 'proto_dtype',
onnx_proto.TensorProto.FLOAT), [1], [min[0]])
min = min_name
if isinstance(min, str):
inputs.append(min)
else:
raise RuntimeError("Parameter 'min' must be a string or a float.")

if max is not None:
if min is None:
raise RuntimeError("Parameter 'min' must be specified if 'max' is.")
if isinstance(max, (np.ndarray, float, int)):
# add initializer
if isinstance(max, np.ndarray):
if max.shape != (1, ):
raise RuntimeError("max must an array of one element.")
else:
max = np.array([max], dtype=getattr(
container, 'dtype', np.float32))
max_name = scope.get_unique_variable_name('clip_max')
container.add_initializer(max_name, getattr(container, 'proto_dtype',
onnx_proto.TensorProto.FLOAT), [1], [max[0]])
max = max_name
if isinstance(max, str):
inputs.append(max)
else:
raise RuntimeError("Parameter 'max' must be a string or a float.")

container.add_node('Clip', input_name, output_name, op_version=op_version, **attrs)
container.add_node('Clip', input_name, output_name, op_version=op_version,
**attrs)


def apply_concat(scope, input_names, output_name, container, operator_name=None, axis=0):
Expand Down Expand Up @@ -374,9 +425,9 @@ def apply_pad(scope, input_name, output_name, container, operator_name=None, mod


def apply_parametric_softplus(scope, input_name, output_name, container, operator_name=None, alpha=None, beta=None):
if alpha == None:
if alpha is None:
alpha = [1.0]
if beta == None:
if beta is None:
beta = [0.]

name = _create_name_or_use_existing_one(scope, 'ParametricSoftplus', operator_name)
Expand Down Expand Up @@ -515,9 +566,9 @@ def apply_softmax(scope, input_name, output_name, container, operator_name=None,


def apply_scaled_tanh(scope, input_name, output_name, container, operator_name=None, alpha=None, beta=None):
if alpha == None:
if alpha is None:
alpha = [1.0]
if beta == None:
if beta is None:
beta = [1.0]
if len(alpha) != 1 or len(beta) != 1:
raise ValueError('alpha and beta must be 1-element lists')
Expand Down Expand Up @@ -621,7 +672,7 @@ def apply_tanh(scope, input_name, output_name, container, operator_name=None):


def apply_thresholded_relu(scope, input_name, output_name, container, operator_name=None, alpha=None):
if alpha == None:
if alpha is None:
alpha = [1.0]

name = _create_name_or_use_existing_one(scope, 'ThresholdedRelu', operator_name)
Expand Down
4 changes: 2 additions & 2 deletions onnxconverter_common/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def generate(self):
onode.doc_string = self.origin.doc_string
onode.domain = self.origin.domain
onode.attribute.extend(
attr for attr in self.origin.attribute if not attr.name in self.attributes)
attr for attr in self.origin.attribute if attr.name not in self.attributes)
onode.attribute.extend(
helper.make_attribute(attr.name, self.attributes[attr.name]) for attr in self.attributes)

Expand Down Expand Up @@ -415,7 +415,7 @@ def apply(self, node_list):
perm_f = [perm0[idx] for idx in perm1]
if self.is_useless_transpose(perm_f):
node = self.begin # type: LinkedNode
while node != self.end and len(node.successor) >=1:
while node != self.end and len(node.successor) >= 1:
#if node.broadcast:
# node.reshape_input_for_broadcast(perm0)
node = node.successor[0]
Expand Down
1 change: 0 additions & 1 deletion onnxconverter_common/shape_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
import numbers
import six
from .registration import register_shape_calculator
from .data_types import Int64TensorType, FloatTensorType, StringTensorType, DictionaryType, SequenceType
from .utils import check_input_and_output_numbers, check_input_and_output_types

Expand Down
3 changes: 0 additions & 3 deletions onnxconverter_common/tree_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
"""
Common functions to convert any learner based on trees.
"""

import numpy as np
import numbers, six
from .registration import register_converter


Expand Down

0 comments on commit f99b024

Please sign in to comment.