Skip to content

Commit

Permalink
[Feature] Support PointRCNN backbone (#974)
Browse files Browse the repository at this point in the history
* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import (#839)

* [Enhance]  refactor  iou_neg_piecewise_sampler.py (#842)

* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import

* refactor iou_neg_piecewise_sampler.py

* add docstring

* modify docstring

Co-authored-by: Yezhen Cong <[email protected]>
Co-authored-by: THU17cyz <[email protected]>

* [Feature] Add roipooling cuda ops (#843)

* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import

* add roipooling cuda ops

* add roi extractor

* add test_roi_extractor unittest

* Modify setup.py to install roipooling ops

* modify docstring

* remove enlarge bbox in roipoint pooling

* add_roipooling_ops

* modify docstring

Co-authored-by: Yezhen Cong <[email protected]>
Co-authored-by: THU17cyz <[email protected]>

* [Refactor] Refactor code structure and docstrings (#803)

* refactor points_in_boxes

* Merge same functions of three boxes

* More docstring fixes and unify x/y/z size

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Remove None in function param type

* Fix unittest

* Add comments for NMS functions

* Merge methods of Points

* Add unittest

* Add optional and default value

* Fix box conversion and add unittest

* Fix comments

* Add unit test

* Indent

* Fix CI

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Add unit test for box bev

* More unit tests and refine docstrings in box_np_ops

* Fix comment

* Add deprecation warning

* support pointrcnn backbone

* add docstring

* modify docstring

* modify docstring

* modify docstring

* Update pointnet2_fp_neck.py

* add code block

* refine docstring & code

* add unittest on fp_neck

* refine unittest

* refine unittest

* refine unittest

* refine unittest

* refine unittest

* fix docstring

Co-authored-by: Yezhen Cong <[email protected]>
Co-authored-by: Xi Liu <[email protected]>
Co-authored-by: THU17cyz <[email protected]>
Co-authored-by: xiliu8006 <[email protected]>
  • Loading branch information
5 people authored Oct 20, 2021
1 parent 7c62335 commit b8856a1
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 12 deletions.
6 changes: 5 additions & 1 deletion mmdet3d/models/backbones/pointnet2_sa_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def __init__(self,
self.out_indices = out_indices
assert max(out_indices) < self.num_sa
assert len(num_points) == len(radii) == len(num_samples) == len(
sa_channels) == len(aggregation_channels)
sa_channels)
if aggregation_channels is not None:
assert len(sa_channels) == len(aggregation_channels)
else:
aggregation_channels = [None] * len(sa_channels)

self.SA_modules = nn.ModuleList()
self.aggregation_mlps = nn.ModuleList()
Expand Down
4 changes: 4 additions & 0 deletions mmdet3d/models/dense_heads/anchor3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(self,
self.assign_per_class = assign_per_class
self.dir_offset = dir_offset
self.dir_limit_offset = dir_limit_offset
import warnings
warnings.warn(
'dir_offset and dir_limit_offset will be depressed and be '
'incorporated into box coder in the future')
self.fp16_enabled = False

# build anchor generator
Expand Down
12 changes: 8 additions & 4 deletions mmdet3d/models/dense_heads/fcos_mono3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,20 @@ def add_sin_difference(boxes1, boxes2):
@staticmethod
def get_direction_target(reg_targets,
dir_offset=0,
dir_limit_offset=0,
dir_limit_offset=0.0,
num_bins=2,
one_hot=True):
"""Encode direction to 0 ~ num_bins-1.
Args:
reg_targets (torch.Tensor): Bbox regression targets.
dir_offset (int): Direction offset.
num_bins (int): Number of bins to divide 2*PI.
one_hot (bool): Whether to encode as one hot.
dir_offset (int, optional): Direction offset. Default to 0.
dir_limit_offset (float, optional): Offset to set the direction
range. Default to 0.0.
num_bins (int, optional): Number of bins to divide 2*PI.
Default to 2.
one_hot (bool, optional): Whether to encode as one hot.
Default to True.
Returns:
torch.Tensor: Encoded direction targets.
Expand Down
1 change: 0 additions & 1 deletion mmdet3d/models/detectors/mvx_two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __init__(self,
'key, please consider using init_cfg.')
self.img_roi_head.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)

if self.with_pts_backbone:
if pts_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated '
Expand Down
5 changes: 4 additions & 1 deletion mmdet3d/models/necks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from mmdet.models.necks.fpn import FPN
from .dla_neck import DLANeck
from .imvoxel_neck import OutdoorImVoxelNeck
from .pointnet2_fp_neck import PointNetFPNeck
from .second_fpn import SECONDFPN

__all__ = ['FPN', 'SECONDFPN', 'OutdoorImVoxelNeck', 'DLANeck']
__all__ = [
'FPN', 'SECONDFPN', 'OutdoorImVoxelNeck', 'PointNetFPNeck', 'DLANeck'
]
88 changes: 88 additions & 0 deletions mmdet3d/models/necks/pointnet2_fp_neck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from mmcv.runner import BaseModule
from torch import nn as nn

from mmdet3d.ops import PointFPModule
from mmdet.models import NECKS


@NECKS.register_module()
class PointNetFPNeck(BaseModule):
r"""PointNet FP Module used in PointRCNN.
Refer to the `official code <https://github.com/charlesq34/pointnet2>`_.
.. code-block:: none
sa_n ----------------------------------------
|
... --------------------------------- |
| |
sa_1 ------------- | |
| | |
sa_0 -> fp_0 -> fp_module ->fp_1 -> ... -> fp_module -> fp_n
sa_n including sa_xyz (torch.Tensor) and sa_features (torch.Tensor)
fp_n including fp_xyz (torch.Tensor) and fp_features (torch.Tensor)
Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""

def __init__(self, fp_channels, init_cfg=None):
super(PointNetFPNeck, self).__init__(init_cfg=init_cfg)

self.num_fp = len(fp_channels)
self.FP_modules = nn.ModuleList()
for cur_fp_mlps in fp_channels:
self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))

def _extract_input(self, feat_dict):
"""Extract inputs from features dictionary.
Args:
feat_dict (dict): Feature dict from backbone, which may contain
the following keys and values:
- sa_xyz (list[torch.Tensor]): Points of each sa module
in shape (N, 3).
- sa_features (list[torch.Tensor]): Output features of
each sa module in shape (N, M).
Returns:
list[torch.Tensor]: Coordinates of multiple levels of points.
list[torch.Tensor]: Features of multiple levels of points.
"""
sa_xyz = feat_dict['sa_xyz']
sa_features = feat_dict['sa_features']
assert len(sa_xyz) == len(sa_features)

return sa_xyz, sa_features

def forward(self, feat_dict):
"""Forward pass.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
dict[str, torch.Tensor]: Outputs of the Neck.
- fp_xyz (torch.Tensor): The coordinates of fp features.
- fp_features (torch.Tensor): The features from the last
feature propogation layers.
"""
sa_xyz, sa_features = self._extract_input(feat_dict)

fp_feature = sa_features[-1]
fp_xyz = sa_xyz[-1]

for i in range(self.num_fp):
# consume the points in a bottom-up manner
fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)],
sa_features[-(i + 2)], fp_feature)
fp_xyz = sa_xyz[-(i + 2)]

ret = dict(fp_xyz=fp_xyz, fp_features=fp_feature)
return ret
19 changes: 14 additions & 5 deletions mmdet3d/ops/pointnet_modules/point_sa_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(self,
self.num_point = [num_point]
elif isinstance(num_point, list) or isinstance(num_point, tuple):
self.num_point = num_point
elif num_point is None:
self.num_point = None
else:
raise NotImplementedError('Error type of num_point!')

Expand All @@ -78,8 +80,12 @@ def __init__(self,
self.fps_mod_list = fps_mod
self.fps_sample_range_list = fps_sample_range_list

self.points_sampler = Points_Sampler(self.num_point, self.fps_mod_list,
self.fps_sample_range_list)
if self.num_point is not None:
self.points_sampler = Points_Sampler(self.num_point,
self.fps_mod_list,
self.fps_sample_range_list)
else:
self.points_sampler = None

for i in range(len(radii)):
radius = radii[i]
Expand Down Expand Up @@ -126,9 +132,12 @@ def _sample_points(self, points_xyz, features, indices, target_xyz):
elif target_xyz is not None:
new_xyz = target_xyz.contiguous()
else:
indices = self.points_sampler(points_xyz, features)
new_xyz = gather_points(xyz_flipped, indices).transpose(
1, 2).contiguous() if self.num_point is not None else None
if self.num_point is not None:
indices = self.points_sampler(points_xyz, features)
new_xyz = gather_points(xyz_flipped,
indices).transpose(1, 2).contiguous()
else:
new_xyz = None

return new_xyz, indices

Expand Down
14 changes: 14 additions & 0 deletions tests/test_models/test_common_modules/test_pointnet_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ def test_pointnet_sa_module_msg():
assert new_features.shape == torch.Size([1, 48, 20])
assert inds.shape == torch.Size([1, 20])

# test num_points = None
self = PointSAModuleMSG(
num_point=None,
radii=[0.2, 0.4],
sample_nums=[4, 8],
mlp_channels=[[12, 16], [12, 32]],
norm_cfg=dict(type='BN2d'),
use_xyz=False,
pool_mod='max').cuda()

# test forward
new_xyz, new_features, inds = self(xyz, features)
assert new_features.shape == torch.Size([1, 48, 1])

# length of 'fps_mod' should be same as 'fps_sample_range_list'
with pytest.raises(AssertionError):
PointSAModuleMSG(
Expand Down
33 changes: 33 additions & 0 deletions tests/test_models/test_necks/test_necks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,39 @@ def test_imvoxel_neck():
assert outputs[0].shape == (1, 256, 248, 216)


def test_fp_neck():
if not torch.cuda.is_available():
pytest.skip()

xyzs = [16384, 4096, 1024, 256, 64]
feat_channels = [1, 96, 256, 512, 1024]
channel_num = 5

sa_xyz = [torch.rand(3, xyzs[i], 3) for i in range(channel_num)]
sa_features = [
torch.rand(3, feat_channels[i], xyzs[i]) for i in range(channel_num)
]

neck_cfg = dict(
type='PointNetFPNeck',
fp_channels=((1536, 512, 512), (768, 512, 512), (608, 256, 256),
(257, 128, 128)))

neck = build_neck(neck_cfg)
neck.init_weights()

if torch.cuda.is_available():
sa_xyz = [x.cuda() for x in sa_xyz]
sa_features = [x.cuda() for x in sa_features]
neck.cuda()

feats_sa = {'sa_xyz': sa_xyz, 'sa_features': sa_features}
outputs = neck(feats_sa)
assert outputs['fp_xyz'].cpu().numpy().shape == (3, 16384, 3)
assert outputs['fp_features'].detach().cpu().numpy().shape == (3, 128,
16384)


def test_dla_neck():

s = 32
Expand Down

0 comments on commit b8856a1

Please sign in to comment.