diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py index f91ca48ad0..57a007443f 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py @@ -147,6 +147,11 @@ def single_roi_extractor__forward(ctx, device=target_lvls.device) target_lvls = torch.cat((_tmp, _tmp, target_lvls)) for i in range(num_levels): + # use the roi align in torhcvision to accelerate the inference + # roi_align in MMCV is same as torchvision when pool mode is 'avg' + if backend == Backend.TORCHSCRIPT or self.roi_layers[ + i].pool_mode == 'avg': + self.roi_layers[i].use_torchvision = True mask = target_lvls == i inds = mask.nonzero(as_tuple=False).squeeze(1) rois_t = rois[inds]