From 997d111a6f4ca9624ab3b36717748e6ce002037d Mon Sep 17 00:00:00 2001 From: Semyon Bevzyuk Date: Thu, 13 Jan 2022 10:37:23 +0300 Subject: [PATCH] [Fix] Fixed device_id in tools/test.py for the CPU. (#58) * [Fix] fix bugs for mmcls performance test (#269) * fix bugs for mmcls performance test * fix yapf * add comments of CLASSES attribute * Rewrote the dictionary traversal for new versions of Python. * Fix device_id for cpu * Rewrite parse_device_id and tests * Added None for cpu Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> --- mmdeploy/backend/openvino/wrapper.py | 13 ++++++++----- mmdeploy/utils/device.py | 22 +++++++++++++--------- tests/test_utils/test_util.py | 24 ++++++++++++++++++++++++ tools/test.py | 5 +++-- 4 files changed, 48 insertions(+), 16 deletions(-) diff --git a/mmdeploy/backend/openvino/wrapper.py b/mmdeploy/backend/openvino/wrapper.py index 32f1ab2f74..589906f345 100644 --- a/mmdeploy/backend/openvino/wrapper.py +++ b/mmdeploy/backend/openvino/wrapper.py @@ -107,11 +107,14 @@ def __process_outputs( name: torch.from_numpy(tensor) for name, tensor in outputs.items() } - for output_name in outputs.keys(): - if '.' in output_name: - new_output_name = output_name.split('.')[0] - outputs[new_output_name] = outputs.pop(output_name) - return outputs + cleaned_outputs = {} + for name, value in outputs.items(): + if '.' in name: + new_output_name = name.split('.')[0] + cleaned_outputs[new_output_name] = value + else: + cleaned_outputs[name] = value + return cleaned_outputs def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: diff --git a/mmdeploy/utils/device.py b/mmdeploy/utils/device.py index a972421c33..1925e6523c 100644 --- a/mmdeploy/utils/device.py +++ b/mmdeploy/utils/device.py @@ -1,23 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch -def parse_device_id(device: str) -> int: - """Parse cuda device index from a string. +def parse_device_id(device: str) -> Optional[int]: + """Parse device index from a string. Args: - device (str): The typical style of string specifying cuda device, - e.g.: 'cuda:0'. + device (str): The typical style of string specifying device, + e.g.: 'cuda:0', 'cpu'. Returns: - int: The parsed device id, defaults to `0`. + Optional[int]: The return value depends on the type of device. + If device is 'cuda': cuda device index, defaults to `0`. + If device is 'cpu': `-1`. + Otherwise, `None` will be returned. """ if device == 'cpu': return -1 - device_id = 0 - if len(device) >= 6: - device_id = torch.device(device).index - return device_id + if 'cuda' in device: + return parse_cuda_device_id(device) + return None def parse_cuda_device_id(device: str) -> int: diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index 5eec87b6e1..e9f5ad33c2 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -369,6 +369,30 @@ def test_can_get_axes_from_list_with_cfg(self): assert expected_dynamic_axes == dynamic_axes +class TestParseDeviceID: + + def test_cpu(self): + device = 'cpu' + assert util.parse_device_id(device) == -1 + + def test_cuda(self): + device = 'cuda' + assert util.parse_device_id(device) == 0 + + def test_cuda10(self): + device = 'cuda:10' + assert util.parse_device_id(device) == 10 + + def test_incorrect_cuda_device(self): + device = 'cuda_5' + with pytest.raises(RuntimeError): + util.parse_device_id(device) + + def test_incorrect_device(self): + device = 'abcd:1' + assert util.parse_device_id(device) is None + + def test_AdvancedEnum(): keys = [ Task.TEXT_DETECTION, Task.TEXT_RECOGNITION, Task.SEGMENTATION, diff --git a/tools/test.py b/tools/test.py index 88514f967b..c191585afb 100644 --- a/tools/test.py +++ b/tools/test.py @@ -105,7 +105,8 @@ def main(): # load the model of the backend model = task_processor.init_backend_model(args.model) - device_id = parse_device_id(args.device) + is_device_cpu = (args.device == 'cpu') + device_id = None if is_device_cpu else parse_device_id(args.device) model = MMDataParallel(model, device_ids=[device_id]) # The whole dataset test wrapped a MMDataParallel class outside the module. @@ -115,7 +116,7 @@ def main(): if hasattr(model.module, 'CLASSES'): model.CLASSES = model.module.CLASSES if args.speed_test: - with_sync = device_id >= 0 + with_sync = not is_device_cpu output_file = sys.stdout if args.log2file: output_file = args.log2file