Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Unify camera poses #653

Merged
merged 7 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
type='Collect3D',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'points', 'gt_bboxes_3d',
'gt_labels_3d', 'calib'
'gt_labels_3d'
])
]

Expand Down Expand Up @@ -230,7 +230,7 @@
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img', 'points', 'calib'])
filaPro marked this conversation as resolved.
Show resolved Hide resolved
dict(type='Collect3D', keys=['img', 'points'])
]),
]
# construct a pipeline for data and gt loading in show function
Expand All @@ -247,7 +247,7 @@
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img', 'points', 'calib'])
dict(type='Collect3D', keys=['img', 'points'])
]

data = dict(
Expand Down
25 changes: 10 additions & 15 deletions mmdet3d/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,23 +155,25 @@ def inference_multi_modality_detector(model, pcd, image, ann_file):
bbox_fields=[],
mask_fields=[],
seg_fields=[])

# depth map points to image conversion
if box_mode_3d == Box3DMode.DEPTH:
data.update(dict(calib=info['calib']))

data = test_pipeline(data)

# TODO: this code is dataset-specific. Move lidar2img and
# depth2img to .pkl annotations in the future.
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
# LiDAR to image conversion
if box_mode_3d == Box3DMode.LIDAR:
rect = info['calib']['R0_rect'].astype(np.float32)
Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32)
P2 = info['calib']['P2'].astype(np.float32)
lidar2img = P2 @ rect @ Trv2c
data['img_metas'][0].data['lidar2img'] = lidar2img
# Depth to image conversion
elif box_mode_3d == Box3DMode.DEPTH:
data['calib'][0]['Rt'] = data['calib'][0]['Rt'].astype(np.float32)
data['calib'][0]['K'] = data['calib'][0]['K'].astype(np.float32)
rt_mat = info['calib']['Rt']
# follow Coord3DMode.convert_point
rt_mat = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]
]) @ rt_mat.transpose(1, 0)
depth2img = info['calib']['K'] @ rt_mat
data['img_metas'][0].data['depth2img'] = depth2img

data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
Expand All @@ -182,9 +184,6 @@ def inference_multi_modality_detector(model, pcd, image, ann_file):
data['img_metas'] = data['img_metas'][0].data
data['points'] = data['points'][0].data
data['img'] = data['img'][0].data
if box_mode_3d == Box3DMode.DEPTH:
data['calib'][0]['Rt'] = data['calib'][0]['Rt'][0].data
data['calib'][0]['K'] = data['calib'][0]['K'][0].data

# forward the model
with torch.no_grad():
Expand Down Expand Up @@ -411,17 +410,13 @@ def show_proj_det_result_meshlab(data,
box_mode='lidar',
show=show)
elif box_mode == Box3DMode.DEPTH:
if 'calib' not in data.keys():
raise NotImplementedError(
'camera calibration information is not provided')

show_bboxes = DepthInstance3DBoxes(pred_bboxes, origin=(0.5, 0.5, 0))

