Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]Support centerpoint #252

Merged
merged 25 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fac97d3
bump version to v0.4.0
RunningLeon Feb 28, 2022
4944564
[Enhancement] Make rewriter more powerful (#150)
SingleZombie Mar 1, 2022
63a998f
Torchscript support (#159)
AllentDan Mar 7, 2022
61b0146
Merge branch 'master' into dev-v0.4.0
Mar 8, 2022
31402ff
Update supported mmseg models (#181)
RunningLeon Mar 8, 2022
9553b8c
[Features]Support mmdet3d (#103)
VVsssssk Mar 10, 2022
987d48c
[Enhancement] Update pad logic in detection heads (#168)
Mar 14, 2022
00ca2a3
[Enhancement] Additional arguments support for OpenVINO Model Optimiz…
SemyonBevzuk Mar 15, 2022
167400c
[Enhancement] Switch to statically typed Value::Any (#209)
lzhangzz Mar 17, 2022
5e9f0cb
support for centerpoint
Mar 18, 2022
5ec6f73
Merge branch 'dev-v0.4.0' of https://github.com/open-mmlab/mmdeploy i…
Mar 18, 2022
ea0d0cf
[Enhancement] TensorRT DCN support (#205)
Mar 21, 2022
a3323ce
Merge branch 'dev-v0.4.0' of https://github.com/open-mmlab/mmdeploy i…
Mar 21, 2022
1a22a63
add docstring and dcn model support
Mar 22, 2022
896e512
add centerpoint ut and docs
Mar 23, 2022
ac046d9
Merge branch 'master' of https://github.com/open-mmlab/mmdeploy into …
Mar 24, 2022
fd44c6c
Merge branch 'dev-v0.4.0' of https://github.com/open-mmlab/mmdeploy i…
Mar 30, 2022
bc4449b
add config and fix input rank
Mar 30, 2022
66054f8
fix merge error
Mar 30, 2022
28d240d
fix a bug
Mar 30, 2022
b700f38
Merge branch 'dev-v0.4.0' of https://github.com/open-mmlab/mmdeploy i…
Mar 31, 2022
807331e
fix comment
Apr 1, 2022
47b8e6e
[Doc] update benchmark add supported-model-list (#286)
AllentDan Apr 1, 2022
ca68e87
Merge branch 'master' of https://github.com/open-mmlab/mmdeploy into …
Apr 1, 2022
61c0654
fix ut
Apr 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/openvino.py']

onnx_config = dict(input_shape=None)

backend_config = dict(model_inputs=[
dict(
opt_shapes=dict(
voxels=[20000, 20, 5], num_points=[20000], coors=[20000, 4]))
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/tensorrt.py']
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
voxels=dict(
min_shape=[5000, 20, 5],
opt_shape=[20000, 20, 5],
max_shape=[30000, 20, 5]),
num_points=dict(
min_shape=[5000], opt_shape=[20000], max_shape=[30000]),
coors=dict(
min_shape=[5000, 4],
opt_shape=[20000, 4],
max_shape=[30000, 4]),
))
])
2 changes: 2 additions & 0 deletions docs/en/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ The table below lists the models that are guaranteed to be exportable to other b
| MSPN | MMPose | ? | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) |
| LiteHRNet | MMPose | ? | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |


### Note

Expand Down
31 changes: 23 additions & 8 deletions mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import functional as F

from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.core import RewriterContext
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
get_root_logger, load_config)

Expand Down Expand Up @@ -92,7 +93,8 @@ def forward(self,
'coors': coors
}
outputs = self.wrapper(input_dict)
result = VoxelDetectionModel.post_process(self.model_cfg, outputs,
result = VoxelDetectionModel.post_process(self.model_cfg,
VVsssssk marked this conversation as resolved.
Show resolved Hide resolved
self.deploy_cfg, outputs,
img_metas[i],
self.device)[0]
result_list.append(result)
Expand Down Expand Up @@ -171,6 +173,7 @@ def voxelize(model_cfg: Union[str, mmcv.Config], points: torch.Tensor):

@staticmethod
def post_process(model_cfg: Union[str, mmcv.Config],
deploy_cfg: Union[str, mmcv.Config],
outs: torch.Tensor,
img_metas: Dict,
device: str,
Expand All @@ -179,6 +182,8 @@ def post_process(model_cfg: Union[str, mmcv.Config],

Args:
model_cfg (str | mmcv.Config): The model config.
deploy_cfg (str|mmcv.Config): Deployment config file or loaded
Config object.
outs (torch.Tensor): Output of model's head.
img_metas(Dict): Meta info for pcd.
device (str): A string specifying device type.
Expand All @@ -189,7 +194,13 @@ def post_process(model_cfg: Union[str, mmcv.Config],
from mmdet3d.core import bbox3d2result
from mmdet3d.models.builder import build_head
model_cfg = load_config(model_cfg)[0]
head_cfg = dict(**model_cfg.model['bbox_head'])
deploy_cfg = load_config(deploy_cfg)[0]
if 'bbox_head' in model_cfg.model.keys():
head_cfg = dict(**model_cfg.model['bbox_head'])
elif 'pts_bbox_head' in model_cfg.model.keys():
head_cfg = dict(**model_cfg.model['pts_bbox_head'])
else:
raise
VVsssssk marked this conversation as resolved.
Show resolved Hide resolved
head_cfg['train_cfg'] = None
head_cfg['test_cfg'] = model_cfg.model['test_cfg']
head = build_head(head_cfg)
Expand All @@ -206,12 +217,16 @@ def post_process(model_cfg: Union[str, mmcv.Config],
cls_scores = [outs['scores'].to(device)]
bbox_preds = [outs['bbox_preds'].to(device)]
dir_scores = [outs['dir_scores'].to(device)]
bbox_list = head.get_bboxes(
cls_scores, bbox_preds, dir_scores, img_metas, rescale=False)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
with RewriterContext(
cfg=deploy_cfg,
backend=deploy_cfg.backend_config.type,
opset=deploy_cfg.onnx_config.opset_version):
bbox_list = head.get_bboxes(
cls_scores, bbox_preds, dir_scores, img_metas, rescale=False)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results


Expand Down
2 changes: 2 additions & 0 deletions mmdeploy/codebase/mmdet3d/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import * # noqa: F401,F403
from .centerpoint import * # noqa: F401,F403
from .mvx_two_stage import * # noqa: F401,F403
from .pillar_encode import * # noqa: F401,F403
from .pillar_scatter import * # noqa: F401,F403
from .voxelnet import * # noqa: F401,F403
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet3d/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def base3ddetector__forward_test(ctx,
coors,
img_metas=None,
img=None,
rescale=True):
rescale=False):
"""Rewrite this function to run simple_test directly."""
return self.simple_test(voxels, num_points, coors, img_metas, img)

Expand Down
198 changes: 198 additions & 0 deletions mmdeploy/codebase/mmdet3d/models/centerpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet3d.core import circle_nms

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.centerpoint.CenterPoint.extract_pts_feat')
def centerpoint__extract_pts_feat(ctx, self, voxels, num_points, coors,
img_feats, img_metas):
"""Extract features from points. Rewrite this func to remove voxelize op.

Args:
voxels (torch.Tensor): Point features or raw points in shape (N, M, C).
num_points (torch.Tensor): Number of points in each voxel.
coors (torch.Tensor): Coordinates of each voxel.
img_feats (list[torch.Tensor], optional): Image features used for
multi-modality fusion. Defaults to None.
img_metas (list[dict]): Meta information of samples.

Returns:
torch.Tensor: Points feature.
"""
if not self.with_pts_bbox:
return None

voxel_features = self.pts_voxel_encoder(voxels, num_points, coors)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, coors, batch_size)
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x)
return x


@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.centerpoint.CenterPoint.simple_test_pts')
def centerpoint__simple_test_pts(ctx, self, x, img_metas, rescale=False):
"""Rewrite this func to format model outputs.

Args:
x (torch.Tensor): Input points feature.
img_metas (list[dict]): Meta information of samples.
rescale (bool): Whether need rescale.

Returns:
List: Result of model.
"""
outs = self.pts_bbox_head(x)
bbox_preds, scores, dir_scores = [], [], []
for task_res in outs:
bbox_preds.append(task_res[0]['reg'])
bbox_preds.append(task_res[0]['height'])
bbox_preds.append(task_res[0]['dim'])
if 'vel' in task_res[0].keys():
bbox_preds.append(task_res[0]['vel'])
scores.append(task_res[0]['heatmap'])
dir_scores.append(task_res[0]['rot'])
bbox_preds = torch.cat(bbox_preds, dim=1)
scores = torch.cat(scores, dim=1)
dir_scores = torch.cat(dir_scores, dim=1)
return scores, bbox_preds, dir_scores


@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.dense_heads.centerpoint_head.CenterHead.get_bboxes')
def centerpoint__get_bbox(ctx,
VVsssssk marked this conversation as resolved.
Show resolved Hide resolved
self,
cls_scores,
bbox_preds,
dir_scores,
img_metas,
img=None,
rescale=False):
"""Rewrite this func to format func inputs.

Args
cls_scores (list[torch.Tensor]): Classification predicts results.
bbox_preds (list[torch.Tensor]): Bbox predicts results.
dir_scores (list[torch.Tensor]): Dir predicts results.
img_metas (list[dict]): Point cloud and image's meta info.
img (torch.Tensor): Input image.
rescale (Bool): Whether need rescale.

Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
rets = []
# common_heads = self.task_heads[0].heads
batch_size = 1
scores_range = [0]
bbox_range = [0]
dir_range = [0]
self.test_cfg = self.test_cfg['pts']
for i, task_head in enumerate(self.task_heads):
scores_range.append(scores_range[i] + self.num_classes[i])
bbox_range.append(bbox_range[i] +
8 if 'vel' in task_head.heads.keys() else 6)
dir_range.append(dir_range[i] + 2)
for task_id in range(len(self.num_classes)):
num_class_with_bg = self.num_classes[task_id]

batch_heatmap = cls_scores[
0][:, scores_range[task_id]:scores_range[task_id + 1],
...].sigmoid()

batch_reg = bbox_preds[0][:,
bbox_range[task_id]:bbox_range[task_id] + 2,
...]
batch_hei = bbox_preds[0][:, bbox_range[task_id] +
2:bbox_range[task_id] + 3, ...]

if self.norm_bbox:
batch_dim = torch.exp(bbox_preds[0][:, bbox_range[task_id] +
3:bbox_range[task_id] + 6,
...])
else:
batch_dim = bbox_preds[0][:, bbox_range[task_id] +
3:bbox_range[task_id] + 6, ...]

batch_vel = bbox_preds[0][:, bbox_range[task_id] +
6:bbox_range[task_id + 1], ...]

batch_rots = dir_scores[0][:,
dir_range[task_id]:dir_range[task_id + 1],
...][:, 0].unsqueeze(1)
batch_rotc = dir_scores[0][:,
dir_range[task_id]:dir_range[task_id + 1],
...][:, 1].unsqueeze(1)

# if 'vel' in preds_dict[0]:
VVsssssk marked this conversation as resolved.
Show resolved Hide resolved
# batch_vel = preds_dict[0]['vel']
# else:
# batch_vel = None

temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id)
assert self.test_cfg['nms_type'] in ['circle', 'rotate']
batch_reg_preds = [box['bboxes'] for box in temp]
batch_cls_preds = [box['scores'] for box in temp]
batch_cls_labels = [box['labels'] for box in temp]
if self.test_cfg['nms_type'] == 'circle':
ret_task = []
for i in range(batch_size):
VVsssssk marked this conversation as resolved.
Show resolved Hide resolved
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)

boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(num_class_with_bg, batch_cls_preds,
batch_reg_preds, batch_cls_labels,
img_metas))

# Merge branches results
num_samples = len(rets[0])

ret_list = []
for i in range(num_samples):
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[i]['box_type_3d'](bboxes,
self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets])
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes):
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels])
return ret_list
54 changes: 54 additions & 0 deletions mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.simple_test')
def mvxtwostagedetector__simple_test(ctx,
self,
voxels,
num_points,
coors,
img_metas,
img=None,
rescale=False):
"""Rewrite this func to remove voxelize op.

Args:
voxels (torch.Tensor): Point features or raw points in shape (N, M, C).
num_points (torch.Tensor): Number of points in each voxel.
coors (torch.Tensor): Coordinates of each voxel.
img_metas (list[dict]): Meta information of samples.
img (torch.Tensor): Input image.
rescale (Bool): Whether need rescale.

Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
_, pts_feats = self.extract_feat(
voxels, num_points, coors, img=img, img_metas=img_metas)
if pts_feats and self.with_pts_bbox:
bbox_pts = self.simple_test_pts(pts_feats, img_metas, rescale=rescale)
return bbox_pts


@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_feat')
def mvxtwostagedetector__extract_feat(ctx, self, voxels, num_points, coors,
img, img_metas):
"""Rewrite this func to remove voxelize op.

Args:
voxels (torch.Tensor): Point features or raw points in shape (N, M, C).
num_points (torch.Tensor): Number of points in each voxel.
coors (torch.Tensor): Coordinates of each voxel.
img (torch.Tensor): Input image.
img_metas (list[dict]): Meta information of samples.

Returns:
tuple(torch.Tensor) : image feature and points feather.
"""
img_feats = self.extract_img_feat(img, img_metas)
pts_feats = self.extract_pts_feat(voxels, num_points, coors, img_feats,
img_metas)
return (img_feats, pts_feats)
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmdet3d/models/voxelnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def voxelnet__simple_test(ctx,
post process.

Args:
voxels(torch.Tensor): Point features or raw points in shape (N, M, C).
voxels (torch.Tensor): Point features or raw points in shape (N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel.
input_metas (list[dict]): Contain pcd meta info.
Expand All @@ -40,7 +40,7 @@ def voxelnet__extract_feat(ctx,
"""Extract features from points. Rewrite this func to remove voxelize op.

Args:
voxels(torch.Tensor): Point features or raw points in shape (N, M, C).
voxels (torch.Tensor): Point features or raw points in shape (N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel.
input_metas (list[dict]): Contain pcd meta info.
Expand Down
Loading