Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mg_head #49

Merged
merged 70 commits into from
Sep 19, 2020
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
20c0a7f
Add modules.
Aug 3, 2020
3665044
Add test_center_head.
Aug 3, 2020
5901b96
Add docstring.
Aug 3, 2020
80e21df
Change comments.
Aug 4, 2020
efbfa0f
Add dcn_head.
Aug 4, 2020
0d0e4eb
Add doc_string.
Aug 4, 2020
f554c11
Add get_targets.
Aug 5, 2020
62e5803
Can use_get_targets.
Aug 5, 2020
04e2369
get_targets results aligned.
Aug 5, 2020
c914501
Use box_structure.
Aug 5, 2020
f10bade
Use get_targets_single.
Aug 5, 2020
d32e27c
Add docstring.
Aug 5, 2020
f3341ee
Fix dcn_center_head unittest.
Aug 5, 2020
ff3cb5e
Delete unnecessary unittest.
Aug 5, 2020
03084f3
Add docstring.
Aug 5, 2020
4b94c1f
Change format.
Aug 7, 2020
ceee32a
Add circle_nms.
Aug 10, 2020
d06d3ed
Change structure of mg_head.
Aug 10, 2020
a94ea14
Add bbox coder for centerpoint.
Aug 10, 2020
22588e0
Add docstrings.
Aug 10, 2020
ed43874
Add docstrings.
Aug 11, 2020
6dd23ff
Add get_bboxes and unittest.
Aug 11, 2020
3e1da83
Change docstring.
Aug 11, 2020
57d4dd1
Add img_metas.
Aug 12, 2020
f8a801d
Change bbox coder unittest.
Aug 12, 2020
f5273fe
Add task_detections.
Aug 12, 2020
0851793
Change docstring.
Aug 12, 2020
7b4936f
Change circle_nms to cpu.
Aug 13, 2020
a725de2
Change test_nms.
Aug 13, 2020
03b9d93
Change score_th, chang keep to long type.
Aug 14, 2020
9919ddd
Change docstring and unittest.
Aug 14, 2020
e6ba28e
Remove unnecessary things.
Aug 14, 2020
a162ef4
Move gaussian.
Aug 18, 2020
bf76a52
move clip_sigmoid, change dict.
Aug 19, 2020
8fa594b
Change config.
Aug 19, 2020
6020265
Change test_heads.
Aug 20, 2020
c3ba51c
Move weight initialization to init_weights func.
Aug 20, 2020
b09b831
Remove loc_loss_element adn==nd num_postive.
Aug 20, 2020
04b80a0
Change bboxes to the right format.
Aug 20, 2020
d49a7ae
Change loss and bbox order.
Aug 27, 2020
2bcf086
Update test_heads.
Aug 27, 2020
bfb894a
Change loss.
Aug 28, 2020
950ae47
Change names in mg_head, change head unittest.
Aug 30, 2020
ebd3c4c
Remove centerpoint_focal_loss, change docstring.
Aug 30, 2020
f505a79
Merge branch 'master' into add_mg_head
Aug 30, 2020
3e54739
Change topK default to 80.
Sep 1, 2020
d5d8432
Change boxes in test_nms. Change task_boxes defaults to None.
Sep 4, 2020
02010fc
Fix rotate nms bug.
Sep 6, 2020
502c4e7
Change docstring.
Sep 7, 2020
cb00d7b
Add docstring for get_task_detection and loss.
Sep 7, 2020
14087e7
Merge branch 'master' into add_mg_head
Sep 8, 2020
533841b
Remove gaussian funcs, change mg_head.
Sep 8, 2020
13ecc7a
Change gaussianfocalloss to mean.
Sep 8, 2020
0b6b3bf
change centerpoint_bbox_coder '/' to torch.div, fix centerhead unittest.
Sep 9, 2020
d59a641
Change div to '/'
Sep 9, 2020
cba1083
Change order in centerpoint_coder, change names, change dcn layer.
Sep 13, 2020
6208b93
Fix import in __init__
Sep 14, 2020
b9364bc
Add gaussian unittest.
Sep 14, 2020
b9be6e0
Remove np ops in mg_head.
Sep 14, 2020
e22f7f6
Update docstring.
Sep 14, 2020
cfd9875
Merge branch 'master' into add_mg_head
Sep 17, 2020
ab20a01
Fix docstring use config to build head.
Sep 18, 2020
4035e42
Remove **kwargs
Sep 18, 2020
63ad7d8
Remove unnecessary codes, change order of bboxes.
Sep 18, 2020
b1bda22
Remove '\' in args and pdb, change loss_bbox.
Sep 18, 2020
1caf185
Fix test_heads unittest.
Sep 18, 2020
092a64e
Remove unnecessary comments
Sep 19, 2020
9b3f6ec
Change bbox order in rotate nms.
Sep 19, 2020
a300691
Remove unnecessary attributes
Sep 19, 2020
c92e161
Change name, remove float
Sep 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mmdet3d/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .bbox import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .post_processing import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
from .visualizer import * # noqa: F401, F403
from .voxel import * # noqa: F401, F403
3 changes: 2 additions & 1 deletion mmdet3d/core/bbox/coders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from mmdet.core.bbox import build_bbox_coder
from .anchor_free_bbox_coder import AnchorFreeBBoxCoder
from .centerpoint_bbox_coders import CenterPointBBoxCoder
from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder

