From 0ac6625d2d05fafe6b0b637540e97d90e54e3596 Mon Sep 17 00:00:00 2001 From: liyinhao Date: Sun, 2 Aug 2020 11:27:01 +0800 Subject: [PATCH] Add centerpoint_rpn. --- mmdet3d/models/necks/centerpoint_fpn.py | 91 +++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 mmdet3d/models/necks/centerpoint_fpn.py diff --git a/mmdet3d/models/necks/centerpoint_fpn.py b/mmdet3d/models/necks/centerpoint_fpn.py new file mode 100644 index 0000000000..88c44c23d9 --- /dev/null +++ b/mmdet3d/models/necks/centerpoint_fpn.py @@ -0,0 +1,91 @@ +import numpy as np +import torch +from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, + constant_init, is_norm, kaiming_init) +from torch import nn as nn + +from mmdet.models import NECKS + + +@NECKS.register_module() +class CenterPointFPN(nn.Module): + """FPN used in SECOND/PointPillars/PartA2/MVXNet. + + Args: + in_channels (list[int]): Input channels of multi-scale feature maps. + out_channels (list[int]): Output channels of feature maps. + upsample_strides (list[int]): Strides used to upsample + the feature maps. + norm_cfg (dict): Config dict of normalization layers. + upsample_cfg (dict): Config dict of upsample layers. + conv_cfg (dict): Config dict of conv layers. + """ + + def __init__(self, + in_channels=[128, 128, 256], + out_channels=[256, 256, 256], + upsample_strides=[1, 2, 4], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + upsample_cfg=dict(type='deconv', bias=False), + conv_cfg=dict(type='Conv2d', bias=False)): + # if for GroupNorm, + # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True) + super(CenterPointFPN, self).__init__() + assert len(out_channels) == len(upsample_strides) == len(in_channels) + self.in_channels = in_channels + self.out_channels = out_channels + + deblocks = [] + for i, out_channel in enumerate(out_channels): + stride = upsample_strides[i] + if stride > 1: + upsample_layer = build_upsample_layer( + upsample_cfg, + in_channels=in_channels[i], + out_channels=out_channel, + kernel_size=stride, + stride=upsample_strides[i]) + deblock = nn.Sequential( + upsample_layer, + build_norm_layer(norm_cfg, out_channel)[1], + nn.ReLU(inplace=True)) + else: + stride = np.round(1 / stride).astype(np.int64) + upsample_layer = build_conv_layer( + conv_cfg, + in_channels=in_channels[i], + out_channels=out_channel, + kernel_size=stride, + stride=stride) + deblock = nn.Sequential( + upsample_layer, + build_norm_layer(norm_cfg, out_channel)[1], + nn.ReLU(inplace=True)) + deblocks.append(deblock) + self.deblocks = nn.ModuleList(deblocks) + + def init_weights(self): + """Initialize weights of FPN.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif is_norm(m): + constant_init(m, 1) + + def forward(self, x): + """Forward function. + + Args: + x (torch.Tensor): 4D Tensor in (N, C, H, W) shape. + + Returns: + list[torch.Tensor]: Multi-level feature maps. + """ + assert len(x) == len(self.in_channels) + ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)] + + if len(ups) > 1: + out = torch.cat(ups, dim=1) + else: + out = ups[0] + return out