Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Add onnx export support for where and greater_scalar operators. #19745

Merged
merged 13 commits into from
Jan 14, 2021
44 changes: 44 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,3 +2861,47 @@ def convert_repeat(node, **kwargs):
]

return nodes

@mx_op.register("_greater_scalar")
def convert_greater_scalar(node, **kwargs):
"""Map MXNet's greater_scalar operator attributes to onnx's Greater
operator and return the created node.
"""
from onnx.helper import make_node, make_tensor
name, input_nodes, attrs = get_inputs(node, kwargs)

scalar = float(attrs.get('scalar'))
input_type = kwargs['in_type']
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]

if str(dtype).startswith('int'):
scalar = int(scalar)
else:
if dtype == 'float16':
# when using float16, we must convert it to np.uint16 view first
# pylint: disable=too-many-function-args
scalar = np.float16(scalar).view(np.uint16)

tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
make_node("Shape", [input_nodes[0]], [name+"_shape"]),
make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
make_node("Greater", [input_nodes[0], name+"_rhs"], [name+"_gt"]),
make_node("Cast", [name+"_gt"], [name], to=input_type, name=name)
]
return nodes


@mx_op.register("where")
def convert_where(node, **kwargs):
"""Map MXNet's where operator attributes to onnx's Where
operator and return the created node.
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, _ = get_inputs(node, kwargs)
nodes = [
make_node("Cast", [input_nodes[0]], [name+"_bool"], to=int(TensorProto.BOOL)),
make_node("Where", [name+"_bool", input_nodes[1], input_nodes[2]], [name], name=name)
]
return nodes
22 changes: 22 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,25 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, params):
x = mx.nd.arange(0, 160).reshape((2, 2, 5, 8))
M = def_model('contrib.BilinearResize2D', **params)
op_export_test('contrib_BilinearResize2D', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
def test_onnx_export_greater_scalar(tmp_path, dtype, scalar):
if 'int' in dtype:
scalar = int(scalar)
x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4))
else:
x = mx.random.uniform(0, 9999, (5,10), dtype=dtype)
M = def_model('_internal._greater_scalar', scalar=scalar)
op_export_test('_internal._greater_scalar', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
def test_onnx_export_where(tmp_path, dtype, shape):
M = def_model('where')
x = mx.nd.zeros(shape, dtype=dtype)
y = mx.nd.ones(shape, dtype=dtype)
cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32')
op_export_test('where', M, [cond, x, y], tmp_path)