Skip to content

Commit

Permalink
[Fix] fix bugs for mmcls performance test (open-mmlab#269)
Browse files Browse the repository at this point in the history
* fix bugs for mmcls performance test

* fix yapf

* add comments of CLASSES attribute
  • Loading branch information
hanrui1sensetime authored Dec 10, 2021
1 parent 54885e5 commit 0f90a0a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
5 changes: 4 additions & 1 deletion mmdeploy/codebase/mmcls/deploy/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def forward(self, img: List[torch.Tensor], *args, **kwargs) -> list:
list: A list contains predictions.
"""

input_img = img[0].contiguous()
if isinstance(img, list):
input_img = img[0].contiguous()
else:
input_img = img.contiguous()
outputs = self.forward_test(input_img, *args, **kwargs)

return list(outputs)
Expand Down
5 changes: 4 additions & 1 deletion mmdeploy/codebase/mmcls/deploy/mmclassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def single_gpu_test(model: torch.nn.Module,
data_loader: DataLoader,
show: bool = False,
out_dir: Optional[str] = None,
win_name: str = '',
**kwargs) -> List:
"""Run test with single gpu.
Expand All @@ -132,10 +133,12 @@ def single_gpu_test(model: torch.nn.Module,
show (bool): Specifying whether to show plotted results.
Default: False.
out_dir (str): A directory to save results, Default: None.
win_name (str): The name of windows, Default: ''.
Returns:
list: The prediction results.
"""
from mmcls.apis import single_gpu_test
outputs = single_gpu_test(model, data_loader, show, out_dir, **kwargs)
outputs = single_gpu_test(
model, data_loader, show, out_dir, win_name=win_name, **kwargs)
return outputs
6 changes: 6 additions & 0 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def main():
device_id = parse_device_id(args.device)

model = MMDataParallel(model, device_ids=[0])
# The whole dataset test wrapped a MMDataParallel class outside the module.
# As mmcls.apis.test.py single_gpu_test defined, the MMDataParallel needs
# a 'CLASSES' attribute. So we ensure the MMDataParallel class has the same
# CLASSES attribute as the inside module.
if hasattr(model.module, 'CLASSES'):
model.CLASSES = model.module.CLASSES
if args.speed_test:
with_sync = device_id == 0
output_file = sys.stdout
Expand Down

0 comments on commit 0f90a0a

Please sign in to comment.