__all__ = [
'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder',
'AnchorFreeBBoxCoder'
'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder'
]
227 changes: 227 additions & 0 deletions mmdet3d/core/bbox/coders/centerpoint_bbox_coders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import torch

from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS


@BBOX_CODERS.register_module()
class CenterPointBBoxCoder(BaseBBoxCoder):
"""Bbox coder for CenterPoint.

Args:
pc_range (list[float]): Range of point cloud.
out_size_factor (int): Downsample factor of the model.
voxel_size (list[float]): Size of voxel.
post_center_range (list[float]): Limit of the center.
Default: None.
max_num (int): Max number to be kept. Default: 100.
score_threshold (float): Threshold to filter boxes based on score.
Default: None.
code_size (int): Code size of bboxes. Default: 9
"""

def __init__(self,
pc_range,
out_size_factor,
voxel_size,
post_center_range=None,
max_num=100,
score_threshold=None,
code_size=9):

self.pc_range = pc_range
self.out_size_factor = out_size_factor
self.voxel_size = voxel_size
self.post_center_range = post_center_range
self.max_num = max_num
self.score_threshold = score_threshold
self.code_size = code_size

def _gather_feat(self, feats, inds, feat_masks=None):
"""Given feats and indexes, returns the gathered feats.

Args:
feats (torch.Tensor): Features to be transposed and gathered
with the shape of [B, 2, W, H].
inds (torch.Tensor): Indexes with the shape of [B, N].
feat_masks (torch.Tensor): Mask of the feats. Default: None.

Returns:
torch.Tensor: Gathered feats.
"""
dim = feats.size(2)
inds = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), dim)
feats = feats.gather(1, inds)
if feat_masks is not None:
feat_masks = feat_masks.unsqueeze(2).expand_as(feats)
feats = feats[feat_masks]
feats = feats.view(-1, dim)
return feats

def _topk(self, scores, K=80):
"""Get indexes based on scores.

Args:
scores (torch.Tensor): scores with the shape of [B, N, W, H].
K (int): Number to be kept. Defaults to 80.

Returns:
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
tuple[torch.Tensor]
torch.Tensor: Selected scores with the shape of [B, K].
torch.Tensor: Selected indexes with the shape of [B, K].
torch.Tensor: Selected classes with the shape of [B, K].
torch.Tensor: Selected y coord with the shape of [B, K].
torch.Tensor: Selected x coord with the shape of [B, K].
"""
batch, cat, height, width = scores.size()

topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds.float() /
torch.tensor(width, dtype=torch.float)).int().float()
topk_xs = (topk_inds % width).int().float()

topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_clses = (topk_ind / torch.tensor(K, dtype=torch.float)).int()
topk_inds = self._gather_feat(topk_inds.view(batch, -1, 1),
topk_ind).view(batch, K)
topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1),
topk_ind).view(batch, K)
topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1),
topk_ind).view(batch, K)

return topk_score, topk_inds, topk_clses, topk_ys, topk_xs

def _transpose_and_gather_feat(self, feat, ind):
"""Given feats and indexes, returns the transposed and gathered feats.

Args:
feat (torch.Tensor): Features to be transposed and gathered
with the shape of [B, 2, W, H].
ind (torch.Tensor): Indexes with the shape of [B, N].

Returns:
torch.Tensor: Transposed and gathered feats.
"""
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = self._gather_feat(feat, ind)
return feat

def encode(self):
pass

def decode(self,
heat,
rot_sine,
rot_cosine,
hei,
dim,
vel,
reg=None,
task_id=-1):
"""Decode bboxes.

Args:
heat (torch.Tensor): Heatmap with the shape of [B, N, W, H].
rot_sine (torch.Tensor): Sine of rotation with the shape of
[B, 1, W, H].
rot_cosine (torch.Tensor): Cosine of rotation with the shape of
[B, 1, W, H].
hei (torch.Tensor): Height of the boxes with the shape
of [B, 1, W, H].
dim (torch.Tensor): Dim of the boxes with the shape of
[B, 1, W, H].
vel (torch.Tensor): Velocity with the shape of [B, 1, W, H].
reg (torch.Tensor): Regression value of the boxes in 2D with
the shape of [B, 2, W, H]. Default: None.
task_id (int): Index of task. Default: -1.

Returns:
list[dict]: Decoded boxes.
"""
batch, cat, _, _ = heat.size()

scores, inds, clses, ys, xs = self._topk(heat, K=self.max_num)

