Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] PGD Benchmark on KITTI #1014

Merged
merged 67 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
46d420f
[Refactor] Main code modification for coordinate system refactor (#677)
yezhen17 Jul 22, 2021
6877232
[Enhance] Add script for data update (#774)
yezhen17 Jul 29, 2021
febd2eb
fix import (#839)
yezhen17 Aug 5, 2021
3f64754
[Enhance] refactor iou_neg_piecewise_sampler.py (#842)
xiliu8006 Aug 9, 2021
b2abf1e
[Feature] Add roipooling cuda ops (#843)
xiliu8006 Aug 9, 2021
5a07dfe
[Refactor] Refactor code structure and docstrings (#803)
yezhen17 Aug 11, 2021
3d9268b
[Feature] PointXYZWHLRBBoxCoder (#856)
xiliu8006 Aug 16, 2021
3389237
[Enhance] Change Groupfree3D config (#855)
yezhen17 Aug 13, 2021
fc4bb0c
[Doc] Add tutorials/data_pipeline Chinese version (#827)
wHao-Wu Aug 18, 2021
0f81e49
[Doc] Add Chinese doc for `scannet_det.md` (#836)
yezhen17 Aug 18, 2021
cac2ef8
[Doc] Add Chinese doc for `waymo_det.md` (#859)
yezhen17 Aug 18, 2021
8d0b12a
Remove 2D annotations on Lyft (#867)
Tai-Wang Aug 18, 2021
2375d3c
Add header for files (#869)
DCNSW Aug 19, 2021
63fd399
[fix] fix typos (#872)
xieenze Aug 19, 2021
884b593
Fix 3 unworking configs (#882)
yezhen17 Aug 24, 2021
d688160
[Fix] Fix `index.rst` for Chinese docs (#873)
yezhen17 Aug 24, 2021
a000db5
[Fix] Centerpoint head nested list transpose (#879)
robin-karlsson0 Aug 25, 2021
4e9c992
[Enhance] Update PointFusion (#791)
filaPro Aug 25, 2021
08dae04
[Doc] Add nuscenes_det.md Chinese version (#854)
ZCMax Aug 26, 2021
93de7c2
[Fix] Fix RegNet pretrained weight loading (#889)
yezhen17 Aug 27, 2021
0eb7e71
Fix centerpoint tta (#892)
yezhen17 Aug 30, 2021
00c037a
[Enhance] Add benchmark regression script (#808)
yezhen17 Aug 30, 2021
1e6cdea
Initial commit
yezhen17 Sep 1, 2021
d4b1244
Merge pull request #899 from THU17cyz/coord_sys_tutorial_again
yezhen17 Sep 1, 2021
f095eb6
[Feature] Support DGCNN (v1.0.0.dev0) (#896)
DCNSW Sep 3, 2021
459c637
Change cam rot_3d_in_axis (#906)
yezhen17 Sep 6, 2021
2ae6b55
[Doc] Add coord sys tutorial pic and change links to dev branch (#912)
yezhen17 Sep 7, 2021
fce176f
[Feature] add kitti AP40 evaluation metric (v1.0.0.dev0) (#927)
ZCMax Sep 13, 2021
66f0c07
[Feature] add smoke backbone neck (#939)
ZCMax Sep 15, 2021
0b26a9a
[Refactor] Refactor the transformation from image to camera coordinat…
Tai-Wang Sep 15, 2021
911a333
[Feature] FCOS3D BBox Coder (#940)
Tai-Wang Sep 15, 2021
0899bad
Support PGD BBox Coder
Tai-Wang Sep 22, 2021
b217f7a
Refine docstring
Tai-Wang Sep 22, 2021
d28a8b5
Add uncertain l1 loss and its unit tests
Tai-Wang Sep 22, 2021
2282fca
Merge branch 'uncertain_loss' into pgd_head
Tai-Wang Sep 22, 2021
506f929
[Feature] PGD BBox Coder (#948)
Tai-Wang Sep 22, 2021
38f75f5
PGD Head initialized
Tai-Wang Sep 22, 2021
89be05c
Refactor init methods, fix legacy variable names
Tai-Wang Sep 22, 2021
5be3d11
[Feature] Support Uncertain L1 Loss (#950)
Tai-Wang Sep 22, 2021
4a804bf
[Fix] Fix visualization in KITTI dataset (#956)
ZCMax Sep 22, 2021
038c39d
Refine variable names and docstrings
Tai-Wang Sep 23, 2021
0061d89
Add unit tests and fix some minor bugs
Tai-Wang Sep 24, 2021
0e1f4ed
Refine assertion messages
Tai-Wang Sep 24, 2021
c0a5021
Merge branch 'v1.0.0.dev0' into pgd_head
Tai-Wang Sep 24, 2021
e3690c6
Merge branch 'v1.0.0.dev0' into pgd_head
Tai-Wang Sep 24, 2021
a62e993
Fix typo in the docs_zh-CN
Tai-Wang Sep 24, 2021
efef7e9
Use Pretrain init and remove unused init_cfg in FCOS3D
Tai-Wang Sep 27, 2021
016fc29
Fix the comments for the input_modality in the dataset config
Tai-Wang Sep 29, 2021
b47ad7e
Fix minor bugs in pgd_bbox_coder and incorrect setting for uncertain …
Tai-Wang Oct 26, 2021
414f561
Merge branch 'v1.0.0.dev0' into pgd_head
Tai-Wang Oct 26, 2021
3baa41a
Add explanations for code_weights
Tai-Wang Oct 26, 2021
292a5d0
Add PGD README
Tai-Wang Oct 26, 2021
9658a2b
Adjust the unit test for pgd bbox coder
Tai-Wang Oct 26, 2021
fa78b41
Remove unused codes
Tai-Wang Oct 26, 2021
4e086bc
Add mono3d metric into the gather_models and fix bugs
Tai-Wang Oct 27, 2021
18ee0cc
Update README.md
Tai-Wang Oct 27, 2021
751f57c
Merge branch 'pgd_head' into pgd_kitti
Tai-Wang Oct 27, 2021
7f4ae6a
Update links
Tai-Wang Oct 27, 2021
cb92793
Update links
Tai-Wang Oct 27, 2021
3f77afc
Involve the value assignment of loss_dict into the computing procedure
Tai-Wang Oct 28, 2021
b8a46bc
Fix incorrect loss_depth
Tai-Wang Oct 28, 2021
789ccfa
Merge branch 'pgd_head' into pgd_kitti
Tai-Wang Oct 28, 2021
3237c30
Update README.md
Tai-Wang Oct 31, 2021
3c48a3a
Update README_zh-CN.md
Tai-Wang Oct 31, 2021
b7804dc
Update PGD in the model_zoo.md
Tai-Wang Nov 1, 2021
f5d0104
Update PGD in the model_zoo.md
Tai-Wang Nov 1, 2021
7ec72cf
Update metafiles
Tai-Wang Nov 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions .dev_scripts/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

Usage:
python gather_models.py ${root_path} ${out_dir}

Example:
python gather_models.py \
work_dirs/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d \
work_dirs/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d

Note that before running the above command, rename the directory with the
config name if you did not use the default directory name, create
a corresponding directory 'pgd' under the above path and put the used config
into it.
"""

import argparse
Expand Down Expand Up @@ -36,16 +46,18 @@
RESULTS_LUT = {
'coco': ['bbox_mAP', 'segm_mAP'],
'nus': ['pts_bbox_NuScenes/NDS', 'NDS'],
'kitti-3d-3class': [
'KITTI/Overall_3D_moderate',
'Overall_3D_moderate',
],
'kitti-3d-3class': ['KITTI/Overall_3D_moderate', 'Overall_3D_moderate'],
'kitti-3d-car': ['KITTI/Car_3D_moderate_strict', 'Car_3D_moderate_strict'],
'lyft': ['score'],
'scannet_seg': ['miou'],
's3dis_seg': ['miou'],
'scannet': ['mAP_0.50'],
'sunrgbd': ['mAP_0.50']
'sunrgbd': ['mAP_0.50'],
'kitti-mono3d': [
'img_bbox/KITTI/Car_3D_AP40_moderate_strict',
'Car_3D_AP40_moderate_strict'
],
'nus-mono3d': ['img_bbox_NuScenes/NDS', 'NDS']
}


Expand Down Expand Up @@ -145,15 +157,13 @@ def main():
# and parse the best performance
model_infos = []
for used_config in used_configs:
exp_dir = osp.join(models_root, used_config)

# get logs
log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0]
log_txt_path = glob.glob(osp.join(exp_dir, '*.log'))[0]
log_json_path = glob.glob(osp.join(models_root, '*.log.json'))[0]
log_txt_path = glob.glob(osp.join(models_root, '*.log'))[0]
model_performance = get_best_results(log_json_path)
final_epoch = model_performance['epoch']
final_model = 'epoch_{}.pth'.format(final_epoch)
model_path = osp.join(exp_dir, final_model)
model_path = osp.join(models_root, final_model)

# skip if the model is still training
if not osp.exists(model_path):
Expand Down Expand Up @@ -182,7 +192,7 @@ def main():
model_name = model['config'].split('/')[-1].rstrip(
'.py') + '_' + model['model_time']
publish_model_path = osp.join(model_publish_dir, model_name)
trained_model_path = osp.join(models_root, model['config'],
trained_model_path = osp.join(models_root,
'epoch_{}.pth'.format(model['epochs']))

# convert model
Expand All @@ -191,11 +201,10 @@ def main():

# copy log
shutil.copy(
osp.join(models_root, model['config'], model['log_json_path']),
osp.join(models_root, model['log_json_path']),
osp.join(model_publish_dir, f'{model_name}.log.json'))
shutil.copy(
osp.join(models_root, model['config'],
model['log_json_path'].rstrip('.json')),
osp.join(models_root, model['log_json_path'].rstrip('.json')),
osp.join(model_publish_dir, f'{model_name}.log'))

# copy config to guarantee reproducibility
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Support methods
- [x] [PAConv (CVPR'2021)](configs/paconv/README.md)
- [x] [DGCNN (TOG'2019)](configs/dgcnn/README.md)
- [x] [SMOKE (CVPRW'2020)](configs/smoke/README.md)
- [x] [PGD (CoRL'2021)](configs/pgd/README.md)

| | ResNet | ResNeXt | SENet |PointNet++ |DGCNN | HRNet | RegNetX | Res2Net | DLA |
|--------------------|:--------:|:--------:|:--------:|:---------:|:---------:|:-----:|:--------:|:-----:|:---:|
Expand All @@ -127,6 +128,7 @@ Support methods
| PAConv | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | ✗
| DGCNN | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗
| SMOKE | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✓
| PGD | ✓ | ☐ | ☐ | ✗ | ✗ | ☐ | ☐ | ☐ | ✗

Other features
- [x] [Dynamic Voxelization](configs/dynamic_voxelization/README.md)
Expand Down
2 changes: 2 additions & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代
- [x] [PAConv (CVPR'2021)](configs/paconv/README.md)
- [x] [DGCNN (TOG'2019)](configs/dgcnn/README.md)
- [x] [SMOKE (CVPRW'2020)](configs/smoke/README.md)
- [x] [PGD (CoRL'2021)](configs/pgd/README.md)

| | ResNet | ResNeXt | SENet |PointNet++ |DGCNN | HRNet | RegNetX | Res2Net | DLA |
|--------------------|:--------:|:--------:|:--------:|:---------:|:---------:|:-----:|:--------:|:-----:|:---:|
Expand All @@ -126,6 +127,7 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代
| PAConv | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | ✗
| DGCNN | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗
| SMOKE | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✓
| PGD | ✓ | ☐ | ☐ | ✗ | ✗ | ☐ | ☐ | ☐ | ✗

其他特性
- [x] [Dynamic Voxelization](configs/dynamic_voxelization/README.md)
Expand Down
6 changes: 4 additions & 2 deletions configs/_base_/models/fcos3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
model = dict(
type='FCOSMono3D',
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(
type='ResNet',
depth=101,
Expand All @@ -9,7 +8,10 @@
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe'),
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet101_caffe')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
Expand Down
55 changes: 55 additions & 0 deletions configs/_base_/models/pgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
_base_ = './fcos3d.py'
# model settings
model = dict(
bbox_head=dict(
_delete_=True,
type='PGDHead',
num_classes=10,
in_channels=256,
stacked_convs=2,
feat_channels=256,
use_direction_classifier=True,
diff_rad_by_sin=True,
pred_attrs=True,
pred_velo=True,
pred_bbox2d=True,
pred_keypoints=False,
dir_offset=0.7854, # pi/4
strides=[8, 16, 32, 64, 128],
group_reg_dims=(2, 1, 3, 1, 2), # offset, depth, size, rot, velo
cls_branch=(256, ),
reg_branch=(
(256, ), # offset
(256, ), # depth
(256, ), # size
(256, ), # rot
() # velo
),
dir_branch=(256, ),
attr_branch=(256, ),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_attr=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
norm_on_bbox=True,
centerness_on_reg=True,
center_sampling=True,
conv_bias=True,
dcn_on_last_conv=True,
use_depth_classifier=True,
depth_branch=(256, ),
depth_range=(0, 50),
depth_unit=10,
division='uniform',
depth_bins=6,
bbox_coder=dict(type='PGDBBoxCoder', code_size=9)),
test_cfg=dict(nms_pre=1000, nms_thr=0.8, score_thr=0.01, max_per_img=200))
37 changes: 37 additions & 0 deletions configs/pgd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Probabilistic and Geometric Depth: Detecting Objects in Perspective

## Introduction

<!-- [ALGORITHM] -->

PGD, also can be regarded as FCOS3D++, is a simple yet effective monocular 3D detector. It enhances the FCOS3D baseline by involving local geometric constraints and improving instance depth estimation.

We first release the code and model for KITTI benchmark, which is a good supplement for the original FCOS3D baseline (only supported on nuScenes). Models for nuScenes will be released soon.

For clean implementation, our preliminary release supports base models with proposed local geometric constraints and the probabilistic depth representation. We will involve the geometric graph part in the future.

```
@inproceedings{wang2021pgd,
title={Probabilistic and Geometric Depth: Detecting Objects in Perspective},
author={Wang, Tai and Zhu, Xinge and Pang, Jiangmiao and Lin, Dahua},
booktitle={Conference on Robot Learning (CoRL) 2021},
year={2021}
}
```

## Results

### KITTI

| Backbone | Lr schd | Mem (GB) | Inf time (fps) | mAP_11 / mAP_40 | Download |
| :---------: | :-----: | :------: | :------------: | :----: | :------: |
|[ResNet101](./pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py)|4x|9.07||18.33 / 13.23|[model](https://download.openmmlab.com/mmdetection3d/v1.0.0_models/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d_20211022_102608-8a97533b.pth) &#124; [log](https://download.openmmlab.com/mmdetection3d/v1.0.0_models/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d_20211022_102608.log.json)

Detailed performance on KITTI 3D detection (3D/BEV) is as follows, evaluated by AP11 and AP40 metric:

| | Easy | Moderate | Hard |
|-------------|:-------------:|:--------------:|:-------------:|
| Car (AP11) | 24.09 / 30.11 | 18.33 / 23.46 | 16.90 / 19.33 |
| Car (AP40) | 19.27 / 26.60 | 13.23 / 18.23 | 10.65 / 15.00 |

Note: mAP represents Car moderate 3D strict AP11 / AP40 results. Because of the limited data for pedestrians and cyclists, the detection performance for these two classes is usually unstable. Therefore, we only list car detection results here. In addition, AP40 is a more recommended metric for reference due to its much better stability.
29 changes: 29 additions & 0 deletions configs/pgd/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Collections:
- Name: PGD
Metadata:
Training Data: KITTI
Training Techniques:
- SGD
Training Resources: 4x TITAN XP
Architecture:
- PGDHead
Paper:
URL: https://arxiv.org/abs/2107.14160
Title: 'Probabilistic and Geometric Depth: Detecting Objects in Perspective'
README: configs/pgd/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection3d/blob/v1.0.0.dev0/mmdet3d/models/dense_heads/pgd_head.py#17
Version: v1.0.0

Models:
- Name: pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d
In Collection: PGD
Config: configs/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py
Metadata:
Training Memory (GB): 9.1
Results:
- Task: 3D Object Detection
Dataset: KITTI
Metrics:
mAP: 18.33
Weights: https://download.openmmlab.com/mmdetection3d/v1.0.0_models/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d_20211022_102608-8a97533b.pth
127 changes: 127 additions & 0 deletions configs/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
_base_ = [
'../_base_/datasets/kitti-mono3d.py', '../_base_/models/pgd.py',
'../_base_/schedules/mmdet_schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
backbone=dict(frozen_stages=0),
neck=dict(start_level=0, num_outs=4),
bbox_head=dict(
num_classes=3,
bbox_code_size=7,
pred_attrs=False,
pred_velo=False,
pred_bbox2d=True,
use_onlyreg_proj=True,
strides=(4, 8, 16, 32),
regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 1e8)),
group_reg_dims=(2, 1, 3, 1, 16,
4), # offset, depth, size, rot, kpts, bbox2d
reg_branch=(
(256, ), # offset
(256, ), # depth
(256, ), # size
(256, ), # rot
(256, ), # kpts
(256, ) # bbox2d
),
centerness_branch=(256, ),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
use_depth_classifier=True,
depth_branch=(256, ),
depth_range=(0, 70),
depth_unit=10,
division='uniform',
depth_bins=8,
pred_keypoints=True,
weight_dim=1,
loss_depth=dict(
type='UncertainSmoothL1Loss', alpha=1.0, beta=3.0,
loss_weight=1.0),
bbox_coder=dict(
type='PGDBBoxCoder',
base_depths=((28.01, 16.32), ),
base_dims=((0.8, 1.73, 0.6), (1.76, 1.73, 0.6), (3.9, 1.56, 1.6)),
code_size=7)),
# set weight 1.0 for base 7 dims (offset, depth, size, rot)
# 0.2 for 16-dim keypoint offsets and 1.0 for 4-dim 2D distance targets
train_cfg=dict(code_weight=[
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 1.0, 1.0, 1.0, 1.0
]),
test_cfg=dict(nms_pre=100, nms_thr=0.05, score_thr=0.001, max_per_img=20))

class_names = ['Pedestrian', 'Cyclist', 'Car']
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='LoadAnnotations3D',
with_bbox=True,
with_label=True,
with_attr_label=False,
with_bbox_3d=True,
with_label_3d=True,
with_bbox_depth=True),
dict(type='Resize', img_scale=(1242, 375), keep_ratio=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d', 'gt_labels_3d',
'centers2d', 'depths'
]),
]
test_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='MultiScaleFlipAug',
scale_factor=1.0,
flip=False,
transforms=[
dict(type='RandomFlip3D'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img']),
])
]
data = dict(
samples_per_gpu=3,
workers_per_gpu=3,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer
optimizer = dict(
lr=0.001, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.))
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[32, 44])
total_epochs = 48
runner = dict(type='EpochBasedRunner', max_epochs=48)
evaluation = dict(interval=2)
checkpoint_config = dict(interval=8)
4 changes: 2 additions & 2 deletions configs/smoke/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Collections:
Title: 'SMOKE: Single-Stage Monocular 3D Object Detection via Keypoint Estimation'
README: configs/smoke/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/models/detectors/smoke_mono3d.py#L7
Version: v0.17.1
URL: https://github.com/open-mmlab/mmdetection3d/blob/v1.0.0.dev0/mmdet3d/models/detectors/smoke_mono3d.py#L7
Version: v1.0.0

Models:
- Name: smoke_dla34_pytorch_dlaneck_gn-all_8x4_6x_kitti-mono3d
Expand Down
Loading