Skip to content

Commit

Permalink
Add centerpoint_rpn.
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinhao committed Aug 2, 2020
1 parent bde41e7 commit 0ac6625
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions mmdet3d/models/necks/centerpoint_fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
import torch
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, is_norm, kaiming_init)
from torch import nn as nn

from mmdet.models import NECKS


@NECKS.register_module()
class CenterPointFPN(nn.Module):
"""FPN used in SECOND/PointPillars/PartA2/MVXNet.
Args:
in_channels (list[int]): Input channels of multi-scale feature maps.
out_channels (list[int]): Output channels of feature maps.
upsample_strides (list[int]): Strides used to upsample
the feature maps.
norm_cfg (dict): Config dict of normalization layers.
upsample_cfg (dict): Config dict of upsample layers.
conv_cfg (dict): Config dict of conv layers.
"""

def __init__(self,
in_channels=[128, 128, 256],
out_channels=[256, 256, 256],
upsample_strides=[1, 2, 4],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False)):
# if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(CenterPointFPN, self).__init__()
assert len(out_channels) == len(upsample_strides) == len(in_channels)
self.in_channels = in_channels
self.out_channels = out_channels

deblocks = []
for i, out_channel in enumerate(out_channels):
stride = upsample_strides[i]
if stride > 1:
upsample_layer = build_upsample_layer(
upsample_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=stride,
stride=upsample_strides[i])
deblock = nn.Sequential(
upsample_layer,
build_norm_layer(norm_cfg, out_channel)[1],
nn.ReLU(inplace=True))
else:
stride = np.round(1 / stride).astype(np.int64)
upsample_layer = build_conv_layer(
conv_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=stride,
stride=stride)
deblock = nn.Sequential(
upsample_layer,
build_norm_layer(norm_cfg, out_channel)[1],
nn.ReLU(inplace=True))
deblocks.append(deblock)
self.deblocks = nn.ModuleList(deblocks)

def init_weights(self):
"""Initialize weights of FPN."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif is_norm(m):
constant_init(m, 1)

def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
Returns:
list[torch.Tensor]: Multi-level feature maps.
"""
assert len(x) == len(self.in_channels)
ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]

if len(ups) > 1:
out = torch.cat(ups, dim=1)
else:
out = ups[0]
return out

0 comments on commit 0ac6625

Please sign in to comment.