From eab8cfa2dc89ff05b144b9af4e68e26cd97f6e5c Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Mon, 25 Apr 2022 09:22:03 +0400 Subject: [PATCH] [Fix] NMS for Point RCNN (#1418) * fix nms for point rcnn * add else case --- mmdet3d/models/dense_heads/point_rpn_head.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index 04755942e0..5561ce52db 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -3,6 +3,7 @@ from mmcv.runner import BaseModule, force_fp32 from torch import nn as nn +from mmdet3d.core import xywhr2xyxyr from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, LiDARInstance3DBoxes) from mmdet3d.core.post_processing import nms_bev, nms_normal_bev @@ -320,29 +321,33 @@ def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points, else: raise NotImplementedError('Unsupported bbox type!') - bbox = bbox.tensor[nonempty_box_mask] + bbox = bbox[nonempty_box_mask] if self.test_cfg.score_thr is not None: score_thr = self.test_cfg.score_thr keep = (obj_scores >= score_thr) obj_scores = obj_scores[keep] sem_scores = sem_scores[keep] - bbox = bbox[keep] + bbox = bbox.tensor[keep] if obj_scores.shape[0] > 0: topk = min(nms_cfg.nms_pre, obj_scores.shape[0]) obj_scores_nms, indices = torch.topk(obj_scores, k=topk) - bbox_for_nms = bbox[indices] + bbox_for_nms = xywhr2xyxyr(bbox[indices].bev) sem_scores_nms = sem_scores[indices] - keep = nms_func(bbox_for_nms[:, 0:7], obj_scores_nms, - nms_cfg.iou_thr) + keep = nms_func(bbox_for_nms, obj_scores_nms, nms_cfg.iou_thr) keep = keep[:nms_cfg.nms_post] - bbox_selected = bbox_for_nms[keep] + bbox_selected = bbox.tensor[indices][keep] score_selected = obj_scores_nms[keep] cls_preds = sem_scores_nms[keep] labels = torch.argmax(cls_preds, -1) + else: + bbox_selected = bbox.tensor + score_selected = obj_scores.new_zeros([0]) + labels = obj_scores.new_zeros([0]) + cls_preds = obj_scores.new_zeros([0, sem_scores.shape[-1]]) return bbox_selected, score_selected, labels, cls_preds