Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support to train using FP16 #132

Merged
merged 5 commits into from
Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions configs/fp16/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Mixed Precision Training

## Introduction

We implement mixed precision training and apply it to VoxelNets (e.g., SECOND and PointPillars).
The results are in the following tables.

**Note**: For mixed precision training, we currently do not support PointNet-based methods (e.g., VoteNet).
Mixed precision training for PointNet-based methods will be supported in the future release.

## Results

### SECOND on KITTI dataset
| Backbone |Class| Lr schd | FP32 Mem (GB) | FP16 Mem (GB) | FP32 mAP | FP16 mAP |Download |
| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | :------: |
| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py)| Car |cyclic 80e|5.4|2.9|79.07|78.72||
| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-3class.py)| 3 Class |cyclic 80e|5.4|2.9|64.41|67.4||

### PointPillars on nuScenes dataset


**Note**: With mixed precision training, we can train PointPillars with RegNet-400mf on 8 Titan XP GPUS with batch size of 2.
This will cause OOM error without mixed precision training.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = '../pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = '../regnet/hv_pointpillars_regnet-400mf_fpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = '../pointpillars/hv_pointpillars_secfpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
3 changes: 3 additions & 0 deletions configs/fp16/hv_second_secfpn_6x8_80e_kitti-3d-3class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py'
# fp16 settings
fp16 = dict(loss_scale=512.)
3 changes: 3 additions & 0 deletions configs/fp16/hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-car.py'
# fp16 settings
fp16 = dict(loss_scale=512.)
1 change: 1 addition & 0 deletions mmdet3d/models/backbones/base_pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class BasePointNet(nn.Module, metaclass=ABCMeta):

def __init__(self):
super(BasePointNet, self).__init__()
self.fp16_enabled = False

def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/backbones/multi_backbone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint
from mmcv.runner import auto_fp16, load_checkpoint
from torch import nn as nn

from mmdet.models import BACKBONES, build_backbone
Expand Down Expand Up @@ -86,6 +86,7 @@ def init_weights(self, pretrained=None):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)

@auto_fp16()
def forward(self, points):
"""Forward pass.

Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/models/backbones/nostem_regnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from mmcv.runner import auto_fp16

from mmdet.models.backbones import RegNet
from ..builder import BACKBONES

Expand Down Expand Up @@ -65,6 +67,7 @@ def _make_stem_layer(self, in_channels, base_channels):
since 3D detector's voxel encoder works like a stem layer."""
return

@auto_fp16()
def forward(self, x):
"""Forward function of backbone.

Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/backbones/pointnet2_sa_msg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import build_sa_module
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(self,
bias=True))
sa_in_channel = aggregation_channels[sa_index]

@auto_fp16(apply_to=('points', ))
def forward(self, points):
"""Forward pass.

Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/backbones/pointnet2_sa_ssg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import PointFPModule, build_sa_module
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(self,
fp_source_channel = cur_fp_mlps[-1]
fp_target_channel = skip_channel_list.pop()

@auto_fp16(apply_to=('points', ))
def forward(self, points):
"""Forward pass.

Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/backbones/second.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import load_checkpoint
from mmcv.runner import auto_fp16, load_checkpoint
from torch import nn as nn

from mmdet.models import BACKBONES
Expand Down Expand Up @@ -70,6 +70,7 @@ def init_weights(self, pretrained=None):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)

@auto_fp16()
def forward(self, x):
"""Forward function.

Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/models/dense_heads/anchor3d_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
from mmcv.cnn import bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from torch import nn as nn

from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self,
self.assign_per_class = assign_per_class
self.dir_offset = dir_offset
self.dir_limit_offset = dir_limit_offset
self.fp16_enabled = False

# build anchor generator
self.anchor_generator = build_anchor_generator(anchor_generator)
Expand Down Expand Up @@ -270,6 +272,7 @@ def add_sin_difference(boxes1, boxes2):
dim=-1)
return boxes1, boxes2

@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
Expand Down
5 changes: 4 additions & 1 deletion mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import torch
from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init
from mmcv.runner import force_fp32
from torch import nn

