Skip to content

Commit

Permalink
fix regression test (#958)
Browse files Browse the repository at this point in the history
* fix reg

* set sdk wrapper device id

* resolve comment
  • Loading branch information
RunningLeon authored Sep 2, 2022
1 parent 5874f10 commit cbedf1c
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 8 deletions.
9 changes: 6 additions & 3 deletions mmdeploy/backend/sdk/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmdeploy_python as c_api

from mmdeploy.utils import Backend
from mmdeploy.utils import Backend, parse_device_id, parse_device_type
from mmdeploy.utils.timer import TimeCounter
from ..base import BACKEND_WRAPPER, BaseWrapper

Expand All @@ -12,8 +12,11 @@ class SDKWrapper(BaseWrapper):
def __init__(self, model_file, task_name, device):
super().__init__([])
creator = getattr(c_api, task_name)
# TODO: get device id somewhere
self.handle = creator(model_file, device, 0)
device_id = parse_device_id(device)
device_type = parse_device_type(device)
# sdk does not support -1 device id
device_id = 0 if device_id < 0 else device_id
self.handle = creator(model_file, device_type, device_id)

@TimeCounter.count_time(Backend.SDK.value)
def invoke(self, imgs):
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import importlib

from .constants import IR, SDK_TASK_MAP, Backend, Codebase, Task
from .device import parse_cuda_device_id, parse_device_id
from .device import parse_cuda_device_id, parse_device_id, parse_device_type
from .env import get_backend_version, get_codebase_version, get_library_version
from .utils import get_file_path, get_root_logger, target_wrapper

__all__ = [
'SDK_TASK_MAP', 'IR', 'Backend', 'Codebase', 'Task',
'parse_cuda_device_id', 'get_library_version', 'get_codebase_version',
'get_backend_version', 'parse_device_id', 'get_file_path',
'get_root_logger', 'target_wrapper'
'get_root_logger', 'target_wrapper', 'parse_device_type'
]

if importlib.util.find_spec('mmcv') is not None:
Expand Down
16 changes: 16 additions & 0 deletions mmdeploy/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,19 @@ def parse_cuda_device_id(device: str) -> int:
match_result.group(2)[1:])

return device_id


def parse_device_type(device: str) -> str:
"""Parse device type from a string.
Args:
device (str): The typical style of string specifying cuda device,
e.g.: 'cuda:0', 'cpu', 'npu'.
Returns:
str: The parsed device type such as 'cuda', 'cpu', 'npu'.
"""
device_type = device
if ':' in device:
device_type = device.split(':')[0]
return device_type
6 changes: 3 additions & 3 deletions tests/regression/mmcls.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ globals:
onnxruntime:
pipeline_ort_static_fp32: &pipeline_ort_static_fp32
convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic
backend_test: False
deploy_config: configs/mmcls/classification_onnxruntime_static.py

pipeline_ort_dynamic_fp32: &pipeline_ort_dynamic_fp32
convert_image: *convert_image
backend_test: False
backend_test: *default_backend_test
sdk_config: *sdk_dynamic
deploy_config: configs/mmcls/classification_onnxruntime_dynamic.py


Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,12 @@ class TestParseDeviceID:
def test_cpu(self):
device = 'cpu'
assert util.parse_device_id(device) == -1
assert util.parse_device_type(device) == 'cpu'

def test_cuda(self):
device = 'cuda'
assert util.parse_device_id(device) == 0
assert util.parse_device_type(device) == 'cuda'

def test_cuda10(self):
device = 'cuda:10'
Expand Down
6 changes: 6 additions & 0 deletions tools/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def get_model_metafile_info(global_info: dict, model_info: dict,

# get model metafile info
metafile_path = Path(codebase_dir).joinpath(model_info.get('metafile'))
if not metafile_path.exists():
logger.warning(f'Metafile not exists: {metafile_path}')
return [], '', ''
with open(metafile_path) as f:
metafile_info = yaml.load(f, Loader=yaml.FullLoader)

Expand Down Expand Up @@ -985,6 +988,9 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
replace_top_in_pipeline_json(backend_output_path, logger)

log_path = gen_log_path(backend_output_path, 'sdk_test.log')
if backend_name == 'onnxruntime':
# sdk only support onnxruntime of cpu
device_type = 'cpu'
# sdk test
get_backend_fps_metric(
deploy_cfg_path=str(sdk_config),
Expand Down

0 comments on commit cbedf1c

Please sign in to comment.