diff --git a/mmdeploy/backend/sdk/wrapper.py b/mmdeploy/backend/sdk/wrapper.py index 7fa7092215..9273ba647c 100644 --- a/mmdeploy/backend/sdk/wrapper.py +++ b/mmdeploy/backend/sdk/wrapper.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import mmdeploy_python as c_api -from mmdeploy.utils import Backend +from mmdeploy.utils import Backend, parse_device_id, parse_device_type from mmdeploy.utils.timer import TimeCounter from ..base import BACKEND_WRAPPER, BaseWrapper @@ -12,8 +12,11 @@ class SDKWrapper(BaseWrapper): def __init__(self, model_file, task_name, device): super().__init__([]) creator = getattr(c_api, task_name) - # TODO: get device id somewhere - self.handle = creator(model_file, device, 0) + device_id = parse_device_id(device) + device_type = parse_device_type(device) + # sdk does not support -1 device id + device_id = 0 if device_id < 0 else device_id + self.handle = creator(model_file, device_type, device_id) @TimeCounter.count_time(Backend.SDK.value) def invoke(self, imgs): diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index dda45c09d2..5ec291b136 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -2,7 +2,7 @@ import importlib from .constants import IR, SDK_TASK_MAP, Backend, Codebase, Task -from .device import parse_cuda_device_id, parse_device_id +from .device import parse_cuda_device_id, parse_device_id, parse_device_type from .env import get_backend_version, get_codebase_version, get_library_version from .utils import get_file_path, get_root_logger, target_wrapper @@ -10,7 +10,7 @@ 'SDK_TASK_MAP', 'IR', 'Backend', 'Codebase', 'Task', 'parse_cuda_device_id', 'get_library_version', 'get_codebase_version', 'get_backend_version', 'parse_device_id', 'get_file_path', - 'get_root_logger', 'target_wrapper' + 'get_root_logger', 'target_wrapper', 'parse_device_type' ] if importlib.util.find_spec('mmcv') is not None: diff --git a/mmdeploy/utils/device.py b/mmdeploy/utils/device.py index 1b980e449c..f3346e31f0 100644 --- a/mmdeploy/utils/device.py +++ b/mmdeploy/utils/device.py @@ -41,3 +41,19 @@ def parse_cuda_device_id(device: str) -> int: match_result.group(2)[1:]) return device_id + + +def parse_device_type(device: str) -> str: + """Parse device type from a string. + + Args: + device (str): The typical style of string specifying cuda device, + e.g.: 'cuda:0', 'cpu', 'npu'. + + Returns: + str: The parsed device type such as 'cuda', 'cpu', 'npu'. + """ + device_type = device + if ':' in device: + device_type = device.split(':')[0] + return device_type diff --git a/tests/regression/mmcls.yml b/tests/regression/mmcls.yml index a6f817f484..13ac34cd19 100644 --- a/tests/regression/mmcls.yml +++ b/tests/regression/mmcls.yml @@ -33,13 +33,13 @@ globals: onnxruntime: pipeline_ort_static_fp32: &pipeline_ort_static_fp32 convert_image: *convert_image - backend_test: *default_backend_test - sdk_config: *sdk_dynamic + backend_test: False deploy_config: configs/mmcls/classification_onnxruntime_static.py pipeline_ort_dynamic_fp32: &pipeline_ort_dynamic_fp32 convert_image: *convert_image - backend_test: False + backend_test: *default_backend_test + sdk_config: *sdk_dynamic deploy_config: configs/mmcls/classification_onnxruntime_dynamic.py diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index be273469be..b572b8566b 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -380,10 +380,12 @@ class TestParseDeviceID: def test_cpu(self): device = 'cpu' assert util.parse_device_id(device) == -1 + assert util.parse_device_type(device) == 'cpu' def test_cuda(self): device = 'cuda' assert util.parse_device_id(device) == 0 + assert util.parse_device_type(device) == 'cuda' def test_cuda10(self): device = 'cuda:10' diff --git a/tools/regression_test.py b/tools/regression_test.py index b13edbd070..f9a836746e 100644 --- a/tools/regression_test.py +++ b/tools/regression_test.py @@ -154,6 +154,9 @@ def get_model_metafile_info(global_info: dict, model_info: dict, # get model metafile info metafile_path = Path(codebase_dir).joinpath(model_info.get('metafile')) + if not metafile_path.exists(): + logger.warning(f'Metafile not exists: {metafile_path}') + return [], '', '' with open(metafile_path) as f: metafile_info = yaml.load(f, Loader=yaml.FullLoader) @@ -985,6 +988,9 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path, replace_top_in_pipeline_json(backend_output_path, logger) log_path = gen_log_path(backend_output_path, 'sdk_test.log') + if backend_name == 'onnxruntime': + # sdk only support onnxruntime of cpu + device_type = 'cpu' # sdk test get_backend_fps_metric( deploy_cfg_path=str(sdk_config),