from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
Expand Down Expand Up @@ -228,7 +229,7 @@ def forward(self, x):
return ret


@HEADS.register_module
@HEADS.register_module()
class CenterHead(nn.Module):
"""CenterHead for CenterPoint.

Expand Down Expand Up @@ -292,6 +293,7 @@ def __init__(self,
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False

# a shared convolution
self.shared_conv = ConvModule(
Expand Down Expand Up @@ -548,6 +550,7 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
inds.append(ind)
return heatmaps, anno_boxes, inds, masks

@force_fp32(apply_to=('preds_dicts'))
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.

Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/dense_heads/free_anchor3d_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self,
self.gamma = gamma
self.alpha = alpha

@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/dense_heads/parta2_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
from mmcv.runner import force_fp32

from mmdet3d.core import limit_period, xywhr2xyxyr
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self,
diff_rad_by_sin, dir_offset, dir_limit_offset,
bbox_coder, loss_cls, loss_bbox, loss_dir)

@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/dense_heads/ssd_3d_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from mmcv.ops.nms import batched_nms
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
Expand Down Expand Up @@ -108,6 +109,7 @@ def _extract_input(self, feat_dict):

return seed_points, seed_features, seed_indices

@force_fp32(apply_to=('bbox_preds', ))
def loss(self,
bbox_preds,
points,
Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/models/dense_heads/vote_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
from mmcv.runner import force_fp32
from torch import nn as nn
from torch.nn import functional as F

Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(self,

self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
self.fp16_enabled = False

# Bbox classification and regression
self.conv_pred = BaseConvBboxHead(
Expand Down Expand Up @@ -204,6 +206,7 @@ def forward(self, feat_dict, sample_mod):

return results

@force_fp32(apply_to=('bbox_preds', ))
def loss(self,
bbox_preds,
points,
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16
from os import path as osp

from mmdet3d.core import Box3DMode, show_result
Expand Down Expand Up @@ -42,6 +43,7 @@ def forward_test(self, points, img_metas, img=None, **kwargs):
else:
return self.aug_test(points, img_metas, img, **kwargs)

@auto_fp16(apply_to=('img', 'points'))
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/dynamic_voxelnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet.models import DETECTORS
Expand Down Expand Up @@ -44,6 +45,7 @@ def extract_feat(self, points, img_metas):
return x

@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.

Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/mvx_faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet.models import DETECTORS
Expand All @@ -21,6 +22,7 @@ def __init__(self, **kwargs):
super(DynamicMVXFasterRCNN, self).__init__(**kwargs)

@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.

Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/mvx_two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32
from os import path as osp
from torch import nn as nn
from torch.nn import functional as F
Expand Down Expand Up @@ -203,6 +204,7 @@ def extract_feat(self, points, img, img_metas):
return (img_feats, pts_feats)

@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.

Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/voxelnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
Expand Down Expand Up @@ -46,6 +47,7 @@ def extract_feat(self, points, img_metas):
return x

@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply hard voxelization to points."""
voxels, coors, num_points = [], [], []
Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/models/middle_encoders/pillar_scatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import auto_fp16
from torch import nn

from ..registry import MIDDLE_ENCODERS
Expand All @@ -21,7 +22,9 @@ def __init__(self, in_channels, output_shape):
self.ny = output_shape[0]
self.nx = output_shape[1]
self.in_channels = in_channels
self.fp16_enabled = False

@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size=None):
"""Foraward function to scatter features."""
# TODO: rewrite the function in a batch manner
Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/models/middle_encoders/sparse_encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self,
self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own

assert isinstance(order, tuple) and len(order) == 3
Expand Down Expand Up @@ -90,6 +92,7 @@ def __init__(self,
indice_key='spconv_down2',
conv_type='SparseConv3d')

@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseEncoder.

Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/models/middle_encoders/sparse_unet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(self,
self.decoder_channels = decoder_channels
self.decoder_paddings = decoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own

assert isinstance(order, tuple) and len(order) == 3
Expand Down Expand Up @@ -91,6 +93,7 @@ def __init__(self,
indice_key='spconv_down2',
conv_type='SparseConv3d')

@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUNet.

Expand Down
Loading