From ec35b402bf75e1c22fcb12ce633698be4a2b1a99 Mon Sep 17 00:00:00 2001 From: CescMessi Date: Thu, 14 Sep 2023 14:08:29 +0800 Subject: [PATCH] fix roi align symbolic function in onnx opset>=16 (#2428) --- mmdeploy/mmcv/ops/roi_align.py | 46 ++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/mmdeploy/mmcv/ops/roi_align.py b/mmdeploy/mmcv/ops/roi_align.py index 6ee901a047..a511d3904d 100644 --- a/mmdeploy/mmcv/ops/roi_align.py +++ b/mmdeploy/mmcv/ops/roi_align.py @@ -58,23 +58,38 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int], else: from torch.onnx.symbolic_opset9 import _cast_Long from torch.onnx.symbolic_opset11 import add, select - batch_indices = _cast_Long( - g, - g.op( - 'Squeeze', - select( - g, rois, 1, - g.op( - 'Constant', - value_t=torch.tensor([0], dtype=torch.long))), - axes_i=[1]), False) + ir_cfg = get_ir_config(ctx.cfg) + opset_version = ir_cfg.get('opset_version', 11) + if opset_version < 13: + batch_indices = _cast_Long( + g, + g.op( + 'Squeeze', + select( + g, rois, 1, + g.op( + 'Constant', + value_t=torch.tensor([0], dtype=torch.long))), + axes_i=[1]), False) + else: + axes = g.op( + 'Constant', value_t=torch.tensor([1], dtype=torch.long)) + batch_indices = _cast_Long( + g, + g.op( + 'Squeeze', + select( + g, rois, 1, + g.op( + 'Constant', + value_t=torch.tensor([0], dtype=torch.long))), + axes), False) rois = select( g, rois, 1, g.op( 'Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) - ir_cfg = get_ir_config(ctx.cfg) - opset_version = ir_cfg.get('opset_version', 11) + if opset_version < 16: # preprocess rois to make compatible with opset 16- # as for opset 16+, `aligned` get implemented inside onnxruntime. @@ -96,6 +111,10 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int], sampling_ratio_i=sampling_ratio, mode_s=pool_mode) else: + if aligned: + coordinate_transformation_mode = 'half_pixel' + else: + coordinate_transformation_mode = 'output_half_pixel' return g.op( 'RoiAlign', input, @@ -106,4 +125,5 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int], spatial_scale_f=spatial_scale, sampling_ratio_i=sampling_ratio, mode_s=pool_mode, - aligned_i=aligned) + coordinate_transformation_mode_s=coordinate_transformation_mode + )