Skip to content

Commit

Permalink
[Fix] Support MMDet custom dataset (open-mmlab#33)
Browse files Browse the repository at this point in the history
* Fix mmdet classes

* Fix classes in data_cfg
  • Loading branch information
SingleZombie authored Jan 10, 2022
1 parent 37a1b83 commit 3e5c785
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 39 deletions.
40 changes: 29 additions & 11 deletions mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,31 +582,49 @@ class labels of shape [N, num_det].
return [dets, labels]


def get_classes_from_config(model_cfg: Union[str, mmcv.Config], **kwargs):
"""Get class name from config.
def get_classes_from_config(model_cfg: Union[str, mmcv.Config], **kwargs) -> \
List[str]:
"""Get class name from config. The class name is the `classes` field if it
is set in the config, or the classes in `module_dict` of MMDet whose type
is set in the config.
Args:
model_cfg (str | mmcv.Config): Input model config file or
Config object.
Returns:
list[str]: A list of string specifying names of different class.
List[str]: A list of string specifying names of different class.
"""
# load cfg if necessary
model_cfg = load_config(model_cfg)[0]

# For custom dataset
if 'classes' in model_cfg:
return list(model_cfg['classes'])

module_dict = DATASETS.module_dict
data_cfg = model_cfg.data
classes = None
module = None

if 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
elif 'val' in data_cfg:
module = module_dict[data_cfg.val.type]
elif 'train' in data_cfg:
module = module_dict[data_cfg.train.type]
else:
keys = ['test', 'val', 'train']

for key in keys:
if key in data_cfg:
if 'classes' in data_cfg[key]:
classes = list(data_cfg[key]['classes'])
break
elif 'type' in data_cfg[key]:
module = module_dict[data_cfg[key]['type']]
break

if classes is None and module is None:
raise RuntimeError(f'No dataset config found in: {model_cfg}')

return module.CLASSES
if classes is not None:
return classes
else:
return module.CLASSES


def build_object_detection_model(model_files: Sequence[str],
Expand Down
80 changes: 52 additions & 28 deletions tests/test_codebase/test_mmdet/test_object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,34 +385,58 @@ def partition1_postprocess(self, *args, **kwargs):
assert_forward_results(results, 'PartitionTwoStageModel')


data_cfg1 = mmcv.Config(
dict(
data=dict(
test=dict(type='CocoDataset'),
val=dict(type='CityscapesDataset'),
train=dict(type='CityscapesDataset'))))
data_cfg2 = mmcv.Config(
dict(
data=dict(
val=dict(type='CocoDataset'), train=dict(
type='CityscapesDataset'))))
data_cfg3 = mmcv.Config(dict(data=dict(train=dict(type='CocoDataset'))))
data_cfg4 = mmcv.Config(dict(data=dict(error=dict(type='CocoDataset'))))


@pytest.mark.parametrize('cfg', [data_cfg1, data_cfg2, data_cfg3, data_cfg4])
def test_get_classes_from_cfg(cfg):
from mmdet.datasets import DATASETS

from mmdeploy.codebase.mmdet.deploy.object_detection_model import \
get_classes_from_config

if 'error' in cfg.data:
with pytest.raises(RuntimeError):
get_classes_from_config(cfg)
else:
assert get_classes_from_config(
cfg) == DATASETS.module_dict['CocoDataset'].CLASSES
class TestGetClassesFromCfg:
data_cfg1 = mmcv.Config(
dict(
data=dict(
test=dict(type='CocoDataset'),
val=dict(type='CityscapesDataset'),
train=dict(type='CityscapesDataset'))))

data_cfg2 = mmcv.Config(
dict(
data=dict(
val=dict(type='CocoDataset'),
train=dict(type='CityscapesDataset'))))
data_cfg3 = mmcv.Config(dict(data=dict(train=dict(type='CocoDataset'))))
data_cfg4 = mmcv.Config(dict(data=dict(error=dict(type='CocoDataset'))))

data_cfg_classes_1 = mmcv.Config(
dict(
data=dict(
test=dict(classes=('a')),
val=dict(classes=('b')),
train=dict(classes=('b')))))

data_cfg_classes_2 = mmcv.Config(
dict(data=dict(val=dict(classes=('a')), train=dict(classes=('b')))))
data_cfg_classes_3 = mmcv.Config(
dict(data=dict(train=dict(classes=('a')))))
data_cfg_classes_4 = mmcv.Config(dict(classes=('a')))

@pytest.mark.parametrize('cfg',
[data_cfg1, data_cfg2, data_cfg3, data_cfg4])
def test_get_classes_from_cfg(self, cfg):
from mmdet.datasets import DATASETS
from mmdeploy.codebase.mmdet.deploy.object_detection_model import \
get_classes_from_config

if 'error' in cfg.data:
with pytest.raises(RuntimeError):
get_classes_from_config(cfg)
else:
assert get_classes_from_config(
cfg) == DATASETS.module_dict['CocoDataset'].CLASSES

@pytest.mark.parametrize('cfg', [
data_cfg_classes_1, data_cfg_classes_2, data_cfg_classes_3,
data_cfg_classes_4
])
def test_get_classes_from_custom_cfg(self, cfg):
from mmdeploy.codebase.mmdet.deploy.object_detection_model import \
get_classes_from_config

assert get_classes_from_config(cfg) == ['a']


@backend_checker(Backend.ONNXRUNTIME)
Expand Down

0 comments on commit 3e5c785

Please sign in to comment.