diff --git a/configs/_base_/models/centerpoint_01voxel_second_secfpn_nus.py b/configs/_base_/models/centerpoint_01voxel_second_secfpn_nus.py new file mode 100644 index 0000000000..b797c5158f --- /dev/null +++ b/configs/_base_/models/centerpoint_01voxel_second_secfpn_nus.py @@ -0,0 +1,84 @@ +voxel_size = [0.1, 0.1, 0.2] +model = dict( + type='CenterPoint', + pts_voxel_layer=dict( + max_num_points=10, voxel_size=voxel_size, max_voxels=(90000, 120000)), + pts_voxel_encoder=dict(type='HardSimpleVFE', num_features=5), + pts_middle_encoder=dict( + type='SparseEncoder', + in_channels=5, + sparse_shape=[41, 1024, 1024], + output_channels=128, + order=('conv', 'norm', 'act'), + encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128, + 128)), + encoder_paddings=((0, 0, 1), (0, 0, 1), (0, 0, [0, 1, 1]), (0, 0)), + block_type='basicblock'), + pts_backbone=dict( + type='SECOND', + in_channels=256, + out_channels=[128, 256], + layer_nums=[5, 5], + layer_strides=[1, 2], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + conv_cfg=dict(type='Conv2d', bias=False)), + pts_neck=dict( + type='SECONDFPN', + in_channels=[128, 256], + out_channels=[256, 256], + upsample_strides=[1, 2], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + upsample_cfg=dict(type='deconv', bias=False), + use_conv_for_no_stride=True), + pts_bbox_head=dict( + type='CenterHead', + in_channels=sum([256, 256]), + tasks=[ + dict(num_class=1, class_names=['car']), + dict(num_class=2, class_names=['truck', 'construction_vehicle']), + dict(num_class=2, class_names=['bus', 'trailer']), + dict(num_class=1, class_names=['barrier']), + dict(num_class=2, class_names=['motorcycle', 'bicycle']), + dict(num_class=2, class_names=['pedestrian', 'traffic_cone']), + ], + common_heads=dict( + reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)), + share_conv_channel=64, + bbox_coder=dict( + type='CenterPointBBoxCoder', + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_num=500, + score_threshold=0.1, + out_size_factor=8, + voxel_size=voxel_size[:2], + code_size=9), + seperate_head=dict( + type='SeparateHead', init_bias=-2.19, final_kernel=3), + loss_cls=dict(type='GaussianFocalLoss', reduction='mean'), + loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=0.25), + norm_bbox=True)) +# model training and testing settings +train_cfg = dict( + pts=dict( + grid_size=[1024, 1024, 40], + voxel_size=voxel_size, + out_size_factor=8, + dense_reg=1, + gaussian_overlap=0.1, + max_objs=500, + min_radius=2, + code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2])) +test_cfg = dict( + pts=dict( + post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_per_img=500, + max_pool_nms=False, + min_radius=[4, 12, 10, 1, 0.85, 0.175], + post_max_size=83, + score_threshold=0.1, + out_size_factor=8, + voxel_size=voxel_size[:2], + nms_type='rotate', + nms_pre_max_size=1000, + nms_post_max_size=83, + nms_iou_threshold=0.2)) diff --git a/configs/_base_/models/centerpoint_02pillar_second_secfpn_nus.py b/configs/_base_/models/centerpoint_02pillar_second_secfpn_nus.py new file mode 100644 index 0000000000..451b209365 --- /dev/null +++ b/configs/_base_/models/centerpoint_02pillar_second_secfpn_nus.py @@ -0,0 +1,84 @@ +voxel_size = [0.2, 0.2, 8] +model = dict( + type='CenterPoint', + pts_voxel_layer=dict( + max_num_points=20, voxel_size=voxel_size, max_voxels=(30000, 40000)), + pts_voxel_encoder=dict( + type='PillarFeatureNet', + in_channels=5, + feat_channels=[64], + with_distance=False, + voxel_size=(0.2, 0.2, 8), + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + legacy=False), + pts_middle_encoder=dict( + type='PointPillarsScatter', in_channels=64, output_shape=(512, 512)), + pts_backbone=dict( + type='SECOND', + in_channels=64, + out_channels=[64, 128, 256], + layer_nums=[3, 5, 5], + layer_strides=[2, 2, 2], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + conv_cfg=dict(type='Conv2d', bias=False)), + pts_neck=dict( + type='SECONDFPN', + in_channels=[64, 128, 256], + out_channels=[128, 128, 128], + upsample_strides=[0.5, 1, 2], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + upsample_cfg=dict(type='deconv', bias=False), + use_conv_for_no_stride=True), + pts_bbox_head=dict( + type='CenterHead', + in_channels=sum([128, 128, 128]), + tasks=[ + dict(num_class=1, class_names=['car']), + dict(num_class=2, class_names=['truck', 'construction_vehicle']), + dict(num_class=2, class_names=['bus', 'trailer']), + dict(num_class=1, class_names=['barrier']), + dict(num_class=2, class_names=['motorcycle', 'bicycle']), + dict(num_class=2, class_names=['pedestrian', 'traffic_cone']), + ], + common_heads=dict( + reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)), + share_conv_channel=64, + bbox_coder=dict( + type='CenterPointBBoxCoder', + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_num=500, + score_threshold=0.1, + out_size_factor=4, + voxel_size=voxel_size[:2], + code_size=9), + seperate_head=dict( + type='SeparateHead', init_bias=-2.19, final_kernel=3), + loss_cls=dict(type='GaussianFocalLoss', reduction='mean'), + loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=0.25), + norm_bbox=True)) +# model training and testing settings +train_cfg = dict( + pts=dict( + grid_size=[512, 512, 1], + voxel_size=voxel_size, + out_size_factor=4, + dense_reg=1, + gaussian_overlap=0.1, + max_objs=500, + min_radius=2, + code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2])) +test_cfg = dict( + pts=dict( + post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_per_img=500, + max_pool_nms=False, + min_radius=[4, 12, 10, 1, 0.85, 0.175], + post_max_size=83, + score_threshold=0.1, + pc_range=[-51.2, -51.2], + out_size_factor=4, + voxel_size=voxel_size[:2], + nms_type='rotate', + nms_pre_max_size=1000, + nms_post_max_size=83, + nms_iou_threshold=0.2)) diff --git a/configs/_base_/schedules/cyclic_20e.py b/configs/_base_/schedules/cyclic_20e.py new file mode 100644 index 0000000000..c7df532525 --- /dev/null +++ b/configs/_base_/schedules/cyclic_20e.py @@ -0,0 +1,24 @@ +# For nuScenes dataset, we usually evaluate the model at the end of training. +# Since the models are trained by 24 epochs by default, we set evaluation +# interval to be 20. Please change the interval accordingly if you do not +# use a default schedule. +# optimizer +# This schedule is mainly used by models on nuScenes dataset +optimizer = dict(type='AdamW', lr=1e-4, weight_decay=0.01) +# max_norm=10 is better for SECOND +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +lr_config = dict( + policy='cyclic', + target_ratio=(10, 1e-4), + cyclic_times=1, + step_ratio_up=0.4, +) +momentum_config = dict( + policy='cyclic', + target_ratio=(0.85 / 0.95, 1), + cyclic_times=1, + step_ratio_up=0.4, +) + +# runtime settings +total_epochs = 20 diff --git a/configs/centerpoint/README.md b/configs/centerpoint/README.md new file mode 100644 index 0000000000..56a1e06951 --- /dev/null +++ b/configs/centerpoint/README.md @@ -0,0 +1,53 @@ +# Center-based 3D Object Detection and Tracking + +## Introduction + +We implement CenterPoint and provide the result and checkpoints on nuScenes dataset. + +We follow the below style to name config files. Contributors are advised to follow the same style. +`{xxx}` is required field and `[yyy]` is optional. + +`{model}`: model type like `centerpoint`. + +`{model setting}`: voxel size and voxel type like `01voxel`, `02pillar`. + +`{backbone}`: backbone type like `second`. + +`{neck}`: neck type like `secfpn`. + +`[dcn]`: Whether to use deformable convolution. + +`[circle]`: Whether to use circular nms. + +`[batch_per_gpu x gpu]`: GPUs and samples per GPU, 4x8 is used by default. + +`{schedule}`: training schedule, options are 1x, 2x, 20e, etc. 1x and 2x means 12 epochs and 24 epochs respectively. 20e is adopted in cascade models, which denotes 20 epochs. For 1x/2x, initial learning rate decays by a factor of 10 at the 8/16th and 11/22th epochs. For 20e, initial learning rate decays by a factor of 10 at the 16th and 19th epochs. + +`{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively. +``` +@article{yin2020center, + title={Center-based 3d object detection and tracking}, + author={Yin, Tianwei and Zhou, Xingyi and Kr{\"a}henb{\"u}hl, Philipp}, + journal={arXiv preprint arXiv:2006.11275}, + year={2020} +} +``` + +## Results + +### CenterPoint + +|Backbone| Voxel type (voxel size) |Dcn|Circular nms| Mem (GB) | Inf time (fps) | mAP |NDS| Download | +| :---------: |:-----: |:-----: | :------: | :------------: | :----: |:----: | :------: |:------: | +|[SECFPN](./centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py)|voxel (0.1)|✗|✗|||||| +|[SECFPN](./centerpoint_01voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py)|voxel (0.1)|✗|✓|||||| +|[SECFPN](./centerpoint_01voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py)|voxel (0.1)|✓|✗|||||| +|[SECFPN](./centerpoint_01voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py)|voxel (0.1)|✓|✓|||||| +|[SECFPN](./centerpoint_0075voxel_second_secfpn_4x8_cyclic_20e_nus.py)|voxel (0.075)|✗|✗|||||| +|[SECFPN](./centerpoint_0075voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py)|voxel (0.075)|✗|✓|||||| +|[SECFPN](./centerpoint_0075voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py)|voxel (0.075)|✓|✗|||||| +|[SECFPN](./centerpoint_0075voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py)|voxel (0.075)|✓|✓|||||| +|[SECFPN](./centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py)|pillar (0.2)|✗|✗|||||| +|[SECFPN](./centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus.py)|pillar (0.2)|✗|✓|||||| +|[SECFPN](./centerpoint_02pillar_second_secfpn_dcn_4x8_cyclic_20e_nus.py)|pillar (0.2)|✓|✗|||||| +|[SECFPN](./centerpoint_02pillar_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py)|pillar (0.2)|✓|✓|||||| diff --git a/configs/centerpoint/centerpoint_0075voxel_second_secfpn_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_0075voxel_second_secfpn_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..2dc58e0560 --- /dev/null +++ b/configs/centerpoint/centerpoint_0075voxel_second_secfpn_4x8_cyclic_20e_nus.py @@ -0,0 +1,139 @@ +_base_ = ['./centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py'] + +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +voxel_size = [0.075, 0.075, 0.2] +point_cloud_range = [-54, -54, -5.0, 54, 54, 3.0] +# For nuScenes we usually do 10-class detection +class_names = [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', + 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' +] + +model = dict( + pts_voxel_layer=dict( + voxel_size=voxel_size, point_cloud_range=point_cloud_range), + pts_middle_encoder=dict(sparse_shape=[41, 1440, 1440]), + pts_bbox_head=dict( + bbox_coder=dict( + voxel_size=voxel_size[:2], pc_range=point_cloud_range[:2]))) + +train_cfg = dict( + pts=dict( + grid_size=[1440, 1440, 40], + voxel_size=voxel_size, + point_cloud_range=point_cloud_range)) + +test_cfg = dict( + pts=dict(voxel_size=voxel_size[:2], pc_range=point_cloud_range[:2])) + +dataset_type = 'NuScenesDataset' +data_root = 'data/nuscenes/' +file_client_args = dict(backend='disk') + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'nuscenes_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict( + car=5, + truck=5, + bus=5, + trailer=5, + construction_vehicle=5, + traffic_cone=5, + barrier=5, + motorcycle=5, + bicycle=5, + pedestrian=5)), + classes=class_names, + sample_groups=dict( + car=2, + truck=3, + construction_vehicle=7, + bus=4, + trailer=6, + barrier=2, + motorcycle=6, + bicycle=6, + pedestrian=2, + traffic_cone=2), + points_loader=dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args)) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.3925, 0.3925], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + train=dict(dataset=dict(pipeline=train_pipeline)), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/centerpoint/centerpoint_0075voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_0075voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..140d6bc857 --- /dev/null +++ b/configs/centerpoint/centerpoint_0075voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py @@ -0,0 +1,3 @@ +_base_ = ['./centerpoint_0075voxel_second_secfpn_4x8_cyclic_20e_nus.py'] + +test_cfg = dict(pts=dict(nms_type='circle')) diff --git a/configs/centerpoint/centerpoint_0075voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_0075voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..84b32c6908 --- /dev/null +++ b/configs/centerpoint/centerpoint_0075voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py @@ -0,0 +1,16 @@ +_base_ = ['./centerpoint_0075voxel_second_secfpn_4x8_cyclic_20e_nus.py'] + +model = dict( + pts_bbox_head=dict( + seperate_head=dict( + type='DCNSeperateHead', + dcn_config=dict( + type='DCN', + in_channels=64, + out_channels=64, + kernel_size=3, + padding=1, + groups=4, + bias=True), + init_bias=-2.19, + final_kernel=3))) diff --git a/configs/centerpoint/centerpoint_0075voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_0075voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..8937773d6d --- /dev/null +++ b/configs/centerpoint/centerpoint_0075voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py @@ -0,0 +1,18 @@ +_base_ = ['./centerpoint_0075voxel_second_secfpn_4x8_cyclic_20e_nus.py'] + +model = dict( + pts_bbox_head=dict( + seperate_head=dict( + type='DCNSeperateHead', + dcn_config=dict( + type='DCN', + in_channels=64, + out_channels=64, + kernel_size=3, + padding=1, + groups=4, + bias=True), + init_bias=-2.19, + final_kernel=3))) + +test_cfg = dict(pts=dict(nms_type='circle')) diff --git a/configs/centerpoint/centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..8a714f2630 --- /dev/null +++ b/configs/centerpoint/centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py @@ -0,0 +1,144 @@ +_base_ = [ + '../_base_/datasets/nus-3d.py', + '../_base_/models/centerpoint_01voxel_second_secfpn_nus.py', + '../_base_/schedules/cyclic_20e.py', '../_base_/default_runtime.py' +] + +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] +# For nuScenes we usually do 10-class detection +class_names = [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', + 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' +] + +model = dict( + pts_voxel_layer=dict(point_cloud_range=point_cloud_range), + pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2]))) +# model training and testing settings +train_cfg = dict(pts=dict(point_cloud_range=point_cloud_range)) +test_cfg = dict(pts=dict(pc_range=point_cloud_range[:2])) + +dataset_type = 'NuScenesDataset' +data_root = 'data/nuscenes/' +file_client_args = dict(backend='disk') + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'nuscenes_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict( + car=5, + truck=5, + bus=5, + trailer=5, + construction_vehicle=5, + traffic_cone=5, + barrier=5, + motorcycle=5, + bicycle=5, + pedestrian=5)), + classes=class_names, + sample_groups=dict( + car=2, + truck=3, + construction_vehicle=7, + bus=4, + trailer=6, + barrier=2, + motorcycle=6, + bicycle=6, + pedestrian=2, + traffic_cone=2), + points_loader=dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args)) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.3925, 0.3925], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + train=dict( + type='CBGSDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'nuscenes_infos_train.pkl', + pipeline=train_pipeline, + classes=class_names, + test_mode=False, + use_valid_flag=True, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR')), + val=dict(pipeline=test_pipeline, classes=class_names), + test=dict(pipeline=test_pipeline, classes=class_names)) diff --git a/configs/centerpoint/centerpoint_01voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_01voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..240f948837 --- /dev/null +++ b/configs/centerpoint/centerpoint_01voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py @@ -0,0 +1,3 @@ +_base_ = ['./centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py'] + +test_cfg = dict(pts=dict(nms_type='circle')) diff --git a/configs/centerpoint/centerpoint_01voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_01voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..e08aad24b6 --- /dev/null +++ b/configs/centerpoint/centerpoint_01voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py @@ -0,0 +1,16 @@ +_base_ = ['./centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py'] + +model = dict( + pts_bbox_head=dict( + seperate_head=dict( + type='DCNSeperateHead', + dcn_config=dict( + type='DCN', + in_channels=64, + out_channels=64, + kernel_size=3, + padding=1, + groups=4, + bias=True), + init_bias=-2.19, + final_kernel=3))) diff --git a/configs/centerpoint/centerpoint_01voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_01voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..295f545142 --- /dev/null +++ b/configs/centerpoint/centerpoint_01voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py @@ -0,0 +1,18 @@ +_base_ = ['./centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py'] + +model = dict( + pts_bbox_head=dict( + seperate_head=dict( + type='DCNSeperateHead', + dcn_config=dict( + type='DCN', + in_channels=64, + out_channels=64, + kernel_size=3, + padding=1, + groups=4, + bias=True), + init_bias=-2.19, + final_kernel=3))) + +test_cfg = dict(pts=dict(nms_type='circle')) diff --git a/configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..e71210906a --- /dev/null +++ b/configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py @@ -0,0 +1,143 @@ +_base_ = [ + '../_base_/datasets/nus-3d.py', + '../_base_/models/centerpoint_02pillar_second_secfpn_nus.py', + '../_base_/schedules/cyclic_20e.py', '../_base_/default_runtime.py' +] + +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] +# For nuScenes we usually do 10-class detection +class_names = [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', + 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' +] + +model = dict( + pts_voxel_layer=dict(point_cloud_range=point_cloud_range), + pts_voxel_encoder=dict(point_cloud_range=point_cloud_range), + pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2]))) +# model training and testing settings +train_cfg = dict(pts=dict(point_cloud_range=point_cloud_range)) +test_cfg = dict(pts=dict(pc_range=point_cloud_range[:2])) + +dataset_type = 'NuScenesDataset' +data_root = 'data/nuscenes/' +file_client_args = dict(backend='disk') + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'nuscenes_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict( + car=5, + truck=5, + bus=5, + trailer=5, + construction_vehicle=5, + traffic_cone=5, + barrier=5, + motorcycle=5, + bicycle=5, + pedestrian=5)), + classes=class_names, + sample_groups=dict( + car=2, + truck=3, + construction_vehicle=7, + bus=4, + trailer=6, + barrier=2, + motorcycle=6, + bicycle=6, + pedestrian=2, + traffic_cone=2), + points_loader=dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args)) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.3925, 0.3925], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + train=dict( + type='CBGSDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'nuscenes_infos_train.pkl', + pipeline=train_pipeline, + classes=class_names, + test_mode=False, + use_valid_flag=True, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR')), + val=dict(pipeline=test_pipeline, classes=class_names), + test=dict(pipeline=test_pipeline, classes=class_names)) diff --git a/configs/centerpoint/centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..9ae815c31d --- /dev/null +++ b/configs/centerpoint/centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus.py @@ -0,0 +1,3 @@ +_base_ = ['./centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py'] + +test_cfg = dict(pts=dict(nms_type='circle')) diff --git a/configs/centerpoint/centerpoint_02pillar_second_secfpn_dcn_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_02pillar_second_secfpn_dcn_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..d98223b6da --- /dev/null +++ b/configs/centerpoint/centerpoint_02pillar_second_secfpn_dcn_4x8_cyclic_20e_nus.py @@ -0,0 +1,16 @@ +_base_ = ['./centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py'] + +model = dict( + pts_bbox_head=dict( + seperate_head=dict( + type='DCNSeperateHead', + dcn_config=dict( + type='DCN', + in_channels=64, + out_channels=64, + kernel_size=3, + padding=1, + groups=4, + bias=True), + init_bias=-2.19, + final_kernel=3))) diff --git a/configs/centerpoint/centerpoint_02pillar_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py b/configs/centerpoint/centerpoint_02pillar_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py new file mode 100644 index 0000000000..5ea666a7cc --- /dev/null +++ b/configs/centerpoint/centerpoint_02pillar_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py @@ -0,0 +1,18 @@ +_base_ = ['./centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py'] + +model = dict( + pts_bbox_head=dict( + seperate_head=dict( + type='DCNSeperateHead', + dcn_config=dict( + type='DCN', + in_channels=64, + out_channels=64, + kernel_size=3, + padding=1, + groups=4, + bias=True), + init_bias=-2.19, + final_kernel=3))) + +test_cfg = dict(pts=dict(nms_type='circle')) diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index a6b91ec7d6..2cf9407677 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -1,5 +1,6 @@ from .anchor3d_head import Anchor3DHead from .base_conv_bbox_head import BaseConvBboxHead +from .centerpoint_head import CenterHead from .free_anchor3d_head import FreeAnchor3DHead from .parta2_rpn_head import PartA2RPNHead from .ssd_3d_head import SSD3DHead @@ -7,5 +8,5 @@ __all__ = [ 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead', - 'SSD3DHead', 'BaseConvBboxHead' + 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/multi_group_head.py b/mmdet3d/models/dense_heads/centerpoint_head.py similarity index 99% rename from mmdet3d/models/roi_heads/bbox_heads/multi_group_head.py rename to mmdet3d/models/dense_heads/centerpoint_head.py index cc2ce05914..33eb80f7da 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/multi_group_head.py +++ b/mmdet3d/models/dense_heads/centerpoint_head.py @@ -6,11 +6,11 @@ from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius, xywhr2xyxyr) +from mmdet3d.models import builder +from mmdet3d.models.builder import HEADS, build_loss from mmdet3d.models.utils import clip_sigmoid from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu from mmdet.core import build_bbox_coder, multi_apply -from ... import builder -from ...builder import HEADS, build_loss @HEADS.register_module() diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index 63d3645adc..1ee43a9a20 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -1,4 +1,5 @@ from .base import Base3DDetector +from .centerpoint import CenterPoint from .dynamic_voxelnet import DynamicVoxelNet from .h3dnet import H3DNet from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN @@ -11,5 +12,5 @@ __all__ = [ 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector', 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet', - 'SSD3DNet' + 'CenterPoint', 'SSD3DNet' ] diff --git a/mmdet3d/models/detectors/centerpoint.py b/mmdet3d/models/detectors/centerpoint.py new file mode 100644 index 0000000000..f43c8968ed --- /dev/null +++ b/mmdet3d/models/detectors/centerpoint.py @@ -0,0 +1,81 @@ +from mmdet3d.core import bbox3d2result +from mmdet.models import DETECTORS +from .mvx_two_stage import MVXTwoStageDetector + + +@DETECTORS.register_module() +class CenterPoint(MVXTwoStageDetector): + """Base class of Multi-modality VoxelNet.""" + + def __init__(self, + pts_voxel_layer=None, + pts_voxel_encoder=None, + pts_middle_encoder=None, + pts_fusion_layer=None, + img_backbone=None, + pts_backbone=None, + img_neck=None, + pts_neck=None, + pts_bbox_head=None, + img_roi_head=None, + img_rpn_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(CenterPoint, + self).__init__(pts_voxel_layer, pts_voxel_encoder, + pts_middle_encoder, pts_fusion_layer, + img_backbone, pts_backbone, img_neck, pts_neck, + pts_bbox_head, img_roi_head, img_rpn_head, + train_cfg, test_cfg, pretrained) + + def extract_pts_feat(self, pts, img_feats, img_metas): + """Extract features of points.""" + if not self.with_pts_bbox: + return None + voxels, num_points, coors = self.voxelize(pts) + + voxel_features = self.pts_voxel_encoder(voxels, num_points, coors) + batch_size = coors[-1, 0] + 1 + x = self.pts_middle_encoder(voxel_features, coors, batch_size) + x = self.pts_backbone(x) + if self.with_pts_neck: + x = self.pts_neck(x) + return x + + def forward_pts_train(self, + pts_feats, + gt_bboxes_3d, + gt_labels_3d, + img_metas, + gt_bboxes_ignore=None): + """Forward function for point cloud branch. + + Args: + pts_feats (list[torch.Tensor]): Features of point cloud branch + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth + boxes for each sample. + gt_labels_3d (list[torch.Tensor]): Ground truth labels for + boxes of each sampole + img_metas (list[dict]): Meta information of samples. + gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth + boxes to be ignored. Defaults to None. + + Returns: + dict: Losses of each branch. + """ + outs = self.pts_bbox_head(pts_feats) + loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs] + losses = self.pts_bbox_head.loss(*loss_inputs) + return losses + + def simple_test_pts(self, x, img_metas, rescale=False): + """Test function of point cloud branch.""" + outs = self.pts_bbox_head(x) + bbox_list = self.pts_bbox_head.get_bboxes( + outs, img_metas, rescale=rescale) + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results diff --git a/mmdet3d/models/roi_heads/bbox_heads/__init__.py b/mmdet3d/models/roi_heads/bbox_heads/__init__.py index 38a9d4824f..1005cf614c 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/__init__.py +++ b/mmdet3d/models/roi_heads/bbox_heads/__init__.py @@ -3,11 +3,10 @@ Shared2FCBBoxHead, Shared4Conv1FCBBoxHead) from .h3d_bbox_head import H3DBboxHead -from .multi_group_head import CenterHead from .parta2_bbox_head import PartA2BboxHead __all__ = [ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead', - 'H3DBboxHead', 'CenterHead' + 'H3DBboxHead' ]