diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py index 21cb107fd7..0adb78c497 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py @@ -26,7 +26,7 @@ def fcn_mask_head__get_seg_masks(ctx, self, mask_pred, det_bboxes, det_labels, Returns: Tensor: a mask of shape (N, img_h, img_w). """ - backend = get_backend(ctx.cfg, 'default') + backend = get_backend(ctx.cfg) mask_pred = mask_pred.sigmoid() bboxes = det_bboxes[:, :4] labels = det_labels diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index c2f170f053..4d977dfd5c 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .config_utils import (cfg_apply_marks, get_backend, get_calib_config, - get_calib_filename, get_codebase, +from .config_utils import (cfg_apply_marks, get_backend, get_backend_config, + get_calib_config, get_calib_filename, get_codebase, get_codebase_config, get_common_config, get_input_shape, get_model_inputs, get_onnx_config, get_partition_config, get_task_type, @@ -14,5 +14,5 @@ 'get_onnx_config', 'get_partition_config', 'get_calib_config', 'get_calib_filename', 'get_common_config', 'get_model_inputs', 'cfg_apply_marks', 'get_input_shape', 'parse_device_id', - 'parse_cuda_device_id', 'get_codebase_config' + 'parse_cuda_device_id', 'get_codebase_config', 'get_backend_config' ] diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py index 1ddd4f124a..1d9296c23e 100644 --- a/mmdeploy/utils/config_utils.py +++ b/mmdeploy/utils/config_utils.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Union +from typing import Dict, List, Union import mmcv @@ -30,85 +30,82 @@ def _load_config(cfg): return configs -def get_task_type(deploy_cfg: Union[str, mmcv.Config], - default: str = None) -> Task: - """Get the task type of the algorithm. +def get_codebase_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict: + """Get the codebase_config from the config. Args: deploy_cfg (str | mmcv.Config): The path or content of config. - default (str): If the "task" field of config is empty, then return - default task type. Returns: - Task : An enumeration denotes the task type. + Dict : codebase config dict. """ - - codebase_config = get_codebase_config(deploy_cfg) - try: - task = codebase_config['task'] - except KeyError: - return default - - task = Task.get(task, default) - return task + deploy_cfg = load_config(deploy_cfg)[0] + codebase_config = deploy_cfg.get('codebase_config', {}) + return codebase_config -def get_codebase_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict: - """Get the codebase_config from the config. +def get_task_type(deploy_cfg: Union[str, mmcv.Config]) -> Task: + """Get the task type of the algorithm. Args: deploy_cfg (str | mmcv.Config): The path or content of config. Returns: - Dict : codebase config dict. + Task : An enumeration denotes the task type. """ - deploy_cfg = load_config(deploy_cfg)[0] - codebase_config = deploy_cfg.get('codebase_config', {}) - return codebase_config + + codebase_config = get_codebase_config(deploy_cfg) + assert 'task' in codebase_config, 'The codebase config of deploy config'\ + 'requires a "task" field' + task = codebase_config['task'] + return Task.get(task) -def get_codebase(deploy_cfg: Union[str, mmcv.Config], - default: Optional[str] = None) -> Codebase: +def get_codebase(deploy_cfg: Union[str, mmcv.Config]) -> Codebase: """Get the codebase from the config. Args: deploy_cfg (str | mmcv.Config): The path or content of config. - default (str): If the "codebase" field of config is empty, then return - default codebase type. Returns: Codebase : An enumeration denotes the codebase type. """ codebase_config = get_codebase_config(deploy_cfg) - try: - codebase = codebase_config['type'] - except KeyError: - return default + assert 'type' in codebase_config, 'The codebase config of deploy config'\ + 'requires a "type" field' + codebase = codebase_config['type'] + return Codebase.get(codebase) + - codebase = Codebase.get(codebase, default) - return codebase +def get_backend_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict: + """Get the backend_config from the config. + Args: + deploy_cfg (str | mmcv.Config): The path or content of config. -def get_backend(deploy_cfg: Union[str, mmcv.Config], default=None) -> Backend: + Returns: + Dict : backend config dict. + """ + deploy_cfg = load_config(deploy_cfg)[0] + backend_config = deploy_cfg.get('backend_config', {}) + return backend_config + + +def get_backend(deploy_cfg: Union[str, mmcv.Config]) -> Backend: """Get the backend from the config. Args: deploy_cfg (str | mmcv.Config): The path or content of config. - default (str): If the "backend" field of config is empty, then return - default backend type. Returns: Backend: An enumeration denotes the backend type. """ - - deploy_cfg = load_config(deploy_cfg)[0] - try: - backend = deploy_cfg['backend_config']['type'] - except KeyError: - return default - backend = Backend.get(backend, default) - return backend + backend_config = get_backend_config(deploy_cfg) + assert 'type' in backend_config, 'The backend config of deploy config'\ + 'requires a "type" field' + backend = backend_config['type'] + return Backend.get(backend) def get_onnx_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict: @@ -118,11 +115,12 @@ def get_onnx_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict: deploy_cfg (str | mmcv.Config): The path or content of config. Returns: - dict: The config dictionary of onnx parameters + Dict: The config dictionary of onnx parameters """ deploy_cfg = load_config(deploy_cfg)[0] - return deploy_cfg['onnx_config'] + onnx_config = deploy_cfg.get('onnx_config', {}) + return onnx_config def is_dynamic_batch(deploy_cfg: Union[str, mmcv.Config], diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index 2c759c3ac2..2e3d978cba 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -6,12 +6,13 @@ class AdvancedEnum(Enum): """Define an enumeration class.""" @classmethod - def get(cls, str, a): + def get(cls, value): """Get the key through a value.""" for k in cls: - if k.value == str: + if k.value == value: return k - return a + + raise KeyError(f'Cannot get key by value "{value}" of {cls}') class Task(AdvancedEnum): diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index 7210e21b32..a488f09e86 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -57,20 +57,6 @@ def test_load_config(self, args): assert v[0]._cfg_dict == cfg._cfg_dict -class TestGetTaskType: - - def test_get_task_type_none(self): - assert util.get_task_type(mmcv.Config(dict())) is None - - def test_get_task_type_default(self): - assert util.get_task_type(mmcv.Config(dict()), - Task.SUPER_RESOLUTION) == \ - Task.SUPER_RESOLUTION - - def test_get_task_type(self): - assert util.get_task_type(correct_deploy_path) == Task.SUPER_RESOLUTION - - class TestGetCodebaseConfig: def test_get_codebase_config_empty(self): @@ -81,27 +67,41 @@ def test_get_codebase_config(self): assert isinstance(codebase_config, dict) and len(codebase_config) > 1 +class TestGetTaskType: + + def test_get_task_type_none(self): + with pytest.raises(AssertionError): + util.get_task_type(mmcv.Config(dict())) + + def test_get_task_type(self): + assert util.get_task_type(correct_deploy_path) == Task.SUPER_RESOLUTION + + class TestGetCodebase: def test_get_codebase_none(self): - assert util.get_codebase(mmcv.Config(dict())) is None - - def test_get_codebase_default(self): - assert util.get_codebase(mmcv.Config(dict()), - Codebase.MMEDIT) == Codebase.MMEDIT + with pytest.raises(AssertionError): + util.get_codebase(mmcv.Config(dict())) def test_get_codebase(self): assert util.get_codebase(correct_deploy_path) == Codebase.MMEDIT +class TestGetBackendConfig: + + def test_get_backend_config_empty(self): + assert util.get_backend_config(mmcv.Config(dict())) == {} + + def test_get_backend_config(self): + backend_config = util.get_backend_config(correct_deploy_path) + assert isinstance(backend_config, dict) and len(backend_config) == 1 + + class TestGetBackend: def test_get_backend_none(self): - assert util.get_backend(mmcv.Config(dict())) is None - - def test_get_backend_default(self): - assert util.get_backend(empty_file_path, - Backend.ONNXRUNTIME) == Backend.ONNXRUNTIME + with pytest.raises(AssertionError): + util.get_backend(mmcv.Config(dict())) def test_get_backend(self): assert util.get_backend(correct_deploy_path) == Backend.ONNXRUNTIME @@ -109,9 +109,8 @@ def test_get_backend(self): class TestGetOnnxConfig: - def test_get_onnx_config_error(self): - with pytest.raises(Exception): - util.get_onnx_config(empty_file_path) + def test_get_onnx_config_empty(self): + assert util.get_onnx_config(mmcv.Config(dict())) == {} def test_get_onnx_config(self): onnx_config = dict( @@ -286,9 +285,8 @@ def test_AdvancedEnum(): 'Classification', 'ObjectDetection' ] for k, v in zip(keys, vals): - assert Task.get(v, None) == k + assert Task.get(v) == k assert k.value == v - assert Task.get('a', Task.TEXT_DETECTION) == Task.TEXT_DETECTION def test_export_info(): diff --git a/tools/deploy.py b/tools/deploy.py index 13f1c5bcd1..424f195207 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -172,7 +172,7 @@ def main(): backend_files = onnx_files # convert backend - backend = get_backend(deploy_cfg, 'default') + backend = get_backend(deploy_cfg) if backend == Backend.TENSORRT: model_params = get_model_inputs(deploy_cfg) assert len(model_params) == len(onnx_files)