Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Onnx export support for ROIAlign (#19814)
Browse files Browse the repository at this point in the history
* onnx roi align

* fix

* Update _op_translations.py
  • Loading branch information
Zha0q1 authored Feb 3, 2021
1 parent f7b7acc commit 6dc8edf
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
38 changes: 38 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3643,6 +3643,44 @@ def convert_broadcast_like(node, **kwargs):
return nodes


@mx_op.register('_contrib_ROIAlign')
def convert_contrib_roialign(node, **kwargs):
"""Map MXNet's _contrib_ROIAlign
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, attrs = get_inputs(node, kwargs)

pooled_size = convert_string_to_list(str(attrs.get('pooled_size')))
spatial_scale = float(attrs.get('spatial_scale'))
sample_ratio = int(attrs.get('sample_ratio', '0'))
position_sensitive = attrs.get('position_sensitive', 'False')
aligned = attrs.get('aligned', 'False')

if position_sensitive != 'False':
raise NotImplementedError('_contrib_ROIAlign does not currently support \
position_sensitive!=False')
if aligned != 'False':
raise NotImplementedError('_contrib_ROIAlign does not currently support \
aligned!=False')

_ = create_tensor([0], name+'_0', kwargs['initializer'])
_ = create_tensor([1], name+'_1', kwargs['initializer'])
_ = create_tensor([5], name+'_5', kwargs['initializer'])

nodes = [
make_node('Slice', [input_nodes[1], name+'_1', name+'_5', name+'_1'], [name+'_rois']),
make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds__']),
make_node('Squeeze', [name+'_inds__'], [name+'_inds_'], axes=(1,)),
make_node('Cast', [name+'_inds_'], [name+'_inds'], to=int(TensorProto.INT64)),
make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name],
mode='avg', output_height=pooled_size[0], output_width=pooled_size[1],
sampling_ratio=sample_ratio, spatial_scale=spatial_scale)
]

return nodes


@mx_op.register("batch_dot")
def convert_batch_dot(node, **kwargs):
"""Map MXNet's batch_dot operator attributes to onnx's operator.
Expand Down
18 changes: 18 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,24 @@ def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes):
M2 = def_model('broadcast_like', lhs_axes=lhs_axes, rhs_axes=rhs_axes)
op_export_test('broadcast_like2', M2, [x, y], tmp_path)


@pytest.mark.parametrize('dtype', ['float32'])
@pytest.mark.parametrize('pooled_size', [(1, 1), (3, 3), (14, 14), (5, 7)])
@pytest.mark.parametrize('spatial_scale', [1, 0.5, 0.0625])
@pytest.mark.parametrize('spatial_ratio', [1, 2, 3, 5])
def test_onnx_export_contrib_ROIAlign(tmp_path, dtype, pooled_size, spatial_scale, spatial_ratio):
data = mx.random.uniform(0, 1, (5, 3, 128, 128)).astype(dtype)
rois = mx.nd.array([[0, 0, 0, 63, 63],
[1, 34, 52, 25, 85],
[2, 50, 50, 100, 100],
[3, 0, 0, 127, 127],
[4, 12, 84, 22, 94],
[0, 0, 0, 1, 1]]).astype(dtype)
M = def_model('contrib.ROIAlign', pooled_size=pooled_size, spatial_scale=spatial_scale,
sample_ratio=spatial_ratio)
op_export_test('_contrib_ROIAlign', M, [data, rois], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64'])
@pytest.mark.parametrize('transpose_a', [True, False])
@pytest.mark.parametrize('transpose_b', [True, False])
Expand Down

0 comments on commit 6dc8edf

Please sign in to comment.