Skip to content

Commit

Permalink
Support to train using FP16 (#132)
Browse files Browse the repository at this point in the history
* Support to train using FP16

* fix type inconsistency error on naive syncBN

* resolve comments

* clean nan check
  • Loading branch information
ZwwWayne authored Oct 10, 2020
1 parent e4320fb commit e67b3f8
Show file tree
Hide file tree
Showing 32 changed files with 175 additions and 27 deletions.
23 changes: 23 additions & 0 deletions configs/fp16/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Mixed Precision Training

## Introduction

We implement mixed precision training and apply it to VoxelNets (e.g., SECOND and PointPillars).
The results are in the following tables.

**Note**: For mixed precision training, we currently do not support PointNet-based methods (e.g., VoteNet).
Mixed precision training for PointNet-based methods will be supported in the future release.

## Results

### SECOND on KITTI dataset
| Backbone |Class| Lr schd | FP32 Mem (GB) | FP16 Mem (GB) | FP32 mAP | FP16 mAP |Download |
| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | :------: |
| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py)| Car |cyclic 80e|5.4|2.9|79.07|78.72||
| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-3class.py)| 3 Class |cyclic 80e|5.4|2.9|64.41|67.4||

### PointPillars on nuScenes dataset


**Note**: With mixed precision training, we can train PointPillars with RegNet-400mf on 8 Titan XP GPUS with batch size of 2.
This will cause OOM error without mixed precision training.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = '../pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = '../regnet/hv_pointpillars_regnet-400mf_fpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = '../pointpillars/hv_pointpillars_secfpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
3 changes: 3 additions & 0 deletions configs/fp16/hv_second_secfpn_6x8_80e_kitti-3d-3class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py'
# fp16 settings
fp16 = dict(loss_scale=512.)
3 changes: 3 additions & 0 deletions configs/fp16/hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-car.py'
# fp16 settings
fp16 = dict(loss_scale=512.)
1 change: 1 addition & 0 deletions mmdet3d/models/backbones/base_pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class BasePointNet(nn.Module, metaclass=ABCMeta):

def __init__(self):
super(BasePointNet, self).__init__()
self.fp16_enabled = False

def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/backbones/multi_backbone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint
from mmcv.runner import auto_fp16, load_checkpoint
from torch import nn as nn

from mmdet.models import BACKBONES, build_backbone
Expand Down Expand Up @@ -86,6 +86,7 @@ def init_weights(self, pretrained=None):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)

