Skip to content

Commit

Permalink
set test_mode for mmdet (#920)
Browse files Browse the repository at this point in the history
* fix

* update
  • Loading branch information
RunningLeon authored Aug 19, 2022
1 parent a6e07da commit 1f8d889
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
3 changes: 2 additions & 1 deletion mmdeploy/codebase/base/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def init_pytorch_model(self,
def build_dataset(self,
dataset_cfg: Union[str, mmcv.Config],
dataset_type: str = 'val',
is_sort_dataset: bool = True,
is_sort_dataset: bool = False,
**kwargs) -> Dataset:
"""Build dataset for different codebase.
Expand All @@ -80,6 +80,7 @@ def build_dataset(self,
is_sort_dataset (bool): When 'True', the dataset will be sorted
by image shape in ascending order if 'dataset_cfg'
contains information about height and width.
Default is `False`.
Returns:
Dataset: The built dataset.
Expand Down
16 changes: 13 additions & 3 deletions mmdeploy/codebase/mmdet/deploy/mmdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,19 @@ def build_dataset(dataset_cfg: Union[str, mmcv.Config],

data_cfg = dataset_cfg.data[dataset_type]
samples_per_gpu = dataset_cfg.data.get('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
if isinstance(data_cfg, dict):
data_cfg.test_mode = True
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
elif isinstance(data_cfg, list):
for ds_cfg in data_cfg:
ds_cfg.test_mode = True
if samples_per_gpu > 1:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
else:
raise TypeError(f'Unsupported type {type(data_cfg)}')

dataset = build_dataset_mmdet(data_cfg)

return dataset
Expand Down

0 comments on commit 1f8d889

Please sign in to comment.