show_multi_modality_result(
img,
None,
show_bboxes,
data['calib'][0],
None,
out_dir,
file_name,
box_mode='depth',
Expand Down
7 changes: 0 additions & 7 deletions mmdet3d/core/bbox/structures/coord_3d_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,8 @@ def convert_point(point, src, dst, rt_mat=None):
if rt_mat is None:
rt_mat = arr.new_tensor([[0, 0, 1], [-1, 0, 0], [0, -1, 0]])
elif src == Coord3DMode.DEPTH and dst == Coord3DMode.CAM:
# LIDAR-CAM conversion is different from DEPTH-CAM conversion
# because SUNRGB-D camera calibration files are different from
# that of KITTI, and currently we keep this hack
if rt_mat is None:
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
else:
rt_mat = rt_mat.new_tensor(
[[1, 0, 0], [0, 0, -1], [0, 1, 0]]) @ \
rt_mat.transpose(1, 0)
elif src == Coord3DMode.CAM and dst == Coord3DMode.DEPTH:
filaPro marked this conversation as resolved.
Show resolved Hide resolved
if rt_mat is None:
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
Expand Down
8 changes: 6 additions & 2 deletions mmdet3d/core/bbox/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,14 @@ def get_box_type(box_type):
return box_type_3d, box_mode_3d


def points_cam2img(points_3d, proj_mat):
def points_cam2img(points_3d, proj_mat, return_z=False):
filaPro marked this conversation as resolved.
Show resolved Hide resolved
"""Project points from camera coordicates to image coordinates.

Args:
points_3d (torch.Tensor): Points in shape (N, 3)
points_3d (torch.Tensor): Points in shape (N, 3).
proj_mat (torch.Tensor): Transformation matrix between coordinates.
return_z (bool, optional): Return third dimension if True.
Defaults to False.

Returns:
torch.Tensor: Points in image coordinates with shape [N, 2].
Expand All @@ -141,6 +143,8 @@ def points_cam2img(points_3d, proj_mat):
[points_3d, points_3d.new_ones(*points_shape)], dim=-1)
point_2d = torch.matmul(points_4, proj_mat.t())
point_2d_res = point_2d[..., :2] / point_2d[..., 2:3]
if return_z:
return point_2d_res, point_2d[..., 2]
return point_2d_res


Expand Down
18 changes: 3 additions & 15 deletions mmdet3d/core/visualizer/image_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def draw_lidar_bbox3d_on_img(bboxes3d,
return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness)


# TODO: remove third parameter in all functions here in favour of img_metas
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
def draw_depth_bbox3d_on_img(bboxes3d,
raw_img,
calibs,
Expand All @@ -137,35 +138,22 @@ def draw_depth_bbox3d_on_img(bboxes3d,
color (tuple[int]): The color to draw bboxes. Default: (0, 255, 0).
thickness (int, optional): The thickness of bboxes. Default: 1.
"""
from mmdet3d.core import Coord3DMode
from mmdet3d.core.bbox import points_cam2img
from mmdet3d.models import apply_3d_transformation

img = raw_img.copy()
calibs = copy.deepcopy(calibs)
img_metas = copy.deepcopy(img_metas)
corners_3d = bboxes3d.corners
num_bbox = corners_3d.shape[0]
points_3d = corners_3d.reshape(-1, 3)
assert ('Rt' in calibs.keys() and 'K' in calibs.keys()), \
'Rt and K matrix should be provided as camera caliberation information'
if not isinstance(calibs['Rt'], torch.Tensor):
calibs['Rt'] = torch.from_numpy(np.array(calibs['Rt']))
if not isinstance(calibs['K'], torch.Tensor):
calibs['K'] = torch.from_numpy(np.array(calibs['K']))
calibs['Rt'] = calibs['Rt'].reshape(3, 3).float().cpu()
calibs['K'] = calibs['K'].reshape(3, 3).float().cpu()

# first reverse the data transformations
xyz_depth = apply_3d_transformation(
points_3d, 'DEPTH', img_metas, reverse=True)

# then convert from depth coords to camera coords
xyz_cam = Coord3DMode.convert_point(
xyz_depth, Coord3DMode.DEPTH, Coord3DMode.CAM, rt_mat=calibs['Rt'])

# project to 2d to get image coords (uv)
uv_origin = points_cam2img(xyz_cam, calibs['K'])
uv_origin = points_cam2img(xyz_depth,
xyz_depth.new_tensor(img_metas['depth2img']))
uv_origin = (uv_origin - 1).round()
imgfov_pts_2d = uv_origin[..., :2].reshape(num_bbox, 8, 2).numpy()

Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/datasets/pipelines/formating.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Collect3D(object):
- 'ori_shape': original shape of the image as a tuple (h, w, c)
- 'pad_shape': image shape after padding
- 'lidar2img': transform from lidar to image
- 'depth2img': transform from depth to image
- 'pcd_horizontal_flip': a boolean indicating if point cloud is \
flipped horizontally
- 'pcd_vertical_flip': a boolean indicating if point cloud is \
Expand Down Expand Up @@ -134,7 +135,7 @@ class Collect3D(object):
def __init__(self,
keys,
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
'pad_shape', 'scale_factor', 'flip',
'depth2img', 'pad_shape', 'scale_factor', 'flip',
'cam_intrinsic', 'pcd_horizontal_flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'img_norm_cfg', 'rect', 'Trv2c', 'P2', 'pcd_trans',
Expand Down
15 changes: 10 additions & 5 deletions mmdet3d/datasets/sunrgbd_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def get_data_info(self, index):
input_dict['img_prefix'] = None
input_dict['img_info'] = dict(filename=img_filename)
calib = info['calib']
input_dict['calib'] = calib
rt_mat = calib['Rt']
# follow Coord3DMode.convert_point
rt_mat = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]
]) @ rt_mat.transpose(1, 0)
depth2img = calib['K'] @ rt_mat
input_dict['depth2img'] = depth2img

if not self.test_mode:
annos = self.get_ann_info(index)
Expand Down Expand Up @@ -187,8 +192,8 @@ def show(self, results, out_dir, show=True, pipeline=None):
data_info = self.data_infos[i]
pts_path = data_info['pts_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points, img_metas, img, calib = self._extract_data(
i, pipeline, ['points', 'img_metas', 'img', 'calib'])
points, img_metas, img = self._extract_data(
i, pipeline, ['points', 'img_metas', 'img'])
# scale colors to [0, 255]
points = points.numpy()
points[:, 3:] *= 255
Expand All @@ -199,7 +204,7 @@ def show(self, results, out_dir, show=True, pipeline=None):
file_name, show)

