-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
liyinhao
committed
Aug 2, 2020
1 parent
bde41e7
commit 0ac6625
Showing
1 changed file
with
91 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
import torch | ||
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, | ||
constant_init, is_norm, kaiming_init) | ||
from torch import nn as nn | ||
|
||
from mmdet.models import NECKS | ||
|
||
|
||
@NECKS.register_module() | ||
class CenterPointFPN(nn.Module): | ||
"""FPN used in SECOND/PointPillars/PartA2/MVXNet. | ||
Args: | ||
in_channels (list[int]): Input channels of multi-scale feature maps. | ||
out_channels (list[int]): Output channels of feature maps. | ||
upsample_strides (list[int]): Strides used to upsample | ||
the feature maps. | ||
norm_cfg (dict): Config dict of normalization layers. | ||
upsample_cfg (dict): Config dict of upsample layers. | ||
conv_cfg (dict): Config dict of conv layers. | ||
""" | ||
|
||
def __init__(self, | ||
in_channels=[128, 128, 256], | ||
out_channels=[256, 256, 256], | ||
upsample_strides=[1, 2, 4], | ||
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), | ||
upsample_cfg=dict(type='deconv', bias=False), | ||
conv_cfg=dict(type='Conv2d', bias=False)): | ||
# if for GroupNorm, | ||
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True) | ||
super(CenterPointFPN, self).__init__() | ||
assert len(out_channels) == len(upsample_strides) == len(in_channels) | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
|
||
deblocks = [] | ||
for i, out_channel in enumerate(out_channels): | ||
stride = upsample_strides[i] | ||
if stride > 1: | ||
upsample_layer = build_upsample_layer( | ||
upsample_cfg, | ||
in_channels=in_channels[i], | ||
out_channels=out_channel, | ||
kernel_size=stride, | ||
stride=upsample_strides[i]) | ||
deblock = nn.Sequential( | ||
upsample_layer, | ||
build_norm_layer(norm_cfg, out_channel)[1], | ||
nn.ReLU(inplace=True)) | ||
else: | ||
stride = np.round(1 / stride).astype(np.int64) | ||
upsample_layer = build_conv_layer( | ||
conv_cfg, | ||
in_channels=in_channels[i], | ||
out_channels=out_channel, | ||
kernel_size=stride, | ||
stride=stride) | ||
deblock = nn.Sequential( | ||
upsample_layer, | ||
build_norm_layer(norm_cfg, out_channel)[1], | ||
nn.ReLU(inplace=True)) | ||
deblocks.append(deblock) | ||
self.deblocks = nn.ModuleList(deblocks) | ||
|
||
def init_weights(self): | ||
"""Initialize weights of FPN.""" | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
kaiming_init(m) | ||
elif is_norm(m): | ||
constant_init(m, 1) | ||
|
||
def forward(self, x): | ||
"""Forward function. | ||
Args: | ||
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape. | ||
Returns: | ||
list[torch.Tensor]: Multi-level feature maps. | ||
""" | ||
assert len(x) == len(self.in_channels) | ||
ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)] | ||
|
||
if len(ups) > 1: | ||
out = torch.cat(ups, dim=1) | ||
else: | ||
out = ups[0] | ||
return out |