diff --git a/docs/en/useful_tools.md b/docs/en/useful_tools.md index adea03d585..b3ce67a58e 100644 --- a/docs/en/useful_tools.md +++ b/docs/en/useful_tools.md @@ -96,6 +96,12 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py - **Notice**: Once specifying `--output-dir`, the images of views specified by users will be saved when pressing `_ESC_` in open3d window. If you don't have a monitor, you can remove the `--online` flag to only save the visualization results and browse them offline. +To verify the data consistency and the effect of data augmentation, you can also add `--aug` flag to visualize the data after data augmentation using the command as below: + +```shell +python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task det --aug --output-dir ${OUTPUT_DIR} --online +``` + If you also want to show 2D images with 3D bounding boxes projected onto them, you need to find a config that supports multi-modality data loading, and then change the `--task` args to `multi_modality-det`. An example is showed below ```shell diff --git a/docs/zh_cn/useful_tools.md b/docs/zh_cn/useful_tools.md index 1cc8edf27f..2552dfe582 100644 --- a/docs/zh_cn/useful_tools.md +++ b/docs/zh_cn/useful_tools.md @@ -97,6 +97,12 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py - **注意**:一旦指定 `--output-dir` ,当按下 open3d 窗口的 `_ESC_`,用户指定的视图图像将被保存。如果您没有显示器,您可以移除 `--online` 标志,从而仅仅保存可视化结果并且进行离线浏览。 +为了验证数据的一致性和数据增强的效果,您还可以使用以下命令添加 `--aug` 标志来可视化数据增强后的数据: + +```shell +python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task det --aug --output-dir ${OUTPUT_DIR} --online +``` + 如果您还想显示 2D 图像以及投影的 3D 边界框,则需要找到支持多模态数据加载的配置文件,然后将 `--task` 参数更改为 `multi_modality-det`。一个例子如下所示 ```shell diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py index a1782c806f..e5da407178 100644 --- a/tools/misc/browse_dataset.py +++ b/tools/misc/browse_dataset.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import mmcv import numpy as np import warnings -from mmcv import Config, DictAction, mkdir_or_exist, track_iter_progress +from mmcv import Config, DictAction, mkdir_or_exist from os import path as osp +from pathlib import Path from mmdet3d.core.bbox import (Box3DMode, CameraInstance3DBoxes, Coord3DMode, DepthInstance3DBoxes, LiDARInstance3DBoxes) @@ -31,6 +33,10 @@ def parse_args(): type=str, choices=['det', 'seg', 'multi_modality-det', 'mono-det'], help='Determine the visualization method depending on the task.') + parser.add_argument( + '--aug', + action='store_true', + help='Whether to visualize augmented datasets or original dataset.') parser.add_argument( '--online', action='store_true', @@ -50,8 +56,9 @@ def parse_args(): return args -def build_data_cfg(config_path, skip_type, cfg_options): +def build_data_cfg(config_path, skip_type, aug, cfg_options): """Build data config for loading visualization data.""" + cfg = Config.fromfile(config_path) if cfg_options is not None: cfg.merge_from_dict(cfg_options) @@ -63,10 +70,17 @@ def build_data_cfg(config_path, skip_type, cfg_options): if cfg.data.train['type'] == 'ConcatDataset': cfg.data.train = cfg.data.train.datasets[0] train_data_cfg = cfg.data.train - # eval_pipeline purely consists of loading functions - # use eval_pipeline for data loading + + if aug: + show_pipeline = cfg.train_pipeline + else: + show_pipeline = cfg.eval_pipeline + for i in range(len(cfg.train_pipeline)): + if cfg.train_pipeline[i]['type'] == 'LoadAnnotations3D': + show_pipeline.insert(i, cfg.train_pipeline[i]) + train_data_cfg['pipeline'] = [ - x for x in cfg.eval_pipeline if x['type'] not in skip_type + x for x in show_pipeline if x['type'] not in skip_type ] return cfg @@ -83,13 +97,14 @@ def to_depth_mode(points, bboxes): return points, bboxes -def show_det_data(idx, dataset, out_dir, filename, show=False): +def show_det_data(input, out_dir, show=False): """Visualize 3D point cloud and 3D bboxes.""" - example = dataset.prepare_train_data(idx) - points = example['points']._data.numpy() - gt_bboxes = dataset.get_ann_info(idx)['gt_bboxes_3d'].tensor - if dataset.box_mode_3d != Box3DMode.DEPTH: + img_metas = input['img_metas']._data + points = input['points']._data.numpy() + gt_bboxes = input['gt_bboxes_3d']._data.tensor + if img_metas['box_mode_3d'] != Box3DMode.DEPTH: points, gt_bboxes = to_depth_mode(points, gt_bboxes) + filename = osp.splitext(osp.basename(img_metas['pts_filename']))[0] show_result( points, gt_bboxes.clone(), @@ -100,42 +115,35 @@ def show_det_data(idx, dataset, out_dir, filename, show=False): snapshot=True) -def show_seg_data(idx, dataset, out_dir, filename, show=False): +def show_seg_data(input, out_dir, show=False): """Visualize 3D point cloud and segmentation mask.""" - example = dataset.prepare_train_data(idx) - points = example['points']._data.numpy() - gt_seg = example['pts_semantic_mask']._data.numpy() + img_metas = input['img_metas']._data + points = input['points']._data.numpy() + gt_seg = input['pts_semantic_mask']._data.numpy() + filename = osp.splitext(osp.basename(img_metas['pts_filename']))[0] show_seg_result( points, gt_seg.copy(), None, out_dir, filename, - np.array(dataset.PALETTE), - dataset.ignore_index, + np.array(img_metas['PALETTE']), + img_metas['ignore_index'], show=show, snapshot=True) -def show_proj_bbox_img(idx, - dataset, - out_dir, - filename, - show=False, - is_nus_mono=False): +def show_proj_bbox_img(input, out_dir, show=False, is_nus_mono=False): """Visualize 3D bboxes on 2D image by projection.""" - try: - example = dataset.prepare_train_data(idx) - except AttributeError: # for Mono-3D datasets - example = dataset.prepare_train_img(idx) - gt_bboxes = dataset.get_ann_info(idx)['gt_bboxes_3d'] - img_metas = example['img_metas']._data - img = example['img']._data.numpy() + gt_bboxes = input['gt_bboxes_3d']._data + img_metas = input['img_metas']._data + img = input['img']._data.numpy() # need to transpose channel to first dim img = img.transpose(1, 2, 0) # no 3D gt bboxes, just show img if gt_bboxes.tensor.shape[0] == 0: gt_bboxes = None + filename = Path(img_metas['filename']).name if isinstance(gt_bboxes, DepthInstance3DBoxes): show_multi_modality_result( img, @@ -183,53 +191,34 @@ def main(): if args.output_dir is not None: mkdir_or_exist(args.output_dir) - cfg = build_data_cfg(args.config, args.skip_type, args.cfg_options) + cfg = build_data_cfg(args.config, args.skip_type, args.aug, + args.cfg_options) try: dataset = build_dataset( cfg.data.train, default_args=dict(filter_empty_gt=False)) except TypeError: # seg dataset doesn't have `filter_empty_gt` key dataset = build_dataset(cfg.data.train) - data_infos = dataset.data_infos - dataset_type = cfg.dataset_type + dataset_type = cfg.dataset_type # configure visualization mode vis_task = args.task # 'det', 'seg', 'multi_modality-det', 'mono-det' + progress_bar = mmcv.ProgressBar(len(dataset)) - for idx, data_info in enumerate(track_iter_progress(data_infos)): - if dataset_type in ['KittiDataset', 'WaymoDataset']: - data_path = data_info['point_cloud']['velodyne_path'] - elif dataset_type in [ - 'ScanNetDataset', 'SUNRGBDDataset', 'ScanNetSegDataset', - 'S3DISSegDataset', 'S3DISDataset' - ]: - data_path = data_info['pts_path'] - elif dataset_type in ['NuScenesDataset', 'LyftDataset']: - data_path = data_info['lidar_path'] - elif dataset_type in ['NuScenesMonoDataset']: - data_path = data_info['file_name'] - else: - raise NotImplementedError( - f'unsupported dataset type {dataset_type}') - - file_name = osp.splitext(osp.basename(data_path))[0] - + for input in dataset: if vis_task in ['det', 'multi_modality-det']: # show 3D bboxes on 3D point clouds - show_det_data( - idx, dataset, args.output_dir, file_name, show=args.online) + show_det_data(input, args.output_dir, show=args.online) if vis_task in ['multi_modality-det', 'mono-det']: # project 3D bboxes to 2D image show_proj_bbox_img( - idx, - dataset, + input, args.output_dir, - file_name, show=args.online, is_nus_mono=(dataset_type == 'NuScenesMonoDataset')) elif vis_task in ['seg']: # show 3D segmentation mask on 3D point clouds - show_seg_data( - idx, dataset, args.output_dir, file_name, show=args.online) + show_seg_data(input, args.output_dir, show=args.online) + progress_bar.update() if __name__ == '__main__':