-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhance] Support dataset browsing on all datasets (#367)
* browse_dataset * move visualization config from 'configs/' to 'configs/_base_' * refine * refine * adding argument option & browsing testing split only * support get_loading_pipeline & browsing training split * support get_loading_pipeline & browsing training split * support get_loading_pipeline & browsing training split * add condition of RepeatDataset & support the usage of cfg-options & refine useful_tools.md * add condition of RepeatDataset & support the usage of cfg-options & refine useful_tools.md * enable dataset browsing with empty gt & update docs/useful_tools * support all 3d dataset * fix small typos * fix small bugs Co-authored-by: Wuziyi616 <[email protected]>
- Loading branch information
Showing
4 changed files
with
205 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D, | ||
LoadAnnotations3D, | ||
LoadMultiViewImageFromFiles, | ||
LoadPointsFromFile, | ||
LoadPointsFromMultiSweeps) | ||
from mmdet.datasets.builder import PIPELINES | ||
from mmdet.datasets.pipelines import LoadImageFromFile | ||
|
||
|
||
def get_loading_pipeline(pipeline): | ||
"""Only keep loading image, points and annotations related configuration. | ||
Args: | ||
pipeline (list[dict]): Data pipeline configs. | ||
Returns: | ||
list[dict]: The new pipeline list with only keep | ||
loading image, points and annotations related configuration. | ||
Examples: | ||
>>> pipelines = [ | ||
... dict(type='LoadPointsFromFile', | ||
... coord_type='LIDAR', load_dim=4, use_dim=4), | ||
... dict(type='LoadImageFromFile'), | ||
... dict(type='LoadAnnotations3D', | ||
... with_bbox=True, with_label_3d=True), | ||
... dict(type='Resize', | ||
... img_scale=[(640, 192), (2560, 768)], keep_ratio=True), | ||
... dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), | ||
... dict(type='PointsRangeFilter', | ||
... point_cloud_range=point_cloud_range), | ||
... dict(type='ObjectRangeFilter', | ||
... point_cloud_range=point_cloud_range), | ||
... dict(type='PointShuffle'), | ||
... dict(type='Normalize', **img_norm_cfg), | ||
... dict(type='Pad', size_divisor=32), | ||
... dict(type='DefaultFormatBundle3D', class_names=class_names), | ||
... dict(type='Collect3D', | ||
... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) | ||
... ] | ||
>>> expected_pipelines = [ | ||
... dict(type='LoadPointsFromFile', | ||
... coord_type='LIDAR', load_dim=4, use_dim=4), | ||
... dict(type='LoadImageFromFile'), | ||
... dict(type='LoadAnnotations3D', | ||
... with_bbox=True, with_label_3d=True), | ||
... dict(type='DefaultFormatBundle3D', class_names=class_names), | ||
... dict(type='Collect3D', | ||
... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) | ||
... ] | ||
>>> assert expected_pipelines ==\ | ||
... get_loading_pipeline(pipelines) | ||
""" | ||
loading_pipeline_cfg = [] | ||
for cfg in pipeline: | ||
obj_cls = PIPELINES.get(cfg['type']) | ||
# TODO: use more elegant way to distinguish loading modules | ||
if obj_cls is not None and obj_cls in ( | ||
LoadImageFromFile, LoadPointsFromFile, LoadAnnotations3D, | ||
LoadMultiViewImageFromFiles, LoadPointsFromMultiSweeps, | ||
DefaultFormatBundle3D, Collect3D): | ||
loading_pipeline_cfg.append(cfg) | ||
assert len(loading_pipeline_cfg) > 0, \ | ||
'The data pipeline in your config file must include ' \ | ||
'loading step.' | ||
return loading_pipeline_cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import argparse | ||
from mmcv import Config, DictAction, mkdir_or_exist, track_iter_progress | ||
from os import path as osp | ||
|
||
from mmdet3d.core.bbox import Box3DMode, Coord3DMode | ||
from mmdet3d.core.visualizer.open3d_vis import Visualizer | ||
from mmdet3d.datasets import build_dataset, get_loading_pipeline | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Browse a dataset') | ||
parser.add_argument('config', help='train config file path') | ||
parser.add_argument( | ||
'--skip-type', | ||
type=str, | ||
nargs='+', | ||
default=['Normalize'], | ||
help='skip some useless pipeline') | ||
parser.add_argument( | ||
'--output-dir', | ||
default=None, | ||
type=str, | ||
help='If there is no display interface, you can save it') | ||
parser.add_argument( | ||
'--cfg-options', | ||
nargs='+', | ||
action=DictAction, | ||
help='override some settings in the used config, the key-value pair ' | ||
'in xxx=yyy format will be merged into config file. If the value to ' | ||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | ||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | ||
'Note that the quotation marks are necessary and that no white space ' | ||
'is allowed.') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def retrieve_data_cfg(config_path, skip_type, cfg_options): | ||
cfg = Config.fromfile(config_path) | ||
if cfg_options is not None: | ||
cfg.merge_from_dict(cfg_options) | ||
# import modules from string list. | ||
if cfg.get('custom_imports', None): | ||
from mmcv.utils import import_modules_from_strings | ||
import_modules_from_strings(**cfg['custom_imports']) | ||
if cfg.data.train['type'] == 'RepeatDataset': | ||
train_data_cfg = cfg.data.train.dataset | ||
else: | ||
train_data_cfg = cfg.data.train | ||
train_data_cfg['pipeline'] = [ | ||
x for x in train_data_cfg.pipeline if x['type'] not in skip_type | ||
] | ||
|
||
return cfg | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
if args.output_dir is not None: | ||
mkdir_or_exist(args.output_dir) | ||
|
||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options) | ||
if cfg.data.train['type'] == 'RepeatDataset': | ||
cfg.data.train.dataset['pipeline'] = get_loading_pipeline( | ||
cfg.train_pipeline) | ||
else: | ||
cfg.data.train['pipeline'] = get_loading_pipeline(cfg.train_pipeline) | ||
dataset = build_dataset( | ||
cfg.data.train, default_args=dict(filter_empty_gt=False)) | ||
# For RepeatDataset type, the infos are stored in dataset.dataset | ||
if cfg.data.train['type'] == 'RepeatDataset': | ||
dataset = dataset.dataset | ||
data_infos = dataset.data_infos | ||
|
||
for idx, data_info in enumerate(track_iter_progress(data_infos)): | ||
if cfg.dataset_type in ['KittiDataset', 'WaymoDataset']: | ||
pts_path = data_info['point_cloud']['velodyne_path'] | ||
elif cfg.dataset_type in ['ScanNetDataset', 'SUNRGBDDataset']: | ||
pts_path = data_info['pts_path'] | ||
elif cfg.dataset_type in ['NuScenesDataset', 'LyftDataset']: | ||
pts_path = data_info['lidar_path'] | ||
else: | ||
raise NotImplementedError( | ||
f'unsupported dataset type {cfg.dataset_type}') | ||
file_name = osp.splitext(osp.basename(pts_path))[0] | ||
save_path = osp.join(args.output_dir, | ||
f'{file_name}.png') if args.output_dir else None | ||
|
||
example = dataset.prepare_train_data(idx) | ||
points = example['points']._data.numpy() | ||
points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR, | ||
Coord3DMode.DEPTH) | ||
gt_bboxes = dataset.get_ann_info(idx)['gt_bboxes_3d'].tensor | ||
if gt_bboxes is not None: | ||
gt_bboxes = Box3DMode.convert(gt_bboxes, Box3DMode.LIDAR, | ||
Box3DMode.DEPTH) | ||
|
||
vis = Visualizer(points, save_path='./show.png') | ||
vis.add_bboxes(bbox3d=gt_bboxes, bbox_color=(0, 0, 1)) | ||
|
||
vis.show(save_path) | ||
del vis | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |