diff --git a/.dev_scripts/gather_models.py b/.dev_scripts/gather_models.py index 38270a73b8..c4f355db00 100644 --- a/.dev_scripts/gather_models.py +++ b/.dev_scripts/gather_models.py @@ -3,6 +3,16 @@ Usage: python gather_models.py ${root_path} ${out_dir} + +Example: +python gather_models.py \ +work_dirs/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d \ +work_dirs/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d + +Note that before running the above command, rename the directory with the +config name if you did not use the default directory name, create +a corresponding directory 'pgd' under the above path and put the used config +into it. """ import argparse @@ -36,16 +46,18 @@ RESULTS_LUT = { 'coco': ['bbox_mAP', 'segm_mAP'], 'nus': ['pts_bbox_NuScenes/NDS', 'NDS'], - 'kitti-3d-3class': [ - 'KITTI/Overall_3D_moderate', - 'Overall_3D_moderate', - ], + 'kitti-3d-3class': ['KITTI/Overall_3D_moderate', 'Overall_3D_moderate'], 'kitti-3d-car': ['KITTI/Car_3D_moderate_strict', 'Car_3D_moderate_strict'], 'lyft': ['score'], 'scannet_seg': ['miou'], 's3dis_seg': ['miou'], 'scannet': ['mAP_0.50'], - 'sunrgbd': ['mAP_0.50'] + 'sunrgbd': ['mAP_0.50'], + 'kitti-mono3d': [ + 'img_bbox/KITTI/Car_3D_AP40_moderate_strict', + 'Car_3D_AP40_moderate_strict' + ], + 'nus-mono3d': ['img_bbox_NuScenes/NDS', 'NDS'] } @@ -145,15 +157,13 @@ def main(): # and parse the best performance model_infos = [] for used_config in used_configs: - exp_dir = osp.join(models_root, used_config) - # get logs - log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0] - log_txt_path = glob.glob(osp.join(exp_dir, '*.log'))[0] + log_json_path = glob.glob(osp.join(models_root, '*.log.json'))[0] + log_txt_path = glob.glob(osp.join(models_root, '*.log'))[0] model_performance = get_best_results(log_json_path) final_epoch = model_performance['epoch'] final_model = 'epoch_{}.pth'.format(final_epoch) - model_path = osp.join(exp_dir, final_model) + model_path = osp.join(models_root, final_model) # skip if the model is still training if not osp.exists(model_path): @@ -182,7 +192,7 @@ def main(): model_name = model['config'].split('/')[-1].rstrip( '.py') + '_' + model['model_time'] publish_model_path = osp.join(model_publish_dir, model_name) - trained_model_path = osp.join(models_root, model['config'], + trained_model_path = osp.join(models_root, 'epoch_{}.pth'.format(model['epochs'])) # convert model @@ -191,11 +201,10 @@ def main(): # copy log shutil.copy( - osp.join(models_root, model['config'], model['log_json_path']), + osp.join(models_root, model['log_json_path']), osp.join(model_publish_dir, f'{model_name}.log.json')) shutil.copy( - osp.join(models_root, model['config'], - model['log_json_path'].rstrip('.json')), + osp.join(models_root, model['log_json_path'].rstrip('.json')), osp.join(model_publish_dir, f'{model_name}.log')) # copy config to guarantee reproducibility diff --git a/configs/_base_/models/fcos3d.py b/configs/_base_/models/fcos3d.py index a46ed9cd60..be83001d8f 100644 --- a/configs/_base_/models/fcos3d.py +++ b/configs/_base_/models/fcos3d.py @@ -1,6 +1,5 @@ model = dict( type='FCOSMono3D', - pretrained='open-mmlab://detectron2/resnet101_caffe', backbone=dict( type='ResNet', depth=101, @@ -9,7 +8,10 @@ frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=False), norm_eval=True, - style='caffe'), + style='caffe', + init_cfg=dict( + type='Pretrained', + checkpoint='open-mmlab://detectron2/resnet101_caffe')), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], diff --git a/configs/_base_/models/pgd.py b/configs/_base_/models/pgd.py new file mode 100644 index 0000000000..e63fc1fceb --- /dev/null +++ b/configs/_base_/models/pgd.py @@ -0,0 +1,55 @@ +_base_ = './fcos3d.py' +# model settings +model = dict( + bbox_head=dict( + _delete_=True, + type='PGDHead', + num_classes=10, + in_channels=256, + stacked_convs=2, + feat_channels=256, + use_direction_classifier=True, + diff_rad_by_sin=True, + pred_attrs=True, + pred_velo=True, + pred_bbox2d=True, + pred_keypoints=False, + dir_offset=0.7854, # pi/4 + strides=[8, 16, 32, 64, 128], + group_reg_dims=(2, 1, 3, 1, 2), # offset, depth, size, rot, velo + cls_branch=(256, ), + reg_branch=( + (256, ), # offset + (256, ), # depth + (256, ), # size + (256, ), # rot + () # velo + ), + dir_branch=(256, ), + attr_branch=(256, ), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + loss_dir=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_attr=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + norm_on_bbox=True, + centerness_on_reg=True, + center_sampling=True, + conv_bias=True, + dcn_on_last_conv=True, + use_depth_classifier=True, + depth_branch=(256, ), + depth_range=(0, 50), + depth_unit=10, + division='uniform', + depth_bins=6, + bbox_coder=dict(type='PGDBBoxCoder', code_size=9)), + test_cfg=dict(nms_pre=1000, nms_thr=0.8, score_thr=0.01, max_per_img=200)) diff --git a/configs/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py b/configs/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py new file mode 100644 index 0000000000..832b34e64d --- /dev/null +++ b/configs/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py @@ -0,0 +1,127 @@ +_base_ = [ + '../_base_/datasets/kitti-mono3d.py', '../_base_/models/pgd.py', + '../_base_/schedules/mmdet_schedule_1x.py', '../_base_/default_runtime.py' +] +# model settings +model = dict( + backbone=dict(frozen_stages=0), + neck=dict(start_level=0, num_outs=4), + bbox_head=dict( + num_classes=3, + bbox_code_size=7, + pred_attrs=False, + pred_velo=False, + pred_bbox2d=True, + use_onlyreg_proj=True, + strides=(4, 8, 16, 32), + regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 1e8)), + group_reg_dims=(2, 1, 3, 1, 16, + 4), # offset, depth, size, rot, kpts, bbox2d + reg_branch=( + (256, ), # offset + (256, ), # depth + (256, ), # size + (256, ), # rot + (256, ), # kpts + (256, ) # bbox2d + ), + centerness_branch=(256, ), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + loss_dir=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + use_depth_classifier=True, + depth_branch=(256, ), + depth_range=(0, 70), + depth_unit=10, + division='uniform', + depth_bins=8, + pred_keypoints=True, + weight_dim=1, + loss_depth=dict( + type='UncertainSmoothL1Loss', alpha=1.0, beta=3.0, + loss_weight=1.0), + bbox_coder=dict( + type='PGDBBoxCoder', + base_depths=((28.01, 16.32), ), + base_dims=((0.8, 1.73, 0.6), (1.76, 1.73, 0.6), (3.9, 1.56, 1.6)), + code_size=7)), + # set weight 1.0 for base 7 dims (offset, depth, size, rot) + # 0.2 for 16-dim keypoint offsets and 1.0 for 4-dim 2D distance targets + train_cfg=dict(code_weight=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, + 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 1.0, 1.0, 1.0, 1.0 + ]), + test_cfg=dict(nms_pre=100, nms_thr=0.05, score_thr=0.001, max_per_img=20)) + +class_names = ['Pedestrian', 'Cyclist', 'Car'] +img_norm_cfg = dict( + mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) +train_pipeline = [ + dict(type='LoadImageFromFileMono3D'), + dict( + type='LoadAnnotations3D', + with_bbox=True, + with_label=True, + with_attr_label=False, + with_bbox_3d=True, + with_label_3d=True, + with_bbox_depth=True), + dict(type='Resize', img_scale=(1242, 375), keep_ratio=True), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=[ + 'img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d', 'gt_labels_3d', + 'centers2d', 'depths' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFileMono3D'), + dict( + type='MultiScaleFlipAug', + scale_factor=1.0, + flip=False, + transforms=[ + dict(type='RandomFlip3D'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=3, + workers_per_gpu=3, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +# optimizer +optimizer = dict( + lr=0.001, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.)) +optimizer_config = dict( + _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[32, 44]) +total_epochs = 48 +runner = dict(type='EpochBasedRunner', max_epochs=48) +evaluation = dict(interval=2) +checkpoint_config = dict(interval=8) diff --git a/mmdet3d/core/bbox/coders/pgd_bbox_coder.py b/mmdet3d/core/bbox/coders/pgd_bbox_coder.py index 5e42128189..70db450d25 100644 --- a/mmdet3d/core/bbox/coders/pgd_bbox_coder.py +++ b/mmdet3d/core/bbox/coders/pgd_bbox_coder.py @@ -1,4 +1,5 @@ import numpy as np +import torch from torch.nn import functional as F from mmdet.core.bbox.builder import BBOX_CODERS @@ -45,8 +46,9 @@ def decode_2d(self, scale_kpts = scale[3] # 2 dimension of offsets x 8 corners of a 3D bbox bbox[:, self.bbox_code_size:self.bbox_code_size + 16] = \ - scale_kpts(clone_bbox[ - :, self.bbox_code_size:self.bbox_code_size + 16]).float() + torch.tanh(scale_kpts(clone_bbox[ + :, self.bbox_code_size:self.bbox_code_size + 16]).float()) + if pred_bbox2d: scale_bbox2d = scale[-1] # The last four dimensions are offsets to four sides of a 2D bbox diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index 28a746d356..0a9e50d637 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -8,6 +8,7 @@ from .free_anchor3d_head import FreeAnchor3DHead from .groupfree3d_head import GroupFree3DHead from .parta2_rpn_head import PartA2RPNHead +from .pgd_head import PGDHead from .shape_aware_head import ShapeAwareHead from .smoke_mono3d_head import SMOKEMono3DHead from .ssd_3d_head import SSD3DHead @@ -17,5 +18,5 @@ 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead', 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead', 'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead', - 'GroupFree3DHead', 'SMOKEMono3DHead' + 'GroupFree3DHead', 'SMOKEMono3DHead', 'PGDHead' ] diff --git a/mmdet3d/models/dense_heads/anchor_free_mono3d_head.py b/mmdet3d/models/dense_heads/anchor_free_mono3d_head.py index 3f91c5e02a..8aff33b324 100644 --- a/mmdet3d/models/dense_heads/anchor_free_mono3d_head.py +++ b/mmdet3d/models/dense_heads/anchor_free_mono3d_head.py @@ -176,13 +176,6 @@ def __init__( self.attr_branch = attr_branch self._init_layers() - if init_cfg is None: - self.init_cfg = dict( - type='Normal', - layer='Conv2d', - std=0.01, - override=dict( - type='Normal', name='conv_cls', std=0.01, bias_prob=0.01)) def _init_layers(self): """Initialize layers of the head.""" @@ -288,8 +281,34 @@ def _init_predictor(self): self.conv_attr = nn.Conv2d(self.attr_branch[-1], self.num_attrs, 1) def init_weights(self): - super().init_weights() + """Initialize weights of the head. + + We currently still use the customized defined init_weights because the + default init of DCN triggered by the init_cfg will init + conv_offset.weight, which mistakenly affects the training stability. + """ + for modules in [self.cls_convs, self.reg_convs, self.conv_cls_prev]: + for m in modules: + if isinstance(m.conv, nn.Conv2d): + normal_init(m.conv, std=0.01) + for conv_reg_prev in self.conv_reg_prevs: + if conv_reg_prev is None: + continue + for m in conv_reg_prev: + if isinstance(m.conv, nn.Conv2d): + normal_init(m.conv, std=0.01) + if self.use_direction_classifier: + for m in self.conv_dir_cls_prev: + if isinstance(m.conv, nn.Conv2d): + normal_init(m.conv, std=0.01) + if self.pred_attrs: + for m in self.conv_attr_prev: + if isinstance(m.conv, nn.Conv2d): + normal_init(m.conv, std=0.01) bias_cls = bias_init_with_prob(0.01) + normal_init(self.conv_cls, std=0.01, bias=bias_cls) + for conv_reg in self.conv_regs: + normal_init(conv_reg, std=0.01) if self.use_direction_classifier: normal_init(self.conv_dir_cls, std=0.01, bias=bias_cls) if self.pred_attrs: diff --git a/mmdet3d/models/dense_heads/fcos_mono3d_head.py b/mmdet3d/models/dense_heads/fcos_mono3d_head.py index 83a7871f9e..ef1b53a02a 100644 --- a/mmdet3d/models/dense_heads/fcos_mono3d_head.py +++ b/mmdet3d/models/dense_heads/fcos_mono3d_head.py @@ -2,7 +2,7 @@ import numpy as np import torch from logging import warning -from mmcv.cnn import Scale +from mmcv.cnn import Scale, normal_init from mmcv.runner import force_fp32 from torch import nn as nn @@ -101,13 +101,6 @@ def __init__(self, self.loss_centerness = build_loss(loss_centerness) bbox_coder['code_size'] = self.bbox_code_size self.bbox_coder = build_bbox_coder(bbox_coder) - if init_cfg is None: - self.init_cfg = dict( - type='Normal', - layer='Conv2d', - std=0.01, - override=dict( - type='Normal', name='conv_cls', std=0.01, bias_prob=0.01)) def _init_layers(self): """Initialize layers of the head.""" @@ -122,6 +115,19 @@ def _init_layers(self): for _ in self.strides ]) + def init_weights(self): + """Initialize weights of the head. + + We currently still use the customized init_weights because the default + init of DCN triggered by the init_cfg will init conv_offset.weight, + which mistakenly affects the training stability. + """ + super().init_weights() + for m in self.conv_centerness_prev: + if isinstance(m.conv, nn.Conv2d): + normal_init(m.conv, std=0.01) + normal_init(self.conv_centerness, std=0.01) + def forward(self, feats): """Forward features from the upstream network. diff --git a/mmdet3d/models/dense_heads/pgd_head.py b/mmdet3d/models/dense_heads/pgd_head.py new file mode 100644 index 0000000000..f4a9e87bae --- /dev/null +++ b/mmdet3d/models/dense_heads/pgd_head.py @@ -0,0 +1,1242 @@ +import numpy as np +import torch +from mmcv.cnn import Scale, bias_init_with_prob, normal_init +from mmcv.runner import force_fp32 +from torch import nn as nn +from torch.nn import functional as F + +from mmdet3d.core import box3d_multiclass_nms, xywhr2xyxyr +from mmdet3d.core.bbox import points_cam2img, points_img2cam +from mmdet.core import distance2bbox, multi_apply +from mmdet.models.builder import HEADS, build_loss +from .fcos_mono3d_head import FCOSMono3DHead + + +@HEADS.register_module() +class PGDHead(FCOSMono3DHead): + r"""Anchor-free head used in `PGD `_. + + Args: + use_depth_classifer (bool, optional): Whether to use depth classifier. + Defaults to True. + use_only_reg_proj (bool, optional): Whether to use only direct + regressed depth in the re-projection (to make the network easier + to learn). Defaults to False. + weight_dim (int, optional): Dimension of the location-aware weight + map. Defaults to -1. + weight_branch (tuple[tuple[int]], optional): Feature map channels of + the convolutional branch for weight map. Defaults to ((256, ), ). + depth_branch (tuple[int], optional): Feature map channels of the + branch for probabilistic depth estimation. Defaults to (64, ), + depth_range (tuple[float], optional): Range of depth estimation. + Defaults to (0, 70), + depth_unit (int, optional): Unit of depth range division. Defaults to + 10. + division (str, optional): Depth division method. Options include + 'uniform', 'linear', 'log', 'loguniform'. Defaults to 'uniform'. + depth_bins (int, optional): Discrete bins of depth division. Defaults + to 8. + loss_depth (dict, optional): Depth loss. Defaults to dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0). + loss_bbox2d (dict, optional): Loss for 2D box estimation. Defaults to + dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0). + loss_consistency (dict, optional): Consistency loss. Defaults to + dict(type='GIoULoss', loss_weight=1.0), + pred_velo (bool, optional): Whether to predict velocity. Defaults to + False. + pred_bbox2d (bool, optional): Whether to predict 2D bounding boxes. + Defaults to True. + pred_keypoints (bool, optional): Whether to predict keypoints. + Defaults to False, + bbox_coder (dict, optional): Bounding box coder. Defaults to + dict(type='PGDBBoxCoder', base_depths=((28.01, 16.32), ), + base_dims=((0.8, 1.73, 0.6), (1.76, 1.73, 0.6), (3.9, 1.56, 1.6)), + code_size=7). + """ + + def __init__(self, + use_depth_classifier=True, + use_onlyreg_proj=False, + weight_dim=-1, + weight_branch=((256, ), ), + depth_branch=(64, ), + depth_range=(0, 70), + depth_unit=10, + division='uniform', + depth_bins=8, + loss_depth=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + loss_bbox2d=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + loss_consistency=dict(type='GIoULoss', loss_weight=1.0), + pred_velo=False, + pred_bbox2d=True, + pred_keypoints=False, + bbox_coder=dict( + type='PGDBBoxCoder', + base_depths=((28.01, 16.32), ), + base_dims=((0.8, 1.73, 0.6), (1.76, 1.73, 0.6), + (3.9, 1.56, 1.6)), + code_size=7, + depth_range=(0, 70), + depth_unit=10, + division='uniform', + depth_bins=8), + **kwargs): + self.use_depth_classifier = use_depth_classifier + self.use_onlyreg_proj = use_onlyreg_proj + self.depth_branch = depth_branch + self.pred_keypoints = pred_keypoints + self.weight_dim = weight_dim + self.weight_branch = weight_branch + self.weight_out_channels = [] + for weight_branch_channels in weight_branch: + if len(weight_branch_channels) > 0: + self.weight_out_channels.append(weight_branch_channels[-1]) + else: + self.weight_out_channels.append(-1) + self.depth_range = depth_range + self.depth_unit = depth_unit + self.division = division + if self.division == 'uniform': + self.num_depth_cls = int( + (depth_range[1] - depth_range[0]) / depth_unit) + 1 + if self.num_depth_cls != depth_bins: + print('Warning: The number of bins computed from ' + + 'depth_unit is different from given parameter! ' + + 'Depth_unit will be considered with priority in ' + + 'Uniform Division.') + else: + self.num_depth_cls = depth_bins + super().__init__( + pred_bbox2d=pred_bbox2d, bbox_coder=bbox_coder, **kwargs) + self.loss_depth = build_loss(loss_depth) + if self.pred_bbox2d: + self.loss_bbox2d = build_loss(loss_bbox2d) + self.loss_consistency = build_loss(loss_consistency) + if self.pred_keypoints: + self.kpts_start = 9 if self.pred_velo else 7 + + def _init_layers(self): + """Initialize layers of the head.""" + super()._init_layers() + if self.pred_bbox2d: + self.scale_dim += 1 + if self.pred_keypoints: + self.scale_dim += 1 + self.scales = nn.ModuleList([ + nn.ModuleList([Scale(1.0) for _ in range(self.scale_dim)]) + for _ in self.strides + ]) + + def _init_predictor(self): + """Initialize predictor layers of the head.""" + super()._init_predictor() + + if self.use_depth_classifier: + self.conv_depth_cls_prev = self._init_branch( + conv_channels=self.depth_branch, + conv_strides=(1, ) * len(self.depth_branch)) + self.conv_depth_cls = nn.Conv2d(self.depth_branch[-1], + self.num_depth_cls, 1) + # Data-agnostic single param lambda for local depth fusion + self.fuse_lambda = nn.Parameter(torch.tensor(10e-5)) + + if self.weight_dim != -1: + self.conv_weight_prevs = nn.ModuleList() + self.conv_weights = nn.ModuleList() + for i in range(self.weight_dim): + weight_branch_channels = self.weight_branch[i] + weight_out_channel = self.weight_out_channels[i] + if len(weight_branch_channels) > 0: + self.conv_weight_prevs.append( + self._init_branch( + conv_channels=weight_branch_channels, + conv_strides=(1, ) * len(weight_branch_channels))) + self.conv_weights.append( + nn.Conv2d(weight_out_channel, 1, 1)) + else: + self.conv_weight_prevs.append(None) + self.conv_weights.append( + nn.Conv2d(self.feat_channels, 1, 1)) + + def init_weights(self): + """Initialize weights of the head. + + We currently still use the customized defined init_weights because the + default init of DCN triggered by the init_cfg will init + conv_offset.weight, which mistakenly affects the training stability. + """ + super().init_weights() + + bias_cls = bias_init_with_prob(0.01) + if self.use_depth_classifier: + for m in self.conv_depth_cls_prev: + if isinstance(m.conv, nn.Conv2d): + normal_init(m.conv, std=0.01) + normal_init(self.conv_depth_cls, std=0.01, bias=bias_cls) + + if self.weight_dim != -1: + for conv_weight_prev in self.conv_weight_prevs: + if conv_weight_prev is None: + continue + for m in conv_weight_prev: + if isinstance(m.conv, nn.Conv2d): + normal_init(m.conv, std=0.01) + for conv_weight in self.conv_weights: + normal_init(conv_weight, std=0.01) + + def forward(self, feats): + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * bbox_code_size. + dir_cls_preds (list[Tensor]): Box scores for direction class + predictions on each scale level, each is a 4D-tensor, + the channel number is num_points * 2. (bin = 2). + weight (list[Tensor]): Location-aware weight maps on each + scale level, each is a 4D-tensor, the channel number is + num_points * 1. + depth_cls_preds (list[Tensor]): Box scores for depth class + predictions on each scale level, each is a 4D-tensor, + the channel number is num_points * self.num_depth_cls. + attr_preds (list[Tensor]): Attribute scores for each scale + level, each is a 4D-tensor, the channel number is + num_points * num_attrs. + centernesses (list[Tensor]): Centerness for each scale level, + each is a 4D-tensor, the channel number is num_points * 1. + """ + return multi_apply(self.forward_single, feats, self.scales, + self.strides) + + def forward_single(self, x, scale, stride): + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + stride (int): The corresponding stride for feature maps, only + used to normalize the bbox prediction when self.norm_on_bbox + is True. + + Returns: + tuple: scores for each class, bbox and direction class + predictions, depth class predictions, location-aware weights, + attribute and centerness predictions of input feature maps. + """ + cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness, cls_feat, \ + reg_feat = super().forward_single(x, scale, stride) + + max_regress_range = stride * self.regress_ranges[0][1] / \ + self.strides[0] + bbox_pred = self.bbox_coder.decode_2d(bbox_pred, scale, stride, + max_regress_range, self.training, + self.pred_keypoints, + self.pred_bbox2d) + + depth_cls_pred = None + if self.use_depth_classifier: + clone_reg_feat = reg_feat.clone() + for conv_depth_cls_prev_layer in self.conv_depth_cls_prev: + clone_reg_feat = conv_depth_cls_prev_layer(clone_reg_feat) + depth_cls_pred = self.conv_depth_cls(clone_reg_feat) + + weight = None + if self.weight_dim != -1: + weight = [] + for i in range(self.weight_dim): + clone_reg_feat = reg_feat.clone() + if len(self.weight_branch[i]) > 0: + for conv_weight_prev_layer in self.conv_weight_prevs[i]: + clone_reg_feat = conv_weight_prev_layer(clone_reg_feat) + weight.append(self.conv_weights[i](clone_reg_feat)) + weight = torch.cat(weight, dim=1) + + return cls_score, bbox_pred, dir_cls_pred, depth_cls_pred, weight, \ + attr_pred, centerness + + def get_proj_bbox2d(self, + bbox_preds, + pos_dir_cls_preds, + labels_3d, + bbox_targets_3d, + pos_points, + pos_inds, + img_metas, + pos_depth_cls_preds=None, + pos_weights=None, + pos_cls_scores=None, + with_kpts=False): + """Decode box predictions and get projected 2D attributes. + + Args: + bbox_preds (list[Tensor]): Box predictions for each scale + level, each is a 4D-tensor, the channel number is + num_points * bbox_code_size. + pos_dir_cls_preds (Tensor): Box scores for direction class + predictions of positive boxes on all the scale levels in shape + (num_pos_points, 2). + labels_3d (list[Tensor]): 3D box category labels for each scale + level, each is a 4D-tensor. + bbox_targets_3d (list[Tensor]): 3D box targets for each scale + level, each is a 4D-tensor, the channel number is + num_points * bbox_code_size. + pos_points (Tensor): Foreground points. + pos_inds (Tensor): Index of foreground points from flattened + tensors. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + pos_depth_cls_preds (Tensor, optional): Probabilistic depth map of + positive boxes on all the scale levels in shape + (num_pos_points, self.num_depth_cls). Defaults to None. + pos_weights (Tensor, optional): Location-aware weights of positive + boxes in shape (num_pos_points, self.weight_dim). Defaults to + None. + pos_cls_scores (Tensor, optional): Classification scores of + positive boxes in shape (num_pos_points, self.num_classes). + Defaults to None. + with_kpts (bool, optional): Whether to output keypoints targets. + Defaults to False. + + Returns: + tuple[Tensor]: Exterior 2D boxes from projected 3D boxes, + predicted 2D boxes and keypoint targets (if necessary). + """ + views = [np.array(img_meta['cam2img']) for img_meta in img_metas] + num_imgs = len(img_metas) + img_idx = [] + for label in labels_3d: + for idx in range(num_imgs): + img_idx.append( + labels_3d[0].new_ones(int(len(label) / num_imgs)) * idx) + img_idx = torch.cat(img_idx) + pos_img_idx = img_idx[pos_inds] + + flatten_strided_bbox_preds = [] + flatten_strided_bbox2d_preds = [] + flatten_bbox_targets_3d = [] + flatten_strides = [] + + for stride_idx, bbox_pred in enumerate(bbox_preds): + flatten_bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape( + -1, sum(self.group_reg_dims)) + flatten_bbox_pred[:, :2] *= self.strides[stride_idx] + flatten_bbox_pred[:, -4:] *= self.strides[stride_idx] + flatten_strided_bbox_preds.append( + flatten_bbox_pred[:, :self.bbox_coder.bbox_code_size]) + flatten_strided_bbox2d_preds.append(flatten_bbox_pred[:, -4:]) + + bbox_target_3d = bbox_targets_3d[stride_idx].clone() + bbox_target_3d[:, :2] *= self.strides[stride_idx] + bbox_target_3d[:, -4:] *= self.strides[stride_idx] + flatten_bbox_targets_3d.append(bbox_target_3d) + + flatten_stride = flatten_bbox_pred.new_ones( + *flatten_bbox_pred.shape[:-1], 1) * self.strides[stride_idx] + flatten_strides.append(flatten_stride) + + flatten_strided_bbox_preds = torch.cat(flatten_strided_bbox_preds) + flatten_strided_bbox2d_preds = torch.cat(flatten_strided_bbox2d_preds) + flatten_bbox_targets_3d = torch.cat(flatten_bbox_targets_3d) + flatten_strides = torch.cat(flatten_strides) + pos_strided_bbox_preds = flatten_strided_bbox_preds[pos_inds] + pos_strided_bbox2d_preds = flatten_strided_bbox2d_preds[pos_inds] + pos_bbox_targets_3d = flatten_bbox_targets_3d[pos_inds] + pos_strides = flatten_strides[pos_inds] + + pos_decoded_bbox2d_preds = distance2bbox(pos_points, + pos_strided_bbox2d_preds) + + pos_strided_bbox_preds[:, :2] = \ + pos_points - pos_strided_bbox_preds[:, :2] + pos_bbox_targets_3d[:, :2] = \ + pos_points - pos_bbox_targets_3d[:, :2] + + if self.use_depth_classifier and (not self.use_onlyreg_proj): + pos_prob_depth_preds = self.bbox_coder.decode_prob_depth( + pos_depth_cls_preds, self.depth_range, self.depth_unit, + self.division, self.num_depth_cls) + sig_alpha = torch.sigmoid(self.fuse_lambda) + pos_strided_bbox_preds[:, 2] = \ + sig_alpha * pos_strided_bbox_preds.clone()[:, 2] + \ + (1 - sig_alpha) * pos_prob_depth_preds + + box_corners_in_image = pos_strided_bbox_preds.new_zeros( + (*pos_strided_bbox_preds.shape[:-1], 8, 2)) + box_corners_in_image_gt = pos_strided_bbox_preds.new_zeros( + (*pos_strided_bbox_preds.shape[:-1], 8, 2)) + + for idx in range(num_imgs): + mask = (pos_img_idx == idx) + if pos_strided_bbox_preds[mask].shape[0] == 0: + continue + cam2img = torch.eye( + 4, + dtype=pos_strided_bbox_preds.dtype, + device=pos_strided_bbox_preds.device) + view_shape = views[idx].shape + cam2img[:view_shape[0], :view_shape[1]] = \ + pos_strided_bbox_preds.new_tensor(views[idx]) + + centers2d_preds = pos_strided_bbox_preds.clone()[mask, :2] + centers2d_targets = pos_bbox_targets_3d.clone()[mask, :2] + centers3d_targets = points_img2cam(pos_bbox_targets_3d[mask, :3], + views[idx]) + + # use predicted depth to re-project the 2.5D centers + pos_strided_bbox_preds[mask, :3] = points_img2cam( + pos_strided_bbox_preds[mask, :3], views[idx]) + pos_bbox_targets_3d[mask, :3] = centers3d_targets + + # depth fixed when computing re-project 3D bboxes + pos_strided_bbox_preds[mask, 2] = \ + pos_bbox_targets_3d.clone()[mask, 2] + + # decode yaws + if self.use_direction_classifier: + pos_dir_cls_scores = torch.max( + pos_dir_cls_preds[mask], dim=-1)[1] + pos_strided_bbox_preds[mask] = self.bbox_coder.decode_yaw( + pos_strided_bbox_preds[mask], centers2d_preds, + pos_dir_cls_scores, self.dir_offset, cam2img) + pos_bbox_targets_3d[mask, 6] = torch.atan2( + centers2d_targets[:, 0] - cam2img[0, 2], + cam2img[0, 0]) + pos_bbox_targets_3d[mask, 6] + + corners = img_metas[0]['box_type_3d']( + pos_strided_bbox_preds[mask], + box_dim=self.bbox_coder.bbox_code_size, + origin=(0.5, 0.5, 0.5)).corners + box_corners_in_image[mask] = points_cam2img(corners, cam2img) + + corners_gt = img_metas[0]['box_type_3d']( + pos_bbox_targets_3d[mask, :self.bbox_code_size], + box_dim=self.bbox_coder.bbox_code_size, + origin=(0.5, 0.5, 0.5)).corners + box_corners_in_image_gt[mask] = points_cam2img(corners_gt, cam2img) + + minxy = torch.min(box_corners_in_image, dim=1)[0] + maxxy = torch.max(box_corners_in_image, dim=1)[0] + proj_bbox2d_preds = torch.cat([minxy, maxxy], dim=1) + + outputs = (proj_bbox2d_preds, pos_decoded_bbox2d_preds) + + if with_kpts: + norm_strides = pos_strides * self.regress_ranges[0][1] / \ + self.strides[0] + kpts_targets = box_corners_in_image_gt - pos_points[..., None, :] + kpts_targets = kpts_targets.view( + (*pos_strided_bbox_preds.shape[:-1], 16)) + kpts_targets /= norm_strides + + outputs += (kpts_targets, ) + + return outputs + + def get_pos_predictions(self, bbox_preds, dir_cls_preds, depth_cls_preds, + weights, attr_preds, centernesses, pos_inds, + img_metas): + """Flatten predictions and get positive ones. + + Args: + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * bbox_code_size. + dir_cls_preds (list[Tensor]): Box scores for direction class + predictions on each scale level, each is a 4D-tensor, + the channel number is num_points * 2. (bin = 2) + depth_cls_preds (list[Tensor]): Box scores for direction class + predictions on each scale level, each is a 4D-tensor, + the channel number is num_points * self.num_depth_cls. + attr_preds (list[Tensor]): Attribute scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_attrs. + centernesses (list[Tensor]): Centerness for each scale level, each + is a 4D-tensor, the channel number is num_points * 1. + pos_inds (Tensor): Index of foreground points from flattened + tensors. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + tuple[Tensor]: Box predictions, direction classes, probabilistic + depth maps, location-aware weight maps, attributes and + centerness predictions. + """ + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, sum(self.group_reg_dims)) + for bbox_pred in bbox_preds + ] + flatten_dir_cls_preds = [ + dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2) + for dir_cls_pred in dir_cls_preds + ] + flatten_centerness = [ + centerness.permute(0, 2, 3, 1).reshape(-1) + for centerness in centernesses + ] + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_dir_cls_preds = torch.cat(flatten_dir_cls_preds) + flatten_centerness = torch.cat(flatten_centerness) + pos_bbox_preds = flatten_bbox_preds[pos_inds] + pos_dir_cls_preds = flatten_dir_cls_preds[pos_inds] + pos_centerness = flatten_centerness[pos_inds] + + pos_depth_cls_preds = None + if self.use_depth_classifier: + flatten_depth_cls_preds = [ + depth_cls_pred.permute(0, 2, 3, + 1).reshape(-1, self.num_depth_cls) + for depth_cls_pred in depth_cls_preds + ] + flatten_depth_cls_preds = torch.cat(flatten_depth_cls_preds) + pos_depth_cls_preds = flatten_depth_cls_preds[pos_inds] + + pos_weights = None + if self.weight_dim != -1: + flatten_weights = [ + weight.permute(0, 2, 3, 1).reshape(-1, self.weight_dim) + for weight in weights + ] + flatten_weights = torch.cat(flatten_weights) + pos_weights = flatten_weights[pos_inds] + + pos_attr_preds = None + if self.pred_attrs: + flatten_attr_preds = [ + attr_pred.permute(0, 2, 3, 1).reshape(-1, self.num_attrs) + for attr_pred in attr_preds + ] + flatten_attr_preds = torch.cat(flatten_attr_preds) + pos_attr_preds = flatten_attr_preds[pos_inds] + + return pos_bbox_preds, pos_dir_cls_preds, pos_depth_cls_preds, \ + pos_weights, pos_attr_preds, pos_centerness + + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds', + 'depth_cls_preds', 'weights', 'attr_preds', 'centernesses')) + def loss(self, + cls_scores, + bbox_preds, + dir_cls_preds, + depth_cls_preds, + weights, + attr_preds, + centernesses, + gt_bboxes, + gt_labels, + gt_bboxes_3d, + gt_labels_3d, + centers2d, + depths, + attr_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute loss of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * bbox_code_size. + dir_cls_preds (list[Tensor]): Box scores for direction class + predictions on each scale level, each is a 4D-tensor, + the channel number is num_points * 2. (bin = 2) + depth_cls_preds (list[Tensor]): Box scores for direction class + predictions on each scale level, each is a 4D-tensor, + the channel number is num_points * self.num_depth_cls. + weights (list[Tensor]): Location-aware weights for each scale + level, each is a 4D-tensor, the channel number is + num_points * self.weight_dim. + attr_preds (list[Tensor]): Attribute scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_attrs. + centernesses (list[Tensor]): Centerness for each scale level, each + is a 4D-tensor, the channel number is num_points * 1. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + gt_bboxes_3d (list[Tensor]): 3D boxes ground truth with shape of + (num_gts, code_size). + gt_labels_3d (list[Tensor]): same as gt_labels + centers2d (list[Tensor]): 2D centers on the image with shape of + (num_gts, 2). + depths (list[Tensor]): Depth ground truth with shape of + (num_gts, ). + attr_labels (list[Tensor]): Attributes indices of each box. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor]): specify which bounding boxes can + be ignored when computing the loss. Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(cls_scores) == len(bbox_preds) == len(dir_cls_preds) == \ + len(depth_cls_preds) == len(weights) == len(centernesses) == \ + len(attr_preds), 'The length of cls_scores, bbox_preds, ' \ + 'dir_cls_preds, depth_cls_preds, weights, centernesses, and' \ + f'attr_preds: {len(cls_scores)}, {len(bbox_preds)}, ' \ + f'{len(dir_cls_preds)}, {len(depth_cls_preds)}, {len(weights)}' \ + f'{len(centernesses)}, {len(attr_preds)} are inconsistent.' + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, + bbox_preds[0].device) + labels_3d, bbox_targets_3d, centerness_targets, attr_targets = \ + self.get_targets( + all_level_points, gt_bboxes, gt_labels, gt_bboxes_3d, + gt_labels_3d, centers2d, depths, attr_labels) + + num_imgs = cls_scores[0].size(0) + # flatten cls_scores and targets + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_labels_3d = torch.cat(labels_3d) + flatten_bbox_targets_3d = torch.cat(bbox_targets_3d) + flatten_centerness_targets = torch.cat(centerness_targets) + flatten_points = torch.cat( + [points.repeat(num_imgs, 1) for points in all_level_points]) + if self.pred_attrs: + flatten_attr_targets = torch.cat(attr_targets) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((flatten_labels_3d >= 0) + & (flatten_labels_3d < bg_class_ind)).nonzero().reshape(-1) + num_pos = len(pos_inds) + + loss_cls = self.loss_cls( + flatten_cls_scores, + flatten_labels_3d, + avg_factor=num_pos + num_imgs) # avoid num_pos is 0 + + pos_bbox_preds, pos_dir_cls_preds, pos_depth_cls_preds, pos_weights, \ + pos_attr_preds, pos_centerness = self.get_pos_predictions( + bbox_preds, dir_cls_preds, depth_cls_preds, weights, + attr_preds, centernesses, pos_inds, img_metas) + + loss_dict = dict() + + if num_pos > 0: + pos_bbox_targets_3d = flatten_bbox_targets_3d[pos_inds] + pos_centerness_targets = flatten_centerness_targets[pos_inds] + pos_points = flatten_points[pos_inds] + if self.pred_attrs: + pos_attr_targets = flatten_attr_targets[pos_inds] + if self.use_direction_classifier: + pos_dir_cls_targets = self.get_direction_target( + pos_bbox_targets_3d, self.dir_offset, one_hot=False) + + bbox_weights = pos_centerness_targets.new_ones( + len(pos_centerness_targets), sum(self.group_reg_dims)) + equal_weights = pos_centerness_targets.new_ones( + pos_centerness_targets.shape) + code_weight = self.train_cfg.get('code_weight', None) + if code_weight: + assert len(code_weight) == sum(self.group_reg_dims) + bbox_weights = bbox_weights * bbox_weights.new_tensor( + code_weight) + + if self.diff_rad_by_sin: + pos_bbox_preds, pos_bbox_targets_3d = self.add_sin_difference( + pos_bbox_preds, pos_bbox_targets_3d) + + loss_offset = self.loss_bbox( + pos_bbox_preds[:, :2], + pos_bbox_targets_3d[:, :2], + weight=bbox_weights[:, :2], + avg_factor=equal_weights.sum()) + loss_size = self.loss_bbox( + pos_bbox_preds[:, 3:6], + pos_bbox_targets_3d[:, 3:6], + weight=bbox_weights[:, 3:6], + avg_factor=equal_weights.sum()) + loss_rotsin = self.loss_bbox( + pos_bbox_preds[:, 6], + pos_bbox_targets_3d[:, 6], + weight=bbox_weights[:, 6], + avg_factor=equal_weights.sum()) + if self.pred_velo: + loss_dict['loss_velo'] = self.loss_bbox( + pos_bbox_preds[:, 7:9], + pos_bbox_targets_3d[:, 7:9], + weight=bbox_weights[:, 7:9], + avg_factor=equal_weights.sum()) + + proj_bbox2d_inputs = (bbox_preds, pos_dir_cls_preds, labels_3d, + bbox_targets_3d, pos_points, pos_inds, + img_metas) + + # direction classification loss + # TODO: add more check for use_direction_classifier + if self.use_direction_classifier: + loss_dict['loss_dir'] = self.loss_dir( + pos_dir_cls_preds, + pos_dir_cls_targets, + equal_weights, + avg_factor=equal_weights.sum()) + + # init depth loss with the one computed from direct regression + loss_dict['loss_depth'] = self.loss_bbox( + pos_bbox_preds[:, 2], + pos_bbox_targets_3d[:, 2], + weight=bbox_weights[:, 2], + avg_factor=equal_weights.sum()) + # depth classification loss + if self.use_depth_classifier: + pos_prob_depth_preds = self.bbox_coder.decode_prob_depth( + pos_depth_cls_preds, self.depth_range, self.depth_unit, + self.division, self.num_depth_cls) + sig_alpha = torch.sigmoid(self.fuse_lambda) + if self.weight_dim != -1: + loss_fuse_depth = self.loss_depth( + sig_alpha * pos_bbox_preds[:, 2] + + (1 - sig_alpha) * pos_prob_depth_preds, + pos_bbox_targets_3d[:, 2], + sigma=pos_weights[:, 0], + weight=bbox_weights[:, 2], + avg_factor=equal_weights.sum()) + else: + loss_fuse_depth = self.loss_depth( + sig_alpha * pos_bbox_preds[:, 2] + + (1 - sig_alpha) * pos_prob_depth_preds, + pos_bbox_targets_3d[:, 2], + weight=bbox_weights[:, 2], + avg_factor=equal_weights.sum()) + loss_dict['loss_depth'] = loss_fuse_depth + + proj_bbox2d_inputs += (pos_depth_cls_preds, ) + + if self.pred_keypoints: + # use smoothL1 to compute consistency loss for keypoints + # normalize the offsets with strides + proj_bbox2d_preds, pos_decoded_bbox2d_preds, kpts_targets = \ + self.get_proj_bbox2d(*proj_bbox2d_inputs, with_kpts=True) + loss_dict['loss_kpts'] = self.loss_bbox( + pos_bbox_preds[:, self.kpts_start:self.kpts_start + 16], + kpts_targets, + weight=bbox_weights[:, + self.kpts_start:self.kpts_start + 16], + avg_factor=equal_weights.sum()) + + if self.pred_bbox2d: + loss_dict['loss_bbox2d'] = self.loss_bbox2d( + pos_bbox_preds[:, -4:], + pos_bbox_targets_3d[:, -4:], + weight=bbox_weights[:, -4:], + avg_factor=equal_weights.sum()) + if not self.pred_keypoints: + proj_bbox2d_preds, pos_decoded_bbox2d_preds = \ + self.get_proj_bbox2d(*proj_bbox2d_inputs) + loss_dict['loss_consistency'] = self.loss_consistency( + proj_bbox2d_preds, + pos_decoded_bbox2d_preds, + weight=bbox_weights[:, -4:], + avg_factor=equal_weights.sum()) + + loss_centerness = self.loss_centerness(pos_centerness, + pos_centerness_targets) + + # attribute classification loss + if self.pred_attrs: + loss_dict['loss_attr'] = self.loss_attr( + pos_attr_preds, + pos_attr_targets, + pos_centerness_targets, + avg_factor=pos_centerness_targets.sum()) + + else: + # need absolute due to possible negative delta x/y + loss_offset = pos_bbox_preds[:, :2].sum() + loss_size = pos_bbox_preds[:, 3:6].sum() + loss_rotsin = pos_bbox_preds[:, 6].sum() + loss_dict['loss_depth'] = pos_bbox_preds[:, 2].sum() + if self.pred_velo: + loss_dict['loss_velo'] = pos_bbox_preds[:, 7:9].sum() + if self.pred_keypoints: + loss_dict['loss_kpts'] = pos_bbox_preds[:, + self.kpts_start:self. + kpts_start + 16].sum() + if self.pred_bbox2d: + loss_dict['loss_bbox2d'] = pos_bbox_preds[:, -4:].sum() + loss_dict['loss_consistency'] = pos_bbox_preds[:, -4:].sum() + loss_centerness = pos_centerness.sum() + if self.use_direction_classifier: + loss_dict['loss_dir'] = pos_dir_cls_preds.sum() + if self.use_depth_classifier: + sig_alpha = torch.sigmoid(self.fuse_lambda) + if self.weight_dim != -1: + loss_fuse_depth *= torch.exp(-pos_weights[:, 0].sum()) + else: + loss_fuse_depth = \ + sig_alpha * pos_bbox_preds[:, 2].sum() + \ + (1 - sig_alpha) * pos_depth_cls_preds.sum() + loss_dict['loss_depth'] = loss_fuse_depth + if self.pred_attrs: + loss_dict['loss_attr'] = pos_attr_preds.sum() + + loss_dict.update( + dict( + loss_cls=loss_cls, + loss_offset=loss_offset, + loss_size=loss_size, + loss_rotsin=loss_rotsin, + loss_centerness=loss_centerness)) + + return loss_dict + + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds', + 'depth_cls_preds', 'weights', 'attr_preds', 'centernesses')) + def get_bboxes(self, + cls_scores, + bbox_preds, + dir_cls_preds, + depth_cls_preds, + weights, + attr_preds, + centernesses, + img_metas, + cfg=None, + rescale=None): + """Transform network output for a batch into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_points * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_points * 4, H, W) + dir_cls_preds (list[Tensor]): Box scores for direction class + predictions on each scale level, each is a 4D-tensor, + the channel number is num_points * 2. (bin = 2) + depth_cls_preds (list[Tensor]): Box scores for direction class + predictions on each scale level, each is a 4D-tensor, + the channel number is num_points * self.num_depth_cls. + weights (list[Tensor]): Location-aware weights for each scale + level, each is a 4D-tensor, the channel number is + num_points * self.weight_dim. + attr_preds (list[Tensor]): Attribute scores for each scale level + Has shape (N, num_points * num_attrs, H, W) + centernesses (list[Tensor]): Centerness for each scale level with + shape (N, num_points * 1, H, W) + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config, optional): Test / postprocessing configuration, + if None, test_cfg would be used. Defaults to None. + rescale (bool, optional): If True, return boxes in original image + space. Defaults to None. + + Returns: + list[tuple[Tensor]]: Each item in result_list is a tuple, which + consists of predicted 3D boxes, scores, labels, attributes and + 2D boxes (if necessary). + """ + assert len(cls_scores) == len(bbox_preds) == len(dir_cls_preds) == \ + len(depth_cls_preds) == len(weights) == len(centernesses) == \ + len(attr_preds), 'The length of cls_scores, bbox_preds, ' \ + 'dir_cls_preds, depth_cls_preds, weights, centernesses, and' \ + f'attr_preds: {len(cls_scores)}, {len(bbox_preds)}, ' \ + f'{len(dir_cls_preds)}, {len(depth_cls_preds)}, {len(weights)}' \ + f'{len(centernesses)}, {len(attr_preds)} are inconsistent.' + num_levels = len(cls_scores) + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, + bbox_preds[0].device) + result_list = [] + for img_id in range(len(img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(num_levels) + ] + if self.use_direction_classifier: + dir_cls_pred_list = [ + dir_cls_preds[i][img_id].detach() + for i in range(num_levels) + ] + else: + dir_cls_pred_list = [ + cls_scores[i][img_id].new_full( + [2, *cls_scores[i][img_id].shape[1:]], 0).detach() + for i in range(num_levels) + ] + if self.use_depth_classifier: + depth_cls_pred_list = [ + depth_cls_preds[i][img_id].detach() + for i in range(num_levels) + ] + else: + depth_cls_pred_list = [ + cls_scores[i][img_id].new_full( + [self.num_depth_cls, *cls_scores[i][img_id].shape[1:]], + 0).detach() for i in range(num_levels) + ] + if self.weight_dim != -1: + weight_list = [ + weights[i][img_id].detach() for i in range(num_levels) + ] + else: + weight_list = [ + cls_scores[i][img_id].new_full( + [1, *cls_scores[i][img_id].shape[1:]], 0).detach() + for i in range(num_levels) + ] + if self.pred_attrs: + attr_pred_list = [ + attr_preds[i][img_id].detach() for i in range(num_levels) + ] + else: + attr_pred_list = [ + cls_scores[i][img_id].new_full( + [self.num_attrs, *cls_scores[i][img_id].shape[1:]], + self.attr_background_label).detach() + for i in range(num_levels) + ] + centerness_pred_list = [ + centernesses[i][img_id].detach() for i in range(num_levels) + ] + input_meta = img_metas[img_id] + det_bboxes = self._get_bboxes_single( + cls_score_list, bbox_pred_list, dir_cls_pred_list, + depth_cls_pred_list, weight_list, attr_pred_list, + centerness_pred_list, mlvl_points, input_meta, cfg, rescale) + result_list.append(det_bboxes) + return result_list + + def _get_bboxes_single(self, + cls_scores, + bbox_preds, + dir_cls_preds, + depth_cls_preds, + weights, + attr_preds, + centernesses, + mlvl_points, + input_meta, + cfg, + rescale=False): + """Transform outputs for a single batch item into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for a single scale level + Has shape (num_points * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for a single scale + level with shape (num_points * bbox_code_size, H, W). + dir_cls_preds (list[Tensor]): Box scores for direction class + predictions on a single scale level with shape + (num_points * 2, H, W) + depth_cls_preds (list[Tensor]): Box scores for probabilistic depth + predictions on a single scale level with shape + (num_points * self.num_depth_cls, H, W) + weights (list[Tensor]): Location-aware weight maps on a single + scale level with shape (num_points * self.weight_dim, H, W). + attr_preds (list[Tensor]): Attribute scores for each scale level + Has shape (N, num_points * num_attrs, H, W) + centernesses (list[Tensor]): Centerness for a single scale level + with shape (num_points, H, W). + mlvl_points (list[Tensor]): Box reference for a single scale level + with shape (num_total_points, 2). + input_meta (dict): Metadata of input image. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool, optional): If True, return boxes in original image + space. Defaults to False. + + Returns: + tuples[Tensor]: Predicted 3D boxes, scores, labels, attributes and + 2D boxes (if necessary). + """ + view = np.array(input_meta['cam2img']) + scale_factor = input_meta['scale_factor'] + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) + mlvl_centers2d = [] + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_dir_scores = [] + mlvl_attr_scores = [] + mlvl_centerness = [] + mlvl_depth_cls_scores = [] + mlvl_depth_uncertainty = [] + mlvl_bboxes2d = None + if self.pred_bbox2d: + mlvl_bboxes2d = [] + + for cls_score, bbox_pred, dir_cls_pred, depth_cls_pred, weight, \ + attr_pred, centerness, points in zip( + cls_scores, bbox_preds, dir_cls_preds, depth_cls_preds, + weights, attr_preds, centernesses, mlvl_points): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + dir_cls_pred = dir_cls_pred.permute(1, 2, 0).reshape(-1, 2) + dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1] + depth_cls_pred = depth_cls_pred.permute(1, 2, 0).reshape( + -1, self.num_depth_cls) + depth_cls_score = F.softmax( + depth_cls_pred, dim=-1).topk( + k=2, dim=-1)[0].mean(dim=-1) + if self.weight_dim != -1: + weight = weight.permute(1, 2, 0).reshape(-1, self.weight_dim) + else: + weight = weight.permute(1, 2, 0).reshape(-1, 1) + depth_uncertainty = torch.exp(-weight[:, -1]) + attr_pred = attr_pred.permute(1, 2, 0).reshape(-1, self.num_attrs) + attr_score = torch.max(attr_pred, dim=-1)[1] + centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() + + bbox_pred = bbox_pred.permute(1, 2, + 0).reshape(-1, + sum(self.group_reg_dims)) + bbox_pred3d = bbox_pred[:, :self.bbox_coder.bbox_code_size] + if self.pred_bbox2d: + bbox_pred2d = bbox_pred[:, -4:] + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + merged_scores = scores * centerness[:, None] + if self.use_depth_classifier: + merged_scores *= depth_cls_score[:, None] + if self.weight_dim != -1: + merged_scores *= depth_uncertainty[:, None] + max_scores, _ = merged_scores.max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + points = points[topk_inds, :] + bbox_pred3d = bbox_pred3d[topk_inds, :] + scores = scores[topk_inds, :] + dir_cls_pred = dir_cls_pred[topk_inds, :] + depth_cls_pred = depth_cls_pred[topk_inds, :] + centerness = centerness[topk_inds] + dir_cls_score = dir_cls_score[topk_inds] + depth_cls_score = depth_cls_score[topk_inds] + depth_uncertainty = depth_uncertainty[topk_inds] + attr_score = attr_score[topk_inds] + if self.pred_bbox2d: + bbox_pred2d = bbox_pred2d[topk_inds, :] + # change the offset to actual center predictions + bbox_pred3d[:, :2] = points - bbox_pred3d[:, :2] + if rescale: + bbox_pred3d[:, :2] /= bbox_pred3d[:, :2].new_tensor( + scale_factor) + if self.pred_bbox2d: + bbox_pred2d /= bbox_pred2d.new_tensor(scale_factor) + if self.use_depth_classifier: + prob_depth_pred = self.bbox_coder.decode_prob_depth( + depth_cls_pred, self.depth_range, self.depth_unit, + self.division, self.num_depth_cls) + sig_alpha = torch.sigmoid(self.fuse_lambda) + bbox_pred3d[:, 2] = sig_alpha * bbox_pred3d[:, 2] + \ + (1 - sig_alpha) * prob_depth_pred + pred_center2d = bbox_pred3d[:, :3].clone() + bbox_pred3d[:, :3] = points_img2cam(bbox_pred3d[:, :3], view) + mlvl_centers2d.append(pred_center2d) + mlvl_bboxes.append(bbox_pred3d) + mlvl_scores.append(scores) + mlvl_dir_scores.append(dir_cls_score) + mlvl_depth_cls_scores.append(depth_cls_score) + mlvl_attr_scores.append(attr_score) + mlvl_centerness.append(centerness) + mlvl_depth_uncertainty.append(depth_uncertainty) + if self.pred_bbox2d: + bbox_pred2d = distance2bbox( + points, bbox_pred2d, max_shape=input_meta['img_shape']) + mlvl_bboxes2d.append(bbox_pred2d) + + mlvl_centers2d = torch.cat(mlvl_centers2d) + mlvl_bboxes = torch.cat(mlvl_bboxes) + mlvl_dir_scores = torch.cat(mlvl_dir_scores) + if self.pred_bbox2d: + mlvl_bboxes2d = torch.cat(mlvl_bboxes2d) + + # change local yaw to global yaw for 3D nms + cam2img = torch.eye( + 4, dtype=mlvl_centers2d.dtype, device=mlvl_centers2d.device) + cam2img[:view.shape[0], :view.shape[1]] = \ + mlvl_centers2d.new_tensor(view) + mlvl_bboxes = self.bbox_coder.decode_yaw(mlvl_bboxes, mlvl_centers2d, + mlvl_dir_scores, + self.dir_offset, cam2img) + + mlvl_bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d']( + mlvl_bboxes, + box_dim=self.bbox_coder.bbox_code_size, + origin=(0.5, 0.5, 0.5)).bev) + + mlvl_scores = torch.cat(mlvl_scores) + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + mlvl_attr_scores = torch.cat(mlvl_attr_scores) + mlvl_centerness = torch.cat(mlvl_centerness) + # no scale_factors in box3d_multiclass_nms + # Then we multiply it from outside + mlvl_nms_scores = mlvl_scores * mlvl_centerness[:, None] + if self.use_depth_classifier: # multiply the depth confidence + mlvl_depth_cls_scores = torch.cat(mlvl_depth_cls_scores) + mlvl_nms_scores *= mlvl_depth_cls_scores[:, None] + if self.weight_dim != -1: + mlvl_depth_uncertainty = torch.cat(mlvl_depth_uncertainty) + mlvl_nms_scores *= mlvl_depth_uncertainty[:, None] + results = box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms, + mlvl_nms_scores, cfg.score_thr, + cfg.max_per_img, cfg, mlvl_dir_scores, + mlvl_attr_scores, mlvl_bboxes2d) + bboxes, scores, labels, dir_scores, attrs = results[0:5] + attrs = attrs.to(labels.dtype) # change data type to int + bboxes = input_meta['box_type_3d']( + bboxes, + box_dim=self.bbox_coder.bbox_code_size, + origin=(0.5, 0.5, 0.5)) + # Note that the predictions use origin (0.5, 0.5, 0.5) + # Due to the ground truth centers2d are the gravity center of objects + # v0.10.0 fix inplace operation to the input tensor of cam_box3d + # So here we also need to add origin=(0.5, 0.5, 0.5) + if not self.pred_attrs: + attrs = None + + outputs = (bboxes, scores, labels, attrs) + if self.pred_bbox2d: + bboxes2d = results[-1] + bboxes2d = torch.cat([bboxes2d, scores[:, None]], dim=1) + outputs = outputs + (bboxes2d, ) + + return outputs + + def get_targets(self, points, gt_bboxes_list, gt_labels_list, + gt_bboxes_3d_list, gt_labels_3d_list, centers2d_list, + depths_list, attr_labels_list): + """Compute regression, classification and centerss targets for points + in multiple images. + + Args: + points (list[Tensor]): Points of each fpn level, each has shape + (num_points, 2). + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, + each has shape (num_gt, 4). + gt_labels_list (list[Tensor]): Ground truth labels of each box, + each has shape (num_gt,). + gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each + image, each has shape (num_gt, bbox_code_size). + gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each + box, each has shape (num_gt,). + centers2d_list (list[Tensor]): Projected 3D centers onto 2D image, + each has shape (num_gt, 2). + depths_list (list[Tensor]): Depth of projected 3D centers onto 2D + image, each has shape (num_gt, 1). + attr_labels_list (list[Tensor]): Attribute labels of each box, + each has shape (num_gt,). + + Returns: + tuple: + concat_lvl_labels (list[Tensor]): Labels of each level. \ + concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ + level. + """ + assert len(points) == len(self.regress_ranges) + num_levels = len(points) + # expand regress ranges to align with points + expanded_regress_ranges = [ + points[i].new_tensor(self.regress_ranges[i])[None].expand_as( + points[i]) for i in range(num_levels) + ] + # concat all levels points and regress ranges + concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) + concat_points = torch.cat(points, dim=0) + + # the number of points per img, per lvl + num_points = [center.size(0) for center in points] + + if attr_labels_list is None: + attr_labels_list = [ + gt_labels.new_full(gt_labels.shape, self.attr_background_label) + for gt_labels in gt_labels_list + ] + + # get labels and bbox_targets of each image + _, bbox_targets_list, labels_3d_list, bbox_targets_3d_list, \ + centerness_targets_list, attr_targets_list = multi_apply( + self._get_target_single, + gt_bboxes_list, + gt_labels_list, + gt_bboxes_3d_list, + gt_labels_3d_list, + centers2d_list, + depths_list, + attr_labels_list, + points=concat_points, + regress_ranges=concat_regress_ranges, + num_points_per_lvl=num_points) + + # split to per img, per level + bbox_targets_list = [ + bbox_targets.split(num_points, 0) + for bbox_targets in bbox_targets_list + ] + labels_3d_list = [ + labels_3d.split(num_points, 0) for labels_3d in labels_3d_list + ] + bbox_targets_3d_list = [ + bbox_targets_3d.split(num_points, 0) + for bbox_targets_3d in bbox_targets_3d_list + ] + centerness_targets_list = [ + centerness_targets.split(num_points, 0) + for centerness_targets in centerness_targets_list + ] + attr_targets_list = [ + attr_targets.split(num_points, 0) + for attr_targets in attr_targets_list + ] + + # concat per level image + concat_lvl_labels_3d = [] + concat_lvl_bbox_targets_3d = [] + concat_lvl_centerness_targets = [] + concat_lvl_attr_targets = [] + for i in range(num_levels): + concat_lvl_labels_3d.append( + torch.cat([labels[i] for labels in labels_3d_list])) + concat_lvl_centerness_targets.append( + torch.cat([ + centerness_targets[i] + for centerness_targets in centerness_targets_list + ])) + bbox_targets_3d = torch.cat([ + bbox_targets_3d[i] for bbox_targets_3d in bbox_targets_3d_list + ]) + if self.pred_bbox2d: + bbox_targets = torch.cat( + [bbox_targets[i] for bbox_targets in bbox_targets_list]) + bbox_targets_3d = torch.cat([bbox_targets_3d, bbox_targets], + dim=1) + concat_lvl_attr_targets.append( + torch.cat( + [attr_targets[i] for attr_targets in attr_targets_list])) + if self.norm_on_bbox: + bbox_targets_3d[:, :2] = \ + bbox_targets_3d[:, :2] / self.strides[i] + if self.pred_bbox2d: + bbox_targets_3d[:, -4:] = \ + bbox_targets_3d[:, -4:] / self.strides[i] + concat_lvl_bbox_targets_3d.append(bbox_targets_3d) + return concat_lvl_labels_3d, concat_lvl_bbox_targets_3d, \ + concat_lvl_centerness_targets, concat_lvl_attr_targets diff --git a/tests/test_models/test_heads/test_heads.py b/tests/test_models/test_heads/test_heads.py index 6c31833028..8fb545d4d8 100644 --- a/tests/test_models/test_heads/test_heads.py +++ b/tests/test_models/test_heads/test_heads.py @@ -1319,3 +1319,75 @@ def test_groupfree3d_head(): assert results[0][0].tensor.shape[1] == 7 assert results[0][1].shape[0] >= 0 assert results[0][2].shape[0] >= 0 + + +def test_pgd_head(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + _setup_seed(0) + pgd_head_cfg = _get_head_cfg( + 'pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py') + self = build_head(pgd_head_cfg).cuda() + + feats = [ + torch.rand([2, 256, 96, 312], dtype=torch.float32).cuda(), + torch.rand([2, 256, 48, 156], dtype=torch.float32).cuda(), + torch.rand([2, 256, 24, 78], dtype=torch.float32).cuda(), + torch.rand([2, 256, 12, 39], dtype=torch.float32).cuda(), + ] + + # test forward + ret_dict = self(feats) + assert len(ret_dict) == 7 + assert len(ret_dict[0]) == 4 + assert ret_dict[0][0].shape == torch.Size([2, 3, 96, 312]) + + # test loss + gt_bboxes = [ + torch.rand([3, 4], dtype=torch.float32).cuda(), + torch.rand([3, 4], dtype=torch.float32).cuda() + ] + gt_bboxes_3d = CameraInstance3DBoxes( + torch.rand([3, 7], device='cuda'), box_dim=7) + gt_labels = [torch.randint(0, 3, [3], device='cuda') for i in range(2)] + gt_labels_3d = gt_labels + centers2d = [ + torch.rand([3, 2], dtype=torch.float32).cuda(), + torch.rand([3, 2], dtype=torch.float32).cuda() + ] + depths = [ + torch.rand([3], dtype=torch.float32).cuda(), + torch.rand([3], dtype=torch.float32).cuda() + ] + attr_labels = None + img_metas = [ + dict( + img_shape=[384, 1248], + cam2img=[[721.5377, 0.0, 609.5593, 44.85728], + [0.0, 721.5377, 172.854, 0.2163791], + [0.0, 0.0, 1.0, 0.002745884], [0.0, 0.0, 0.0, 1.0]], + scale_factor=np.array([1., 1., 1., 1.], dtype=np.float32), + box_type_3d=CameraInstance3DBoxes) for i in range(2) + ] + losses = self.loss(*ret_dict, gt_bboxes, gt_labels, gt_bboxes_3d, + gt_labels_3d, centers2d, depths, attr_labels, img_metas) + assert losses['loss_cls'] >= 0 + assert losses['loss_offset'] >= 0 + assert losses['loss_depth'] >= 0 + assert losses['loss_size'] >= 0 + assert losses['loss_rotsin'] >= 0 + assert losses['loss_centerness'] >= 0 + assert losses['loss_kpts'] >= 0 + assert losses['loss_bbox2d'] >= 0 + assert losses['loss_consistency'] >= 0 + assert losses['loss_dir'] >= 0 + + # test get_boxes + results = self.get_bboxes(*ret_dict, img_metas) + assert len(results) == 2 + assert len(results[0]) == 5 + assert results[0][0].tensor.shape == torch.Size([20, 7]) + assert results[0][1].shape == torch.Size([20]) + assert results[0][2].shape == torch.Size([20]) + assert results[0][3] is None + assert results[0][4].shape == torch.Size([20, 5]) diff --git a/tests/test_utils/test_bbox_coders.py b/tests/test_utils/test_bbox_coders.py index 3b42d64d86..56a1e6af52 100644 --- a/tests/test_utils/test_bbox_coders.py +++ b/tests/test_utils/test_bbox_coders.py @@ -502,17 +502,17 @@ def test_pgd_bbox_coder(): max_regress_range, training, pred_keypoints, pred_bbox2d) expected_decode_bbox_w2d = torch.tensor( - [[[[0.0206]], [[1.4788]], [[1.3904]], [[1.6013]], [[1.1548]], - [[1.0809]], [[0.9399]], [[13.3856]], [[2.0224]], [[4.8480]], - [[3.0368]], [[1.1424]], [[6.6304]], [[6.9456]], [[10.3072]], - [[4.7216]], [[4.6240]], [[7.1776]], [[4.5568]], [[1.7136]], - [[15.2480]], [[15.1360]], [[6.1152]], [[1.8640]], [[0.5222]], - [[1.1160]], [[0.0794]]], + [[[[0.0206]], [[1.4788]], + [[1.3904]], [[1.6013]], [[1.1548]], [[1.0809]], [[0.9399]], + [[10.9441]], [[2.0117]], [[4.7049]], [[3.0009]], [[1.1405]], + [[6.2752]], [[6.5399]], [[9.0840]], [[4.5892]], [[4.4994]], + [[6.7320]], [[4.4375]], [[1.7071]], [[11.8582]], [[11.8075]], + [[5.8339]], [[1.8640]], [[0.5222]], [[1.1160]], [[0.0794]]], [[[1.7224]], [[0.3360]], [[1.6765]], [[2.3401]], [[1.0384]], - [[1.4355]], [[0.9550]], [[8.3504]], [[2.2432]], [[10.9488]], - [[3.3936]], [[15.1488]], [[9.9808]], [[12.6688]], [[2.6336]], - [[0.8000]], [[10.0640]], [[6.3296]], [[4.6416]], [[7.3792]], - [[11.7328]], [[1.9104]], [[11.1984]], [[0.7960]], [[0.6524]], + [[1.4355]], [[0.9550]], [[7.6666]], [[2.2286]], [[9.5089]], + [[3.3436]], [[11.8133]], [[8.8603]], [[10.5508]], [[2.6101]], + [[0.7993]], [[8.9178]], [[6.0188]], [[4.5156]], [[6.8970]], + [[10.0013]], [[1.9014]], [[9.6689]], [[0.7960]], [[0.6524]], [[1.4370]], [[0.8948]]]]) assert torch.allclose(expected_decode_bbox_w2d, decode_bbox_w2d, atol=1e-3)