# multi-modality visualization
if self.modality['use_camera'] and 'calib' in data_info.keys():
if self.modality['use_camera']:
img = img.numpy()
# need to transpose channel to first dim
img = img.transpose(1, 2, 0)
Expand All @@ -211,7 +216,7 @@ def show(self, results, out_dir, show=True, pipeline=None):
img,
gt_bboxes,
pred_bboxes,
calib,
None,
out_dir,
file_name,
box_mode='depth',
Expand Down
28 changes: 7 additions & 21 deletions mmdet3d/models/detectors/imvotenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def forward_train(self,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
calib=None,
bboxes_2d=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
Expand All @@ -403,8 +402,6 @@ def forward_train(self,
2d bbox, used if the architecture supports a segmentation task.
proposals: override rpn proposals (2d) with custom proposals.
Use when `with_rpn` is False.
calib (dict[str, torch.Tensor]): camera calibration matrices,
Rt and K.
bboxes_2d (list[torch.Tensor]): provided 2d bboxes,
not supported yet.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): 3d gt bboxes.
Expand Down Expand Up @@ -450,7 +447,7 @@ def forward_train(self,
self.extract_pts_feat(points)

img_features, masks = self.fusion_layer(img, bboxes_2d, seeds_3d,
img_metas, calib)
img_metas)

inds = sample_valid_seeds(masks, self.num_sampled_seed)
batch_size, img_feat_size = img_features.shape[:2]
Expand Down Expand Up @@ -526,7 +523,6 @@ def forward_test(self,
points=None,
img_metas=None,
img=None,
calib=None,
bboxes_2d=None,
**kwargs):
"""Forwarding of test for image branch pretrain or stage 2 train.
Expand All @@ -544,9 +540,6 @@ def forward_test(self,
list indicates test-time augmentations and inner Tensor
should have a shape NxCxHxW, which contains all images
in the batch. Defaults to None. Defaults to None.
calibs (list[dict[str, torch.Tensor]], optional): camera
calibration matrices, Rt and K.
List indicates test-time augs. Defaults to None.
bboxes_2d (list[list[torch.Tensor]], optional):
Provided 2d bboxes, not supported yet. Defaults to None.

Expand Down Expand Up @@ -600,11 +593,10 @@ def forward_test(self,
points[0],
img_metas[0],
img[0],
calibs=calib[0],
bboxes_2d=bboxes_2d[0] if bboxes_2d is not None else None,
**kwargs)
else:
return self.aug_test(points, img_metas, img, calib, bboxes_2d,
return self.aug_test(points, img_metas, img, bboxes_2d,
**kwargs)

def simple_test_img_only(self,
Expand Down Expand Up @@ -650,7 +642,6 @@ def simple_test(self,
points=None,
img_metas=None,
img=None,
calibs=None,
bboxes_2d=None,
rescale=False,
**kwargs):
Expand All @@ -664,8 +655,6 @@ def simple_test(self,
images in a batch. Defaults to None.
img (torch.Tensor, optional): Should have a shape NxCxHxW,
which contains all images in the batch. Defaults to None.
calibs (dict[str, torch.Tensor], optional): camera
calibration matrices, Rt and K. Defaults to None.
bboxes_2d (list[torch.Tensor], optional):
Provided 2d bboxes, not supported yet. Defaults to None.
rescale (bool, optional): Whether or not rescale bboxes.
Expand All @@ -682,7 +671,7 @@ def simple_test(self,
self.extract_pts_feat(points)

img_features, masks = self.fusion_layer(img, bboxes_2d, seeds_3d,
img_metas, calibs)
img_metas)

inds = sample_valid_seeds(masks, self.num_sampled_seed)
batch_size, img_feat_size = img_features.shape[:2]
Expand Down Expand Up @@ -753,7 +742,6 @@ def aug_test(self,
points=None,
img_metas=None,
imgs=None,
calibs=None,
bboxes_2d=None,
rescale=False,
**kwargs):
Expand All @@ -772,9 +760,6 @@ def aug_test(self,
list indicates test-time augmentations and inner Tensor
should have a shape NxCxHxW, which contains all images
in the batch. Defaults to None. Defaults to None.
calibs (list[dict[str, torch.Tensor]], optional): camera
calibration matrices, Rt and K.
List indicates test-time augs. Defaults to None.
bboxes_2d (list[list[torch.Tensor]], optional):
Provided 2d bboxes, not supported yet. Defaults to None.
rescale (bool, optional): Whether or not rescale bboxes.
Expand All @@ -788,16 +773,17 @@ def aug_test(self,

# only support aug_test for one sample
aug_bboxes = []
for x, pts_cat, img_meta, bbox_2d, img, calib in zip(
feats, points_cat, img_metas, bboxes_2d, imgs, calibs):
for x, pts_cat, img_meta, bbox_2d, img in zip(feats, points_cat,
img_metas, bboxes_2d,
imgs):

bbox_2d = self.extract_bboxes_2d(
img, img_metas, train=False, bboxes_2d=bbox_2d, **kwargs)

seeds_3d, seed_3d_features, seed_indices = x

img_features, masks = self.fusion_layer(img, bbox_2d, seeds_3d,
img_metas, calib)
img_metas)

inds = sample_valid_seeds(masks, self.num_sampled_seed)
batch_size, img_feat_size = img_features.shape[:2]
Expand Down
Loading