diff --git a/configs/fp16/README.md b/configs/fp16/README.md new file mode 100644 index 0000000000..68bebca919 --- /dev/null +++ b/configs/fp16/README.md @@ -0,0 +1,23 @@ +# Mixed Precision Training + +## Introduction + +We implement mixed precision training and apply it to VoxelNets (e.g., SECOND and PointPillars). +The results are in the following tables. + +**Note**: For mixed precision training, we currently do not support PointNet-based methods (e.g., VoteNet). +Mixed precision training for PointNet-based methods will be supported in the future release. + +## Results + +### SECOND on KITTI dataset +| Backbone |Class| Lr schd | FP32 Mem (GB) | FP16 Mem (GB) | FP32 mAP | FP16 mAP |Download | +| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | :------: | +| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py)| Car |cyclic 80e|5.4|2.9|79.07|78.72|| +| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-3class.py)| 3 Class |cyclic 80e|5.4|2.9|64.41|67.4|| + +### PointPillars on nuScenes dataset + + +**Note**: With mixed precision training, we can train PointPillars with RegNet-400mf on 8 Titan XP GPUS with batch size of 2. +This will cause OOM error without mixed precision training. diff --git a/configs/fp16/hv_pointpillars_fpn_sbn-all_fp16_2x8_2x_nus-3d.py b/configs/fp16/hv_pointpillars_fpn_sbn-all_fp16_2x8_2x_nus-3d.py new file mode 100644 index 0000000000..d55fd939c3 --- /dev/null +++ b/configs/fp16/hv_pointpillars_fpn_sbn-all_fp16_2x8_2x_nus-3d.py @@ -0,0 +1,4 @@ +_base_ = '../pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py' +data = dict(samples_per_gpu=2, workers_per_gpu=2) +# fp16 settings +fp16 = dict(loss_scale=512.) diff --git a/configs/fp16/hv_pointpillars_regnet-400mf_fpn_sbn-all_fp16_2x8_2x_nus-3d.py b/configs/fp16/hv_pointpillars_regnet-400mf_fpn_sbn-all_fp16_2x8_2x_nus-3d.py new file mode 100644 index 0000000000..5baf282c05 --- /dev/null +++ b/configs/fp16/hv_pointpillars_regnet-400mf_fpn_sbn-all_fp16_2x8_2x_nus-3d.py @@ -0,0 +1,4 @@ +_base_ = '../regnet/hv_pointpillars_regnet-400mf_fpn_sbn-all_4x8_2x_nus-3d.py' +data = dict(samples_per_gpu=2, workers_per_gpu=2) +# fp16 settings +fp16 = dict(loss_scale=512.) diff --git a/configs/fp16/hv_pointpillars_secfpn_sbn-all_fp16_2x8_2x_nus-3d.py b/configs/fp16/hv_pointpillars_secfpn_sbn-all_fp16_2x8_2x_nus-3d.py new file mode 100644 index 0000000000..1c269590cf --- /dev/null +++ b/configs/fp16/hv_pointpillars_secfpn_sbn-all_fp16_2x8_2x_nus-3d.py @@ -0,0 +1,4 @@ +_base_ = '../pointpillars/hv_pointpillars_secfpn_sbn-all_4x8_2x_nus-3d.py' +data = dict(samples_per_gpu=2, workers_per_gpu=2) +# fp16 settings +fp16 = dict(loss_scale=512.) diff --git a/configs/fp16/hv_second_secfpn_6x8_80e_kitti-3d-3class.py b/configs/fp16/hv_second_secfpn_6x8_80e_kitti-3d-3class.py new file mode 100644 index 0000000000..0632a04842 --- /dev/null +++ b/configs/fp16/hv_second_secfpn_6x8_80e_kitti-3d-3class.py @@ -0,0 +1,3 @@ +_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py' +# fp16 settings +fp16 = dict(loss_scale=512.) diff --git a/configs/fp16/hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py b/configs/fp16/hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py new file mode 100644 index 0000000000..a2aae1518e --- /dev/null +++ b/configs/fp16/hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py @@ -0,0 +1,3 @@ +_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-car.py' +# fp16 settings +fp16 = dict(loss_scale=512.) diff --git a/mmdet3d/models/backbones/base_pointnet.py b/mmdet3d/models/backbones/base_pointnet.py index 504d9f814b..5330c91543 100644 --- a/mmdet3d/models/backbones/base_pointnet.py +++ b/mmdet3d/models/backbones/base_pointnet.py @@ -8,6 +8,7 @@ class BasePointNet(nn.Module, metaclass=ABCMeta): def __init__(self): super(BasePointNet, self).__init__() + self.fp16_enabled = False def init_weights(self, pretrained=None): """Initialize the weights of PointNet backbone.""" diff --git a/mmdet3d/models/backbones/multi_backbone.py b/mmdet3d/models/backbones/multi_backbone.py index 9e115d1e4f..77180fbfd7 100644 --- a/mmdet3d/models/backbones/multi_backbone.py +++ b/mmdet3d/models/backbones/multi_backbone.py @@ -1,7 +1,7 @@ import copy import torch from mmcv.cnn import ConvModule -from mmcv.runner import load_checkpoint +from mmcv.runner import auto_fp16, load_checkpoint from torch import nn as nn from mmdet.models import BACKBONES, build_backbone @@ -86,6 +86,7 @@ def init_weights(self, pretrained=None): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) + @auto_fp16() def forward(self, points): """Forward pass. diff --git a/mmdet3d/models/backbones/pointnet2_sa_msg.py b/mmdet3d/models/backbones/pointnet2_sa_msg.py index cbda996e2c..fff9d344c4 100644 --- a/mmdet3d/models/backbones/pointnet2_sa_msg.py +++ b/mmdet3d/models/backbones/pointnet2_sa_msg.py @@ -1,5 +1,6 @@ import torch from mmcv.cnn import ConvModule +from mmcv.runner import auto_fp16 from torch import nn as nn from mmdet3d.ops import build_sa_module @@ -111,6 +112,7 @@ def __init__(self, bias=True)) sa_in_channel = aggregation_channels[sa_index] + @auto_fp16(apply_to=('points', )) def forward(self, points): """Forward pass. diff --git a/mmdet3d/models/backbones/pointnet2_sa_ssg.py b/mmdet3d/models/backbones/pointnet2_sa_ssg.py index a215d8eb2a..a5c5d200f7 100644 --- a/mmdet3d/models/backbones/pointnet2_sa_ssg.py +++ b/mmdet3d/models/backbones/pointnet2_sa_ssg.py @@ -1,4 +1,5 @@ import torch +from mmcv.runner import auto_fp16 from torch import nn as nn from mmdet3d.ops import PointFPModule, build_sa_module @@ -83,6 +84,7 @@ def __init__(self, fp_source_channel = cur_fp_mlps[-1] fp_target_channel = skip_channel_list.pop() + @auto_fp16(apply_to=('points', )) def forward(self, points): """Forward pass. diff --git a/mmdet3d/models/dense_heads/anchor3d_head.py b/mmdet3d/models/dense_heads/anchor3d_head.py index 275a0e1e82..eed6f05249 100644 --- a/mmdet3d/models/dense_heads/anchor3d_head.py +++ b/mmdet3d/models/dense_heads/anchor3d_head.py @@ -1,6 +1,7 @@ import numpy as np import torch from mmcv.cnn import bias_init_with_prob, normal_init +from mmcv.runner import force_fp32 from torch import nn as nn from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period, @@ -79,6 +80,7 @@ def __init__(self, self.assign_per_class = assign_per_class self.dir_offset = dir_offset self.dir_limit_offset = dir_limit_offset + self.fp16_enabled = False # build anchor generator self.anchor_generator = build_anchor_generator(anchor_generator) @@ -211,39 +213,61 @@ def loss_single(self, cls_score, bbox_pred, dir_cls_preds, labels, labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes) + assert labels.max().item() <= self.num_classes loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=num_total_samples) # regression loss + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(-1, self.box_code_size) bbox_targets = bbox_targets.reshape(-1, self.box_code_size) bbox_weights = bbox_weights.reshape(-1, self.box_code_size) - code_weight = self.train_cfg.get('code_weight', None) - if code_weight: - bbox_weights = bbox_weights * bbox_weights.new_tensor(code_weight) - bbox_pred = bbox_pred.permute(0, 2, 3, - 1).reshape(-1, self.box_code_size) - if self.diff_rad_by_sin: - bbox_pred, bbox_targets = self.add_sin_difference( - bbox_pred, bbox_targets) - loss_bbox = self.loss_bbox( - bbox_pred, - bbox_targets, - bbox_weights, - avg_factor=num_total_samples) - - # direction classification loss - loss_dir = None + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().reshape(-1) + num_pos = len(pos_inds) + + pos_bbox_pred = bbox_pred[pos_inds] + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_weights = bbox_weights[pos_inds] + + # dir loss if self.use_direction_classifier: dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).reshape(-1, 2) dir_targets = dir_targets.reshape(-1) dir_weights = dir_weights.reshape(-1) - loss_dir = self.loss_dir( - dir_cls_preds, - dir_targets, - dir_weights, + pos_dir_cls_preds = dir_cls_preds[pos_inds] + pos_dir_targets = dir_targets[pos_inds] + pos_dir_weights = dir_weights[pos_inds] + + if num_pos > 0: + code_weight = self.train_cfg.get('code_weight', None) + if code_weight: + bbox_weights = bbox_weights * bbox_weights.new_tensor( + code_weight) + if self.diff_rad_by_sin: + pos_bbox_pred, pos_bbox_targets = self.add_sin_difference( + pos_bbox_pred, pos_bbox_targets) + loss_bbox = self.loss_bbox( + pos_bbox_pred, + pos_bbox_targets, + pos_bbox_weights, avg_factor=num_total_samples) + # direction classification loss + loss_dir = None + if self.use_direction_classifier: + loss_dir = self.loss_dir( + pos_dir_cls_preds, + pos_dir_targets, + pos_dir_weights, + avg_factor=num_total_samples) + else: + loss_bbox = pos_bbox_pred.sum() + if self.use_direction_classifier: + loss_dir = pos_dir_cls_preds.sum() + return loss_cls, loss_bbox, loss_dir @staticmethod @@ -270,6 +294,7 @@ def add_sin_difference(boxes1, boxes2): dim=-1) return boxes1, boxes2 + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds')) def loss(self, cls_scores, bbox_preds, diff --git a/mmdet3d/models/dense_heads/centerpoint_head.py b/mmdet3d/models/dense_heads/centerpoint_head.py index 33eb80f7da..3e765998bf 100644 --- a/mmdet3d/models/dense_heads/centerpoint_head.py +++ b/mmdet3d/models/dense_heads/centerpoint_head.py @@ -2,6 +2,7 @@ import numpy as np import torch from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init +from mmcv.runner import force_fp32 from torch import nn from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius, @@ -228,7 +229,7 @@ def forward(self, x): return ret -@HEADS.register_module +@HEADS.register_module() class CenterHead(nn.Module): """CenterHead for CenterPoint. @@ -292,6 +293,7 @@ def __init__(self, self.loss_bbox = build_loss(loss_bbox) self.bbox_coder = build_bbox_coder(bbox_coder) self.num_anchor_per_locs = [n for n in num_classes] + self.fp16_enabled = False # a shared convolution self.shared_conv = ConvModule( @@ -548,6 +550,7 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d): inds.append(ind) return heatmaps, anno_boxes, inds, masks + @force_fp32(apply_to=('preds_dicts')) def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs): """Loss function for CenterHead. diff --git a/mmdet3d/models/dense_heads/free_anchor3d_head.py b/mmdet3d/models/dense_heads/free_anchor3d_head.py index 91f8379d7f..633c635901 100644 --- a/mmdet3d/models/dense_heads/free_anchor3d_head.py +++ b/mmdet3d/models/dense_heads/free_anchor3d_head.py @@ -1,4 +1,5 @@ import torch +from mmcv.runner import force_fp32 from torch.nn import functional as F from mmdet3d.core.bbox import bbox_overlaps_nearest_3d @@ -38,6 +39,7 @@ def __init__(self, self.gamma = gamma self.alpha = alpha + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds')) def loss(self, cls_scores, bbox_preds, diff --git a/mmdet3d/models/dense_heads/parta2_rpn_head.py b/mmdet3d/models/dense_heads/parta2_rpn_head.py index adbbaf4dd7..d45d9a5508 100644 --- a/mmdet3d/models/dense_heads/parta2_rpn_head.py +++ b/mmdet3d/models/dense_heads/parta2_rpn_head.py @@ -2,6 +2,7 @@ import numpy as np import torch +from mmcv.runner import force_fp32 from mmdet3d.core import limit_period, xywhr2xyxyr from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu @@ -81,6 +82,7 @@ def __init__(self, diff_rad_by_sin, dir_offset, dir_limit_offset, bbox_coder, loss_cls, loss_bbox, loss_dir) + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds')) def loss(self, cls_scores, bbox_preds, diff --git a/mmdet3d/models/dense_heads/ssd_3d_head.py b/mmdet3d/models/dense_heads/ssd_3d_head.py index 8677a3b371..5a8d805aa7 100644 --- a/mmdet3d/models/dense_heads/ssd_3d_head.py +++ b/mmdet3d/models/dense_heads/ssd_3d_head.py @@ -1,5 +1,6 @@ import torch from mmcv.ops.nms import batched_nms +from mmcv.runner import force_fp32 from torch.nn import functional as F from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, @@ -108,6 +109,7 @@ def _extract_input(self, feat_dict): return seed_points, seed_features, seed_indices + @force_fp32(apply_to=('bbox_preds', )) def loss(self, bbox_preds, points, diff --git a/mmdet3d/models/dense_heads/vote_head.py b/mmdet3d/models/dense_heads/vote_head.py index 2ad12ef16c..ed687908ac 100644 --- a/mmdet3d/models/dense_heads/vote_head.py +++ b/mmdet3d/models/dense_heads/vote_head.py @@ -1,5 +1,6 @@ import numpy as np import torch +from mmcv.runner import force_fp32 from torch import nn as nn from torch.nn import functional as F @@ -78,6 +79,7 @@ def __init__(self, self.vote_module = VoteModule(**vote_module_cfg) self.vote_aggregation = build_sa_module(vote_aggregation_cfg) + self.fp16_enabled = False # Bbox classification and regression self.conv_pred = BaseConvBboxHead( @@ -204,6 +206,7 @@ def forward(self, feat_dict, sample_mod): return results + @force_fp32(apply_to=('bbox_preds', )) def loss(self, bbox_preds, points, diff --git a/mmdet3d/models/detectors/base.py b/mmdet3d/models/detectors/base.py index e0cdb16ac7..d77fbd0ed7 100644 --- a/mmdet3d/models/detectors/base.py +++ b/mmdet3d/models/detectors/base.py @@ -2,6 +2,7 @@ import mmcv import torch from mmcv.parallel import DataContainer as DC +from mmcv.runner import auto_fp16 from os import path as osp from mmdet3d.core import Box3DMode, show_result @@ -42,6 +43,7 @@ def forward_test(self, points, img_metas, img=None, **kwargs): else: return self.aug_test(points, img_metas, img, **kwargs) + @auto_fp16(apply_to=('img', 'points')) def forward(self, return_loss=True, **kwargs): """Calls either forward_train or forward_test depending on whether return_loss=True. diff --git a/mmdet3d/models/detectors/dynamic_voxelnet.py b/mmdet3d/models/detectors/dynamic_voxelnet.py index 68e8b9df41..a7241ac43e 100644 --- a/mmdet3d/models/detectors/dynamic_voxelnet.py +++ b/mmdet3d/models/detectors/dynamic_voxelnet.py @@ -1,4 +1,5 @@ import torch +from mmcv.runner import force_fp32 from torch.nn import functional as F from mmdet.models import DETECTORS @@ -44,6 +45,7 @@ def extract_feat(self, points, img_metas): return x @torch.no_grad() + @force_fp32() def voxelize(self, points): """Apply dynamic voxelization to points. diff --git a/mmdet3d/models/detectors/mvx_faster_rcnn.py b/mmdet3d/models/detectors/mvx_faster_rcnn.py index 097439a17c..40d6ade489 100644 --- a/mmdet3d/models/detectors/mvx_faster_rcnn.py +++ b/mmdet3d/models/detectors/mvx_faster_rcnn.py @@ -1,4 +1,5 @@ import torch +from mmcv.runner import force_fp32 from torch.nn import functional as F from mmdet.models import DETECTORS @@ -21,6 +22,7 @@ def __init__(self, **kwargs): super(DynamicMVXFasterRCNN, self).__init__(**kwargs) @torch.no_grad() + @force_fp32() def voxelize(self, points): """Apply dynamic voxelization to points. diff --git a/mmdet3d/models/detectors/mvx_two_stage.py b/mmdet3d/models/detectors/mvx_two_stage.py index cad160f601..1a160ba2f0 100644 --- a/mmdet3d/models/detectors/mvx_two_stage.py +++ b/mmdet3d/models/detectors/mvx_two_stage.py @@ -2,6 +2,7 @@ import mmcv import torch from mmcv.parallel import DataContainer as DC +from mmcv.runner import force_fp32 from os import path as osp from torch import nn as nn from torch.nn import functional as F @@ -203,6 +204,7 @@ def extract_feat(self, points, img, img_metas): return (img_feats, pts_feats) @torch.no_grad() + @force_fp32() def voxelize(self, points): """Apply dynamic voxelization to points. diff --git a/mmdet3d/models/detectors/voxelnet.py b/mmdet3d/models/detectors/voxelnet.py index 67f97e5090..1b4841582c 100644 --- a/mmdet3d/models/detectors/voxelnet.py +++ b/mmdet3d/models/detectors/voxelnet.py @@ -1,4 +1,5 @@ import torch +from mmcv.runner import force_fp32 from torch.nn import functional as F from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d @@ -46,6 +47,7 @@ def extract_feat(self, points, img_metas): return x @torch.no_grad() + @force_fp32() def voxelize(self, points): """Apply hard voxelization to points.""" voxels, coors, num_points = [], [], [] diff --git a/mmdet3d/models/middle_encoders/pillar_scatter.py b/mmdet3d/models/middle_encoders/pillar_scatter.py index e02d61a7a8..0141d13351 100644 --- a/mmdet3d/models/middle_encoders/pillar_scatter.py +++ b/mmdet3d/models/middle_encoders/pillar_scatter.py @@ -1,4 +1,5 @@ import torch +from mmcv.runner import auto_fp16 from torch import nn from ..registry import MIDDLE_ENCODERS @@ -21,7 +22,9 @@ def __init__(self, in_channels, output_shape): self.ny = output_shape[0] self.nx = output_shape[1] self.in_channels = in_channels + self.fp16_enabled = False + @auto_fp16(apply_to=('voxel_features', )) def forward(self, voxel_features, coors, batch_size=None): """Foraward function to scatter features.""" # TODO: rewrite the function in a batch manner diff --git a/mmdet3d/models/middle_encoders/sparse_encoder.py b/mmdet3d/models/middle_encoders/sparse_encoder.py index 00462aad01..c49d47d797 100644 --- a/mmdet3d/models/middle_encoders/sparse_encoder.py +++ b/mmdet3d/models/middle_encoders/sparse_encoder.py @@ -1,3 +1,4 @@ +from mmcv.runner import auto_fp16 from torch import nn as nn from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule @@ -49,6 +50,7 @@ def __init__(self, self.encoder_channels = encoder_channels self.encoder_paddings = encoder_paddings self.stage_num = len(self.encoder_channels) + self.fp16_enabled = False # Spconv init all weight on its own assert isinstance(order, tuple) and len(order) == 3 @@ -90,6 +92,7 @@ def __init__(self, indice_key='spconv_down2', conv_type='SparseConv3d') + @auto_fp16(apply_to=('voxel_features', )) def forward(self, voxel_features, coors, batch_size): """Forward of SparseEncoder. diff --git a/mmdet3d/models/middle_encoders/sparse_unet.py b/mmdet3d/models/middle_encoders/sparse_unet.py index 469a878575..201da05757 100644 --- a/mmdet3d/models/middle_encoders/sparse_unet.py +++ b/mmdet3d/models/middle_encoders/sparse_unet.py @@ -1,4 +1,5 @@ import torch +from mmcv.runner import auto_fp16 from torch import nn as nn from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule @@ -51,6 +52,7 @@ def __init__(self, self.decoder_channels = decoder_channels self.decoder_paddings = decoder_paddings self.stage_num = len(self.encoder_channels) + self.fp16_enabled = False # Spconv init all weight on its own assert isinstance(order, tuple) and len(order) == 3 @@ -91,6 +93,7 @@ def __init__(self, indice_key='spconv_down2', conv_type='SparseConv3d') + @auto_fp16(apply_to=('voxel_features', )) def forward(self, voxel_features, coors, batch_size): """Forward of SparseUNet. diff --git a/mmdet3d/models/necks/second_fpn.py b/mmdet3d/models/necks/second_fpn.py index ec9d40ee5c..4acd318aa7 100644 --- a/mmdet3d/models/necks/second_fpn.py +++ b/mmdet3d/models/necks/second_fpn.py @@ -2,6 +2,7 @@ import torch from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, constant_init, is_norm, kaiming_init) +from mmcv.runner import auto_fp16 from torch import nn as nn from mmdet.models import NECKS @@ -36,6 +37,7 @@ def __init__(self, assert len(out_channels) == len(upsample_strides) == len(in_channels) self.in_channels = in_channels self.out_channels = out_channels + self.fp16_enabled = False deblocks = [] for i, out_channel in enumerate(out_channels): @@ -70,6 +72,7 @@ def init_weights(self): elif is_norm(m): constant_init(m, 1) + @auto_fp16() def forward(self, x): """Forward function. diff --git a/mmdet3d/models/voxel_encoders/pillar_encoder.py b/mmdet3d/models/voxel_encoders/pillar_encoder.py index 3411260971..f69a60bf55 100644 --- a/mmdet3d/models/voxel_encoders/pillar_encoder.py +++ b/mmdet3d/models/voxel_encoders/pillar_encoder.py @@ -1,5 +1,6 @@ import torch from mmcv.cnn import build_norm_layer +from mmcv.runner import force_fp32 from torch import nn from mmdet3d.ops import DynamicScatter @@ -58,7 +59,7 @@ def __init__(self, self._with_distance = with_distance self._with_cluster_center = with_cluster_center self._with_voxel_center = with_voxel_center - + self.fp16_enabled = False # Create PillarFeatureNet layers self.in_channels = in_channels feat_channels = [in_channels] + list(feat_channels) @@ -86,6 +87,7 @@ def __init__(self, self.y_offset = self.vy / 2 + point_cloud_range[1] self.point_cloud_range = point_cloud_range + @force_fp32(out_fp16=True) def forward(self, features, num_points, coors): """Forward function. @@ -196,7 +198,7 @@ def __init__(self, point_cloud_range=point_cloud_range, norm_cfg=norm_cfg, mode=mode) - + self.fp16_enabled = False feat_channels = [self.in_channels] + list(feat_channels) pfn_layers = [] # TODO: currently only support one PFNLayer @@ -257,6 +259,7 @@ def map_voxel_center_to_point(self, pts_coors, voxel_mean, voxel_coors): center_per_point = canvas[:, voxel_index.long()].t() return center_per_point + @force_fp32(out_fp16=True) def forward(self, features, coors): """Forward function. diff --git a/mmdet3d/models/voxel_encoders/utils.py b/mmdet3d/models/voxel_encoders/utils.py index 280a6a08ea..68f8fdd79d 100644 --- a/mmdet3d/models/voxel_encoders/utils.py +++ b/mmdet3d/models/voxel_encoders/utils.py @@ -1,5 +1,6 @@ import torch from mmcv.cnn import build_norm_layer +from mmcv.runner import auto_fp16 from torch import nn from torch.nn import functional as F @@ -51,6 +52,7 @@ def __init__(self, max_out=True, cat_max=True): super(VFELayer, self).__init__() + self.fp16_enabled = False self.cat_max = cat_max self.max_out = max_out # self.units = int(out_channels / 2) @@ -58,6 +60,7 @@ def __init__(self, self.norm = build_norm_layer(norm_cfg, out_channels)[1] self.linear = nn.Linear(in_channels, out_channels, bias=False) + @auto_fp16(apply_to=('inputs'), out_fp32=True) def forward(self, inputs): """Forward function. @@ -78,6 +81,7 @@ def forward(self, inputs): """ # [K, T, 7] tensordot [7, units] = [K, T, units] voxel_count = inputs.shape[1] + x = self.linear(inputs) x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() @@ -123,6 +127,7 @@ def __init__(self, mode='max'): super().__init__() + self.fp16_enabled = False self.name = 'PFNLayer' self.last_vfe = last_layer if not self.last_vfe: @@ -135,6 +140,7 @@ def __init__(self, assert mode in ['max', 'avg'] self.mode = mode + @auto_fp16(apply_to=('inputs'), out_fp32=True) def forward(self, inputs, num_voxels=None, aligned_distance=None): """Forward function. diff --git a/mmdet3d/models/voxel_encoders/voxel_encoder.py b/mmdet3d/models/voxel_encoders/voxel_encoder.py index b647eb608f..9dfd59ede6 100644 --- a/mmdet3d/models/voxel_encoders/voxel_encoder.py +++ b/mmdet3d/models/voxel_encoders/voxel_encoder.py @@ -1,5 +1,6 @@ import torch from mmcv.cnn import build_norm_layer +from mmcv.runner import force_fp32 from torch import nn from mmdet3d.ops import DynamicScatter @@ -21,7 +22,9 @@ class HardSimpleVFE(nn.Module): def __init__(self, num_features=4): super(HardSimpleVFE, self).__init__() self.num_features = num_features + self.fp16_enabled = False + @force_fp32(out_fp16=True) def forward(self, features, num_points, coors): """Forward function. @@ -58,8 +61,10 @@ def __init__(self, point_cloud_range=(0, -40, -3, 70.4, 40, 1)): super(DynamicSimpleVFE, self).__init__() self.scatter = DynamicScatter(voxel_size, point_cloud_range, True) + self.fp16_enabled = False @torch.no_grad() + @force_fp32(out_fp16=True) def forward(self, features, coors): """Forward function. @@ -134,6 +139,7 @@ def __init__(self, self._with_cluster_center = with_cluster_center self._with_voxel_center = with_voxel_center self.return_point_feats = return_point_feats + self.fp16_enabled = False # Need pillar (voxel) size and x/y offset in order to calculate offset self.vx = voxel_size[0] @@ -209,6 +215,7 @@ def map_voxel_center_to_point(self, pts_coors, voxel_mean, voxel_coors): center_per_point = voxel_mean[voxel_inds, ...] return center_per_point + @force_fp32(out_fp16=True) def forward(self, features, coors, @@ -330,6 +337,7 @@ def __init__(self, self._with_cluster_center = with_cluster_center self._with_voxel_center = with_voxel_center self.return_point_feats = return_point_feats + self.fp16_enabled = False # Need pillar (voxel) size and x/y offset to calculate pillar offset self.vx = voxel_size[0] @@ -372,6 +380,7 @@ def __init__(self, if fusion_layer is not None: self.fusion_layer = builder.build_fusion_layer(fusion_layer) + @force_fp32(out_fp16=True) def forward(self, features, num_points, diff --git a/mmdet3d/ops/furthest_point_sample/points_sampler.py b/mmdet3d/ops/furthest_point_sample/points_sampler.py index 32645be999..4c962517f3 100644 --- a/mmdet3d/ops/furthest_point_sample/points_sampler.py +++ b/mmdet3d/ops/furthest_point_sample/points_sampler.py @@ -1,4 +1,5 @@ import torch +from mmcv.runner import force_fp32 from torch import nn as nn from typing import List @@ -59,7 +60,9 @@ def __init__(self, self.samplers = nn.ModuleList() for fps_mod in fps_mod_list: self.samplers.append(get_sampler_type(fps_mod)()) + self.fp16_enabled = False + @force_fp32() def forward(self, points_xyz, features): """forward. diff --git a/mmdet3d/ops/norm.py b/mmdet3d/ops/norm.py index f8aae10a8c..e9db8fb579 100644 --- a/mmdet3d/ops/norm.py +++ b/mmdet3d/ops/norm.py @@ -1,5 +1,6 @@ import torch from mmcv.cnn import NORM_LAYERS +from mmcv.runner import force_fp32 from torch import distributed as dist from torch import nn as nn from torch.autograd.function import Function @@ -42,10 +43,19 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): It is slower than `nn.SyncBatchNorm`. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fp16_enabled = False + + # customized normalization layer still needs this decorator + # to force the input to be fp32 and the output to be fp16 + # TODO: make mmcv fp16 utils handle customized norm layers + @force_fp32(out_fp16=True) def forward(self, input): + assert input.dtype == torch.float32, \ + f'input should be in float32 type, got {input.dtype}' if dist.get_world_size() == 1 or not self.training: return super().forward(input) - assert input.shape[0] > 0, 'SyncBN does not support empty inputs' C = input.shape[1] mean = torch.mean(input, dim=[0, 2]) @@ -87,7 +97,17 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d): It is slower than `nn.SyncBatchNorm`. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fp16_enabled = False + + # customized normalization layer still needs this decorator + # to force the input to be fp32 and the output to be fp16 + # TODO: make mmcv fp16 utils handle customized norm layers + @force_fp32(out_fp16=True) def forward(self, input): + assert input.dtype == torch.float32, \ + f'input should be in float32 type, got {input.dtype}' if dist.get_world_size() == 1 or not self.training: return super().forward(input) diff --git a/mmdet3d/ops/pointnet_modules/point_fp_module.py b/mmdet3d/ops/pointnet_modules/point_fp_module.py index cc46ce2a8c..eb9414d1c0 100644 --- a/mmdet3d/ops/pointnet_modules/point_fp_module.py +++ b/mmdet3d/ops/pointnet_modules/point_fp_module.py @@ -1,5 +1,6 @@ import torch from mmcv.cnn import ConvModule +from mmcv.runner import force_fp32 from torch import nn as nn from typing import List @@ -21,7 +22,7 @@ def __init__(self, mlp_channels: List[int], norm_cfg: dict = dict(type='BN2d')): super().__init__() - + self.fp16_enabled = False self.mlps = nn.Sequential() for i in range(len(mlp_channels) - 1): self.mlps.add_module( @@ -34,6 +35,7 @@ def __init__(self, conv_cfg=dict(type='Conv2d'), norm_cfg=norm_cfg)) + @force_fp32() def forward(self, target: torch.Tensor, source: torch.Tensor, target_feats: torch.Tensor, source_feats: torch.Tensor) -> torch.Tensor: diff --git a/mmdet3d/ops/pointnet_modules/point_sa_module.py b/mmdet3d/ops/pointnet_modules/point_sa_module.py index b0ebd21ddc..f23e536683 100644 --- a/mmdet3d/ops/pointnet_modules/point_sa_module.py +++ b/mmdet3d/ops/pointnet_modules/point_sa_module.py @@ -145,7 +145,6 @@ def forward( """ new_features_list = [] xyz_flipped = points_xyz.transpose(1, 2).contiguous() - if indices is not None: assert (indices.shape[1] == self.num_point[0]) new_xyz = gather_points(xyz_flipped, indices).transpose(