-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
imvoxelnet.py
136 lines (117 loc) · 4.76 KB
/
imvoxelnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from mmdet3d.core import bbox3d2result, build_anchor_generator
from mmdet3d.models.fusion_layers.point_fusion import point_sample
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from mmdet.models.detectors import BaseDetector
@DETECTORS.register_module()
class ImVoxelNet(BaseDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_."""
def __init__(self,
backbone,
neck,
neck_3d,
bbox_head,
n_voxels,
anchor_generator,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
self.neck_3d = build_neck(neck_3d)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.n_voxels = n_voxels
self.anchor_generator = build_anchor_generator(anchor_generator)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def extract_feat(self, img, img_metas):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
"""
x = self.backbone(img)
x = self.neck(x)[0]
points = self.anchor_generator.grid_anchors(
[self.n_voxels[::-1]], device=img.device)[0][:, :3]
volumes = []
for feature, img_meta in zip(x, img_metas):
img_scale_factor = (
points.new_tensor(img_meta['scale_factor'][:2])
if 'scale_factor' in img_meta.keys() else 1)
img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
img_crop_offset = (
points.new_tensor(img_meta['img_crop_offset'])
if 'img_crop_offset' in img_meta.keys() else 0)
volume = point_sample(
img_meta,
img_features=feature[None, ...],
points=points,
lidar2img_rt=points.new_tensor(img_meta['lidar2img']),
img_scale_factor=img_scale_factor,
img_crop_offset=img_crop_offset,
img_flip=img_flip,
img_pad_shape=img.shape[-2:],
img_shape=img_meta['img_shape'][:2],
aligned=False)
volumes.append(
volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
x = torch.stack(volumes)
x = self.neck_3d(x)
return x
def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d,
**kwargs):
"""Forward of training.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
Returns:
dict[str, torch.Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x)
losses = self.bbox_head.loss(*x, gt_bboxes_3d, gt_labels_3d, img_metas)
return losses
def forward_test(self, img, img_metas, **kwargs):
"""Forward of testing.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
# not supporting aug_test for now
return self.simple_test(img, img_metas)
def simple_test(self, img, img_metas):
"""Test without augmentations.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(*x, img_metas)
bbox_results = [
bbox3d2result(det_bboxes, det_scores, det_labels)
for det_bboxes, det_scores, det_labels in bbox_list
]
return bbox_results
def aug_test(self, imgs, img_metas, **kwargs):
"""Test with augmentations.
Args:
imgs (list[torch.Tensor]): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
raise NotImplementedError