diff --git a/mmdet3d/core/bbox/structures/cam_box3d.py b/mmdet3d/core/bbox/structures/cam_box3d.py index f87761afa3..0906371ef5 100644 --- a/mmdet3d/core/bbox/structures/cam_box3d.py +++ b/mmdet3d/core/bbox/structures/cam_box3d.py @@ -34,6 +34,41 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): boxes. """ + def __init__(self, + tensor, + box_dim=7, + with_yaw=True, + origin=(0.5, 1.0, 0.5)): + if isinstance(tensor, torch.Tensor): + device = tensor.device + else: + device = torch.device('cpu') + tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) + if tensor.numel() == 0: + # Use reshape, so we don't end up creating a new tensor that + # does not depend on the inputs (and consequently confuses jit) + tensor = tensor.reshape((0, box_dim)).to( + dtype=torch.float32, device=device) + assert tensor.dim() == 2 and tensor.size(-1) == box_dim, tensor.size() + + if tensor.shape[-1] == 6: + # If the dimension of boxes is 6, we expand box_dim by padding + # 0 as a fake yaw and set with_yaw to False. + assert box_dim == 6 + fake_rot = tensor.new_zeros(tensor.shape[0], 1) + tensor = torch.cat((tensor, fake_rot), dim=-1) + self.box_dim = box_dim + 1 + self.with_yaw = False + else: + self.box_dim = box_dim + self.with_yaw = with_yaw + self.tensor = tensor + + if origin != (0.5, 1.0, 0.5): + dst = self.tensor.new_tensor((0.5, 1.0, 0.5)) + src = self.tensor.new_tensor(origin) + self.tensor[:, :3] += self.tensor[:, 3:6] * (dst - src) + @property def height(self): """torch.Tensor: A vector with height of each box."""