diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 57ef546de29f..ea622cbe0532 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2861,3 +2861,47 @@ def convert_repeat(node, **kwargs): ] return nodes + +@mx_op.register("_greater_scalar") +def convert_greater_scalar(node, **kwargs): + """Map MXNet's greater_scalar operator attributes to onnx's Greater + operator and return the created node. + """ + from onnx.helper import make_node, make_tensor + name, input_nodes, attrs = get_inputs(node, kwargs) + + scalar = float(attrs.get('scalar')) + input_type = kwargs['in_type'] + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] + + if str(dtype).startswith('int'): + scalar = int(scalar) + else: + if dtype == 'float16': + # when using float16, we must convert it to np.uint16 view first + # pylint: disable=too-many-function-args + scalar = np.float16(scalar).view(np.uint16) + + tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar]) + nodes = [ + make_node("Shape", [input_nodes[0]], [name+"_shape"]), + make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value), + make_node("Greater", [input_nodes[0], name+"_rhs"], [name+"_gt"]), + make_node("Cast", [name+"_gt"], [name], to=input_type, name=name) + ] + return nodes + + +@mx_op.register("where") +def convert_where(node, **kwargs): + """Map MXNet's where operator attributes to onnx's Where + operator and return the created node. + """ + from onnx.helper import make_node + from onnx import TensorProto + name, input_nodes, _ = get_inputs(node, kwargs) + nodes = [ + make_node("Cast", [input_nodes[0]], [name+"_bool"], to=int(TensorProto.BOOL)), + make_node("Where", [name+"_bool", input_nodes[1], input_nodes[2]], [name], name=name) + ] + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index c17a03bc3276..a06b043b02c1 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -378,3 +378,25 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, params): x = mx.nd.arange(0, 160).reshape((2, 2, 5, 8)) M = def_model('contrib.BilinearResize2D', **params) op_export_test('contrib_BilinearResize2D', M, [x], tmp_path) + + +@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"]) +@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.]) +def test_onnx_export_greater_scalar(tmp_path, dtype, scalar): + if 'int' in dtype: + scalar = int(scalar) + x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4)) + else: + x = mx.random.uniform(0, 9999, (5,10), dtype=dtype) + M = def_model('_internal._greater_scalar', scalar=scalar) + op_export_test('_internal._greater_scalar', M, [x], tmp_path) + + +@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"]) +@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)]) +def test_onnx_export_where(tmp_path, dtype, shape): + M = def_model('where') + x = mx.nd.zeros(shape, dtype=dtype) + y = mx.nd.ones(shape, dtype=dtype) + cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32') + op_export_test('where', M, [cond, x, y], tmp_path)