From 5b1a30aaf65157a5c68e808c12bdfb5594851da0 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Mon, 8 Feb 2021 13:41:49 -0800 Subject: [PATCH 1/2] fix tile --- .../contrib/onnx/mx2onnx/_op_translations.py | 47 +++++++++---------- tests/python-pytest/onnx/test_operators.py | 8 ++++ 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index a2d9a6bf1c2c..84ab9a9b0335 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2262,37 +2262,32 @@ def convert_tile(node, **kwargs): """Map MXNet's Tile operator attributes to onnx's Tile operator and return the created node. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) - reps_list = convert_string_to_list(attrs["reps"]) - - initializer = kwargs["initializer"] - reps_shape_np = np.array(reps_list, dtype='int64') - data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[reps_shape_np.dtype] - dims = np.shape(reps_shape_np) - - output_shape_name = "reps_attr_tensor" + str(kwargs["idx"]) - tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims) + data = input_nodes[0] + reps = convert_string_to_list(attrs["reps"]) - initializer.append( - onnx.helper.make_tensor( - name=output_shape_name, - data_type=data_type, - dims=dims, - vals=reps_list, - raw=False, - ) - ) + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor(reps, name+'_reps', kwargs['initializer'], dtype='int64'), + create_tensor([len(reps)], name+'_reps_len', kwargs['initializer']), - input_nodes.append(output_shape_name) - tile_node = onnx.helper.make_node( - "Tile", - input_nodes, - [name], - name=name - ) + nodes = [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Shape', [name+'_data_shape'], [name+'_data_dim']), + make_node('Max', [name+'_data_dim', name+'_reps_len'], [name+'_max']), + make_node('Sub', [name+'_max', name+'_data_dim'], [name+'_data_diff']), + make_node('Concat', [name+'_data_diff', name+'_0'], [name+'_concat0_out'], axis=0), + make_node('Pad', [name+'_data_shape', name+'_concat0_out', name+'_1'], [name+'_data_shape_pad']), + make_node('Reshape', [data, name+'_data_shape_pad'], [name+'_data']), + make_node('Sub', [name+'_max', name+'_reps_len'], [name+'_reps_diff']), + make_node('Concat', [name+'_reps_diff', name+'_0'], [name+'_concat1_out'], axis=0), + make_node('Pad', [name+'_reps', name+'_concat1_out', name+'_1'], [name+'_reps_pad']), + make_node('Tile', [name+'_data', name+'_reps_pad'], [name], name=name), + ] - return [tensor_node, tile_node] + return nodes @mx_op.register("broadcast_to") diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 97eae9ff2aca..5c8fe03473e9 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1121,3 +1121,11 @@ def test_onnx_export_argsort(tmp_path, dtype, axis, is_ascend, dtype_i): kwargs['is_ascend'] = is_ascend M = def_model('argsort', axis=axis, dtype=dtype_i, **kwargs) op_export_test('argsort', M, [A], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('reps', [(2, 3), (2, ), (2, 3, 4)]) +def test_onnx_export_tile(tmp_path, dtype, reps): + x = mx.nd.random.normal(0, 100, (5, 6)).astype(dtype) + M = def_model('tile', reps=reps) + op_export_test('tile', M, [x], tmp_path) From dd514d954c652cf02aa22ab526c40c005847ab1c Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Mon, 8 Feb 2021 14:44:43 -0800 Subject: [PATCH 2/2] fix sanity --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 84ab9a9b0335..04df523729ac 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2268,10 +2268,10 @@ def convert_tile(node, **kwargs): data = input_nodes[0] reps = convert_string_to_list(attrs["reps"]) - create_tensor([0], name+'_0', kwargs['initializer']), - create_tensor([1], name+'_1', kwargs['initializer']), - create_tensor(reps, name+'_reps', kwargs['initializer'], dtype='int64'), - create_tensor([len(reps)], name+'_reps_len', kwargs['initializer']), + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor(reps, name+'_reps', kwargs['initializer'], dtype='int64') + create_tensor([len(reps)], name+'_reps_len', kwargs['initializer']) nodes = [ make_node('Shape', [data], [name+'_data_shape']),