Skip to content

Commit

Permalink
Reconstruct centerpoint_rpn.
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinhao committed Aug 2, 2020
1 parent 878b24f commit bde41e7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 179 deletions.
4 changes: 2 additions & 2 deletions mmdet3d/models/necks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mmdet.models.necks.fpn import FPN
from .centerpoint_rpn import CenterPointRPN
from .centerpoint_fpn import CenterPointFPN
from .second_fpn import SECONDFPN

__all__ = ['FPN', 'SECONDFPN', 'CenterPointRPN']
__all__ = ['FPN', 'SECONDFPN', 'CenterPointFPN']
166 changes: 0 additions & 166 deletions mmdet3d/models/necks/centerpoint_rpn.py

This file was deleted.

33 changes: 22 additions & 11 deletions tests/test_necks.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
import torch

from mmdet3d.models.builder import build_neck
from mmdet3d.models.builder import build_backbone, build_neck


def test_centerpoint_rpn():
centerpoint_rpn_cfg = dict(
type='CenterPointRPN',
second_cfg = dict(
type='SECOND',
in_channels=64,
out_channels=[64, 128, 256],
layer_nums=[3, 5, 5],
downsample_strides=[2, 2, 2],
downsample_channels=[64, 128, 256],
layer_strides=[2, 2, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False))

second = build_backbone(second_cfg)

centerpoint_fpn_cfg = dict(
type='CenterPointFPN',
in_channels=[64, 128, 256],
out_channels=[128, 128, 128],
upsample_strides=[0.5, 1, 2],
upsample_channels=[128, 128, 128],
input_channels=64,
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01))
centerpoint_rpn = build_neck(centerpoint_rpn_cfg)
centerpoint_rpn.init_weights()
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False))

centerpoint_fpn = build_neck(centerpoint_fpn_cfg)

input = torch.rand([4, 64, 512, 512])
output = centerpoint_rpn(input)
sec_output = second(input)
output = centerpoint_fpn(sec_output)
assert output.shape == torch.Size([4, 384, 128, 128])

0 comments on commit bde41e7

Please sign in to comment.