From 825f47a4791a72b30effa66b7eabca0cb7a36db3 Mon Sep 17 00:00:00 2001 From: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com> Date: Thu, 8 Apr 2021 19:49:43 +0800 Subject: [PATCH] [Enhance] Support dataset browsing on all datasets (#367) * browse_dataset * move visualization config from 'configs/' to 'configs/_base_' * refine * refine * adding argument option & browsing testing split only * support get_loading_pipeline & browsing training split * support get_loading_pipeline & browsing training split * support get_loading_pipeline & browsing training split * add condition of RepeatDataset & support the usage of cfg-options & refine useful_tools.md * add condition of RepeatDataset & support the usage of cfg-options & refine useful_tools.md * enable dataset browsing with empty gt & update docs/useful_tools * support all 3d dataset * fix small typos * fix small bugs Co-authored-by: Wuziyi616 --- docs/useful_tools.md | 30 ++++++++++ mmdet3d/datasets/__init__.py | 3 +- mmdet3d/datasets/utils.py | 66 +++++++++++++++++++++ tools/misc/browse_dataset.py | 107 +++++++++++++++++++++++++++++++++++ 4 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 mmdet3d/datasets/utils.py create mode 100644 tools/misc/browse_dataset.py diff --git a/docs/useful_tools.md b/docs/useful_tools.md index ea1697994a..fe911ad7dd 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -49,8 +49,12 @@ time std over epochs is 0.0028 average iter time: 1.1959 s/iter ``` +  + # Visualization +## Results + To see the SUNRGBD, ScanNet or KITTI points and detection results, you can run the following command ```bash @@ -80,6 +84,26 @@ Or you can use 3D visualization software such as the [MeshLab](http://www.meshla **Notice**: The visualization API is a little unstable since we plan to refactor these parts together with MMDetection in the future. +## Dataset + +To browse the KITTI directly without inference, you can run the following command + +```shell +python tools/misc/browse_dataset.py ${CONFIG_FILE} --output-dir ${OUTPUT_DIR} +``` + +Sample config can be found in `configs/_base_/datasets/` folder. + +E.g., + +```shell +python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py +``` + +**Notice**: Once specifying `--output-dir`, the images of views specified by users will be saved when pressing _ESC_ in open3d window. + +  + # Model Complexity You can use `tools/analysis_tools/get_flops.py` in MMDetection, a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch), to compute the FLOPs and params of a given model. @@ -107,6 +131,8 @@ comparisons, but double check it before you adopt it in technical reports or pap 2. Some operators are not counted into FLOPs like GN and custom operators. Refer to [`mmcv.cnn.get_model_complexity_info()`](https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/flops_counter.py) for details. 3. The FLOPs of two-stage detectors is dependent on the number of proposals. +  + # Model Conversion ## RegNet model to MMDetection @@ -150,6 +176,8 @@ python tools/model_converters/publish_model.py work_dirs/faster_rcnn/latest.pth The final output filename will be `faster_rcnn_r50_fpn_1x_20190801-{hash id}.pth`. +  + # Dataset Conversion `tools/data_converter/` contains tools to convert datasets to other formats. Most of them convert datasets to pickle based info files, like kitti, nuscenes and lyft. Waymo converter is used to reorganize waymo raw data like KITTI style. Users could refer to them for our approach to converting data format. It is also convenient to modify them to use as scripts like nuImages converter. @@ -169,6 +197,8 @@ python -u tools/data_converter/nuimage_converter.py --data-root ${DATA_ROOT} --v More details could be referred to the [doc](https://mmdetection3d.readthedocs.io/en/latest/data_preparation.html) for dataset preparation and [README](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/nuimages/README.md) for nuImages dataset. +  + # Miscellaneous ## Print the entire config diff --git a/mmdet3d/datasets/__init__.py b/mmdet3d/datasets/__init__.py index ea590801ba..b6ef6bce8d 100644 --- a/mmdet3d/datasets/__init__.py +++ b/mmdet3d/datasets/__init__.py @@ -15,6 +15,7 @@ from .scannet_dataset import ScanNetDataset, ScanNetSegDataset from .semantickitti_dataset import SemanticKITTIDataset from .sunrgbd_dataset import SUNRGBDDataset +from .utils import get_loading_pipeline from .waymo_dataset import WaymoDataset __all__ = [ @@ -27,5 +28,5 @@ 'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset', - 'BackgroundPointsFilter', 'VoxelBasedPointSampler' + 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline' ] diff --git a/mmdet3d/datasets/utils.py b/mmdet3d/datasets/utils.py new file mode 100644 index 0000000000..41c1656f3f --- /dev/null +++ b/mmdet3d/datasets/utils.py @@ -0,0 +1,66 @@ +from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D, + LoadAnnotations3D, + LoadMultiViewImageFromFiles, + LoadPointsFromFile, + LoadPointsFromMultiSweeps) +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines import LoadImageFromFile + + +def get_loading_pipeline(pipeline): + """Only keep loading image, points and annotations related configuration. + + Args: + pipeline (list[dict]): Data pipeline configs. + + Returns: + list[dict]: The new pipeline list with only keep + loading image, points and annotations related configuration. + + Examples: + >>> pipelines = [ + ... dict(type='LoadPointsFromFile', + ... coord_type='LIDAR', load_dim=4, use_dim=4), + ... dict(type='LoadImageFromFile'), + ... dict(type='LoadAnnotations3D', + ... with_bbox=True, with_label_3d=True), + ... dict(type='Resize', + ... img_scale=[(640, 192), (2560, 768)], keep_ratio=True), + ... dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + ... dict(type='PointsRangeFilter', + ... point_cloud_range=point_cloud_range), + ... dict(type='ObjectRangeFilter', + ... point_cloud_range=point_cloud_range), + ... dict(type='PointShuffle'), + ... dict(type='Normalize', **img_norm_cfg), + ... dict(type='Pad', size_divisor=32), + ... dict(type='DefaultFormatBundle3D', class_names=class_names), + ... dict(type='Collect3D', + ... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) + ... ] + >>> expected_pipelines = [ + ... dict(type='LoadPointsFromFile', + ... coord_type='LIDAR', load_dim=4, use_dim=4), + ... dict(type='LoadImageFromFile'), + ... dict(type='LoadAnnotations3D', + ... with_bbox=True, with_label_3d=True), + ... dict(type='DefaultFormatBundle3D', class_names=class_names), + ... dict(type='Collect3D', + ... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) + ... ] + >>> assert expected_pipelines ==\ + ... get_loading_pipeline(pipelines) + """ + loading_pipeline_cfg = [] + for cfg in pipeline: + obj_cls = PIPELINES.get(cfg['type']) + # TODO: use more elegant way to distinguish loading modules + if obj_cls is not None and obj_cls in ( + LoadImageFromFile, LoadPointsFromFile, LoadAnnotations3D, + LoadMultiViewImageFromFiles, LoadPointsFromMultiSweeps, + DefaultFormatBundle3D, Collect3D): + loading_pipeline_cfg.append(cfg) + assert len(loading_pipeline_cfg) > 0, \ + 'The data pipeline in your config file must include ' \ + 'loading step.' + return loading_pipeline_cfg diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py new file mode 100644 index 0000000000..214e9f6b06 --- /dev/null +++ b/tools/misc/browse_dataset.py @@ -0,0 +1,107 @@ +import argparse +from mmcv import Config, DictAction, mkdir_or_exist, track_iter_progress +from os import path as osp + +from mmdet3d.core.bbox import Box3DMode, Coord3DMode +from mmdet3d.core.visualizer.open3d_vis import Visualizer +from mmdet3d.datasets import build_dataset, get_loading_pipeline + + +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--skip-type', + type=str, + nargs='+', + default=['Normalize'], + help='skip some useless pipeline') + parser.add_argument( + '--output-dir', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def retrieve_data_cfg(config_path, skip_type, cfg_options): + cfg = Config.fromfile(config_path) + if cfg_options is not None: + cfg.merge_from_dict(cfg_options) + # import modules from string list. + if cfg.get('custom_imports', None): + from mmcv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + if cfg.data.train['type'] == 'RepeatDataset': + train_data_cfg = cfg.data.train.dataset + else: + train_data_cfg = cfg.data.train + train_data_cfg['pipeline'] = [ + x for x in train_data_cfg.pipeline if x['type'] not in skip_type + ] + + return cfg + + +def main(): + args = parse_args() + + if args.output_dir is not None: + mkdir_or_exist(args.output_dir) + + cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options) + if cfg.data.train['type'] == 'RepeatDataset': + cfg.data.train.dataset['pipeline'] = get_loading_pipeline( + cfg.train_pipeline) + else: + cfg.data.train['pipeline'] = get_loading_pipeline(cfg.train_pipeline) + dataset = build_dataset( + cfg.data.train, default_args=dict(filter_empty_gt=False)) + # For RepeatDataset type, the infos are stored in dataset.dataset + if cfg.data.train['type'] == 'RepeatDataset': + dataset = dataset.dataset + data_infos = dataset.data_infos + + for idx, data_info in enumerate(track_iter_progress(data_infos)): + if cfg.dataset_type in ['KittiDataset', 'WaymoDataset']: + pts_path = data_info['point_cloud']['velodyne_path'] + elif cfg.dataset_type in ['ScanNetDataset', 'SUNRGBDDataset']: + pts_path = data_info['pts_path'] + elif cfg.dataset_type in ['NuScenesDataset', 'LyftDataset']: + pts_path = data_info['lidar_path'] + else: + raise NotImplementedError( + f'unsupported dataset type {cfg.dataset_type}') + file_name = osp.splitext(osp.basename(pts_path))[0] + save_path = osp.join(args.output_dir, + f'{file_name}.png') if args.output_dir else None + + example = dataset.prepare_train_data(idx) + points = example['points']._data.numpy() + points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR, + Coord3DMode.DEPTH) + gt_bboxes = dataset.get_ann_info(idx)['gt_bboxes_3d'].tensor + if gt_bboxes is not None: + gt_bboxes = Box3DMode.convert(gt_bboxes, Box3DMode.LIDAR, + Box3DMode.DEPTH) + + vis = Visualizer(points, save_path='./show.png') + vis.add_bboxes(bbox3d=gt_bboxes, bbox_color=(0, 0, 1)) + + vis.show(save_path) + del vis + + +if __name__ == '__main__': + main()