if reg is not None:
reg = self._transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, self.max_num, 2)
xs = xs.view(batch, self.max_num, 1) + reg[:, :, 0:1]
ys = ys.view(batch, self.max_num, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, self.max_num, 1) + 0.5
ys = ys.view(batch, self.max_num, 1) + 0.5

# rotation value and direction label
rot_sine = self._transpose_and_gather_feat(rot_sine, inds)
rot_sine = rot_sine.view(batch, self.max_num, 1)

rot_cosine = self._transpose_and_gather_feat(rot_cosine, inds)
rot_cosine = rot_cosine.view(batch, self.max_num, 1)
rot = torch.atan2(rot_sine, rot_cosine)

# height in the bev
hei = self._transpose_and_gather_feat(hei, inds)
hei = hei.view(batch, self.max_num, 1)

# dim of the box
dim = self._transpose_and_gather_feat(dim, inds)
dim = dim.view(batch, self.max_num, 3)

# class label
clses = clses.view(batch, self.max_num).float()
scores = scores.view(batch, self.max_num)

xs = xs.view(
batch, self.max_num,
1) * self.out_size_factor * self.voxel_size[0] + self.pc_range[0]
ys = ys.view(
batch, self.max_num,
1) * self.out_size_factor * self.voxel_size[1] + self.pc_range[1]

if vel is None: # KITTI FORMAT
final_box_preds = torch.cat([xs, ys, hei, dim, rot], dim=2)
else: # exist velocity, nuscene format
vel = self._transpose_and_gather_feat(vel, inds)
vel = vel.view(batch, self.max_num, 2)
final_box_preds = torch.cat([xs, ys, hei, dim, rot, vel], dim=2)

final_scores = scores
final_preds = clses

# use score threshold
if self.score_threshold is not None:
thresh_mask = final_scores > self.score_threshold

if self.post_center_range is not None:
self.post_center_range = torch.tensor(
self.post_center_range, device=heat.device)
mask = (final_box_preds[..., :3] >=
self.post_center_range[:3]).all(2)
mask &= (final_box_preds[..., :3] <=
self.post_center_range[3:]).all(2)

predictions_dicts = []
for i in range(batch):
cmask = mask[i, :]
if self.score_threshold:
cmask &= thresh_mask[i]

boxes3d = final_box_preds[i, cmask]
scores = final_scores[i, cmask]
labels = final_preds[i, cmask]
predictions_dict = {
'bboxes': boxes3d,
'scores': scores,
'labels': labels
}

predictions_dicts.append(predictions_dict)
else:
raise NotImplementedError(
'Need to reorganize output as a batch, only '
'support post_center_range is not None for now!')

return predictions_dicts
4 changes: 2 additions & 2 deletions mmdet3d/core/post_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from mmdet.core.post_processing import (merge_aug_bboxes, merge_aug_masks,
merge_aug_proposals, merge_aug_scores,
multiclass_nms)
from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms
from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms, circle_nms
from .merge_augs import merge_aug_bboxes_3d

__all__ = [
'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
'merge_aug_scores', 'merge_aug_masks', 'box3d_multiclass_nms',
'aligned_3d_nms', 'merge_aug_bboxes_3d'
'aligned_3d_nms', 'merge_aug_bboxes_3d', 'circle_nms'
]
45 changes: 45 additions & 0 deletions mmdet3d/core/post_processing/box3d_nms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numba
import numpy as np
import torch

from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
Expand Down Expand Up @@ -134,3 +136,46 @@ def aligned_3d_nms(boxes, scores, classes, thresh):

indices = boxes.new_tensor(pick, dtype=torch.long)
return indices


@numba.jit(nopython=True)
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
def circle_nms(dets, thresh, post_max_size=83):
"""Circular NMS.

An object is only counted as positive if no other center
with a higher confidence exists within a radius r using a
bird-eye view distance metric.

Args:
dets (torch.Tensor): Detection results with the shape of [N, 3].
thresh (float): Value of threshold.
post_max_size (int): Max number of prediction to be kept. Defaults
to 83

Returns:
torch.Tensor: Indexes of the detections to be kept.
"""
x1 = dets[:, 0]
y1 = dets[:, 1]
scores = dets[:, 2]
order = scores.argsort()[::-1].astype(np.int32) # highest->lowest
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int32)
keep = []
for _i in range(ndets):
i = order[_i] # start with highest score box
if suppressed[
i] == 1: # if any box have enough iou with this, remove it
continue
keep.append(i)
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
# calculate center distance between i and j box
dist = (x1[i] - x1[j])**2 + (y1[i] - y1[j])**2

# ovr = inter / areas[j]
if dist <= thresh:
suppressed[j] = 1
return keep[:post_max_size]
3 changes: 3 additions & 0 deletions mmdet3d/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gaussian import draw_heatmap_gaussian, gaussian_2d, gaussian_radius

__all__ = ['gaussian_2d', 'gaussian_radius', 'draw_heatmap_gaussian']
Loading