diff --git a/README.md b/README.md index c6cdb44..88403ad 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection **News**: - * :fire: October, 2021. Our paper is accepted at [WACV 2022](https://wacv2022.thecvf.com). Stay tuned for 5 times faster and more accurate models for indoor datasets. + * :fire: October, 2021. Our paper is accepted at [WACV 2022](https://wacv2022.thecvf.com). We simplify 3d neck to make indoor models much faster and accurate. For example, this improves `ScanNet` `mAP` by more than 2%. Please find updated configs in [configs/imvoxelnet/*_fast.py](https://github.com/saic-vul/imvoxelnet/tree/master/configs/imvoxelnet) and [models](https://github.com/saic-vul/imvoxelnet/releases/tag/v1.2). * :fire: August, 2021. We adapt center sampling for indoor detection. For example, this improves `ScanNet` `mAP` by more than 5%. Please find updated configs in [configs/imvoxelnet/*_top27.py](https://github.com/saic-vul/imvoxelnet/tree/master/configs/imvoxelnet) and [models](https://github.com/saic-vul/imvoxelnet/releases/tag/v1.1). * :fire: July, 2021. We update `ScanNet` image preprocessing both [here](https://github.com/saic-vul/imvoxelnet/pull/21) and in [mmdetection3d](https://github.com/open-mmlab/mmdetection3d/pull/696). * :fire: June, 2021. `ImVoxelNet` for `KITTI` is now [supported](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/imvoxelnet) in [mmdetection3d](https://github.com/open-mmlab/mmdetection3d). @@ -87,14 +87,16 @@ python tools/test.py configs/imvoxelnet/imvoxelnet_kitti.py \ ### Models -| Dataset | Object Classes | Center Sampling | Download | -|:---------:|:--------------:|:---------------:|:--------:| -| SUN RGB-D | 37 from Total3dUnderstanding | ✘
✔ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210525_091810.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210525_091810_atlas_total_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_005013.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_005013_imvoxelnet_total_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd_top27.py)| -| SUN RGB-D | 30 from PerspectiveNet | ✘
✔ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210526_072029.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210526_072029_atlas_perspective_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_114832.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_114832_imvoxelnet_perspective_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_top27.py)| -| SUN RGB-D | 10 from VoteNet | ✘
✔ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210428_124351.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210428_124351_atlas_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_112435.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_112435_imvoxelnet_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd_top27.py)| -| ScanNet | 18 from VoteNet | ✘
✔ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210520_223109.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210520_223109_atlas_scannet.log) | [config](configs/imvoxelnet/imvoxelnet_scannet.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_070616.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_070616_imvoxelnet_scannet_top27.log) | [config](configs/imvoxelnet/imvoxelnet_scannet_top27.py)| -| KITTI | Car | ✘ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210503_214214.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210503_214214_atlas_kitti.log) | [config](configs/imvoxelnet/imvoxelnet_kitti.py) | -| nuScenes | Car | ✘ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210505_131108.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210505_131108_atlas_nuscenes.log) | [config](configs/imvoxelnet/imvoxelnet_nuscenes.py) | +`v2` adds center sampling for indoor scenario. `v3` simplifies 3d neck for indoor scenario. Differences are discussed in [v2](https://arxiv.org/abs/2106.01178v2) and [v3](https://arxiv.org/abs/2106.01178v3) preprints. + +| Dataset | Object Classes | Version | Download | +|:---------:|:--------------:|:-------:|:--------:| +| SUN RGB-D | 37 from
Total3dUnderstanding | v1 | mAP@0.15: 41.5
v2 | mAP@0.15: 42.7
v3 | mAP@0.15: 43.7 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210525_091810.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210525_091810_atlas_total_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_005013.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_005013_imvoxelnet_total_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd_top27.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105247.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105247_imvoxelnet_total_sunrgbd_fast.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd_fast.py)| +| SUN RGB-D | 30 from
PerspectiveNet | v1 | mAP@0.15: 44.9
v2 | mAP@0.15: 47.2
v3 | mAP@0.15: 48.7 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210526_072029.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210526_072029_atlas_perspective_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_114832.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_114832_imvoxelnet_perspective_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_top27.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105254.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105254_imvoxelnet_perspective_sunrgbd_fast.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_fast.py)| +| SUN RGB-D | 10 from VoteNet | v1 | mAP@0.25: 38.8
v2 | mAP@0.25: 39.4
v3 | mAP@0.25: 40.7 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210428_124351.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210428_124351_atlas_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_112435.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_112435_imvoxelnet_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd_top27.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105255.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105255_imvoxelnet_sunrgbd_fast.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd_fast.py)| +| ScanNet | 18 from VoteNet | v1 | mAP@0.25: 40.6
v2 | mAP@0.25: 45.7
v3 | mAP@0.25: 48.1 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210520_223109.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210520_223109_atlas_scannet.log) | [config](configs/imvoxelnet/imvoxelnet_scannet.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_070616.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_070616_imvoxelnet_scannet_top27.log) | [config](configs/imvoxelnet/imvoxelnet_scannet_top27.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_113826.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_113826_imvoxelnet_scannet_fast.log) | [config](configs/imvoxelnet/imvoxelnet_scannet_fast.py)| +| KITTI | Car | v1 | AP@0.7: 17.8 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210503_214214.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210503_214214_atlas_kitti.log) | [config](configs/imvoxelnet/imvoxelnet_kitti.py) | +| nuScenes | Car | v1 | AP: 51.8 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210505_131108.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210505_131108_atlas_nuscenes.log) | [config](configs/imvoxelnet/imvoxelnet_nuscenes.py) | ### Example Detections diff --git a/configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_fast.py b/configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_fast.py new file mode 100644 index 0000000..775d8d0 --- /dev/null +++ b/configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_fast.py @@ -0,0 +1,127 @@ +model = dict( + type='ImVoxelNet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + neck_3d=dict( + type='FastIndoorImVoxelNeck', + in_channels=256, + out_channels=128, + n_blocks=[1, 1, 1]), + bbox_head=dict( + type='SunRgbdImVoxelHeadV2', + n_classes=30, + n_channels=128, + n_reg_outs=7, + n_scales=3, + limit=27, + centerness_topk=18), + n_voxels=(40, 40, 16), + voxel_size=(.16, .16, .16)) +train_cfg = dict() +test_cfg = dict( + nms_pre=1000, + nms_thr=.15, + use_rotate_nms=True, + score_thr=.01) +img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +dataset_type = 'SunRgbdPerspectiveMultiViewDataset' +data_root = 'data/sunrgbd/' +class_names = ('recycle_bin', 'cpu', 'paper', 'toilet', 'stool', 'whiteboard', 'coffee_table', 'picture', + 'keyboard', 'dresser', 'painting', 'bookshelf', 'night_stand', 'endtable', 'drawer', 'sink', + 'monitor', 'computer', 'cabinet', 'shelf', 'lamp', 'garbage_bin', 'box', 'bed', 'sofa', + 'sofa_chair', 'pillow', 'desk', 'table', 'chair') + +train_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=1, + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Resize', img_scale=[(512, 384), (768, 576)], multiscale_mode='range', keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32)]), + dict(type='SunRgbdRandomFlip'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])] +test_pipeline = [ + dict( + type='MultiViewPipeline', + n_images=1, + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='Resize', img_scale=(640, 480), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32)]), + dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False), + dict(type='Collect3D', keys=['img'])] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=2, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_perspective_infos_train.pkl', + pipeline=train_pipeline, + classes=class_names, + filter_empty_gt=True, + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_perspective_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_perspective_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) + +optimizer = dict( + type='AdamW', + lr=0.0001, + weight_decay=0.0001, + paramwise_cfg=dict( + custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) +optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2)) +lr_config = dict(policy='step', step=[8, 11]) +total_epochs = 12 + +checkpoint_config = dict(interval=1, max_keep_ckpts=1) +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +evaluation = dict(interval=1) +dist_params = dict(backend='nccl') +find_unused_parameters = True # todo: fix number of FPN outputs +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/imvoxelnet/imvoxelnet_scannet_fast.py b/configs/imvoxelnet/imvoxelnet_scannet_fast.py new file mode 100644 index 0000000..9c8cd3b --- /dev/null +++ b/configs/imvoxelnet/imvoxelnet_scannet_fast.py @@ -0,0 +1,131 @@ +model = dict( + type='ImVoxelNet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + neck_3d=dict( + type='FastIndoorImVoxelNeck', + in_channels=256, + out_channels=128, + n_blocks=[1, 1, 1]), + bbox_head=dict( + type='ScanNetImVoxelHeadV2', + loss_bbox=dict(type='AxisAlignedIoULoss', loss_weight=1.0), + n_classes=18, + n_channels=128, + n_reg_outs=6, + n_scales=3, + limit=27, + centerness_topk=18), + voxel_size=(.16, .16, .16), + n_voxels=(40, 40, 16)) +train_cfg = dict() +test_cfg = dict( + nms_pre=1000, + iou_thr=.25, + score_thr=.01) +img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +dataset_type = 'ScanNetMultiViewDataset' +data_root = 'data/scannet/' +class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', + 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', + 'garbagebin') + +train_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=20, + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='Resize', img_scale=(640, 480), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=(480, 640)) + ]), + dict(type='RandomShiftOrigin', std=(.7, .7, .0)), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='MultiViewPipeline', + n_images=50, + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='Resize', img_scale=(640, 480), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=(480, 640)) + ]), + dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False), + dict(type='Collect3D', keys=['img']) +] +data = dict( + samples_per_gpu=1, + workers_per_gpu=1, + train=dict( + type='RepeatDataset', + times=3, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_train.pkl', + pipeline=train_pipeline, + classes=class_names, + filter_empty_gt=True, + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth') +) + +optimizer = dict( + type='AdamW', + lr=0.0001, + weight_decay=0.0001, + paramwise_cfg=dict( + custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) +optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2)) +lr_config = dict(policy='step', step=[8, 11]) +total_epochs = 12 + +checkpoint_config = dict(interval=1, max_keep_ckpts=1) +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +evaluation = dict(interval=1) +dist_params = dict(backend='nccl') +find_unused_parameters = True # todo: fix number of FPN outputs +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] \ No newline at end of file diff --git a/configs/imvoxelnet/imvoxelnet_sunrgbd_fast.py b/configs/imvoxelnet/imvoxelnet_sunrgbd_fast.py new file mode 100644 index 0000000..f0d0eab --- /dev/null +++ b/configs/imvoxelnet/imvoxelnet_sunrgbd_fast.py @@ -0,0 +1,125 @@ +model = dict( + type='ImVoxelNet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + neck_3d=dict( + type='FastIndoorImVoxelNeck', + in_channels=256, + out_channels=128, + n_blocks=[1, 1, 1]), + bbox_head=dict( + type='SunRgbdImVoxelHeadV2', + n_classes=10, + n_channels=128, + n_reg_outs=7, + n_scales=3, + limit=27, + centerness_topk=18), + n_voxels=(40, 40, 16), + voxel_size=(.16, .16, .16)) +train_cfg = dict() +test_cfg = dict( + nms_pre=1000, + nms_thr=.15, + use_rotate_nms=True, + score_thr=.0) +img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +dataset_type = 'SunRgbdMultiViewDataset' +data_root = 'data/sunrgbd/' +class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', + 'night_stand', 'bookshelf', 'bathtub') + +train_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=1, + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Resize', img_scale=[(512, 384), (768, 576)], multiscale_mode='range', keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32)]), + dict(type='SunRgbdRandomFlip'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])] +test_pipeline = [ + dict( + type='MultiViewPipeline', + n_images=1, + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='Resize', img_scale=(640, 480), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32)]), + dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False), + dict(type='Collect3D', keys=['img'])] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=2, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_imvoxelnet_infos_train.pkl', + pipeline=train_pipeline, + classes=class_names, + filter_empty_gt=True, + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_imvoxelnet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_imvoxelnet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) + +optimizer = dict( + type='AdamW', + lr=0.0001, + weight_decay=0.0001, + paramwise_cfg=dict( + custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) +optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2)) +lr_config = dict(policy='step', step=[8, 11]) +total_epochs = 12 + +checkpoint_config = dict(interval=1, max_keep_ckpts=1) +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +evaluation = dict(interval=1) +dist_params = dict(backend='nccl') +find_unused_parameters = True # todo: fix number of FPN outputs +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/imvoxelnet/imvoxelnet_total_sunrgbd_fast.py b/configs/imvoxelnet/imvoxelnet_total_sunrgbd_fast.py new file mode 100644 index 0000000..4570c58 --- /dev/null +++ b/configs/imvoxelnet/imvoxelnet_total_sunrgbd_fast.py @@ -0,0 +1,134 @@ +model = dict( + type='ImVoxelNet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + head_2d=dict( + type='LayoutHead', + n_channels=2048, + linear_size=256, + dropout=.0, + loss_angle=dict(type='SmoothL1Loss', loss_weight=100.), + loss_layout=dict(type='IoU3DLoss', loss_weight=1.)), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + neck_3d=dict( + type='FastIndoorImVoxelNeck', + in_channels=256, + out_channels=128, + n_blocks=[1, 1, 1]), + bbox_head=dict( + type='SunRgbdImVoxelHeadV2', + n_classes=33, + n_channels=128, + n_reg_outs=7, + n_scales=3, + limit=27, + centerness_topk=18), + n_voxels=(40, 40, 16), + voxel_size=(.16, .16, .16)) +train_cfg = dict() +test_cfg = dict( + nms_pre=1000, + nms_thr=.15, + use_rotate_nms=True, + score_thr=.0) +img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +dataset_type = 'SunRgbdTotalMultiViewDataset' +data_root = 'data/sunrgbd/' +class_names = [ + 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', + 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'clothes', 'books', + 'fridge', 'tv', 'paper', 'towel', 'shower_curtain', 'box', 'whiteboard', 'person', 'night_stand', 'toilet', + 'sink', 'lamp', 'bathtub', 'bag' +] + +train_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=1, + transforms=[ + dict(type='SunRgbdTotalLoadImageFromFile'), + dict(type='Resize', img_scale=[(512, 384), (768, 576)], multiscale_mode='range', keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32)]), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])] +test_pipeline = [ + dict( + type='MultiViewPipeline', + n_images=1, + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='Resize', img_scale=(640, 480), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32)]), + dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False), + dict(type='Collect3D', keys=['img'])] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_total_infos_train.pkl', + pipeline=train_pipeline, + classes=class_names, + filter_empty_gt=True, + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_total_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'sunrgbd_total_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) + +optimizer = dict( + type='AdamW', + lr=0.0001, + weight_decay=0.0001, + paramwise_cfg=dict( + custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) +optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2)) +lr_config = dict(policy='step', step=[8, 11]) +total_epochs = 12 + +checkpoint_config = dict(interval=1, max_keep_ckpts=1) +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +evaluation = dict(interval=1) +dist_params = dict(backend='nccl') +find_unused_parameters = True # todo: fix number of FPN outputs +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index ecea98d..ba0075e 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -7,6 +7,7 @@ from .ssd_3d_head import SSD3DHead from .vote_head import VoteHead from .imvoxel_head import SunRgbdImVoxelHead, ScanNetImVoxelHead +from .imvoxel_head_v2 import SunRgbdImVoxelHeadV2, ScanNetImVoxelHeadV2 from .layout_head import LayoutHead __all__ = [ diff --git a/mmdet3d/models/dense_heads/imvoxel_head_v2.py b/mmdet3d/models/dense_heads/imvoxel_head_v2.py index e69de29..2f4d58d 100644 --- a/mmdet3d/models/dense_heads/imvoxel_head_v2.py +++ b/mmdet3d/models/dense_heads/imvoxel_head_v2.py @@ -0,0 +1,566 @@ +import torch +from torch import nn +from mmdet.core import multi_apply, reduce_mean +from mmdet.models.builder import HEADS, build_loss +from mmcv.cnn import Scale, bias_init_with_prob, normal_init + +from mmdet3d.models.detectors.imvoxelnet import get_points +from mmdet3d.core.bbox.structures import rotation_3d_in_axis +from mmdet3d.core.post_processing import aligned_3d_nms, box3d_multiclass_nms + + +class ImVoxelHeadV2(nn.Module): + def __init__(self, + n_classes, + n_channels, + n_reg_outs, + n_scales, + limit, + centerness_topk=-1, + loss_centerness=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_bbox=dict(type='IoU3DLoss', loss_weight=1.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + train_cfg=None, + test_cfg=None): + super(ImVoxelHeadV2, self).__init__() + self.n_classes = n_classes + self.n_scales = n_scales + self.limit = limit + self.centerness_topk = centerness_topk + self.loss_centerness = build_loss(loss_centerness) + self.loss_bbox = build_loss(loss_bbox) + self.loss_cls = build_loss(loss_cls) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers(n_channels, n_reg_outs) + + def _init_layers(self, n_channels, n_reg_outs): + self.centerness_conv = nn.Conv3d(n_channels, 1, 3, padding=1, bias=False) + self.reg_conv = nn.Conv3d(n_channels, n_reg_outs, 3, padding=1, bias=False) + self.cls_conv = nn.Conv3d(n_channels, self.n_classes, 3, padding=1) + self.scales = nn.ModuleList([Scale(1.) for _ in range(self.n_scales)]) + + # Follow AnchorFreeHead.init_weights + def init_weights(self): + normal_init(self.centerness_conv, std=.01) + normal_init(self.reg_conv, std=.01) + normal_init(self.cls_conv, std=.01, bias=bias_init_with_prob(.01)) + + def forward(self, x): + return multi_apply(self.forward_single, x, self.scales) + + def forward_train(self, x, valid, img_metas, gt_bboxes, gt_labels): + loss_inputs = self(x) + (valid, img_metas, gt_bboxes, gt_labels) + losses = self.loss(*loss_inputs) + return losses + + def loss(self, + centernesses, + bbox_preds, + cls_scores, + valid, + img_metas, + gt_bboxes, + gt_labels): + """ + Args: + centernesses (list(Tensor)): Multi-level centernesses + of shape (batch, 1, nx[i], ny[i], nz[i]) + bbox_preds (list(Tensor)): Multi-level xyz min and max distances + of shape (batch, 6, nx[i], ny[i], nz[i]) + cls_scores (list(Tensor)): Multi-level class scores + of shape (batch, n_classes, nx[i], ny[i], nz[i]) + img_metas (list[dict]): Meta information of each image + gt_bboxes (list(BaseInstance3DBoxes)): Ground truth bboxes for each image + gt_labels (list(Tensor)): Ground truth class labels for each image + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(centernesses[0]) == len(bbox_preds[0]) == len(cls_scores[0]) == \ + len(valid) == len(img_metas) == len(gt_bboxes) == len(gt_labels) + + valids = [] + for x in centernesses: + valids.append(nn.Upsample(size=x.shape[-3:], mode='trilinear')(valid).round().bool()) + + loss_centerness, loss_bbox, loss_cls = [], [], [] + for i in range(len(img_metas)): + img_loss_centerness, img_loss_bbox, img_loss_cls = self._loss_single( + centernesses=[x[i] for x in centernesses], + bbox_preds=[x[i] for x in bbox_preds], + cls_scores=[x[i] for x in cls_scores], + valids=[x[i] for x in valids], + img_meta=img_metas[i], + gt_bboxes=gt_bboxes[i], + gt_labels=gt_labels[i] + ) + loss_centerness.append(img_loss_centerness) + loss_bbox.append(img_loss_bbox) + loss_cls.append(img_loss_cls) + return dict( + loss_centerness=torch.mean(torch.stack(loss_centerness)), + loss_bbox=torch.mean(torch.stack(loss_bbox)), + loss_cls=torch.mean(torch.stack(loss_cls)) + ) + + def _loss_single(self, + centernesses, + bbox_preds, + cls_scores, + valids, + img_meta, + gt_bboxes, + gt_labels): + """ + Args: + centernesses (list(Tensor)): Multi-level centernesses + of shape (1, nx[i], ny[i], nz[i]) + bbox_preds (list(Tensor)): Multi-level xyz min and max distances + of shape (6, nx[i], ny[i], nz[i]) + cls_scores (list(Tensor)): Multi-level class scores + of shape (n_classes, nx[i], ny[i], nz[i]) + img_metas (list[dict]): Meta information + gt_bboxes (BaseInstance3DBoxes): Ground truth bboxes + of shape (n_boxes, 7) + gt_labels (list(Tensor)): Ground truth class labels + of shape (n_boxes,) + + Returns: + tuple(Tensor): 3 losses + """ + featmap_sizes = [featmap.size()[-3:] for featmap in centernesses] + mlvl_points = self.get_points( + featmap_sizes=featmap_sizes, + origin=img_meta['lidar2img']['origin'], + device=gt_bboxes.device + ) + + centerness_targets, bbox_targets, labels = self.get_targets(mlvl_points, gt_bboxes, gt_labels) + + flatten_centerness = [centerness.permute(1, 2, 3, 0).reshape(-1) + for centerness in centernesses] + bbox_pred_size = bbox_preds[0].shape[0] + flatten_bbox_preds = [bbox_pred.permute(1, 2, 3, 0).reshape(-1, bbox_pred_size) + for bbox_pred in bbox_preds] + flatten_cls_scores = [cls_score.permute(1, 2, 3, 0).reshape(-1, self.n_classes) + for cls_score in cls_scores] + flatten_valids = [valid.permute(1, 2, 3, 0).reshape(-1) + for valid in valids] + + flatten_centerness = torch.cat(flatten_centerness) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_valids = torch.cat(flatten_valids) + + flatten_centerness_targets = centerness_targets.to(centernesses[0].device) + flatten_bbox_targets = bbox_targets.to(centernesses[0].device) + flatten_labels = labels.to(centernesses[0].device) + flatten_points = torch.cat(mlvl_points) + + # skip background + pos_inds = torch.nonzero(torch.logical_and( + flatten_labels >= 0, + flatten_valids + )).reshape(-1) + n_pos = torch.tensor(len(pos_inds), dtype=torch.float, device=centernesses[0].device) + n_pos = max(reduce_mean(n_pos), 1.) + if torch.any(flatten_valids): + loss_cls = self.loss_cls( + flatten_cls_scores[flatten_valids], + flatten_labels[flatten_valids], + avg_factor=n_pos + ) + else: + loss_cls = flatten_cls_scores[flatten_valids].sum() + pos_centerness = flatten_centerness[pos_inds] + pos_bbox_preds = flatten_bbox_preds[pos_inds] + + if len(pos_inds) > 0: + pos_centerness_targets = flatten_centerness_targets[pos_inds] + pos_bbox_targets = flatten_bbox_targets[pos_inds] + pos_points = flatten_points[pos_inds].to(pos_bbox_preds.device) + loss_centerness = self.loss_centerness( + pos_centerness, pos_centerness_targets, avg_factor=n_pos + ) + loss_bbox = self.loss_bbox( + self._bbox_pred_to_loss(pos_points, pos_bbox_preds), + pos_bbox_targets, + weight=pos_centerness_targets, + avg_factor=pos_centerness_targets.sum() + ) + else: + loss_centerness = pos_centerness.sum() + loss_bbox = pos_bbox_preds.sum() + return loss_centerness, loss_bbox, loss_cls + + @torch.no_grad() + def get_points(self, featmap_sizes, origin, device): + mlvl_points = [] + for i, featmap_size in enumerate(featmap_sizes): + mlvl_points.append(get_points( + n_voxels=torch.tensor(featmap_size), + voxel_size=torch.tensor(self.voxel_size) * (2 ** i), + origin=torch.tensor(origin) + ).reshape(3, -1).transpose(0, 1).to(device)) + return mlvl_points + + def get_bboxes(self, + centernesses, + bbox_preds, + cls_scores, + valid, + img_metas): + assert len(centernesses[0]) == len(bbox_preds[0]) == len(cls_scores[0]) \ + == len(img_metas) + valids = [] + for x in centernesses: + valids.append(nn.Upsample(size=x.shape[-3:], mode='trilinear')(valid).round().bool()) + n_levels = len(centernesses) + result_list = [] + for img_id in range(len(img_metas)): + centerness_list = [ + centernesses[i][img_id].detach() for i in range(n_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(n_levels) + ] + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(n_levels) + ] + valid_list = [ + valids[i][img_id].detach() for i in range(n_levels) + ] + det_bboxes_3d = self._get_bboxes_single( + centerness_list, bbox_pred_list, cls_score_list, valid_list, img_metas[img_id] + ) + result_list.append(det_bboxes_3d) + return result_list + + def _get_bboxes_single(self, + centernesses, + bbox_preds, + cls_scores, + valids, + img_meta): + featmap_sizes = [featmap.size()[-3:] for featmap in centernesses] + mlvl_points = self.get_points( + featmap_sizes=featmap_sizes, + origin=img_meta['lidar2img']['origin'], + device=centernesses[0].device + ) + bbox_pred_size = bbox_preds[0].shape[0] + mlvl_bboxes, mlvl_scores = [], [] + for centerness, bbox_pred, cls_score, valid, points in zip( + centernesses, bbox_preds, cls_scores, valids, mlvl_points + ): + centerness = centerness.permute(1, 2, 3, 0).reshape(-1).sigmoid() + bbox_pred = bbox_pred.permute(1, 2, 3, 0).reshape(-1, bbox_pred_size) + scores = cls_score.permute(1, 2, 3, 0).reshape(-1, self.n_classes).sigmoid() + valid = valid.permute(1, 2, 3, 0).reshape(-1) + scores = scores * centerness[:, None] * valid[:, None] + max_scores, _ = scores.max(dim=1) + + if len(scores) > self.test_cfg.nms_pre > 0: + _, ids = max_scores.topk(self.test_cfg.nms_pre) + bbox_pred = bbox_pred[ids] + scores = scores[ids] + points = points[ids] + + bboxes = self._bbox_pred_to_result(points, bbox_pred) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + + bboxes = torch.cat(mlvl_bboxes) + scores = torch.cat(mlvl_scores) + bboxes, scores, labels = self._nms(bboxes, scores, img_meta) + return bboxes, scores, labels + + def forward_single(self, x, scale): + raise NotImplementedError + + def _bbox_pred_to_loss(self, points, bbox_preds): + raise NotImplementedError + + def _bbox_pred_to_result(self, points, bbox_preds): + raise NotImplementedError + + def get_targets(self, points, gt_bboxes, gt_labels): + raise NotImplementedError + + def _nms(self, bboxes, scores, img_meta): + raise NotImplementedError + + +@HEADS.register_module() +class SunRgbdImVoxelHeadV2(ImVoxelHeadV2): + def forward_single(self, x, scale): + reg_final = self.reg_conv(x) + reg_distance = torch.exp(scale(reg_final[:, :6])) + reg_angle = reg_final[:, 6:] + return ( + self.centerness_conv(x), + torch.cat((reg_distance, reg_angle), dim=1), + self.cls_conv(x) + ) + + def _bbox_pred_to_loss(self, points, bbox_preds): + return self._bbox_pred_to_bbox(points, bbox_preds) + + def _bbox_pred_to_result(self, points, bbox_preds): + return self._bbox_pred_to_bbox(points, bbox_preds) + + @torch.no_grad() + def get_targets(self, points, gt_bboxes, gt_labels): + float_max = 1e8 + expanded_scales = [ + points[i].new_tensor(i).expand(len(points[i])).to(gt_labels.device) + for i in range(len(points)) + ] + points = torch.cat(points, dim=0).to(gt_labels.device) + scales = torch.cat(expanded_scales, dim=0) + + # below is based on FCOSHead._get_target_single + n_points = len(points) + n_boxes = len(gt_bboxes) + volumes = gt_bboxes.volume.to(points.device) + volumes = volumes.expand(n_points, n_boxes).contiguous() + gt_bboxes = torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), dim=1) + gt_bboxes = gt_bboxes.to(points.device).expand(n_points, n_boxes, 7) + expanded_points = points.unsqueeze(1).expand(n_points, n_boxes, 3) + shift = torch.stack(( + expanded_points[..., 0] - gt_bboxes[..., 0], + expanded_points[..., 1] - gt_bboxes[..., 1], + expanded_points[..., 2] - gt_bboxes[..., 2] + ), dim=-1).permute(1, 0, 2) + shift = rotation_3d_in_axis(shift, -gt_bboxes[0, :, 6], axis=2).permute(1, 0, 2) + centers = gt_bboxes[..., :3] + shift + dx_min = centers[..., 0] - gt_bboxes[..., 0] + gt_bboxes[..., 3] / 2 + dx_max = gt_bboxes[..., 0] + gt_bboxes[..., 3] / 2 - centers[..., 0] + dy_min = centers[..., 1] - gt_bboxes[..., 1] + gt_bboxes[..., 4] / 2 + dy_max = gt_bboxes[..., 1] + gt_bboxes[..., 4] / 2 - centers[..., 1] + dz_min = centers[..., 2] - gt_bboxes[..., 2] + gt_bboxes[..., 5] / 2 + dz_max = gt_bboxes[..., 2] + gt_bboxes[..., 5] / 2 - centers[..., 2] + bbox_targets = torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max, gt_bboxes[..., 6]), dim=-1) + + # condition1: inside a gt bbox + inside_gt_bbox_mask = bbox_targets[..., :6].min(-1)[0] > 0 # skip angle + + # condition2: positive points per scale >= limit + # calculate positive points per scale + n_pos_points_per_scale = [] + for i in range(self.n_scales): + n_pos_points_per_scale.append(torch.sum(inside_gt_bbox_mask[scales == i], dim=0)) + # find best scale + n_pos_points_per_scale = torch.stack(n_pos_points_per_scale, dim=0) + lower_limit_mask = n_pos_points_per_scale < self.limit + # fix nondeterministic argmax for torch<1.7 + extra = torch.arange(self.n_scales, 0, -1).unsqueeze(1).expand(self.n_scales, n_boxes).to(lower_limit_mask.device) + lower_index = torch.argmax(lower_limit_mask.int() * extra, dim=0) - 1 + lower_index = torch.where(lower_index < 0, torch.zeros_like(lower_index), lower_index) + all_upper_limit_mask = torch.all(torch.logical_not(lower_limit_mask), dim=0) + best_scale = torch.where(all_upper_limit_mask, torch.ones_like(all_upper_limit_mask) * self.n_scales - 1, lower_index) + # keep only points with best scale + best_scale = torch.unsqueeze(best_scale, 0).expand(n_points, n_boxes) + scales = torch.unsqueeze(scales, 1).expand(n_points, n_boxes) + inside_best_scale_mask = best_scale == scales + + # condition3: limit topk locations per box by centerness + centerness = compute_centerness(bbox_targets) + centerness = torch.where(inside_gt_bbox_mask, centerness, torch.ones_like(centerness) * -1) + centerness = torch.where(inside_best_scale_mask, centerness, torch.ones_like(centerness) * -1) + top_centerness = torch.topk(centerness, self.centerness_topk + 1, dim=0).values[-1] + inside_top_centerness_mask = centerness > top_centerness.unsqueeze(0) + + # if there are still more than one objects for a location, + # we choose the one with minimal area + volumes = torch.where(inside_gt_bbox_mask, volumes, torch.ones_like(volumes) * float_max) + volumes = torch.where(inside_best_scale_mask, volumes, torch.ones_like(volumes) * float_max) + volumes = torch.where(inside_top_centerness_mask, volumes, torch.ones_like(volumes) * float_max) + min_area, min_area_inds = volumes.min(dim=1) + + labels = gt_labels[min_area_inds] + labels = torch.where(min_area == float_max, torch.ones_like(labels) * -1, labels) + bbox_targets = bbox_targets[range(n_points), min_area_inds] + centerness_targets = compute_centerness(bbox_targets) + + return centerness_targets, gt_bboxes[range(n_points), min_area_inds], labels + + def _nms(self, bboxes, scores, img_meta): + # Add a dummy background class to the end. Nms needs to be fixed in the future. + padding = scores.new_zeros(scores.shape[0], 1) + scores = torch.cat([scores, padding], dim=1) + bboxes_for_nms = torch.stack(( + bboxes[:, 0] - bboxes[:, 3] / 2, + bboxes[:, 1] - bboxes[:, 4] / 2, + bboxes[:, 0] + bboxes[:, 3] / 2, + bboxes[:, 1] + bboxes[:, 4] / 2, + bboxes[:, 6] + ), dim=1) + bboxes, scores, labels, _ = box3d_multiclass_nms( + mlvl_bboxes=bboxes, + mlvl_bboxes_for_nms=bboxes_for_nms, + mlvl_scores=scores, + score_thr=self.test_cfg.score_thr, + max_num=self.test_cfg.nms_pre, + cfg=self.test_cfg, + ) + bboxes = img_meta['box_type_3d'](bboxes, origin=(.5, .5, .5)) + return bboxes, scores, labels + + @staticmethod + def _bbox_pred_to_bbox(points, bbox_pred): + if bbox_pred.shape[0] == 0: + return bbox_pred + + # dx_min, dx_max, dy_min, dy_max, dz_min, dz_max, alpha -> + # x_center, y_center, z_center, w, l, h, alpha + shift = torch.stack(( + (bbox_pred[:, 1] - bbox_pred[:, 0]) / 2, + (bbox_pred[:, 3] - bbox_pred[:, 2]) / 2, + (bbox_pred[:, 5] - bbox_pred[:, 4]) / 2 + ), dim=-1).view(-1, 1, 3) + shift = rotation_3d_in_axis(shift, bbox_pred[:, 6], axis=2)[:, 0, :] + center = points + shift + size = torch.stack(( + bbox_pred[:, 0] + bbox_pred[:, 1], + bbox_pred[:, 2] + bbox_pred[:, 3], + bbox_pred[:, 4] + bbox_pred[:, 5] + ), dim=-1) + return torch.cat((center, size, bbox_pred[:, 6:7]), dim=-1) + + + +@HEADS.register_module() +class ScanNetImVoxelHeadV2(ImVoxelHeadV2): + def forward_single(self, x, scale): + return ( + self.centerness_conv(x), + torch.exp(scale(self.reg_conv(x))), + self.cls_conv(x) + ) + + def _bbox_pred_to_loss(self, points, bbox_preds): + return self._bbox_pred_to_bbox(points, bbox_preds) + + def _bbox_pred_to_result(self, points, bbox_preds): + return self._bbox_pred_to_bbox(points, bbox_preds) + + @torch.no_grad() + def get_targets(self, points, gt_bboxes, gt_labels): + float_max = 1e8 + expanded_scales = [ + points[i].new_tensor(i).expand(len(points[i])).to(gt_labels.device) + for i in range(len(points)) + ] + points = torch.cat(points, dim=0).to(gt_labels.device) + scales = torch.cat(expanded_scales, dim=0) + + # below is based on FCOSHead._get_target_single + n_points = len(points) + n_boxes = len(gt_bboxes) + volumes = gt_bboxes.volume.to(points.device) + volumes = volumes.expand(n_points, n_boxes).contiguous() + gt_bboxes = torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:6]), dim=1) + gt_bboxes = gt_bboxes.to(points.device).expand(n_points, n_boxes, 6) + expanded_points = points.unsqueeze(1).expand(n_points, n_boxes, 3) + dx_min = expanded_points[..., 0] - gt_bboxes[..., 0] + gt_bboxes[..., 3] / 2 + dx_max = gt_bboxes[..., 0] + gt_bboxes[..., 3] / 2 - expanded_points[..., 0] + dy_min = expanded_points[..., 1] - gt_bboxes[..., 1] + gt_bboxes[..., 4] / 2 + dy_max = gt_bboxes[..., 1] + gt_bboxes[..., 4] / 2 - expanded_points[..., 1] + dz_min = expanded_points[..., 2] - gt_bboxes[..., 2] + gt_bboxes[..., 5] / 2 + dz_max = gt_bboxes[..., 2] + gt_bboxes[..., 5] / 2 - expanded_points[..., 2] + bbox_targets = torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max), dim=-1) + + # condition1: inside a gt bbox + inside_gt_bbox_mask = bbox_targets[..., :6].min(-1)[0] > 0 # skip angle + + # condition2: positive points per scale >= limit + # calculate positive points per scale + n_pos_points_per_scale = [] + for i in range(self.n_scales): + n_pos_points_per_scale.append(torch.sum(inside_gt_bbox_mask[scales == i], dim=0)) + # find best scale + n_pos_points_per_scale = torch.stack(n_pos_points_per_scale, dim=0) + lower_limit_mask = n_pos_points_per_scale < self.limit + # fix nondeterministic argmax for torch<1.7 + extra = torch.arange(self.n_scales, 0, -1).unsqueeze(1).expand(self.n_scales, n_boxes).to( + lower_limit_mask.device) + lower_index = torch.argmax(lower_limit_mask.int() * extra, dim=0) - 1 + lower_index = torch.where(lower_index < 0, torch.zeros_like(lower_index), lower_index) + all_upper_limit_mask = torch.all(torch.logical_not(lower_limit_mask), dim=0) + best_scale = torch.where(all_upper_limit_mask, torch.ones_like(all_upper_limit_mask) * self.n_scales - 1, + lower_index) + # keep only points with best scale + best_scale = torch.unsqueeze(best_scale, 0).expand(n_points, n_boxes) + scales = torch.unsqueeze(scales, 1).expand(n_points, n_boxes) + inside_best_scale_mask = best_scale == scales + + # condition3: limit topk locations per box by centerness + centerness = compute_centerness(bbox_targets) + centerness = torch.where(inside_gt_bbox_mask, centerness, torch.ones_like(centerness) * -1) + centerness = torch.where(inside_best_scale_mask, centerness, torch.ones_like(centerness) * -1) + top_centerness = torch.topk(centerness, self.centerness_topk + 1, dim=0).values[-1] + inside_top_centerness_mask = centerness > top_centerness.unsqueeze(0) + + # if there are still more than one objects for a location, + # we choose the one with minimal area + volumes = torch.where(inside_gt_bbox_mask, volumes, torch.ones_like(volumes) * float_max) + volumes = torch.where(inside_best_scale_mask, volumes, torch.ones_like(volumes) * float_max) + volumes = torch.where(inside_top_centerness_mask, volumes, torch.ones_like(volumes) * float_max) + min_area, min_area_inds = volumes.min(dim=1) + + labels = gt_labels[min_area_inds] + labels = torch.where(min_area == float_max, torch.ones_like(labels) * -1, labels) + bbox_targets = bbox_targets[range(n_points), min_area_inds] + centerness_targets = compute_centerness(bbox_targets) + + return centerness_targets, self._bbox_pred_to_bbox(points, bbox_targets), labels + + def _nms(self, bboxes, scores, img_meta): + scores, labels = scores.max(dim=1) + ids = scores > self.test_cfg.score_thr + bboxes = bboxes[ids] + scores = scores[ids] + labels = labels[ids] + ids = aligned_3d_nms(bboxes, scores, labels, self.test_cfg.iou_thr) + bboxes = bboxes[ids] + bboxes = torch.stack(( + (bboxes[:, 0] + bboxes[:, 3]) / 2., + (bboxes[:, 1] + bboxes[:, 4]) / 2., + (bboxes[:, 2] + bboxes[:, 5]) / 2., + bboxes[:, 3] - bboxes[:, 0], + bboxes[:, 4] - bboxes[:, 1], + bboxes[:, 5] - bboxes[:, 2] + ), dim=1) + bboxes = img_meta['box_type_3d'](bboxes, origin=(.5, .5, .5), box_dim=6, with_yaw=False) + return bboxes, scores[ids], labels[ids] + + def _bbox_pred_to_bbox(self, points, bbox_pred): + return torch.stack([ + points[:, 0] - bbox_pred[:, 0], + points[:, 1] - bbox_pred[:, 2], + points[:, 2] - bbox_pred[:, 4], + points[:, 0] + bbox_pred[:, 1], + points[:, 1] + bbox_pred[:, 3], + points[:, 2] + bbox_pred[:, 5] + ], -1) + + +def compute_centerness(bbox_targets): + x_dims = bbox_targets[..., [0, 1]] + y_dims = bbox_targets[..., [2, 3]] + z_dims = bbox_targets[..., [4, 5]] + centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \ + y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \ + z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0] + # todo: sqrt ? + return torch.sqrt(centerness_targets) diff --git a/mmdet3d/models/necks/__init__.py b/mmdet3d/models/necks/__init__.py index 60c3c27..8adcf7b 100644 --- a/mmdet3d/models/necks/__init__.py +++ b/mmdet3d/models/necks/__init__.py @@ -1,5 +1,5 @@ from mmdet.models.necks.fpn import FPN from .second_fpn import SECONDFPN -from .imvoxelnet import ImVoxelNeck, KittiImVoxelNeck, NuScenesImVoxelNeck +from .imvoxelnet import FastIndoorImVoxelNeck, ImVoxelNeck, KittiImVoxelNeck, NuScenesImVoxelNeck -__all__ = ['FPN', 'SECONDFPN', 'ImVoxelNeck', 'KittiImVoxelNeck', 'NuScenesImVoxelNeck'] +__all__ = ['FPN', 'SECONDFPN', 'FastIndoorImVoxelNeck', 'ImVoxelNeck', 'KittiImVoxelNeck', 'NuScenesImVoxelNeck'] diff --git a/mmdet3d/models/necks/imvoxelnet.py b/mmdet3d/models/necks/imvoxelnet.py index 9192e63..f4ef2c9 100644 --- a/mmdet3d/models/necks/imvoxelnet.py +++ b/mmdet3d/models/necks/imvoxelnet.py @@ -5,6 +5,68 @@ from mmdet.models import NECKS +@NECKS.register_module() +class FastIndoorImVoxelNeck(nn.Module): + def __init__(self, in_channels, n_blocks, out_channels): + super(FastIndoorImVoxelNeck, self).__init__() + self.n_scales = len(n_blocks) + n_channels = in_channels + for i in range(len(n_blocks)): + stride = 1 if i == 0 else 2 + self.__setattr__(f'down_layer_{i}', self._make_layer(stride, n_channels, n_blocks[i])) + n_channels = n_channels * stride + if i > 0: + self.__setattr__(f'up_block_{i}', self._make_up_block(n_channels, n_channels // 2)) + self.__setattr__(f'out_block_{i}', self._make_block(n_channels, out_channels)) + + def forward(self, x): + down_outs = [] + for i in range(self.n_scales): + x = self.__getattr__(f'down_layer_{i}')(x) + down_outs.append(x) + outs = [] + for i in range(self.n_scales - 1, -1, -1): + if i < self.n_scales - 1: + x = self.__getattr__(f'up_block_{i + 1}')(x) + x = down_outs[i] + x + out = self.__getattr__(f'out_block_{i}')(x) + outs.append(out) + return outs[::-1] + + @staticmethod + def _make_layer(stride, n_channels, n_blocks): + blocks = [] + for i in range(n_blocks): + if i == 0 and stride != 1: + blocks.append(BasicBlock3dV2(n_channels, n_channels * 2, stride)) + n_channels = n_channels * 2 + else: + blocks.append(BasicBlock3dV2(n_channels, n_channels)) + return nn.Sequential(*blocks) + + @staticmethod + def _make_block(in_channels, out_channels): + return nn.Sequential( + nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True) + ) + + @staticmethod + def _make_up_block(in_channels, out_channels): + return nn.Sequential( + nn.ConvTranspose3d(in_channels, out_channels, 2, 2, bias=False), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True) + ) + + def init_weights(self): + pass + + @NECKS.register_module() class ImVoxelNeck(nn.Module): def __init__(self, channels, out_channels, down_layers, up_layers, conditional): @@ -168,6 +230,36 @@ def forward(self, x): return out +class BasicBlock3dV2(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(BasicBlock3dV2, self).__init__() + self.stride = stride + self.conv1 = nn.Conv3d(in_channels, out_channels, 3, stride, 1, bias=False) + self.norm1 = nn.BatchNorm3d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False) + self.norm2 = nn.BatchNorm3d(out_channels) + if self.stride != 1: + self.downsample = nn.Sequential( + nn.Conv3d(in_channels, out_channels, 1, stride, bias=False), + nn.BatchNorm3d(out_channels) + ) + self.stride = stride + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.norm2(out) + if self.stride != 1: + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + + class ConditionalProjection(nn.Module): """ Applies a projected skip connection from the encoder to the decoder When condition is False this is a standard projected skip connection