@auto_fp16()
def forward(self, points):
"""Forward pass.
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/backbones/pointnet2_sa_msg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import build_sa_module
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(self,
bias=True))
sa_in_channel = aggregation_channels[sa_index]

@auto_fp16(apply_to=('points', ))
def forward(self, points):
"""Forward pass.
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/backbones/pointnet2_sa_ssg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import PointFPModule, build_sa_module
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(self,
fp_source_channel = cur_fp_mlps[-1]
fp_target_channel = skip_channel_list.pop()

@auto_fp16(apply_to=('points', ))
def forward(self, points):
"""Forward pass.
Expand Down
65 changes: 45 additions & 20 deletions mmdet3d/models/dense_heads/anchor3d_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
from mmcv.cnn import bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from torch import nn as nn

from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self,
self.assign_per_class = assign_per_class
self.dir_offset = dir_offset
self.dir_limit_offset = dir_limit_offset
self.fp16_enabled = False

# build anchor generator
self.anchor_generator = build_anchor_generator(anchor_generator)
Expand Down Expand Up @@ -211,39 +213,61 @@ def loss_single(self, cls_score, bbox_pred, dir_cls_preds, labels,
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
assert labels.max().item() <= self.num_classes
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)

# regression loss
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, self.box_code_size)
bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
bbox_weights = bbox_weights.reshape(-1, self.box_code_size)
code_weight = self.train_cfg.get('code_weight', None)

if code_weight:
bbox_weights = bbox_weights * bbox_weights.new_tensor(code_weight)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, self.box_code_size)
if self.diff_rad_by_sin:
bbox_pred, bbox_targets = self.add_sin_difference(
bbox_pred, bbox_targets)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=num_total_samples)

# direction classification loss
loss_dir = None
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().reshape(-1)
num_pos = len(pos_inds)

pos_bbox_pred = bbox_pred[pos_inds]
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_weights = bbox_weights[pos_inds]

# dir loss
if self.use_direction_classifier:
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).reshape(-1, 2)
dir_targets = dir_targets.reshape(-1)
dir_weights = dir_weights.reshape(-1)
loss_dir = self.loss_dir(
dir_cls_preds,
dir_targets,
dir_weights,
pos_dir_cls_preds = dir_cls_preds[pos_inds]
pos_dir_targets = dir_targets[pos_inds]
pos_dir_weights = dir_weights[pos_inds]

if num_pos > 0:
code_weight = self.train_cfg.get('code_weight', None)
if code_weight:
bbox_weights = bbox_weights * bbox_weights.new_tensor(
code_weight)
if self.diff_rad_by_sin:
pos_bbox_pred, pos_bbox_targets = self.add_sin_difference(
pos_bbox_pred, pos_bbox_targets)
loss_bbox = self.loss_bbox(
pos_bbox_pred,
pos_bbox_targets,
pos_bbox_weights,
avg_factor=num_total_samples)

# direction classification loss
loss_dir = None
if self.use_direction_classifier:
loss_dir = self.loss_dir(
pos_dir_cls_preds,
pos_dir_targets,
pos_dir_weights,
avg_factor=num_total_samples)
else:
loss_bbox = pos_bbox_pred.sum()
if self.use_direction_classifier:
loss_dir = pos_dir_cls_preds.sum()

return loss_cls, loss_bbox, loss_dir

@staticmethod
Expand All @@ -270,6 +294,7 @@ def add_sin_difference(boxes1, boxes2):
dim=-1)
return boxes1, boxes2

@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
Expand Down
5 changes: 4 additions & 1 deletion mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import torch
from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init
from mmcv.runner import force_fp32
from torch import nn

from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
Expand Down Expand Up @@ -228,7 +229,7 @@ def forward(self, x):
return ret


@HEADS.register_module
@HEADS.register_module()
class CenterHead(nn.Module):
"""CenterHead for CenterPoint.
Expand Down Expand Up @@ -292,6 +293,7 @@ def __init__(self,
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False

# a shared convolution
self.shared_conv = ConvModule(
Expand Down Expand Up @@ -548,6 +550,7 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
inds.append(ind)
return heatmaps, anno_boxes, inds, masks

@force_fp32(apply_to=('preds_dicts'))
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/dense_heads/free_anchor3d_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self,
self.gamma = gamma
self.alpha = alpha

@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/dense_heads/parta2_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
from mmcv.runner import force_fp32

from mmdet3d.core import limit_period, xywhr2xyxyr
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self,
diff_rad_by_sin, dir_offset, dir_limit_offset,
bbox_coder, loss_cls, loss_bbox, loss_dir)

@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/dense_heads/ssd_3d_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from mmcv.ops.nms import batched_nms
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
Expand Down Expand Up @@ -108,6 +109,7 @@ def _extract_input(self, feat_dict):

return seed_points, seed_features, seed_indices

@force_fp32(apply_to=('bbox_preds', ))
def loss(self,
bbox_preds,
points,
Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/models/dense_heads/vote_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
from mmcv.runner import force_fp32
from torch import nn as nn
from torch.nn import functional as F

Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(self,

self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
self.fp16_enabled = False

# Bbox classification and regression
self.conv_pred = BaseConvBboxHead(
Expand Down Expand Up @@ -204,6 +206,7 @@ def forward(self, feat_dict, sample_mod):

return results

@force_fp32(apply_to=('bbox_preds', ))
def loss(self,
bbox_preds,
points,
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16
from os import path as osp

from mmdet3d.core import Box3DMode, show_result
Expand Down Expand Up @@ -42,6 +43,7 @@ def forward_test(self, points, img_metas, img=None, **kwargs):
else:
return self.aug_test(points, img_metas, img, **kwargs)

@auto_fp16(apply_to=('img', 'points'))
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/dynamic_voxelnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet.models import DETECTORS
Expand Down Expand Up @@ -44,6 +45,7 @@ def extract_feat(self, points, img_metas):
return x

@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/mvx_faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet.models import DETECTORS
Expand All @@ -21,6 +22,7 @@ def __init__(self, **kwargs):
super(DynamicMVXFasterRCNN, self).__init__(**kwargs)

@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/mvx_two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32
from os import path as osp
from torch import nn as nn
from torch.nn import functional as F
Expand Down Expand Up @@ -203,6 +204,7 @@ def extract_feat(self, points, img, img_metas):
return (img_feats, pts_feats)

@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.
Expand Down
2 changes: 2 additions & 0 deletions mmdet3d/models/detectors/voxelnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F

from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
Expand Down Expand Up @@ -46,6 +47,7 @@ def extract_feat(self, points, img_metas):
return x

@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply hard voxelization to points."""
voxels, coors, num_points = [], [], []
Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/models/middle_encoders/pillar_scatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import auto_fp16
from torch import nn

from ..registry import MIDDLE_ENCODERS
Expand All @@ -21,7 +22,9 @@ def __init__(self, in_channels, output_shape):
self.ny = output_shape[0]
self.nx = output_shape[1]
self.in_channels = in_channels
self.fp16_enabled = False

@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size=None):
"""Foraward function to scatter features."""
# TODO: rewrite the function in a batch manner
Expand Down
Loading

0 comments on commit e67b3f8

Please sign in to comment.