Skip to content

Commit

Permalink
[Fix] Fixed device_id in tools/test.py for the CPU. (open-mmlab#58)
Browse files Browse the repository at this point in the history
* [Fix] fix bugs for mmcls performance test (open-mmlab#269)

* fix bugs for mmcls performance test

* fix yapf

* add comments of CLASSES attribute

* Rewrote the dictionary traversal for new versions of Python.

* Fix device_id for cpu

* Rewrite parse_device_id and tests

* Added None for cpu

Co-authored-by: hanrui1sensetime <[email protected]>
  • Loading branch information
SemyonBevzuk and hanrui1sensetime authored Jan 13, 2022
1 parent bb655af commit 997d111
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
13 changes: 8 additions & 5 deletions mmdeploy/backend/openvino/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,14 @@ def __process_outputs(
name: torch.from_numpy(tensor)
for name, tensor in outputs.items()
}
for output_name in outputs.keys():
if '.' in output_name:
new_output_name = output_name.split('.')[0]
outputs[new_output_name] = outputs.pop(output_name)
return outputs
cleaned_outputs = {}
for name, value in outputs.items():
if '.' in name:
new_output_name = name.split('.')[0]
cleaned_outputs[new_output_name] = value
else:
cleaned_outputs[name] = value
return cleaned_outputs

def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
Expand Down
22 changes: 13 additions & 9 deletions mmdeploy/utils/device.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch


def parse_device_id(device: str) -> int:
"""Parse cuda device index from a string.
def parse_device_id(device: str) -> Optional[int]:
"""Parse device index from a string.
Args:
device (str): The typical style of string specifying cuda device,
e.g.: 'cuda:0'.
device (str): The typical style of string specifying device,
e.g.: 'cuda:0', 'cpu'.
Returns:
int: The parsed device id, defaults to `0`.
Optional[int]: The return value depends on the type of device.
If device is 'cuda': cuda device index, defaults to `0`.
If device is 'cpu': `-1`.
Otherwise, `None` will be returned.
"""
if device == 'cpu':
return -1
device_id = 0
if len(device) >= 6:
device_id = torch.device(device).index
return device_id
if 'cuda' in device:
return parse_cuda_device_id(device)
return None


def parse_cuda_device_id(device: str) -> int:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,30 @@ def test_can_get_axes_from_list_with_cfg(self):
assert expected_dynamic_axes == dynamic_axes


class TestParseDeviceID:

def test_cpu(self):
device = 'cpu'
assert util.parse_device_id(device) == -1

def test_cuda(self):
device = 'cuda'
assert util.parse_device_id(device) == 0

def test_cuda10(self):
device = 'cuda:10'
assert util.parse_device_id(device) == 10

def test_incorrect_cuda_device(self):
device = 'cuda_5'
with pytest.raises(RuntimeError):
util.parse_device_id(device)

def test_incorrect_device(self):
device = 'abcd:1'
assert util.parse_device_id(device) is None


def test_AdvancedEnum():
keys = [
Task.TEXT_DETECTION, Task.TEXT_RECOGNITION, Task.SEGMENTATION,
Expand Down
5 changes: 3 additions & 2 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def main():
# load the model of the backend
model = task_processor.init_backend_model(args.model)

device_id = parse_device_id(args.device)
is_device_cpu = (args.device == 'cpu')
device_id = None if is_device_cpu else parse_device_id(args.device)

model = MMDataParallel(model, device_ids=[device_id])
# The whole dataset test wrapped a MMDataParallel class outside the module.
Expand All @@ -115,7 +116,7 @@ def main():
if hasattr(model.module, 'CLASSES'):
model.CLASSES = model.module.CLASSES
if args.speed_test:
with_sync = device_id >= 0
with_sync = not is_device_cpu
output_file = sys.stdout
if args.log2file:
output_file = args.log2file
Expand Down

0 comments on commit 997d111

Please sign in to comment.