From 1f8d889b36fb39cab92b0f3363d1201a38f74835 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Fri, 19 Aug 2022 10:55:41 +0800 Subject: [PATCH] set test_mode for mmdet (#920) * fix * update --- mmdeploy/codebase/base/task.py | 3 ++- mmdeploy/codebase/mmdet/deploy/mmdetection.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 813c9d93fd..1e27c35f50 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -68,7 +68,7 @@ def init_pytorch_model(self, def build_dataset(self, dataset_cfg: Union[str, mmcv.Config], dataset_type: str = 'val', - is_sort_dataset: bool = True, + is_sort_dataset: bool = False, **kwargs) -> Dataset: """Build dataset for different codebase. @@ -80,6 +80,7 @@ def build_dataset(self, is_sort_dataset (bool): When 'True', the dataset will be sorted by image shape in ascending order if 'dataset_cfg' contains information about height and width. + Default is `False`. Returns: Dataset: The built dataset. diff --git a/mmdeploy/codebase/mmdet/deploy/mmdetection.py b/mmdeploy/codebase/mmdet/deploy/mmdetection.py index ae02c1c04c..1bcd9ff1c2 100644 --- a/mmdeploy/codebase/mmdet/deploy/mmdetection.py +++ b/mmdeploy/codebase/mmdet/deploy/mmdetection.py @@ -62,9 +62,19 @@ def build_dataset(dataset_cfg: Union[str, mmcv.Config], data_cfg = dataset_cfg.data[dataset_type] samples_per_gpu = dataset_cfg.data.get('samples_per_gpu', 1) - if samples_per_gpu > 1: - # Replace 'ImageToTensor' to 'DefaultFormatBundle' - data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline) + if isinstance(data_cfg, dict): + data_cfg.test_mode = True + if samples_per_gpu > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline) + elif isinstance(data_cfg, list): + for ds_cfg in data_cfg: + ds_cfg.test_mode = True + if samples_per_gpu > 1: + ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + else: + raise TypeError(f'Unsupported type {type(data_cfg)}') + dataset = build_dataset_mmdet(data_cfg) return dataset