diff --git a/mmdet3d/core/__init__.py b/mmdet3d/core/__init__.py index 22c1bd251f..1b1e85619d 100644 --- a/mmdet3d/core/__init__.py +++ b/mmdet3d/core/__init__.py @@ -2,5 +2,6 @@ from .bbox import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403 from .post_processing import * # noqa: F401, F403 +from .utils import * # noqa: F401, F403 from .visualizer import * # noqa: F401, F403 from .voxel import * # noqa: F401, F403 diff --git a/mmdet3d/core/bbox/coders/__init__.py b/mmdet3d/core/bbox/coders/__init__.py index 78d4e60272..4d2c93d806 100644 --- a/mmdet3d/core/bbox/coders/__init__.py +++ b/mmdet3d/core/bbox/coders/__init__.py @@ -1,9 +1,10 @@ from mmdet.core.bbox import build_bbox_coder from .anchor_free_bbox_coder import AnchorFreeBBoxCoder +from .centerpoint_bbox_coders import CenterPointBBoxCoder from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder __all__ = [ 'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder', - 'AnchorFreeBBoxCoder' + 'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder' ] diff --git a/mmdet3d/core/bbox/coders/centerpoint_bbox_coders.py b/mmdet3d/core/bbox/coders/centerpoint_bbox_coders.py new file mode 100644 index 0000000000..8180ff8faf --- /dev/null +++ b/mmdet3d/core/bbox/coders/centerpoint_bbox_coders.py @@ -0,0 +1,227 @@ +import torch + +from mmdet.core.bbox import BaseBBoxCoder +from mmdet.core.bbox.builder import BBOX_CODERS + + +@BBOX_CODERS.register_module() +class CenterPointBBoxCoder(BaseBBoxCoder): + """Bbox coder for CenterPoint. + + Args: + pc_range (list[float]): Range of point cloud. + out_size_factor (int): Downsample factor of the model. + voxel_size (list[float]): Size of voxel. + post_center_range (list[float]): Limit of the center. + Default: None. + max_num (int): Max number to be kept. Default: 100. + score_threshold (float): Threshold to filter boxes based on score. + Default: None. + code_size (int): Code size of bboxes. Default: 9 + """ + + def __init__(self, + pc_range, + out_size_factor, + voxel_size, + post_center_range=None, + max_num=100, + score_threshold=None, + code_size=9): + + self.pc_range = pc_range + self.out_size_factor = out_size_factor + self.voxel_size = voxel_size + self.post_center_range = post_center_range + self.max_num = max_num + self.score_threshold = score_threshold + self.code_size = code_size + + def _gather_feat(self, feats, inds, feat_masks=None): + """Given feats and indexes, returns the gathered feats. + + Args: + feats (torch.Tensor): Features to be transposed and gathered + with the shape of [B, 2, W, H]. + inds (torch.Tensor): Indexes with the shape of [B, N]. + feat_masks (torch.Tensor): Mask of the feats. Default: None. + + Returns: + torch.Tensor: Gathered feats. + """ + dim = feats.size(2) + inds = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), dim) + feats = feats.gather(1, inds) + if feat_masks is not None: + feat_masks = feat_masks.unsqueeze(2).expand_as(feats) + feats = feats[feat_masks] + feats = feats.view(-1, dim) + return feats + + def _topk(self, scores, K=80): + """Get indexes based on scores. + + Args: + scores (torch.Tensor): scores with the shape of [B, N, W, H]. + K (int): Number to be kept. Defaults to 80. + + Returns: + tuple[torch.Tensor] + torch.Tensor: Selected scores with the shape of [B, K]. + torch.Tensor: Selected indexes with the shape of [B, K]. + torch.Tensor: Selected classes with the shape of [B, K]. + torch.Tensor: Selected y coord with the shape of [B, K]. + torch.Tensor: Selected x coord with the shape of [B, K]. + """ + batch, cat, height, width = scores.size() + + topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) + + topk_inds = topk_inds % (height * width) + topk_ys = (topk_inds.float() / + torch.tensor(width, dtype=torch.float)).int().float() + topk_xs = (topk_inds % width).int().float() + + topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) + topk_clses = (topk_ind / torch.tensor(K, dtype=torch.float)).int() + topk_inds = self._gather_feat(topk_inds.view(batch, -1, 1), + topk_ind).view(batch, K) + topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), + topk_ind).view(batch, K) + topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), + topk_ind).view(batch, K) + + return topk_score, topk_inds, topk_clses, topk_ys, topk_xs + + def _transpose_and_gather_feat(self, feat, ind): + """Given feats and indexes, returns the transposed and gathered feats. + + Args: + feat (torch.Tensor): Features to be transposed and gathered + with the shape of [B, 2, W, H]. + ind (torch.Tensor): Indexes with the shape of [B, N]. + + Returns: + torch.Tensor: Transposed and gathered feats. + """ + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + feat = self._gather_feat(feat, ind) + return feat + + def encode(self): + pass + + def decode(self, + heat, + rot_sine, + rot_cosine, + hei, + dim, + vel, + reg=None, + task_id=-1): + """Decode bboxes. + + Args: + heat (torch.Tensor): Heatmap with the shape of [B, N, W, H]. + rot_sine (torch.Tensor): Sine of rotation with the shape of + [B, 1, W, H]. + rot_cosine (torch.Tensor): Cosine of rotation with the shape of + [B, 1, W, H]. + hei (torch.Tensor): Height of the boxes with the shape + of [B, 1, W, H]. + dim (torch.Tensor): Dim of the boxes with the shape of + [B, 1, W, H]. + vel (torch.Tensor): Velocity with the shape of [B, 1, W, H]. + reg (torch.Tensor): Regression value of the boxes in 2D with + the shape of [B, 2, W, H]. Default: None. + task_id (int): Index of task. Default: -1. + + Returns: + list[dict]: Decoded boxes. + """ + batch, cat, _, _ = heat.size() + + scores, inds, clses, ys, xs = self._topk(heat, K=self.max_num) + + if reg is not None: + reg = self._transpose_and_gather_feat(reg, inds) + reg = reg.view(batch, self.max_num, 2) + xs = xs.view(batch, self.max_num, 1) + reg[:, :, 0:1] + ys = ys.view(batch, self.max_num, 1) + reg[:, :, 1:2] + else: + xs = xs.view(batch, self.max_num, 1) + 0.5 + ys = ys.view(batch, self.max_num, 1) + 0.5 + + # rotation value and direction label + rot_sine = self._transpose_and_gather_feat(rot_sine, inds) + rot_sine = rot_sine.view(batch, self.max_num, 1) + + rot_cosine = self._transpose_and_gather_feat(rot_cosine, inds) + rot_cosine = rot_cosine.view(batch, self.max_num, 1) + rot = torch.atan2(rot_sine, rot_cosine) + + # height in the bev + hei = self._transpose_and_gather_feat(hei, inds) + hei = hei.view(batch, self.max_num, 1) + + # dim of the box + dim = self._transpose_and_gather_feat(dim, inds) + dim = dim.view(batch, self.max_num, 3) + + # class label + clses = clses.view(batch, self.max_num).float() + scores = scores.view(batch, self.max_num) + + xs = xs.view( + batch, self.max_num, + 1) * self.out_size_factor * self.voxel_size[0] + self.pc_range[0] + ys = ys.view( + batch, self.max_num, + 1) * self.out_size_factor * self.voxel_size[1] + self.pc_range[1] + + if vel is None: # KITTI FORMAT + final_box_preds = torch.cat([xs, ys, hei, dim, rot], dim=2) + else: # exist velocity, nuscene format + vel = self._transpose_and_gather_feat(vel, inds) + vel = vel.view(batch, self.max_num, 2) + final_box_preds = torch.cat([xs, ys, hei, dim, rot, vel], dim=2) + + final_scores = scores + final_preds = clses + + # use score threshold + if self.score_threshold is not None: + thresh_mask = final_scores > self.score_threshold + + if self.post_center_range is not None: + self.post_center_range = torch.tensor( + self.post_center_range, device=heat.device) + mask = (final_box_preds[..., :3] >= + self.post_center_range[:3]).all(2) + mask &= (final_box_preds[..., :3] <= + self.post_center_range[3:]).all(2) + + predictions_dicts = [] + for i in range(batch): + cmask = mask[i, :] + if self.score_threshold: + cmask &= thresh_mask[i] + + boxes3d = final_box_preds[i, cmask] + scores = final_scores[i, cmask] + labels = final_preds[i, cmask] + predictions_dict = { + 'bboxes': boxes3d, + 'scores': scores, + 'labels': labels + } + + predictions_dicts.append(predictions_dict) + else: + raise NotImplementedError( + 'Need to reorganize output as a batch, only ' + 'support post_center_range is not None for now!') + + return predictions_dicts diff --git a/mmdet3d/core/post_processing/__init__.py b/mmdet3d/core/post_processing/__init__.py index 1d92f08f63..2f67514add 100644 --- a/mmdet3d/core/post_processing/__init__.py +++ b/mmdet3d/core/post_processing/__init__.py @@ -1,11 +1,11 @@ from mmdet.core.post_processing import (merge_aug_bboxes, merge_aug_masks, merge_aug_proposals, merge_aug_scores, multiclass_nms) -from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms +from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms, circle_nms from .merge_augs import merge_aug_bboxes_3d __all__ = [ 'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes', 'merge_aug_scores', 'merge_aug_masks', 'box3d_multiclass_nms', - 'aligned_3d_nms', 'merge_aug_bboxes_3d' + 'aligned_3d_nms', 'merge_aug_bboxes_3d', 'circle_nms' ] diff --git a/mmdet3d/core/post_processing/box3d_nms.py b/mmdet3d/core/post_processing/box3d_nms.py index c517b156cb..5731f68304 100644 --- a/mmdet3d/core/post_processing/box3d_nms.py +++ b/mmdet3d/core/post_processing/box3d_nms.py @@ -1,3 +1,5 @@ +import numba +import numpy as np import torch from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu @@ -134,3 +136,46 @@ def aligned_3d_nms(boxes, scores, classes, thresh): indices = boxes.new_tensor(pick, dtype=torch.long) return indices + + +@numba.jit(nopython=True) +def circle_nms(dets, thresh, post_max_size=83): + """Circular NMS. + + An object is only counted as positive if no other center + with a higher confidence exists within a radius r using a + bird-eye view distance metric. + + Args: + dets (torch.Tensor): Detection results with the shape of [N, 3]. + thresh (float): Value of threshold. + post_max_size (int): Max number of prediction to be kept. Defaults + to 83 + + Returns: + torch.Tensor: Indexes of the detections to be kept. + """ + x1 = dets[:, 0] + y1 = dets[:, 1] + scores = dets[:, 2] + order = scores.argsort()[::-1].astype(np.int32) # highest->lowest + ndets = dets.shape[0] + suppressed = np.zeros((ndets), dtype=np.int32) + keep = [] + for _i in range(ndets): + i = order[_i] # start with highest score box + if suppressed[ + i] == 1: # if any box have enough iou with this, remove it + continue + keep.append(i) + for _j in range(_i + 1, ndets): + j = order[_j] + if suppressed[j] == 1: + continue + # calculate center distance between i and j box + dist = (x1[i] - x1[j])**2 + (y1[i] - y1[j])**2 + + # ovr = inter / areas[j] + if dist <= thresh: + suppressed[j] = 1 + return keep[:post_max_size] diff --git a/mmdet3d/core/utils/__init__.py b/mmdet3d/core/utils/__init__.py new file mode 100644 index 0000000000..ad936667bd --- /dev/null +++ b/mmdet3d/core/utils/__init__.py @@ -0,0 +1,3 @@ +from .gaussian import draw_heatmap_gaussian, gaussian_2d, gaussian_radius + +__all__ = ['gaussian_2d', 'gaussian_radius', 'draw_heatmap_gaussian'] diff --git a/mmdet3d/core/utils/gaussian.py b/mmdet3d/core/utils/gaussian.py new file mode 100644 index 0000000000..28605f2601 --- /dev/null +++ b/mmdet3d/core/utils/gaussian.py @@ -0,0 +1,85 @@ +import numpy as np +import torch + + +def gaussian_2d(shape, sigma=1): + """Generate gaussian map. + + Args: + shape (list[int]): Shape of the map. + sigma (float): Sigma to generate gaussian map. + Defaults to 1. + + Returns: + np.ndarray: Generated gaussian map. + """ + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m + 1, -n:n + 1] + + h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h + + +def draw_heatmap_gaussian(heatmap, center, radius, k=1): + """Get gaussian masked heatmap. + + Args: + heatmap (torch.Tensor): Heatmap to be masked. + center (torch.Tensor): Center coord of the heatmap. + radius (int): Radius of gausian. + K (int): Multiple of masked_gaussian. Defaults to 1. + + Returns: + torch.Tensor: Masked heatmap. + """ + diameter = 2 * radius + 1 + gaussian = gaussian_2d((diameter, diameter), sigma=diameter / 6) + + x, y = int(center[0]), int(center[1]) + + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = torch.from_numpy( + gaussian[radius - top:radius + bottom, + radius - left:radius + right]).to(heatmap.device, + torch.float32) + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + torch.max(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + return heatmap + + +def gaussian_radius(det_size, min_overlap=0.5): + """Get radius of gaussian. + + Args: + det_size (tuple[torch.Tensor]): Size of the detection result. + min_overlap (float): Gaussian_overlap. Defaults to 0.5. + + Returns: + torch.Tensor: Computed radius. + """ + height, width = det_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = torch.sqrt(b1**2 - 4 * a1 * c1) + r1 = (b1 + sq1) / 2 + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = torch.sqrt(b2**2 - 4 * a2 * c2) + r2 = (b2 + sq2) / 2 + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = torch.sqrt(b3**2 - 4 * a3 * c3) + r3 = (b3 + sq3) / 2 + return min(r1, r2, r3) diff --git a/mmdet3d/models/roi_heads/bbox_heads/__init__.py b/mmdet3d/models/roi_heads/bbox_heads/__init__.py index 0256706a0c..38a9d4824f 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/__init__.py +++ b/mmdet3d/models/roi_heads/bbox_heads/__init__.py @@ -3,10 +3,11 @@ Shared2FCBBoxHead, Shared4Conv1FCBBoxHead) from .h3d_bbox_head import H3DBboxHead +from .multi_group_head import CenterHead from .parta2_bbox_head import PartA2BboxHead __all__ = [ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', - 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'H3DBboxHead', - 'PartA2BboxHead' + 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead', + 'H3DBboxHead', 'CenterHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/multi_group_head.py b/mmdet3d/models/roi_heads/bbox_heads/multi_group_head.py new file mode 100644 index 0000000000..cc2ce05914 --- /dev/null +++ b/mmdet3d/models/roi_heads/bbox_heads/multi_group_head.py @@ -0,0 +1,809 @@ +import copy +import numpy as np +import torch +from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init +from torch import nn + +from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius, + xywhr2xyxyr) +from mmdet3d.models.utils import clip_sigmoid +from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu +from mmdet.core import build_bbox_coder, multi_apply +from ... import builder +from ...builder import HEADS, build_loss + + +@HEADS.register_module() +class SeparateHead(nn.Module): + """SeparateHead for CenterHead. + + Args: + in_channels (int): Input channels for conv_layer. + heads (dict): Conv information. + head_conv (int): Output channels. + Default: 64. + final_kernal (int): Kernal size for the last conv layer. + Deafult: 1. + init_bias (float): Initial bias. Default: -2.19. + conv_cfg (dict): Config of conv layer. + Default: dict(type='Conv2d') + norm_cfg (dict): Config of norm layer. + Default: dict(type='BN2d'). + bias (str): Type of bias. Default: 'auto'. + """ + + def __init__(self, + in_channels, + heads, + head_conv=64, + final_kernel=1, + init_bias=-2.19, + conv_cfg=dict(type='Conv2d'), + norm_cfg=dict(type='BN2d'), + bias='auto', + **kwargs): + super(SeparateHead, self).__init__() + + self.heads = heads + self.init_bias = init_bias + for head in self.heads: + classes, num_conv = self.heads[head] + + conv_layers = [] + for i in range(num_conv - 1): + conv_layers.append( + ConvModule( + in_channels, + head_conv, + kernel_size=final_kernel, + stride=1, + padding=final_kernel // 2, + bias=bias, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + + conv_layers.append( + build_conv_layer( + conv_cfg, + head_conv, + classes, + kernel_size=final_kernel, + stride=1, + padding=final_kernel // 2, + bias=True)) + conv_layers = nn.Sequential(*conv_layers) + + self.__setattr__(head, conv_layers) + + def init_weights(self): + """Initialize weights.""" + for head in self.heads: + if head == 'heatmap': + self.__getattr__(head)[-1].bias.data.fill_(self.init_bias) + else: + for m in self.__getattr__(head).modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + + def forward(self, x): + """Forward function for SepHead. + + Args: + x (torch.Tensor): Input feature map with the shape of + [B, 512, 128, 128]. + + Returns: + dict[str: torch.Tensor]: contains the following keys: + + -reg (torch.Tensor): 2D regression value with the \ + shape of [B, 2, H, W]. + -height (torch.Tensor): Height value with the \ + shape of [B, 1, H, W]. + -dim (torch.Tensor): Size value with the shape \ + of [B, 3, H, W]. + -rot (torch.Tensor): Rotation value with the \ + shape of [B, 2, H, W]. + -vel (torch.Tensor): Velocity value with the \ + shape of [B, 2, H, W]. + -heatmap (torch.Tensor): Heatmap with the shape of \ + [B, N, H, W]. + """ + ret_dict = dict() + for head in self.heads: + ret_dict[head] = self.__getattr__(head)(x) + + return ret_dict + + +@HEADS.register_module() +class DCNSeperateHead(nn.Module): + r"""DCNSeperateHead for CenterHead. + + .. code-block:: none + /-----> DCN for heatmap task -----> heatmap task. + feature + \-----> DCN for regression tasks -----> regression tasks + + Args: + in_channels (int): Input channels for conv_layer. + heads (dict): Conv information. + dcn_config (dict): Config of dcn layer. + num_cls (int): Output channels. + Default: 64. + final_kernal (int): Kernal size for the last conv layer. + Deafult: 1. + init_bias (float): Initial bias. Default: -2.19. + conv_cfg (dict): Config of conv layer. + Default: dict(type='Conv2d') + norm_cfg (dict): Config of norm layer. + Default: dict(type='BN2d'). + bias (str): Type of bias. Default: 'auto'. + """ # noqa: W605 + + def __init__(self, + in_channels, + num_cls, + heads, + dcn_config, + head_conv=64, + final_kernel=1, + init_bias=-2.19, + conv_cfg=dict(type='Conv2d'), + norm_cfg=dict(type='BN2d'), + bias='auto', + **kwargs): + super(DCNSeperateHead, self).__init__() + if 'heatmap' in heads: + heads.pop('heatmap') + # feature adaptation with dcn + # use separate features for classification / regression + self.feature_adapt_cls = build_conv_layer(dcn_config) + + self.feature_adapt_reg = build_conv_layer(dcn_config) + + # heatmap prediction head + cls_head = [ + ConvModule( + in_channels, + head_conv, + kernel_size=3, + padding=1, + conv_cfg=conv_cfg, + bias=bias, + norm_cfg=norm_cfg), + build_conv_layer( + conv_cfg, + head_conv, + num_cls, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + ] + self.cls_head = nn.Sequential(*cls_head) + self.init_bias = init_bias + # other regression target + self.task_head = SeparateHead( + in_channels, + heads, + head_conv=head_conv, + final_kernel=final_kernel, + bias=bias) + + def init_weights(self): + """Initialize weights.""" + self.cls_head[-1].bias.data.fill_(self.init_bias) + self.task_head.init_weights() + + def forward(self, x): + """Forward function for DCNSepHead. + + Args: + x (torch.Tensor): Input feature map with the shape of + [B, 512, 128, 128]. + + Returns: + dict[str: torch.Tensor]: contains the following keys: + + -reg (torch.Tensor): 2D regression value with the \ + shape of [B, 2, H, W]. + -height (torch.Tensor): Height value with the \ + shape of [B, 1, H, W]. + -dim (torch.Tensor): Size value with the shape \ + of [B, 3, H, W]. + -rot (torch.Tensor): Rotation value with the \ + shape of [B, 2, H, W]. + -vel (torch.Tensor): Velocity value with the \ + shape of [B, 2, H, W]. + -heatmap (torch.Tensor): Heatmap with the shape of \ + [B, N, H, W]. + """ + center_feat = self.feature_adapt_cls(x) + reg_feat = self.feature_adapt_reg(x) + + cls_score = self.cls_head(center_feat) + ret = self.task_head(reg_feat) + ret['heatmap'] = cls_score + + return ret + + +@HEADS.register_module +class CenterHead(nn.Module): + """CenterHead for CenterPoint. + + Args: + mode (str): Mode of the head. Default: '3d'. + in_channels (list[int] | int): Channels of the input feature map. + Default: [128]. + tasks (list[dict]): Task information including class number + and class names. Default: None. + dataset (str): Name of the dataset. Default: 'nuscenes'. + weight (float): Weight for location loss. Default: 0.25. + code_weights (list[int]): Code weights for location loss. Default: []. + common_heads (dict): Conv information for common heads. + Default: dict(). + loss_cls (dict): Config of classification loss function. + Default: dict(type='GaussianFocalLoss', reduction='mean'). + loss_bbox (dict): Config of regression loss function. + Default: dict(type='L1Loss', reduction='none'). + seperate_head (dict): Config of seperate head. Default: dict( + type='SeparateHead', init_bias=-2.19, final_kernel=3) + share_conv_channel (int): Output channels for share_conv_layer. + Default: 64. + num_heatmap_convs (int): Number of conv layers for heatmap conv layer. + Default: 2. + conv_cfg (dict): Config of conv layer. + Default: dict(type='Conv2d') + norm_cfg (dict): Config of norm layer. + Default: dict(type='BN2d'). + bias (str): Type of bias. Default: 'auto'. + """ + + def __init__(self, + in_channels=[128], + tasks=None, + train_cfg=None, + test_cfg=None, + bbox_coder=None, + common_heads=dict(), + loss_cls=dict(type='GaussianFocalLoss', reduction='mean'), + loss_bbox=dict( + type='L1Loss', reduction='none', loss_weight=0.25), + seperate_head=dict( + type='SeparateHead', init_bias=-2.19, final_kernel=3), + share_conv_channel=64, + num_heatmap_convs=2, + conv_cfg=dict(type='Conv2d'), + norm_cfg=dict(type='BN2d'), + bias='auto', + norm_bbox=True): + super(CenterHead, self).__init__() + + num_classes = [len(t['class_names']) for t in tasks] + self.class_names = [t['class_names'] for t in tasks] + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.in_channels = in_channels + self.num_classes = num_classes + self.norm_bbox = norm_bbox + + self.loss_cls = build_loss(loss_cls) + self.loss_bbox = build_loss(loss_bbox) + self.bbox_coder = build_bbox_coder(bbox_coder) + self.num_anchor_per_locs = [n for n in num_classes] + + # a shared convolution + self.shared_conv = ConvModule( + in_channels, + share_conv_channel, + kernel_size=3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=bias) + + self.task_heads = nn.ModuleList() + + for num_cls in num_classes: + heads = copy.deepcopy(common_heads) + heads.update(dict(heatmap=(num_cls, num_heatmap_convs))) + seperate_head.update( + in_channels=share_conv_channel, heads=heads, num_cls=num_cls) + self.task_heads.append(builder.build_head(seperate_head)) + + def init_weights(self): + """Initialize weights.""" + for task_head in self.task_heads: + task_head.init_weights() + + def forward_single(self, x): + """Forward function for CenterPoint. + + Args: + x (torch.Tensor): Input feature map with the shape of + [B, 512, 128, 128]. + + Returns: + list[dict]: Output results for tasks. + """ + ret_dicts = [] + + x = self.shared_conv(x) + + for task in self.task_heads: + ret_dicts.append(task(x)) + + return ret_dicts + + def forward(self, feats): + """Forward pass. + + Args: + feats (list[torch.Tensor]): Multi-level features, e.g., + features produced by FPN. + + Returns: + tuple(list[dict]): Output results for tasks. + """ + return multi_apply(self.forward_single, feats) + + def _gather_feat(self, feat, ind, mask=None): + """Gather feature map. + + Given feature map and index, return indexed feature map. + + Args: + feat (torch.tensor): Feature map with the shape of [B, H*W, 10]. + ind (torch.Tensor): Index of the ground truth boxes with the + shape of [B, max_obj]. + mask (torch.Tensor): Mask of the feature map with the shape + of [B, max_obj]. Default: None. + + Returns: + torch.Tensor: Feature map after gathering with the shape + of [B, max_obj, 10]. + """ + dim = feat.size(2) + ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) + feat = feat.gather(1, ind) + if mask is not None: + mask = mask.unsqueeze(2).expand_as(feat) + feat = feat[mask] + feat = feat.view(-1, dim) + return feat + + def get_targets(self, gt_bboxes_3d, gt_labels_3d): + """Generate targets. + + Args: + gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground + truth gt boxes. + gt_labels_3d (list[torch.Tensor]): Labels of boxes. + + Returns: + Returns: + tuple[list[torch.Tensor]]: Tuple of target including \ + the following results in order. + + - list[torch.Tensor]: Heatmap scores. + - list[torch.Tensor]: Ground truth boxes. + - list[torch.Tensor]: Indexes indicating the \ + position of the valid boxes. + - list[torch.Tensor]: Masks indicating which \ + boxes are valid. + """ + heatmaps, anno_boxes, inds, masks = multi_apply( + self.get_targets_single, gt_bboxes_3d, gt_labels_3d) + # transpose heatmaps, because the dimension of tensors in each task is + # different, we have to use numpy instead of torch to do the transpose. + heatmaps = np.array(heatmaps).transpose(1, 0).tolist() + heatmaps = [torch.stack(hms_) for hms_ in heatmaps] + # transpose anno_boxes + anno_boxes = np.array(anno_boxes).transpose(1, 0).tolist() + anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] + # transpose inds + inds = np.array(inds).transpose(1, 0).tolist() + inds = [torch.stack(inds_) for inds_ in inds] + # transpose inds + masks = np.array(masks).transpose(1, 0).tolist() + masks = [torch.stack(masks_) for masks_ in masks] + return heatmaps, anno_boxes, inds, masks + + def get_targets_single(self, gt_bboxes_3d, gt_labels_3d): + """Generate training targets for a single sample. + + Args: + gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. + gt_labels_3d (torch.Tensor): Labels of boxes. + + Returns: + tuple[list[torch.Tensor]]: Tuple of target including \ + the following results in order. + + - list[torch.Tensor]: Heatmap scores. + - list[torch.Tensor]: Ground truth boxes. + - list[torch.Tensor]: Indexes indicating the position \ + of the valid boxes. + - list[torch.Tensor]: Masks indicating which boxes \ + are valid. + """ + device = gt_labels_3d.device + gt_bboxes_3d = torch.cat( + (gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]), + dim=1).to(device) + max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg'] + grid_size = torch.tensor(self.train_cfg['grid_size']) + pc_range = torch.tensor(self.train_cfg['point_cloud_range']) + voxel_size = torch.tensor(self.train_cfg['voxel_size']) + + feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor'] + + # reorganize the gt_dict by tasks + task_masks = [] + flag = 0 + for class_name in self.class_names: + task_masks.append([ + torch.where(gt_labels_3d == class_name.index(i) + flag) + for i in class_name + ]) + flag += len(class_name) + + task_boxes = [] + task_classes = [] + flag2 = 0 + for idx, mask in enumerate(task_masks): + task_box = [] + task_class = [] + for m in mask: + task_box.append(gt_bboxes_3d[m]) + # 0 is background for each task, so we need to add 1 here. + task_class.append(gt_labels_3d[m] + 1 - flag2) + task_boxes.append(torch.cat(task_box, axis=0).to(device)) + task_classes.append(torch.cat(task_class).long().to(device)) + flag2 += len(mask) + draw_gaussian = draw_heatmap_gaussian + heatmaps, anno_boxes, inds, masks = [], [], [], [] + + for idx, task_head in enumerate(self.task_heads): + heatmap = gt_bboxes_3d.new_zeros( + (len(self.class_names[idx]), feature_map_size[1], + feature_map_size[0])) + + anno_box = gt_bboxes_3d.new_zeros((max_objs, 10), + dtype=torch.float32) + + ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64) + mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8) + + num_objs = min(task_boxes[idx].shape[0], max_objs) + + for k in range(num_objs): + cls_id = task_classes[idx][k] - 1 + + width = task_boxes[idx][k][3] + length = task_boxes[idx][k][4] + width = width / voxel_size[0] / self.train_cfg[ + 'out_size_factor'] + length = length / voxel_size[1] / self.train_cfg[ + 'out_size_factor'] + + if width > 0 and length > 0: + radius = gaussian_radius( + (length, width), + min_overlap=self.train_cfg['gaussian_overlap']) + radius = max(self.train_cfg['min_radius'], int(radius)) + + # be really careful for the coordinate system of + # your box annotation. + x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][ + 1], task_boxes[idx][k][2] + + coor_x = ( + x - pc_range[0] + ) / voxel_size[0] / self.train_cfg['out_size_factor'] + coor_y = ( + y - pc_range[1] + ) / voxel_size[1] / self.train_cfg['out_size_factor'] + + center = torch.tensor([coor_x, coor_y], + dtype=torch.float32, + device=device) + center_int = center.to(torch.int32) + + # throw out not in range objects to avoid out of array + # area when creating the heatmap + if not (0 <= center_int[0] < feature_map_size[0] + and 0 <= center_int[1] < feature_map_size[1]): + continue + + draw_gaussian(heatmap[cls_id], center_int, radius) + + new_idx = k + x, y = center_int[0], center_int[1] + + assert (y * feature_map_size[0] + x < + feature_map_size[0] * feature_map_size[1]) + + ind[new_idx] = y * feature_map_size[0] + x + mask[new_idx] = 1 + # TODO: support other outdoor dataset + vx, vy = task_boxes[idx][k][7:] + rot = task_boxes[idx][k][6] + box_dim = task_boxes[idx][k][3:6] + if self.norm_bbox: + box_dim = box_dim.log() + anno_box[new_idx] = torch.cat([ + center - torch.tensor([x, y], device=device), + z.unsqueeze(0), box_dim, + torch.sin(rot).unsqueeze(0), + torch.cos(rot).unsqueeze(0), + vx.unsqueeze(0), + vy.unsqueeze(0) + ]) + + heatmaps.append(heatmap) + anno_boxes.append(anno_box) + masks.append(mask) + inds.append(ind) + return heatmaps, anno_boxes, inds, masks + + def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs): + """Loss function for CenterHead. + + Args: + gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground + truth gt boxes. + gt_labels_3d (list[torch.Tensor]): Labels of boxes. + preds_dicts (dict): Output of forward function. + + Returns: + dict[str:torch.Tensor]: Loss of heatmap and bbox of each task. + """ + heatmaps, anno_boxes, inds, masks = self.get_targets( + gt_bboxes_3d, gt_labels_3d) + loss_dict = dict() + for task_id, preds_dict in enumerate(preds_dicts): + # heatmap focal loss + preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap']) + num_pos = heatmaps[task_id].eq(1).float().sum().item() + loss_heatmap = self.loss_cls( + preds_dict[0]['heatmap'], + heatmaps[task_id], + avg_factor=max(num_pos, 1)) + target_box = anno_boxes[task_id] + # reconstruct the anno_box from multiple reg heads + preds_dict[0]['anno_box'] = torch.cat( + (preds_dict[0]['reg'], preds_dict[0]['height'], + preds_dict[0]['dim'], preds_dict[0]['rot'], + preds_dict[0]['vel']), + dim=1) + + # Regression loss for dimension, offset, height, rotation + ind = inds[task_id] + num = masks[task_id].float().sum() + pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous() + pred = pred.view(pred.size(0), -1, pred.size(3)) + pred = self._gather_feat(pred, ind) + mask = masks[task_id].unsqueeze(2).expand_as(target_box).float() + isnotnan = (~torch.isnan(target_box)).float() + mask *= isnotnan + + code_weights = self.train_cfg.get('code_weights', None) + bbox_weights = mask * mask.new_tensor(code_weights) + loss_bbox = self.loss_bbox( + pred, target_box, bbox_weights, avg_factor=(num + 1e-4)) + loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap + loss_dict[f'task{task_id}.loss_bbox'] = loss_bbox + return loss_dict + + def get_bboxes(self, preds_dicts, img_metas, img=None, rescale=False): + """Generate bboxes from bbox head predictions. + + Args: + preds_dicts (tuple[list[dict]]): Prediction results. + img_metas (list[dict]): Point cloud and image's meta info. + + Returns: + list[dict]: Decoded bbox, scores and labels after nms. + """ + rets = [] + for task_id, preds_dict in enumerate(preds_dicts): + num_class_with_bg = self.num_classes[task_id] + batch_size = preds_dict[0]['heatmap'].shape[0] + batch_heatmap = preds_dict[0]['heatmap'].sigmoid() + + batch_reg = preds_dict[0]['reg'] + batch_hei = preds_dict[0]['height'] + + if self.norm_bbox: + batch_dim = torch.exp(preds_dict[0]['dim']) + else: + batch_dim = preds_dict[0]['dim'] + + batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1) + batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1) + + if 'vel' in preds_dict[0]: + batch_vel = preds_dict[0]['vel'] + else: + batch_vel = None + temp = self.bbox_coder.decode( + batch_heatmap, + batch_rots, + batch_rotc, + batch_hei, + batch_dim, + batch_vel, + reg=batch_reg, + task_id=task_id) + assert self.test_cfg['nms_type'] in ['circle', 'rotate'] + batch_reg_preds = [box['bboxes'] for box in temp] + batch_cls_preds = [box['scores'] for box in temp] + batch_cls_labels = [box['labels'] for box in temp] + if self.test_cfg['nms_type'] == 'circle': + ret_task = [] + for i in range(batch_size): + boxes3d = temp[i]['bboxes'] + scores = temp[i]['scores'] + labels = temp[i]['labels'] + centers = boxes3d[:, [0, 1]] + boxes = torch.cat([centers, scores.view(-1, 1)], dim=1) + keep = torch.tensor( + circle_nms( + boxes.detach().cpu().numpy(), + self.test_cfg['min_radius'][task_id], + post_max_size=self.test_cfg['post_max_size']), + dtype=torch.long, + device=boxes.device) + + boxes3d = boxes3d[keep] + scores = scores[keep] + labels = labels[keep] + ret = dict(bboxes=boxes3d, scores=scores, labels=labels) + ret_task.append(ret) + rets.append(ret_task) + else: + rets.append( + self.get_task_detections(num_class_with_bg, + batch_cls_preds, batch_reg_preds, + batch_cls_labels, img_metas)) + + # Merge branches results + num_samples = len(rets[0]) + + ret_list = [] + for i in range(num_samples): + for k in rets[0][i].keys(): + if k == 'bboxes': + bboxes = torch.cat([ret[i][k] for ret in rets]) + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 + bboxes = img_metas[i]['box_type_3d']( + bboxes, self.bbox_coder.code_size) + elif k == 'scores': + scores = torch.cat([ret[i][k] for ret in rets]) + elif k == 'labels': + flag = 0 + for j, num_class in enumerate(self.num_classes): + rets[j][i][k] += flag + flag += num_class + labels = torch.cat([ret[i][k] for ret in rets]) + ret_list.append([bboxes, scores, labels]) + return ret_list + + def get_task_detections(self, num_class_with_bg, batch_cls_preds, + batch_reg_preds, batch_cls_labels, img_metas): + """Rotate nms for each task. + + Args: + num_class_with_bg (int): Number of classes for the current task. + batch_cls_preds (list[torch.Tensor]): Prediction score with the + shape of [N]. + batch_reg_preds (list[torch.Tensor]): Prediction bbox with the + shape of [N, 9]. + batch_cls_labels (list[torch.Tensor]): Prediction label with the + shape of [N]. + img_metas (list[dict]): Meta information of each sample. + + Returns: + list[dict[str: torch.Tensor]]: contains the following keys: + + -bboxes (torch.Tensor): Prediction bboxes after nms with the \ + shape of [N, 9]. + -scores (torch.Tensor): Prediction scores after nms with the \ + shape of [N]. + -labels (torch.Tensor): Prediction labels after nms with the \ + shape of [N]. + """ + predictions_dicts = [] + post_center_range = self.test_cfg['post_center_limit_range'] + if len(post_center_range) > 0: + post_center_range = torch.tensor( + post_center_range, + dtype=batch_reg_preds[0].dtype, + device=batch_reg_preds[0].device) + + for i, (box_preds, cls_preds, cls_labels) in enumerate( + zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)): + + # Apply NMS in birdeye view + + # get highest score per prediction, than apply nms + # to remove overlapped box. + if num_class_with_bg == 1: + top_scores = cls_preds.squeeze(-1) + top_labels = torch.zeros( + cls_preds.shape[0], + device=cls_preds.device, + dtype=torch.long) + + else: + top_labels = cls_labels.long() + top_scores = cls_preds.squeeze(-1) + + if self.test_cfg['score_threshold'] > 0.0: + thresh = torch.tensor( + [self.test_cfg['score_threshold']], + device=cls_preds.device).type_as(cls_preds) + top_scores_keep = top_scores >= thresh + top_scores = top_scores.masked_select(top_scores_keep) + + if top_scores.shape[0] != 0: + if self.test_cfg['score_threshold'] > 0.0: + box_preds = box_preds[top_scores_keep] + top_labels = top_labels[top_scores_keep] + + boxes_for_nms = xywhr2xyxyr(img_metas[i]['box_type_3d']( + box_preds[:, :], self.bbox_coder.code_size).bev) + # the nms in 3d detection just remove overlap boxes. + + selected = nms_gpu( + boxes_for_nms, + top_scores, + thresh=self.test_cfg['nms_iou_threshold'], + pre_maxsize=self.test_cfg['nms_pre_max_size'], + post_max_size=self.test_cfg['nms_post_max_size']) + else: + selected = [] + + # if selected is not None: + selected_boxes = box_preds[selected] + selected_labels = top_labels[selected] + selected_scores = top_scores[selected] + + # finally generate predictions. + if selected_boxes.shape[0] != 0: + box_preds = selected_boxes + scores = selected_scores + label_preds = selected_labels + final_box_preds = box_preds + final_scores = scores + final_labels = label_preds + if post_center_range is not None: + mask = (final_box_preds[:, :3] >= + post_center_range[:3]).all(1) + mask &= (final_box_preds[:, :3] <= + post_center_range[3:]).all(1) + predictions_dict = dict( + bboxes=final_box_preds[mask], + scores=final_scores[mask], + labels=final_labels[mask]) + else: + predictions_dict = dict( + bboxes=final_box_preds, + scores=final_scores, + labels=final_labels) + else: + dtype = batch_reg_preds[0].dtype + device = batch_reg_preds[0].device + predictions_dict = dict( + bboxes=torch.zeros([0, self.bbox_coder.code_size], + dtype=dtype, + device=device), + scores=torch.zeros([0], dtype=dtype, device=device), + labels=torch.zeros([0], + dtype=top_labels.dtype, + device=device)) + + predictions_dicts.append(predictions_dict) + return predictions_dicts diff --git a/mmdet3d/models/utils/__init__.py b/mmdet3d/models/utils/__init__.py new file mode 100644 index 0000000000..2206490be1 --- /dev/null +++ b/mmdet3d/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .clip_sigmoid import clip_sigmoid + +__all__ = ['clip_sigmoid'] diff --git a/mmdet3d/models/utils/clip_sigmoid.py b/mmdet3d/models/utils/clip_sigmoid.py new file mode 100644 index 0000000000..5182de1139 --- /dev/null +++ b/mmdet3d/models/utils/clip_sigmoid.py @@ -0,0 +1,16 @@ +import torch + + +def clip_sigmoid(x, eps=1e-4): + """Sigmoid function for input feature. + + Args: + x (torch.Tensor): Input feature map with the shape of [B, N, H, W]. + eps (float): Lower bound of the range to be clamped to. Defaults + to 1e-4. + + Returns: + torch.Tensor: Feature map after sigmoid. + """ + y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps) + return y diff --git a/mmdet3d/ops/iou3d/iou3d_utils.py b/mmdet3d/ops/iou3d/iou3d_utils.py index 6c256ac53d..6f36019e72 100644 --- a/mmdet3d/ops/iou3d/iou3d_utils.py +++ b/mmdet3d/ops/iou3d/iou3d_utils.py @@ -22,24 +22,32 @@ def boxes_iou_bev(boxes_a, boxes_b): return ans_iou -def nms_gpu(boxes, scores, thresh): - """Non maximum suppression on GPU. +def nms_gpu(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): + """Nms function with gpu implementation. Args: - boxes (torch.Tensor): Input boxes with shape (N, 5). - scores (torch.Tensor): Scores of predicted boxes with shape (N). - thresh (torch.Tensor): Threshold of non maximum suppression. + boxes (torch.Tensor): Input boxes with the shape of [N, 5] + ([x1, y1, x2, y2, ry]). + scores (torch.Tensor): Scores of boxes with the shape of [N]. + thresh (int): Threshold. + pre_maxsize (int): Max size of boxes before nms. Default: None. + post_maxsize (int): Max size of boxes after nms. Default: None. Returns: - torch.Tensor: Remaining indices with scores in descending order. + torch.Tensor: Indexes after nms. """ order = scores.sort(0, descending=True)[1] + if pre_maxsize is not None: + order = order[:pre_maxsize] boxes = boxes[order].contiguous() keep = torch.zeros(boxes.size(0), dtype=torch.long) num_out = iou3d_cuda.nms_gpu(boxes, keep, thresh, boxes.device.index) - return order[keep[:num_out].cuda(boxes.device)].contiguous() + keep = order[keep[:num_out].cuda(boxes.device)].contiguous() + if post_max_size is not None: + keep = keep[:post_max_size] + return keep def nms_normal_gpu(boxes, scores, thresh): diff --git a/tests/test_box_coders.py b/tests/test_bbox_coders.py similarity index 94% rename from tests/test_box_coders.py rename to tests/test_bbox_coders.py index bae37ee3cc..5ba9fbfb28 100644 --- a/tests/test_box_coders.py +++ b/tests/test_bbox_coders.py @@ -323,3 +323,31 @@ def test_anchor_free_box_coder(): assert dir_res_norm.shape == torch.Size([2, 256, 12]) assert dir_res.shape == torch.Size([2, 256, 12]) assert size.shape == torch.Size([2, 256, 3]) + + +def test_centerpoint_bbox_coder(): + bbox_coder_cfg = dict( + type='CenterPointBBoxCoder', + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_num=500, + score_threshold=0.1, + pc_range=[-51.2, -51.2], + out_size_factor=4, + voxel_size=[0.2, 0.2]) + + bbox_coder = build_bbox_coder(bbox_coder_cfg) + + batch_dim = torch.rand([2, 3, 128, 128]) + batch_hei = torch.rand([2, 1, 128, 128]) + batch_hm = torch.rand([2, 2, 128, 128]) + batch_reg = torch.rand([2, 2, 128, 128]) + batch_rotc = torch.rand([2, 1, 128, 128]) + batch_rots = torch.rand([2, 1, 128, 128]) + batch_vel = torch.rand([2, 2, 128, 128]) + + temp = bbox_coder.decode(batch_hm, batch_rots, batch_rotc, batch_hei, + batch_dim, batch_vel, batch_reg, 5) + for i in range(len(temp)): + assert temp[i]['bboxes'].shape == torch.Size([500, 9]) + assert temp[i]['scores'].shape == torch.Size([500]) + assert temp[i]['labels'].shape == torch.Size([500]) diff --git a/tests/test_heads.py b/tests/test_heads.py index aea269e248..561e28e00e 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -8,6 +8,7 @@ from mmdet3d.core.bbox import (Box3DMode, DepthInstance3DBoxes, LiDARInstance3DBoxes) from mmdet3d.models.builder import build_head +from mmdet.apis import set_random_seed def _setup_seed(seed): @@ -689,6 +690,199 @@ def test_h3d_head(): assert ret_dict['primitive_sem_matching_loss'] >= 0 +def test_center_head(): + tasks = [ + dict(num_class=1, class_names=['car']), + dict(num_class=2, class_names=['truck', 'construction_vehicle']), + dict(num_class=2, class_names=['bus', 'trailer']), + dict(num_class=1, class_names=['barrier']), + dict(num_class=2, class_names=['motorcycle', 'bicycle']), + dict(num_class=2, class_names=['pedestrian', 'traffic_cone']), + ] + bbox_cfg = dict( + type='CenterPointBBoxCoder', + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_num=500, + score_threshold=0.1, + pc_range=[-51.2, -51.2], + out_size_factor=8, + voxel_size=[0.2, 0.2]) + train_cfg = dict( + grid_size=[1024, 1024, 40], + point_cloud_range=[-51.2, -51.2, -5., 51.2, 51.2, 3.], + voxel_size=[0.1, 0.1, 0.2], + out_size_factor=8, + dense_reg=1, + gaussian_overlap=0.1, + max_objs=500, + code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0], + min_radius=2) + test_cfg = dict( + post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_per_img=500, + max_pool_nms=False, + min_radius=[4, 12, 10, 1, 0.85, 0.175], + post_max_size=83, + score_threshold=0.1, + pc_range=[-51.2, -51.2], + out_size_factor=8, + voxel_size=[0.2, 0.2], + nms_type='circle') + center_head_cfg = dict( + type='CenterHead', + in_channels=sum([256, 256]), + tasks=tasks, + train_cfg=train_cfg, + test_cfg=test_cfg, + bbox_coder=bbox_cfg, + common_heads=dict( + reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)), + share_conv_channel=64, + norm_bbox=True) + + center_head = build_head(center_head_cfg) + + x = torch.rand([2, 512, 128, 128]) + output = center_head([x]) + for i in range(6): + assert output[i][0]['reg'].shape == torch.Size([2, 2, 128, 128]) + assert output[i][0]['height'].shape == torch.Size([2, 1, 128, 128]) + assert output[i][0]['dim'].shape == torch.Size([2, 3, 128, 128]) + assert output[i][0]['rot'].shape == torch.Size([2, 2, 128, 128]) + assert output[i][0]['vel'].shape == torch.Size([2, 2, 128, 128]) + assert output[i][0]['heatmap'].shape == torch.Size( + [2, tasks[i]['num_class'], 128, 128]) + + # test get_bboxes + img_metas = [ + dict(box_type_3d=LiDARInstance3DBoxes), + dict(box_type_3d=LiDARInstance3DBoxes) + ] + ret_lists = center_head.get_bboxes(output, img_metas) + for ret_list in ret_lists: + assert ret_list[0].tensor.shape[0] <= 500 + assert ret_list[1].shape[0] <= 500 + assert ret_list[2].shape[0] <= 500 + + +def test_dcn_center_head(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and CUDA') + set_random_seed(0) + tasks = [ + dict(num_class=1, class_names=['car']), + dict(num_class=2, class_names=['truck', 'construction_vehicle']), + dict(num_class=2, class_names=['bus', 'trailer']), + dict(num_class=1, class_names=['barrier']), + dict(num_class=2, class_names=['motorcycle', 'bicycle']), + dict(num_class=2, class_names=['pedestrian', 'traffic_cone']), + ] + voxel_size = [0.2, 0.2, 8] + dcn_center_head_cfg = dict( + type='CenterHead', + mode='3d', + in_channels=sum([128, 128, 128]), + tasks=[ + dict(num_class=1, class_names=['car']), + dict(num_class=2, class_names=['truck', 'construction_vehicle']), + dict(num_class=2, class_names=['bus', 'trailer']), + dict(num_class=1, class_names=['barrier']), + dict(num_class=2, class_names=['motorcycle', 'bicycle']), + dict(num_class=2, class_names=['pedestrian', 'traffic_cone']), + ], + common_heads={ + 'reg': (2, 2), + 'height': (1, 2), + 'dim': (3, 2), + 'rot': (2, 2), + 'vel': (2, 2) + }, + share_conv_channel=64, + bbox_coder=dict( + type='CenterPointBBoxCoder', + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_num=500, + score_threshold=0.1, + pc_range=[-51.2, -51.2], + out_size_factor=4, + voxel_size=voxel_size[:2], + code_size=9), + seperate_head=dict( + type='DCNSeperateHead', + dcn_config=dict( + type='DCN', + in_channels=64, + out_channels=64, + kernel_size=3, + padding=1, + groups=4, + bias=True), + init_bias=-2.19, + final_kernel=3), + loss_cls=dict(type='GaussianFocalLoss', reduction='mean'), + loss_bbox=dict(type='L1Loss', reduction='none', loss_weight=0.25), + norm_bbox=True) + # model training and testing settings + train_cfg = dict( + grid_size=[512, 512, 1], + point_cloud_range=[-51.2, -51.2, -5., 51.2, 51.2, 3.], + voxel_size=voxel_size, + out_size_factor=4, + dense_reg=1, + gaussian_overlap=0.1, + max_objs=500, + min_radius=2, + code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0]) + + test_cfg = dict( + post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_per_img=500, + max_pool_nms=False, + min_radius=[4, 12, 10, 1, 0.85, 0.175], + post_max_size=83, + score_threshold=0.1, + pc_range=[-51.2, -51.2], + out_size_factor=4, + voxel_size=voxel_size[:2], + nms_type='circle') + dcn_center_head_cfg.update(train_cfg=train_cfg, test_cfg=test_cfg) + + dcn_center_head = build_head(dcn_center_head_cfg).cuda() + + x = torch.ones([2, 384, 128, 128]).cuda() + output = dcn_center_head([x]) + for i in range(6): + assert output[i][0]['reg'].shape == torch.Size([2, 2, 128, 128]) + assert output[i][0]['height'].shape == torch.Size([2, 1, 128, 128]) + assert output[i][0]['dim'].shape == torch.Size([2, 3, 128, 128]) + assert output[i][0]['rot'].shape == torch.Size([2, 2, 128, 128]) + assert output[i][0]['vel'].shape == torch.Size([2, 2, 128, 128]) + assert output[i][0]['heatmap'].shape == torch.Size( + [2, tasks[i]['num_class'], 128, 128]) + + # Test loss. + gt_bboxes_0 = LiDARInstance3DBoxes(torch.rand([10, 9]).cuda(), box_dim=9) + gt_bboxes_1 = LiDARInstance3DBoxes(torch.rand([20, 9]).cuda(), box_dim=9) + gt_labels_0 = torch.randint(1, 11, [10]).cuda() + gt_labels_1 = torch.randint(1, 11, [20]).cuda() + gt_bboxes_3d = [gt_bboxes_0, gt_bboxes_1] + gt_labels_3d = [gt_labels_0, gt_labels_1] + loss = dcn_center_head.loss(gt_bboxes_3d, gt_labels_3d, output) + loss_sum = torch.sum(torch.stack([item for _, item in loss.items()])) + assert torch.isclose(loss_sum, torch.tensor(21972.1230)) + + # test get_bboxes + img_metas = [ + dict(box_type_3d=LiDARInstance3DBoxes), + dict(box_type_3d=LiDARInstance3DBoxes) + ] + ret_lists = dcn_center_head.get_bboxes(output, img_metas) + for ret_list in ret_lists: + assert ret_list[0].tensor.shape[0] <= 500 + assert ret_list[1].shape[0] <= 500 + assert ret_list[2].shape[0] <= 500 + + def test_ssd3d_head(): if not torch.cuda.is_available(): pytest.skip('test requires GPU and torch+cuda') diff --git a/tests/test_nms.py b/tests/test_nms.py index 28e4b296b1..6802bcc2b2 100644 --- a/tests/test_nms.py +++ b/tests/test_nms.py @@ -1,3 +1,4 @@ +import numpy as np import torch @@ -55,3 +56,19 @@ def test_aligned_3d_nms(): ]) assert torch.all(pick == expected_pick) + + +def test_circle_nms(): + from mmdet3d.core.post_processing import circle_nms + boxes = torch.tensor([[-11.1100, 2.1300, 0.8823], + [-11.2810, 2.2422, 0.8914], + [-10.3966, -0.3198, 0.8643], + [-10.2906, -13.3159, + 0.8401], [5.6518, 9.9791, 0.8271], + [-11.2652, 13.3637, 0.8267], + [4.7768, -13.0409, 0.7810], [5.6621, 9.0422, 0.7753], + [-10.5561, 18.9627, 0.7518], + [-10.5643, 13.2293, 0.7200]]) + keep = circle_nms(boxes.numpy(), 0.175) + expected_keep = [1, 2, 3, 4, 5, 6, 7, 8, 9] + assert np.all(keep == expected_keep) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..7493dc83e5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,11 @@ +import torch + +from mmdet3d.core import draw_heatmap_gaussian + + +def test_gaussian(): + heatmap = torch.zeros((128, 128)) + ct_int = torch.tensor([64, 64], dtype=torch.int32) + radius = 2 + draw_heatmap_gaussian(heatmap, ct_int, radius) + assert torch.isclose(torch.sum(heatmap), torch.tensor(4.3505), atol=1e-3)