Skip to content

Commit

Permalink
[Enhancement]: Optimize config utils (open-mmlab#263)
Browse files Browse the repository at this point in the history
* Optimize config utils

* Update `get_backend`

* Add assert
  • Loading branch information
SingleZombie authored Dec 8, 2021
1 parent f424fca commit 03c95a1
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 82 deletions.
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def fcn_mask_head__get_seg_masks(ctx, self, mask_pred, det_bboxes, det_labels,
Returns:
Tensor: a mask of shape (N, img_h, img_w).
"""
backend = get_backend(ctx.cfg, 'default')
backend = get_backend(ctx.cfg)
mask_pred = mask_pred.sigmoid()
bboxes = det_bboxes[:, :4]
labels = det_labels
Expand Down
6 changes: 3 additions & 3 deletions mmdeploy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .config_utils import (cfg_apply_marks, get_backend, get_calib_config,
get_calib_filename, get_codebase,
from .config_utils import (cfg_apply_marks, get_backend, get_backend_config,
get_calib_config, get_calib_filename, get_codebase,
get_codebase_config, get_common_config,
get_input_shape, get_model_inputs, get_onnx_config,
get_partition_config, get_task_type,
Expand All @@ -14,5 +14,5 @@
'get_onnx_config', 'get_partition_config', 'get_calib_config',
'get_calib_filename', 'get_common_config', 'get_model_inputs',
'cfg_apply_marks', 'get_input_shape', 'parse_device_id',
'parse_cuda_device_id', 'get_codebase_config'
'parse_cuda_device_id', 'get_codebase_config', 'get_backend_config'
]
88 changes: 43 additions & 45 deletions mmdeploy/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
from typing import Dict, List, Union

import mmcv

Expand Down Expand Up @@ -30,85 +30,82 @@ def _load_config(cfg):
return configs


def get_task_type(deploy_cfg: Union[str, mmcv.Config],
default: str = None) -> Task:
"""Get the task type of the algorithm.
def get_codebase_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict:
"""Get the codebase_config from the config.
Args:
deploy_cfg (str | mmcv.Config): The path or content of config.
default (str): If the "task" field of config is empty, then return
default task type.
Returns:
Task : An enumeration denotes the task type.
Dict : codebase config dict.
"""

codebase_config = get_codebase_config(deploy_cfg)
try:
task = codebase_config['task']
except KeyError:
return default

task = Task.get(task, default)
return task
deploy_cfg = load_config(deploy_cfg)[0]
codebase_config = deploy_cfg.get('codebase_config', {})
return codebase_config


def get_codebase_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict:
"""Get the codebase_config from the config.
def get_task_type(deploy_cfg: Union[str, mmcv.Config]) -> Task:
"""Get the task type of the algorithm.
Args:
deploy_cfg (str | mmcv.Config): The path or content of config.
Returns:
Dict : codebase config dict.
Task : An enumeration denotes the task type.
"""
deploy_cfg = load_config(deploy_cfg)[0]
codebase_config = deploy_cfg.get('codebase_config', {})
return codebase_config

codebase_config = get_codebase_config(deploy_cfg)
assert 'task' in codebase_config, 'The codebase config of deploy config'\
'requires a "task" field'
task = codebase_config['task']
return Task.get(task)


def get_codebase(deploy_cfg: Union[str, mmcv.Config],
default: Optional[str] = None) -> Codebase:
def get_codebase(deploy_cfg: Union[str, mmcv.Config]) -> Codebase:
"""Get the codebase from the config.
Args:
deploy_cfg (str | mmcv.Config): The path or content of config.
default (str): If the "codebase" field of config is empty, then return
default codebase type.
Returns:
Codebase : An enumeration denotes the codebase type.
"""

codebase_config = get_codebase_config(deploy_cfg)
try:
codebase = codebase_config['type']
except KeyError:
return default
assert 'type' in codebase_config, 'The codebase config of deploy config'\
'requires a "type" field'
codebase = codebase_config['type']
return Codebase.get(codebase)


codebase = Codebase.get(codebase, default)
return codebase
def get_backend_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict:
"""Get the backend_config from the config.
Args:
deploy_cfg (str | mmcv.Config): The path or content of config.
def get_backend(deploy_cfg: Union[str, mmcv.Config], default=None) -> Backend:
Returns:
Dict : backend config dict.
"""
deploy_cfg = load_config(deploy_cfg)[0]
backend_config = deploy_cfg.get('backend_config', {})
return backend_config


def get_backend(deploy_cfg: Union[str, mmcv.Config]) -> Backend:
"""Get the backend from the config.
Args:
deploy_cfg (str | mmcv.Config): The path or content of config.
default (str): If the "backend" field of config is empty, then return
default backend type.
Returns:
Backend: An enumeration denotes the backend type.
"""

deploy_cfg = load_config(deploy_cfg)[0]
try:
backend = deploy_cfg['backend_config']['type']
except KeyError:
return default
backend = Backend.get(backend, default)
return backend
backend_config = get_backend_config(deploy_cfg)
assert 'type' in backend_config, 'The backend config of deploy config'\
'requires a "type" field'
backend = backend_config['type']
return Backend.get(backend)


def get_onnx_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict:
Expand All @@ -118,11 +115,12 @@ def get_onnx_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict:
deploy_cfg (str | mmcv.Config): The path or content of config.
Returns:
dict: The config dictionary of onnx parameters
Dict: The config dictionary of onnx parameters
"""

deploy_cfg = load_config(deploy_cfg)[0]
return deploy_cfg['onnx_config']
onnx_config = deploy_cfg.get('onnx_config', {})
return onnx_config


def is_dynamic_batch(deploy_cfg: Union[str, mmcv.Config],
Expand Down
7 changes: 4 additions & 3 deletions mmdeploy/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ class AdvancedEnum(Enum):
"""Define an enumeration class."""

@classmethod
def get(cls, str, a):
def get(cls, value):
"""Get the key through a value."""
for k in cls:
if k.value == str:
if k.value == value:
return k
return a

raise KeyError(f'Cannot get key by value "{value}" of {cls}')


class Task(AdvancedEnum):
Expand Down
56 changes: 27 additions & 29 deletions tests/test_utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,6 @@ def test_load_config(self, args):
assert v[0]._cfg_dict == cfg._cfg_dict


class TestGetTaskType:

def test_get_task_type_none(self):
assert util.get_task_type(mmcv.Config(dict())) is None

def test_get_task_type_default(self):
assert util.get_task_type(mmcv.Config(dict()),
Task.SUPER_RESOLUTION) == \
Task.SUPER_RESOLUTION

def test_get_task_type(self):
assert util.get_task_type(correct_deploy_path) == Task.SUPER_RESOLUTION


class TestGetCodebaseConfig:

def test_get_codebase_config_empty(self):
Expand All @@ -81,37 +67,50 @@ def test_get_codebase_config(self):
assert isinstance(codebase_config, dict) and len(codebase_config) > 1


class TestGetTaskType:

def test_get_task_type_none(self):
with pytest.raises(AssertionError):
util.get_task_type(mmcv.Config(dict()))

def test_get_task_type(self):
assert util.get_task_type(correct_deploy_path) == Task.SUPER_RESOLUTION


class TestGetCodebase:

def test_get_codebase_none(self):
assert util.get_codebase(mmcv.Config(dict())) is None

def test_get_codebase_default(self):
assert util.get_codebase(mmcv.Config(dict()),
Codebase.MMEDIT) == Codebase.MMEDIT
with pytest.raises(AssertionError):
util.get_codebase(mmcv.Config(dict()))

def test_get_codebase(self):
assert util.get_codebase(correct_deploy_path) == Codebase.MMEDIT


class TestGetBackendConfig:

def test_get_backend_config_empty(self):
assert util.get_backend_config(mmcv.Config(dict())) == {}

def test_get_backend_config(self):
backend_config = util.get_backend_config(correct_deploy_path)
assert isinstance(backend_config, dict) and len(backend_config) == 1


class TestGetBackend:

def test_get_backend_none(self):
assert util.get_backend(mmcv.Config(dict())) is None

def test_get_backend_default(self):
assert util.get_backend(empty_file_path,
Backend.ONNXRUNTIME) == Backend.ONNXRUNTIME
with pytest.raises(AssertionError):
util.get_backend(mmcv.Config(dict()))

def test_get_backend(self):
assert util.get_backend(correct_deploy_path) == Backend.ONNXRUNTIME


class TestGetOnnxConfig:

def test_get_onnx_config_error(self):
with pytest.raises(Exception):
util.get_onnx_config(empty_file_path)
def test_get_onnx_config_empty(self):
assert util.get_onnx_config(mmcv.Config(dict())) == {}

def test_get_onnx_config(self):
onnx_config = dict(
Expand Down Expand Up @@ -286,9 +285,8 @@ def test_AdvancedEnum():
'Classification', 'ObjectDetection'
]
for k, v in zip(keys, vals):
assert Task.get(v, None) == k
assert Task.get(v) == k
assert k.value == v
assert Task.get('a', Task.TEXT_DETECTION) == Task.TEXT_DETECTION


def test_export_info():
Expand Down
2 changes: 1 addition & 1 deletion tools/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def main():

backend_files = onnx_files
# convert backend
backend = get_backend(deploy_cfg, 'default')
backend = get_backend(deploy_cfg)
if backend == Backend.TENSORRT:
model_params = get_model_inputs(deploy_cfg)
assert len(model_params) == len(onnx_files)
Expand Down

0 comments on commit 03c95a1

Please sign in to comment.