diff --git a/README.md b/README.md
index c6cdb44..88403ad 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
# ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection
**News**:
- * :fire: October, 2021. Our paper is accepted at [WACV 2022](https://wacv2022.thecvf.com). Stay tuned for 5 times faster and more accurate models for indoor datasets.
+ * :fire: October, 2021. Our paper is accepted at [WACV 2022](https://wacv2022.thecvf.com). We simplify 3d neck to make indoor models much faster and accurate. For example, this improves `ScanNet` `mAP` by more than 2%. Please find updated configs in [configs/imvoxelnet/*_fast.py](https://github.com/saic-vul/imvoxelnet/tree/master/configs/imvoxelnet) and [models](https://github.com/saic-vul/imvoxelnet/releases/tag/v1.2).
* :fire: August, 2021. We adapt center sampling for indoor detection. For example, this improves `ScanNet` `mAP` by more than 5%. Please find updated configs in [configs/imvoxelnet/*_top27.py](https://github.com/saic-vul/imvoxelnet/tree/master/configs/imvoxelnet) and [models](https://github.com/saic-vul/imvoxelnet/releases/tag/v1.1).
* :fire: July, 2021. We update `ScanNet` image preprocessing both [here](https://github.com/saic-vul/imvoxelnet/pull/21) and in [mmdetection3d](https://github.com/open-mmlab/mmdetection3d/pull/696).
* :fire: June, 2021. `ImVoxelNet` for `KITTI` is now [supported](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/imvoxelnet) in [mmdetection3d](https://github.com/open-mmlab/mmdetection3d).
@@ -87,14 +87,16 @@ python tools/test.py configs/imvoxelnet/imvoxelnet_kitti.py \
### Models
-| Dataset | Object Classes | Center Sampling | Download |
-|:---------:|:--------------:|:---------------:|:--------:|
-| SUN RGB-D | 37 from Total3dUnderstanding | ✘
✔ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210525_091810.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210525_091810_atlas_total_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_005013.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_005013_imvoxelnet_total_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd_top27.py)|
-| SUN RGB-D | 30 from PerspectiveNet | ✘
✔ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210526_072029.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210526_072029_atlas_perspective_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_114832.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_114832_imvoxelnet_perspective_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_top27.py)|
-| SUN RGB-D | 10 from VoteNet | ✘
✔ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210428_124351.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210428_124351_atlas_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_112435.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_112435_imvoxelnet_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd_top27.py)|
-| ScanNet | 18 from VoteNet | ✘
✔ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210520_223109.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210520_223109_atlas_scannet.log) | [config](configs/imvoxelnet/imvoxelnet_scannet.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_070616.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_070616_imvoxelnet_scannet_top27.log) | [config](configs/imvoxelnet/imvoxelnet_scannet_top27.py)|
-| KITTI | Car | ✘ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210503_214214.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210503_214214_atlas_kitti.log) | [config](configs/imvoxelnet/imvoxelnet_kitti.py) |
-| nuScenes | Car | ✘ | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210505_131108.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210505_131108_atlas_nuscenes.log) | [config](configs/imvoxelnet/imvoxelnet_nuscenes.py) |
+`v2` adds center sampling for indoor scenario. `v3` simplifies 3d neck for indoor scenario. Differences are discussed in [v2](https://arxiv.org/abs/2106.01178v2) and [v3](https://arxiv.org/abs/2106.01178v3) preprints.
+
+| Dataset | Object Classes | Version | Download |
+|:---------:|:--------------:|:-------:|:--------:|
+| SUN RGB-D | 37 from
Total3dUnderstanding | v1 | mAP@0.15: 41.5
v2 | mAP@0.15: 42.7
v3 | mAP@0.15: 43.7 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210525_091810.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210525_091810_atlas_total_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_005013.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_005013_imvoxelnet_total_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd_top27.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105247.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105247_imvoxelnet_total_sunrgbd_fast.log) | [config](configs/imvoxelnet/imvoxelnet_total_sunrgbd_fast.py)|
+| SUN RGB-D | 30 from
PerspectiveNet | v1 | mAP@0.15: 44.9
v2 | mAP@0.15: 47.2
v3 | mAP@0.15: 48.7 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210526_072029.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210526_072029_atlas_perspective_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_114832.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_114832_imvoxelnet_perspective_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_top27.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105254.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105254_imvoxelnet_perspective_sunrgbd_fast.log) | [config](configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_fast.py)|
+| SUN RGB-D | 10 from VoteNet | v1 | mAP@0.25: 38.8
v2 | mAP@0.25: 39.4
v3 | mAP@0.25: 40.7 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210428_124351.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210428_124351_atlas_sunrgbd.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_112435.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210809_112435_imvoxelnet_sunrgbd_top27.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd_top27.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105255.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_105255_imvoxelnet_sunrgbd_fast.log) | [config](configs/imvoxelnet/imvoxelnet_sunrgbd_fast.py)|
+| ScanNet | 18 from VoteNet | v1 | mAP@0.25: 40.6
v2 | mAP@0.25: 45.7
v3 | mAP@0.25: 48.1 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210520_223109.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210520_223109_atlas_scannet.log) | [config](configs/imvoxelnet/imvoxelnet_scannet.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_070616.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.1/20210808_070616_imvoxelnet_scannet_top27.log) | [config](configs/imvoxelnet/imvoxelnet_scannet_top27.py)
[model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_113826.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.2/20211007_113826_imvoxelnet_scannet_fast.log) | [config](configs/imvoxelnet/imvoxelnet_scannet_fast.py)|
+| KITTI | Car | v1 | AP@0.7: 17.8 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210503_214214.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210503_214214_atlas_kitti.log) | [config](configs/imvoxelnet/imvoxelnet_kitti.py) |
+| nuScenes | Car | v1 | AP: 51.8 | [model](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210505_131108.pth) | [log](https://github.com/saic-vul/imvoxelnet/releases/download/v1.0/20210505_131108_atlas_nuscenes.log) | [config](configs/imvoxelnet/imvoxelnet_nuscenes.py) |
### Example Detections
diff --git a/configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_fast.py b/configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_fast.py
new file mode 100644
index 0000000..775d8d0
--- /dev/null
+++ b/configs/imvoxelnet/imvoxelnet_perspective_sunrgbd_fast.py
@@ -0,0 +1,127 @@
+model = dict(
+ type='ImVoxelNet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=False),
+ norm_eval=True,
+ style='pytorch'),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ neck_3d=dict(
+ type='FastIndoorImVoxelNeck',
+ in_channels=256,
+ out_channels=128,
+ n_blocks=[1, 1, 1]),
+ bbox_head=dict(
+ type='SunRgbdImVoxelHeadV2',
+ n_classes=30,
+ n_channels=128,
+ n_reg_outs=7,
+ n_scales=3,
+ limit=27,
+ centerness_topk=18),
+ n_voxels=(40, 40, 16),
+ voxel_size=(.16, .16, .16))
+train_cfg = dict()
+test_cfg = dict(
+ nms_pre=1000,
+ nms_thr=.15,
+ use_rotate_nms=True,
+ score_thr=.01)
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+dataset_type = 'SunRgbdPerspectiveMultiViewDataset'
+data_root = 'data/sunrgbd/'
+class_names = ('recycle_bin', 'cpu', 'paper', 'toilet', 'stool', 'whiteboard', 'coffee_table', 'picture',
+ 'keyboard', 'dresser', 'painting', 'bookshelf', 'night_stand', 'endtable', 'drawer', 'sink',
+ 'monitor', 'computer', 'cabinet', 'shelf', 'lamp', 'garbage_bin', 'box', 'bed', 'sofa',
+ 'sofa_chair', 'pillow', 'desk', 'table', 'chair')
+
+train_pipeline = [
+ dict(type='LoadAnnotations3D'),
+ dict(
+ type='MultiViewPipeline',
+ n_images=1,
+ transforms=[
+ dict(type='LoadImageFromFile'),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Resize', img_scale=[(512, 384), (768, 576)], multiscale_mode='range', keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32)]),
+ dict(type='SunRgbdRandomFlip'),
+ dict(type='DefaultFormatBundle3D', class_names=class_names),
+ dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])]
+test_pipeline = [
+ dict(
+ type='MultiViewPipeline',
+ n_images=1,
+ transforms=[
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', img_scale=(640, 480), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32)]),
+ dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
+ dict(type='Collect3D', keys=['img'])]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=2,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_perspective_infos_train.pkl',
+ pipeline=train_pipeline,
+ classes=class_names,
+ filter_empty_gt=True,
+ box_type_3d='Depth')),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_perspective_infos_val.pkl',
+ pipeline=test_pipeline,
+ classes=class_names,
+ test_mode=True,
+ box_type_3d='Depth'),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_perspective_infos_val.pkl',
+ pipeline=test_pipeline,
+ classes=class_names,
+ test_mode=True,
+ box_type_3d='Depth'))
+
+optimizer = dict(
+ type='AdamW',
+ lr=0.0001,
+ weight_decay=0.0001,
+ paramwise_cfg=dict(
+ custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
+optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2))
+lr_config = dict(policy='step', step=[8, 11])
+total_epochs = 12
+
+checkpoint_config = dict(interval=1, max_keep_ckpts=1)
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+evaluation = dict(interval=1)
+dist_params = dict(backend='nccl')
+find_unused_parameters = True # todo: fix number of FPN outputs
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/imvoxelnet/imvoxelnet_scannet_fast.py b/configs/imvoxelnet/imvoxelnet_scannet_fast.py
new file mode 100644
index 0000000..9c8cd3b
--- /dev/null
+++ b/configs/imvoxelnet/imvoxelnet_scannet_fast.py
@@ -0,0 +1,131 @@
+model = dict(
+ type='ImVoxelNet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=False),
+ norm_eval=True,
+ style='pytorch'),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ neck_3d=dict(
+ type='FastIndoorImVoxelNeck',
+ in_channels=256,
+ out_channels=128,
+ n_blocks=[1, 1, 1]),
+ bbox_head=dict(
+ type='ScanNetImVoxelHeadV2',
+ loss_bbox=dict(type='AxisAlignedIoULoss', loss_weight=1.0),
+ n_classes=18,
+ n_channels=128,
+ n_reg_outs=6,
+ n_scales=3,
+ limit=27,
+ centerness_topk=18),
+ voxel_size=(.16, .16, .16),
+ n_voxels=(40, 40, 16))
+train_cfg = dict()
+test_cfg = dict(
+ nms_pre=1000,
+ iou_thr=.25,
+ score_thr=.01)
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+dataset_type = 'ScanNetMultiViewDataset'
+data_root = 'data/scannet/'
+class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
+ 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
+ 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
+ 'garbagebin')
+
+train_pipeline = [
+ dict(type='LoadAnnotations3D'),
+ dict(
+ type='MultiViewPipeline',
+ n_images=20,
+ transforms=[
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', img_scale=(640, 480), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=(480, 640))
+ ]),
+ dict(type='RandomShiftOrigin', std=(.7, .7, .0)),
+ dict(type='DefaultFormatBundle3D', class_names=class_names),
+ dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])
+]
+test_pipeline = [
+ dict(
+ type='MultiViewPipeline',
+ n_images=50,
+ transforms=[
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', img_scale=(640, 480), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=(480, 640))
+ ]),
+ dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
+ dict(type='Collect3D', keys=['img'])
+]
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ train=dict(
+ type='RepeatDataset',
+ times=3,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'scannet_infos_train.pkl',
+ pipeline=train_pipeline,
+ classes=class_names,
+ filter_empty_gt=True,
+ box_type_3d='Depth')),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'scannet_infos_val.pkl',
+ pipeline=test_pipeline,
+ classes=class_names,
+ test_mode=True,
+ box_type_3d='Depth'),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'scannet_infos_val.pkl',
+ pipeline=test_pipeline,
+ classes=class_names,
+ test_mode=True,
+ box_type_3d='Depth')
+)
+
+optimizer = dict(
+ type='AdamW',
+ lr=0.0001,
+ weight_decay=0.0001,
+ paramwise_cfg=dict(
+ custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
+optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2))
+lr_config = dict(policy='step', step=[8, 11])
+total_epochs = 12
+
+checkpoint_config = dict(interval=1, max_keep_ckpts=1)
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+evaluation = dict(interval=1)
+dist_params = dict(backend='nccl')
+find_unused_parameters = True # todo: fix number of FPN outputs
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
\ No newline at end of file
diff --git a/configs/imvoxelnet/imvoxelnet_sunrgbd_fast.py b/configs/imvoxelnet/imvoxelnet_sunrgbd_fast.py
new file mode 100644
index 0000000..f0d0eab
--- /dev/null
+++ b/configs/imvoxelnet/imvoxelnet_sunrgbd_fast.py
@@ -0,0 +1,125 @@
+model = dict(
+ type='ImVoxelNet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=False),
+ norm_eval=True,
+ style='pytorch'),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ neck_3d=dict(
+ type='FastIndoorImVoxelNeck',
+ in_channels=256,
+ out_channels=128,
+ n_blocks=[1, 1, 1]),
+ bbox_head=dict(
+ type='SunRgbdImVoxelHeadV2',
+ n_classes=10,
+ n_channels=128,
+ n_reg_outs=7,
+ n_scales=3,
+ limit=27,
+ centerness_topk=18),
+ n_voxels=(40, 40, 16),
+ voxel_size=(.16, .16, .16))
+train_cfg = dict()
+test_cfg = dict(
+ nms_pre=1000,
+ nms_thr=.15,
+ use_rotate_nms=True,
+ score_thr=.0)
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+dataset_type = 'SunRgbdMultiViewDataset'
+data_root = 'data/sunrgbd/'
+class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
+ 'night_stand', 'bookshelf', 'bathtub')
+
+train_pipeline = [
+ dict(type='LoadAnnotations3D'),
+ dict(
+ type='MultiViewPipeline',
+ n_images=1,
+ transforms=[
+ dict(type='LoadImageFromFile'),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Resize', img_scale=[(512, 384), (768, 576)], multiscale_mode='range', keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32)]),
+ dict(type='SunRgbdRandomFlip'),
+ dict(type='DefaultFormatBundle3D', class_names=class_names),
+ dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])]
+test_pipeline = [
+ dict(
+ type='MultiViewPipeline',
+ n_images=1,
+ transforms=[
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', img_scale=(640, 480), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32)]),
+ dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
+ dict(type='Collect3D', keys=['img'])]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=2,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_imvoxelnet_infos_train.pkl',
+ pipeline=train_pipeline,
+ classes=class_names,
+ filter_empty_gt=True,
+ box_type_3d='Depth')),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_imvoxelnet_infos_val.pkl',
+ pipeline=test_pipeline,
+ classes=class_names,
+ test_mode=True,
+ box_type_3d='Depth'),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_imvoxelnet_infos_val.pkl',
+ pipeline=test_pipeline,
+ classes=class_names,
+ test_mode=True,
+ box_type_3d='Depth'))
+
+optimizer = dict(
+ type='AdamW',
+ lr=0.0001,
+ weight_decay=0.0001,
+ paramwise_cfg=dict(
+ custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
+optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2))
+lr_config = dict(policy='step', step=[8, 11])
+total_epochs = 12
+
+checkpoint_config = dict(interval=1, max_keep_ckpts=1)
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+evaluation = dict(interval=1)
+dist_params = dict(backend='nccl')
+find_unused_parameters = True # todo: fix number of FPN outputs
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/imvoxelnet/imvoxelnet_total_sunrgbd_fast.py b/configs/imvoxelnet/imvoxelnet_total_sunrgbd_fast.py
new file mode 100644
index 0000000..4570c58
--- /dev/null
+++ b/configs/imvoxelnet/imvoxelnet_total_sunrgbd_fast.py
@@ -0,0 +1,134 @@
+model = dict(
+ type='ImVoxelNet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=False),
+ norm_eval=True,
+ style='pytorch'),
+ head_2d=dict(
+ type='LayoutHead',
+ n_channels=2048,
+ linear_size=256,
+ dropout=.0,
+ loss_angle=dict(type='SmoothL1Loss', loss_weight=100.),
+ loss_layout=dict(type='IoU3DLoss', loss_weight=1.)),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ neck_3d=dict(
+ type='FastIndoorImVoxelNeck',
+ in_channels=256,
+ out_channels=128,
+ n_blocks=[1, 1, 1]),
+ bbox_head=dict(
+ type='SunRgbdImVoxelHeadV2',
+ n_classes=33,
+ n_channels=128,
+ n_reg_outs=7,
+ n_scales=3,
+ limit=27,
+ centerness_topk=18),
+ n_voxels=(40, 40, 16),
+ voxel_size=(.16, .16, .16))
+train_cfg = dict()
+test_cfg = dict(
+ nms_pre=1000,
+ nms_thr=.15,
+ use_rotate_nms=True,
+ score_thr=.0)
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+dataset_type = 'SunRgbdTotalMultiViewDataset'
+data_root = 'data/sunrgbd/'
+class_names = [
+ 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter',
+ 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'clothes', 'books',
+ 'fridge', 'tv', 'paper', 'towel', 'shower_curtain', 'box', 'whiteboard', 'person', 'night_stand', 'toilet',
+ 'sink', 'lamp', 'bathtub', 'bag'
+]
+
+train_pipeline = [
+ dict(type='LoadAnnotations3D'),
+ dict(
+ type='MultiViewPipeline',
+ n_images=1,
+ transforms=[
+ dict(type='SunRgbdTotalLoadImageFromFile'),
+ dict(type='Resize', img_scale=[(512, 384), (768, 576)], multiscale_mode='range', keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32)]),
+ dict(type='DefaultFormatBundle3D', class_names=class_names),
+ dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])]
+test_pipeline = [
+ dict(
+ type='MultiViewPipeline',
+ n_images=1,
+ transforms=[
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', img_scale=(640, 480), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32)]),
+ dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
+ dict(type='Collect3D', keys=['img'])]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=1,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_total_infos_train.pkl',
+ pipeline=train_pipeline,
+ classes=class_names,
+ filter_empty_gt=True,
+ box_type_3d='Depth')),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_total_infos_val.pkl',
+ pipeline=test_pipeline,
+ classes=class_names,
+ test_mode=True,
+ box_type_3d='Depth'),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file=data_root + 'sunrgbd_total_infos_val.pkl',
+ pipeline=test_pipeline,
+ classes=class_names,
+ test_mode=True,
+ box_type_3d='Depth'))
+
+optimizer = dict(
+ type='AdamW',
+ lr=0.0001,
+ weight_decay=0.0001,
+ paramwise_cfg=dict(
+ custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
+optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2))
+lr_config = dict(policy='step', step=[8, 11])
+total_epochs = 12
+
+checkpoint_config = dict(interval=1, max_keep_ckpts=1)
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+evaluation = dict(interval=1)
+dist_params = dict(backend='nccl')
+find_unused_parameters = True # todo: fix number of FPN outputs
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py
index ecea98d..ba0075e 100644
--- a/mmdet3d/models/dense_heads/__init__.py
+++ b/mmdet3d/models/dense_heads/__init__.py
@@ -7,6 +7,7 @@
from .ssd_3d_head import SSD3DHead
from .vote_head import VoteHead
from .imvoxel_head import SunRgbdImVoxelHead, ScanNetImVoxelHead
+from .imvoxel_head_v2 import SunRgbdImVoxelHeadV2, ScanNetImVoxelHeadV2
from .layout_head import LayoutHead
__all__ = [
diff --git a/mmdet3d/models/dense_heads/imvoxel_head_v2.py b/mmdet3d/models/dense_heads/imvoxel_head_v2.py
index e69de29..2f4d58d 100644
--- a/mmdet3d/models/dense_heads/imvoxel_head_v2.py
+++ b/mmdet3d/models/dense_heads/imvoxel_head_v2.py
@@ -0,0 +1,566 @@
+import torch
+from torch import nn
+from mmdet.core import multi_apply, reduce_mean
+from mmdet.models.builder import HEADS, build_loss
+from mmcv.cnn import Scale, bias_init_with_prob, normal_init
+
+from mmdet3d.models.detectors.imvoxelnet import get_points
+from mmdet3d.core.bbox.structures import rotation_3d_in_axis
+from mmdet3d.core.post_processing import aligned_3d_nms, box3d_multiclass_nms
+
+
+class ImVoxelHeadV2(nn.Module):
+ def __init__(self,
+ n_classes,
+ n_channels,
+ n_reg_outs,
+ n_scales,
+ limit,
+ centerness_topk=-1,
+ loss_centerness=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_bbox=dict(type='IoU3DLoss', loss_weight=1.0),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None):
+ super(ImVoxelHeadV2, self).__init__()
+ self.n_classes = n_classes
+ self.n_scales = n_scales
+ self.limit = limit
+ self.centerness_topk = centerness_topk
+ self.loss_centerness = build_loss(loss_centerness)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_cls = build_loss(loss_cls)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self._init_layers(n_channels, n_reg_outs)
+
+ def _init_layers(self, n_channels, n_reg_outs):
+ self.centerness_conv = nn.Conv3d(n_channels, 1, 3, padding=1, bias=False)
+ self.reg_conv = nn.Conv3d(n_channels, n_reg_outs, 3, padding=1, bias=False)
+ self.cls_conv = nn.Conv3d(n_channels, self.n_classes, 3, padding=1)
+ self.scales = nn.ModuleList([Scale(1.) for _ in range(self.n_scales)])
+
+ # Follow AnchorFreeHead.init_weights
+ def init_weights(self):
+ normal_init(self.centerness_conv, std=.01)
+ normal_init(self.reg_conv, std=.01)
+ normal_init(self.cls_conv, std=.01, bias=bias_init_with_prob(.01))
+
+ def forward(self, x):
+ return multi_apply(self.forward_single, x, self.scales)
+
+ def forward_train(self, x, valid, img_metas, gt_bboxes, gt_labels):
+ loss_inputs = self(x) + (valid, img_metas, gt_bboxes, gt_labels)
+ losses = self.loss(*loss_inputs)
+ return losses
+
+ def loss(self,
+ centernesses,
+ bbox_preds,
+ cls_scores,
+ valid,
+ img_metas,
+ gt_bboxes,
+ gt_labels):
+ """
+ Args:
+ centernesses (list(Tensor)): Multi-level centernesses
+ of shape (batch, 1, nx[i], ny[i], nz[i])
+ bbox_preds (list(Tensor)): Multi-level xyz min and max distances
+ of shape (batch, 6, nx[i], ny[i], nz[i])
+ cls_scores (list(Tensor)): Multi-level class scores
+ of shape (batch, n_classes, nx[i], ny[i], nz[i])
+ img_metas (list[dict]): Meta information of each image
+ gt_bboxes (list(BaseInstance3DBoxes)): Ground truth bboxes for each image
+ gt_labels (list(Tensor)): Ground truth class labels for each image
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert len(centernesses[0]) == len(bbox_preds[0]) == len(cls_scores[0]) == \
+ len(valid) == len(img_metas) == len(gt_bboxes) == len(gt_labels)
+
+ valids = []
+ for x in centernesses:
+ valids.append(nn.Upsample(size=x.shape[-3:], mode='trilinear')(valid).round().bool())
+
+ loss_centerness, loss_bbox, loss_cls = [], [], []
+ for i in range(len(img_metas)):
+ img_loss_centerness, img_loss_bbox, img_loss_cls = self._loss_single(
+ centernesses=[x[i] for x in centernesses],
+ bbox_preds=[x[i] for x in bbox_preds],
+ cls_scores=[x[i] for x in cls_scores],
+ valids=[x[i] for x in valids],
+ img_meta=img_metas[i],
+ gt_bboxes=gt_bboxes[i],
+ gt_labels=gt_labels[i]
+ )
+ loss_centerness.append(img_loss_centerness)
+ loss_bbox.append(img_loss_bbox)
+ loss_cls.append(img_loss_cls)
+ return dict(
+ loss_centerness=torch.mean(torch.stack(loss_centerness)),
+ loss_bbox=torch.mean(torch.stack(loss_bbox)),
+ loss_cls=torch.mean(torch.stack(loss_cls))
+ )
+
+ def _loss_single(self,
+ centernesses,
+ bbox_preds,
+ cls_scores,
+ valids,
+ img_meta,
+ gt_bboxes,
+ gt_labels):
+ """
+ Args:
+ centernesses (list(Tensor)): Multi-level centernesses
+ of shape (1, nx[i], ny[i], nz[i])
+ bbox_preds (list(Tensor)): Multi-level xyz min and max distances
+ of shape (6, nx[i], ny[i], nz[i])
+ cls_scores (list(Tensor)): Multi-level class scores
+ of shape (n_classes, nx[i], ny[i], nz[i])
+ img_metas (list[dict]): Meta information
+ gt_bboxes (BaseInstance3DBoxes): Ground truth bboxes
+ of shape (n_boxes, 7)
+ gt_labels (list(Tensor)): Ground truth class labels
+ of shape (n_boxes,)
+
+ Returns:
+ tuple(Tensor): 3 losses
+ """
+ featmap_sizes = [featmap.size()[-3:] for featmap in centernesses]
+ mlvl_points = self.get_points(
+ featmap_sizes=featmap_sizes,
+ origin=img_meta['lidar2img']['origin'],
+ device=gt_bboxes.device
+ )
+
+ centerness_targets, bbox_targets, labels = self.get_targets(mlvl_points, gt_bboxes, gt_labels)
+
+ flatten_centerness = [centerness.permute(1, 2, 3, 0).reshape(-1)
+ for centerness in centernesses]
+ bbox_pred_size = bbox_preds[0].shape[0]
+ flatten_bbox_preds = [bbox_pred.permute(1, 2, 3, 0).reshape(-1, bbox_pred_size)
+ for bbox_pred in bbox_preds]
+ flatten_cls_scores = [cls_score.permute(1, 2, 3, 0).reshape(-1, self.n_classes)
+ for cls_score in cls_scores]
+ flatten_valids = [valid.permute(1, 2, 3, 0).reshape(-1)
+ for valid in valids]
+
+ flatten_centerness = torch.cat(flatten_centerness)
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds)
+ flatten_cls_scores = torch.cat(flatten_cls_scores)
+ flatten_valids = torch.cat(flatten_valids)
+
+ flatten_centerness_targets = centerness_targets.to(centernesses[0].device)
+ flatten_bbox_targets = bbox_targets.to(centernesses[0].device)
+ flatten_labels = labels.to(centernesses[0].device)
+ flatten_points = torch.cat(mlvl_points)
+
+ # skip background
+ pos_inds = torch.nonzero(torch.logical_and(
+ flatten_labels >= 0,
+ flatten_valids
+ )).reshape(-1)
+ n_pos = torch.tensor(len(pos_inds), dtype=torch.float, device=centernesses[0].device)
+ n_pos = max(reduce_mean(n_pos), 1.)
+ if torch.any(flatten_valids):
+ loss_cls = self.loss_cls(
+ flatten_cls_scores[flatten_valids],
+ flatten_labels[flatten_valids],
+ avg_factor=n_pos
+ )
+ else:
+ loss_cls = flatten_cls_scores[flatten_valids].sum()
+ pos_centerness = flatten_centerness[pos_inds]
+ pos_bbox_preds = flatten_bbox_preds[pos_inds]
+
+ if len(pos_inds) > 0:
+ pos_centerness_targets = flatten_centerness_targets[pos_inds]
+ pos_bbox_targets = flatten_bbox_targets[pos_inds]
+ pos_points = flatten_points[pos_inds].to(pos_bbox_preds.device)
+ loss_centerness = self.loss_centerness(
+ pos_centerness, pos_centerness_targets, avg_factor=n_pos
+ )
+ loss_bbox = self.loss_bbox(
+ self._bbox_pred_to_loss(pos_points, pos_bbox_preds),
+ pos_bbox_targets,
+ weight=pos_centerness_targets,
+ avg_factor=pos_centerness_targets.sum()
+ )
+ else:
+ loss_centerness = pos_centerness.sum()
+ loss_bbox = pos_bbox_preds.sum()
+ return loss_centerness, loss_bbox, loss_cls
+
+ @torch.no_grad()
+ def get_points(self, featmap_sizes, origin, device):
+ mlvl_points = []
+ for i, featmap_size in enumerate(featmap_sizes):
+ mlvl_points.append(get_points(
+ n_voxels=torch.tensor(featmap_size),
+ voxel_size=torch.tensor(self.voxel_size) * (2 ** i),
+ origin=torch.tensor(origin)
+ ).reshape(3, -1).transpose(0, 1).to(device))
+ return mlvl_points
+
+ def get_bboxes(self,
+ centernesses,
+ bbox_preds,
+ cls_scores,
+ valid,
+ img_metas):
+ assert len(centernesses[0]) == len(bbox_preds[0]) == len(cls_scores[0]) \
+ == len(img_metas)
+ valids = []
+ for x in centernesses:
+ valids.append(nn.Upsample(size=x.shape[-3:], mode='trilinear')(valid).round().bool())
+ n_levels = len(centernesses)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ centerness_list = [
+ centernesses[i][img_id].detach() for i in range(n_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds[i][img_id].detach() for i in range(n_levels)
+ ]
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(n_levels)
+ ]
+ valid_list = [
+ valids[i][img_id].detach() for i in range(n_levels)
+ ]
+ det_bboxes_3d = self._get_bboxes_single(
+ centerness_list, bbox_pred_list, cls_score_list, valid_list, img_metas[img_id]
+ )
+ result_list.append(det_bboxes_3d)
+ return result_list
+
+ def _get_bboxes_single(self,
+ centernesses,
+ bbox_preds,
+ cls_scores,
+ valids,
+ img_meta):
+ featmap_sizes = [featmap.size()[-3:] for featmap in centernesses]
+ mlvl_points = self.get_points(
+ featmap_sizes=featmap_sizes,
+ origin=img_meta['lidar2img']['origin'],
+ device=centernesses[0].device
+ )
+ bbox_pred_size = bbox_preds[0].shape[0]
+ mlvl_bboxes, mlvl_scores = [], []
+ for centerness, bbox_pred, cls_score, valid, points in zip(
+ centernesses, bbox_preds, cls_scores, valids, mlvl_points
+ ):
+ centerness = centerness.permute(1, 2, 3, 0).reshape(-1).sigmoid()
+ bbox_pred = bbox_pred.permute(1, 2, 3, 0).reshape(-1, bbox_pred_size)
+ scores = cls_score.permute(1, 2, 3, 0).reshape(-1, self.n_classes).sigmoid()
+ valid = valid.permute(1, 2, 3, 0).reshape(-1)
+ scores = scores * centerness[:, None] * valid[:, None]
+ max_scores, _ = scores.max(dim=1)
+
+ if len(scores) > self.test_cfg.nms_pre > 0:
+ _, ids = max_scores.topk(self.test_cfg.nms_pre)
+ bbox_pred = bbox_pred[ids]
+ scores = scores[ids]
+ points = points[ids]
+
+ bboxes = self._bbox_pred_to_result(points, bbox_pred)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+
+ bboxes = torch.cat(mlvl_bboxes)
+ scores = torch.cat(mlvl_scores)
+ bboxes, scores, labels = self._nms(bboxes, scores, img_meta)
+ return bboxes, scores, labels
+
+ def forward_single(self, x, scale):
+ raise NotImplementedError
+
+ def _bbox_pred_to_loss(self, points, bbox_preds):
+ raise NotImplementedError
+
+ def _bbox_pred_to_result(self, points, bbox_preds):
+ raise NotImplementedError
+
+ def get_targets(self, points, gt_bboxes, gt_labels):
+ raise NotImplementedError
+
+ def _nms(self, bboxes, scores, img_meta):
+ raise NotImplementedError
+
+
+@HEADS.register_module()
+class SunRgbdImVoxelHeadV2(ImVoxelHeadV2):
+ def forward_single(self, x, scale):
+ reg_final = self.reg_conv(x)
+ reg_distance = torch.exp(scale(reg_final[:, :6]))
+ reg_angle = reg_final[:, 6:]
+ return (
+ self.centerness_conv(x),
+ torch.cat((reg_distance, reg_angle), dim=1),
+ self.cls_conv(x)
+ )
+
+ def _bbox_pred_to_loss(self, points, bbox_preds):
+ return self._bbox_pred_to_bbox(points, bbox_preds)
+
+ def _bbox_pred_to_result(self, points, bbox_preds):
+ return self._bbox_pred_to_bbox(points, bbox_preds)
+
+ @torch.no_grad()
+ def get_targets(self, points, gt_bboxes, gt_labels):
+ float_max = 1e8
+ expanded_scales = [
+ points[i].new_tensor(i).expand(len(points[i])).to(gt_labels.device)
+ for i in range(len(points))
+ ]
+ points = torch.cat(points, dim=0).to(gt_labels.device)
+ scales = torch.cat(expanded_scales, dim=0)
+
+ # below is based on FCOSHead._get_target_single
+ n_points = len(points)
+ n_boxes = len(gt_bboxes)
+ volumes = gt_bboxes.volume.to(points.device)
+ volumes = volumes.expand(n_points, n_boxes).contiguous()
+ gt_bboxes = torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), dim=1)
+ gt_bboxes = gt_bboxes.to(points.device).expand(n_points, n_boxes, 7)
+ expanded_points = points.unsqueeze(1).expand(n_points, n_boxes, 3)
+ shift = torch.stack((
+ expanded_points[..., 0] - gt_bboxes[..., 0],
+ expanded_points[..., 1] - gt_bboxes[..., 1],
+ expanded_points[..., 2] - gt_bboxes[..., 2]
+ ), dim=-1).permute(1, 0, 2)
+ shift = rotation_3d_in_axis(shift, -gt_bboxes[0, :, 6], axis=2).permute(1, 0, 2)
+ centers = gt_bboxes[..., :3] + shift
+ dx_min = centers[..., 0] - gt_bboxes[..., 0] + gt_bboxes[..., 3] / 2
+ dx_max = gt_bboxes[..., 0] + gt_bboxes[..., 3] / 2 - centers[..., 0]
+ dy_min = centers[..., 1] - gt_bboxes[..., 1] + gt_bboxes[..., 4] / 2
+ dy_max = gt_bboxes[..., 1] + gt_bboxes[..., 4] / 2 - centers[..., 1]
+ dz_min = centers[..., 2] - gt_bboxes[..., 2] + gt_bboxes[..., 5] / 2
+ dz_max = gt_bboxes[..., 2] + gt_bboxes[..., 5] / 2 - centers[..., 2]
+ bbox_targets = torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max, gt_bboxes[..., 6]), dim=-1)
+
+ # condition1: inside a gt bbox
+ inside_gt_bbox_mask = bbox_targets[..., :6].min(-1)[0] > 0 # skip angle
+
+ # condition2: positive points per scale >= limit
+ # calculate positive points per scale
+ n_pos_points_per_scale = []
+ for i in range(self.n_scales):
+ n_pos_points_per_scale.append(torch.sum(inside_gt_bbox_mask[scales == i], dim=0))
+ # find best scale
+ n_pos_points_per_scale = torch.stack(n_pos_points_per_scale, dim=0)
+ lower_limit_mask = n_pos_points_per_scale < self.limit
+ # fix nondeterministic argmax for torch<1.7
+ extra = torch.arange(self.n_scales, 0, -1).unsqueeze(1).expand(self.n_scales, n_boxes).to(lower_limit_mask.device)
+ lower_index = torch.argmax(lower_limit_mask.int() * extra, dim=0) - 1
+ lower_index = torch.where(lower_index < 0, torch.zeros_like(lower_index), lower_index)
+ all_upper_limit_mask = torch.all(torch.logical_not(lower_limit_mask), dim=0)
+ best_scale = torch.where(all_upper_limit_mask, torch.ones_like(all_upper_limit_mask) * self.n_scales - 1, lower_index)
+ # keep only points with best scale
+ best_scale = torch.unsqueeze(best_scale, 0).expand(n_points, n_boxes)
+ scales = torch.unsqueeze(scales, 1).expand(n_points, n_boxes)
+ inside_best_scale_mask = best_scale == scales
+
+ # condition3: limit topk locations per box by centerness
+ centerness = compute_centerness(bbox_targets)
+ centerness = torch.where(inside_gt_bbox_mask, centerness, torch.ones_like(centerness) * -1)
+ centerness = torch.where(inside_best_scale_mask, centerness, torch.ones_like(centerness) * -1)
+ top_centerness = torch.topk(centerness, self.centerness_topk + 1, dim=0).values[-1]
+ inside_top_centerness_mask = centerness > top_centerness.unsqueeze(0)
+
+ # if there are still more than one objects for a location,
+ # we choose the one with minimal area
+ volumes = torch.where(inside_gt_bbox_mask, volumes, torch.ones_like(volumes) * float_max)
+ volumes = torch.where(inside_best_scale_mask, volumes, torch.ones_like(volumes) * float_max)
+ volumes = torch.where(inside_top_centerness_mask, volumes, torch.ones_like(volumes) * float_max)
+ min_area, min_area_inds = volumes.min(dim=1)
+
+ labels = gt_labels[min_area_inds]
+ labels = torch.where(min_area == float_max, torch.ones_like(labels) * -1, labels)
+ bbox_targets = bbox_targets[range(n_points), min_area_inds]
+ centerness_targets = compute_centerness(bbox_targets)
+
+ return centerness_targets, gt_bboxes[range(n_points), min_area_inds], labels
+
+ def _nms(self, bboxes, scores, img_meta):
+ # Add a dummy background class to the end. Nms needs to be fixed in the future.
+ padding = scores.new_zeros(scores.shape[0], 1)
+ scores = torch.cat([scores, padding], dim=1)
+ bboxes_for_nms = torch.stack((
+ bboxes[:, 0] - bboxes[:, 3] / 2,
+ bboxes[:, 1] - bboxes[:, 4] / 2,
+ bboxes[:, 0] + bboxes[:, 3] / 2,
+ bboxes[:, 1] + bboxes[:, 4] / 2,
+ bboxes[:, 6]
+ ), dim=1)
+ bboxes, scores, labels, _ = box3d_multiclass_nms(
+ mlvl_bboxes=bboxes,
+ mlvl_bboxes_for_nms=bboxes_for_nms,
+ mlvl_scores=scores,
+ score_thr=self.test_cfg.score_thr,
+ max_num=self.test_cfg.nms_pre,
+ cfg=self.test_cfg,
+ )
+ bboxes = img_meta['box_type_3d'](bboxes, origin=(.5, .5, .5))
+ return bboxes, scores, labels
+
+ @staticmethod
+ def _bbox_pred_to_bbox(points, bbox_pred):
+ if bbox_pred.shape[0] == 0:
+ return bbox_pred
+
+ # dx_min, dx_max, dy_min, dy_max, dz_min, dz_max, alpha ->
+ # x_center, y_center, z_center, w, l, h, alpha
+ shift = torch.stack((
+ (bbox_pred[:, 1] - bbox_pred[:, 0]) / 2,
+ (bbox_pred[:, 3] - bbox_pred[:, 2]) / 2,
+ (bbox_pred[:, 5] - bbox_pred[:, 4]) / 2
+ ), dim=-1).view(-1, 1, 3)
+ shift = rotation_3d_in_axis(shift, bbox_pred[:, 6], axis=2)[:, 0, :]
+ center = points + shift
+ size = torch.stack((
+ bbox_pred[:, 0] + bbox_pred[:, 1],
+ bbox_pred[:, 2] + bbox_pred[:, 3],
+ bbox_pred[:, 4] + bbox_pred[:, 5]
+ ), dim=-1)
+ return torch.cat((center, size, bbox_pred[:, 6:7]), dim=-1)
+
+
+
+@HEADS.register_module()
+class ScanNetImVoxelHeadV2(ImVoxelHeadV2):
+ def forward_single(self, x, scale):
+ return (
+ self.centerness_conv(x),
+ torch.exp(scale(self.reg_conv(x))),
+ self.cls_conv(x)
+ )
+
+ def _bbox_pred_to_loss(self, points, bbox_preds):
+ return self._bbox_pred_to_bbox(points, bbox_preds)
+
+ def _bbox_pred_to_result(self, points, bbox_preds):
+ return self._bbox_pred_to_bbox(points, bbox_preds)
+
+ @torch.no_grad()
+ def get_targets(self, points, gt_bboxes, gt_labels):
+ float_max = 1e8
+ expanded_scales = [
+ points[i].new_tensor(i).expand(len(points[i])).to(gt_labels.device)
+ for i in range(len(points))
+ ]
+ points = torch.cat(points, dim=0).to(gt_labels.device)
+ scales = torch.cat(expanded_scales, dim=0)
+
+ # below is based on FCOSHead._get_target_single
+ n_points = len(points)
+ n_boxes = len(gt_bboxes)
+ volumes = gt_bboxes.volume.to(points.device)
+ volumes = volumes.expand(n_points, n_boxes).contiguous()
+ gt_bboxes = torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:6]), dim=1)
+ gt_bboxes = gt_bboxes.to(points.device).expand(n_points, n_boxes, 6)
+ expanded_points = points.unsqueeze(1).expand(n_points, n_boxes, 3)
+ dx_min = expanded_points[..., 0] - gt_bboxes[..., 0] + gt_bboxes[..., 3] / 2
+ dx_max = gt_bboxes[..., 0] + gt_bboxes[..., 3] / 2 - expanded_points[..., 0]
+ dy_min = expanded_points[..., 1] - gt_bboxes[..., 1] + gt_bboxes[..., 4] / 2
+ dy_max = gt_bboxes[..., 1] + gt_bboxes[..., 4] / 2 - expanded_points[..., 1]
+ dz_min = expanded_points[..., 2] - gt_bboxes[..., 2] + gt_bboxes[..., 5] / 2
+ dz_max = gt_bboxes[..., 2] + gt_bboxes[..., 5] / 2 - expanded_points[..., 2]
+ bbox_targets = torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max), dim=-1)
+
+ # condition1: inside a gt bbox
+ inside_gt_bbox_mask = bbox_targets[..., :6].min(-1)[0] > 0 # skip angle
+
+ # condition2: positive points per scale >= limit
+ # calculate positive points per scale
+ n_pos_points_per_scale = []
+ for i in range(self.n_scales):
+ n_pos_points_per_scale.append(torch.sum(inside_gt_bbox_mask[scales == i], dim=0))
+ # find best scale
+ n_pos_points_per_scale = torch.stack(n_pos_points_per_scale, dim=0)
+ lower_limit_mask = n_pos_points_per_scale < self.limit
+ # fix nondeterministic argmax for torch<1.7
+ extra = torch.arange(self.n_scales, 0, -1).unsqueeze(1).expand(self.n_scales, n_boxes).to(
+ lower_limit_mask.device)
+ lower_index = torch.argmax(lower_limit_mask.int() * extra, dim=0) - 1
+ lower_index = torch.where(lower_index < 0, torch.zeros_like(lower_index), lower_index)
+ all_upper_limit_mask = torch.all(torch.logical_not(lower_limit_mask), dim=0)
+ best_scale = torch.where(all_upper_limit_mask, torch.ones_like(all_upper_limit_mask) * self.n_scales - 1,
+ lower_index)
+ # keep only points with best scale
+ best_scale = torch.unsqueeze(best_scale, 0).expand(n_points, n_boxes)
+ scales = torch.unsqueeze(scales, 1).expand(n_points, n_boxes)
+ inside_best_scale_mask = best_scale == scales
+
+ # condition3: limit topk locations per box by centerness
+ centerness = compute_centerness(bbox_targets)
+ centerness = torch.where(inside_gt_bbox_mask, centerness, torch.ones_like(centerness) * -1)
+ centerness = torch.where(inside_best_scale_mask, centerness, torch.ones_like(centerness) * -1)
+ top_centerness = torch.topk(centerness, self.centerness_topk + 1, dim=0).values[-1]
+ inside_top_centerness_mask = centerness > top_centerness.unsqueeze(0)
+
+ # if there are still more than one objects for a location,
+ # we choose the one with minimal area
+ volumes = torch.where(inside_gt_bbox_mask, volumes, torch.ones_like(volumes) * float_max)
+ volumes = torch.where(inside_best_scale_mask, volumes, torch.ones_like(volumes) * float_max)
+ volumes = torch.where(inside_top_centerness_mask, volumes, torch.ones_like(volumes) * float_max)
+ min_area, min_area_inds = volumes.min(dim=1)
+
+ labels = gt_labels[min_area_inds]
+ labels = torch.where(min_area == float_max, torch.ones_like(labels) * -1, labels)
+ bbox_targets = bbox_targets[range(n_points), min_area_inds]
+ centerness_targets = compute_centerness(bbox_targets)
+
+ return centerness_targets, self._bbox_pred_to_bbox(points, bbox_targets), labels
+
+ def _nms(self, bboxes, scores, img_meta):
+ scores, labels = scores.max(dim=1)
+ ids = scores > self.test_cfg.score_thr
+ bboxes = bboxes[ids]
+ scores = scores[ids]
+ labels = labels[ids]
+ ids = aligned_3d_nms(bboxes, scores, labels, self.test_cfg.iou_thr)
+ bboxes = bboxes[ids]
+ bboxes = torch.stack((
+ (bboxes[:, 0] + bboxes[:, 3]) / 2.,
+ (bboxes[:, 1] + bboxes[:, 4]) / 2.,
+ (bboxes[:, 2] + bboxes[:, 5]) / 2.,
+ bboxes[:, 3] - bboxes[:, 0],
+ bboxes[:, 4] - bboxes[:, 1],
+ bboxes[:, 5] - bboxes[:, 2]
+ ), dim=1)
+ bboxes = img_meta['box_type_3d'](bboxes, origin=(.5, .5, .5), box_dim=6, with_yaw=False)
+ return bboxes, scores[ids], labels[ids]
+
+ def _bbox_pred_to_bbox(self, points, bbox_pred):
+ return torch.stack([
+ points[:, 0] - bbox_pred[:, 0],
+ points[:, 1] - bbox_pred[:, 2],
+ points[:, 2] - bbox_pred[:, 4],
+ points[:, 0] + bbox_pred[:, 1],
+ points[:, 1] + bbox_pred[:, 3],
+ points[:, 2] + bbox_pred[:, 5]
+ ], -1)
+
+
+def compute_centerness(bbox_targets):
+ x_dims = bbox_targets[..., [0, 1]]
+ y_dims = bbox_targets[..., [2, 3]]
+ z_dims = bbox_targets[..., [4, 5]]
+ centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \
+ y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \
+ z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0]
+ # todo: sqrt ?
+ return torch.sqrt(centerness_targets)
diff --git a/mmdet3d/models/necks/__init__.py b/mmdet3d/models/necks/__init__.py
index 60c3c27..8adcf7b 100644
--- a/mmdet3d/models/necks/__init__.py
+++ b/mmdet3d/models/necks/__init__.py
@@ -1,5 +1,5 @@
from mmdet.models.necks.fpn import FPN
from .second_fpn import SECONDFPN
-from .imvoxelnet import ImVoxelNeck, KittiImVoxelNeck, NuScenesImVoxelNeck
+from .imvoxelnet import FastIndoorImVoxelNeck, ImVoxelNeck, KittiImVoxelNeck, NuScenesImVoxelNeck
-__all__ = ['FPN', 'SECONDFPN', 'ImVoxelNeck', 'KittiImVoxelNeck', 'NuScenesImVoxelNeck']
+__all__ = ['FPN', 'SECONDFPN', 'FastIndoorImVoxelNeck', 'ImVoxelNeck', 'KittiImVoxelNeck', 'NuScenesImVoxelNeck']
diff --git a/mmdet3d/models/necks/imvoxelnet.py b/mmdet3d/models/necks/imvoxelnet.py
index 9192e63..f4ef2c9 100644
--- a/mmdet3d/models/necks/imvoxelnet.py
+++ b/mmdet3d/models/necks/imvoxelnet.py
@@ -5,6 +5,68 @@
from mmdet.models import NECKS
+@NECKS.register_module()
+class FastIndoorImVoxelNeck(nn.Module):
+ def __init__(self, in_channels, n_blocks, out_channels):
+ super(FastIndoorImVoxelNeck, self).__init__()
+ self.n_scales = len(n_blocks)
+ n_channels = in_channels
+ for i in range(len(n_blocks)):
+ stride = 1 if i == 0 else 2
+ self.__setattr__(f'down_layer_{i}', self._make_layer(stride, n_channels, n_blocks[i]))
+ n_channels = n_channels * stride
+ if i > 0:
+ self.__setattr__(f'up_block_{i}', self._make_up_block(n_channels, n_channels // 2))
+ self.__setattr__(f'out_block_{i}', self._make_block(n_channels, out_channels))
+
+ def forward(self, x):
+ down_outs = []
+ for i in range(self.n_scales):
+ x = self.__getattr__(f'down_layer_{i}')(x)
+ down_outs.append(x)
+ outs = []
+ for i in range(self.n_scales - 1, -1, -1):
+ if i < self.n_scales - 1:
+ x = self.__getattr__(f'up_block_{i + 1}')(x)
+ x = down_outs[i] + x
+ out = self.__getattr__(f'out_block_{i}')(x)
+ outs.append(out)
+ return outs[::-1]
+
+ @staticmethod
+ def _make_layer(stride, n_channels, n_blocks):
+ blocks = []
+ for i in range(n_blocks):
+ if i == 0 and stride != 1:
+ blocks.append(BasicBlock3dV2(n_channels, n_channels * 2, stride))
+ n_channels = n_channels * 2
+ else:
+ blocks.append(BasicBlock3dV2(n_channels, n_channels))
+ return nn.Sequential(*blocks)
+
+ @staticmethod
+ def _make_block(in_channels, out_channels):
+ return nn.Sequential(
+ nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=False),
+ nn.BatchNorm3d(out_channels),
+ nn.ReLU(inplace=True)
+ )
+
+ @staticmethod
+ def _make_up_block(in_channels, out_channels):
+ return nn.Sequential(
+ nn.ConvTranspose3d(in_channels, out_channels, 2, 2, bias=False),
+ nn.BatchNorm3d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False),
+ nn.BatchNorm3d(out_channels),
+ nn.ReLU(inplace=True)
+ )
+
+ def init_weights(self):
+ pass
+
+
@NECKS.register_module()
class ImVoxelNeck(nn.Module):
def __init__(self, channels, out_channels, down_layers, up_layers, conditional):
@@ -168,6 +230,36 @@ def forward(self, x):
return out
+class BasicBlock3dV2(nn.Module):
+ def __init__(self, in_channels, out_channels, stride=1):
+ super(BasicBlock3dV2, self).__init__()
+ self.stride = stride
+ self.conv1 = nn.Conv3d(in_channels, out_channels, 3, stride, 1, bias=False)
+ self.norm1 = nn.BatchNorm3d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False)
+ self.norm2 = nn.BatchNorm3d(out_channels)
+ if self.stride != 1:
+ self.downsample = nn.Sequential(
+ nn.Conv3d(in_channels, out_channels, 1, stride, bias=False),
+ nn.BatchNorm3d(out_channels)
+ )
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ if self.stride != 1:
+ identity = self.downsample(x)
+ out += identity
+ out = self.relu(out)
+ return out
+
+
class ConditionalProjection(nn.Module):
""" Applies a projected skip connection from the encoder to the decoder
When condition is False this is a standard projected skip connection