From 5a4ae0f92dc041b01ca8360cda4542ab031c8100 Mon Sep 17 00:00:00 2001 From: Hu Di <476658825@qq.com> Date: Wed, 16 Feb 2022 17:02:30 +0800 Subject: [PATCH 1/7] some change to mmcls/models/losses/utils.py:convert_to_one_hot() --- mmcls/models/losses/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mmcls/models/losses/utils.py b/mmcls/models/losses/utils.py index a33da4dc9a5..40e12bf9b85 100644 --- a/mmcls/models/losses/utils.py +++ b/mmcls/models/losses/utils.py @@ -114,8 +114,5 @@ def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: """ assert (torch.max(targets).item() < classes), 'Class Index must be less than number of classes' - one_hot_targets = torch.zeros((targets.shape[0], classes), - dtype=torch.long, - device=targets.device) - one_hot_targets.scatter_(1, targets.long(), 1) + one_hot_targets = torch.nn.functional.one_hot(targets.long().squeeze(), num_classes=classes) return one_hot_targets From 1b25a9eca7b5ced5e510057192c3c2ca1443cb81 Mon Sep 17 00:00:00 2001 From: Hu Di <476658825@qq.com> Date: Wed, 16 Feb 2022 17:50:39 +0800 Subject: [PATCH 2/7] fixed problem: line too long --- mmcls/models/losses/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmcls/models/losses/utils.py b/mmcls/models/losses/utils.py index 40e12bf9b85..e3689c25ab1 100644 --- a/mmcls/models/losses/utils.py +++ b/mmcls/models/losses/utils.py @@ -114,5 +114,6 @@ def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: """ assert (torch.max(targets).item() < classes), 'Class Index must be less than number of classes' - one_hot_targets = torch.nn.functional.one_hot(targets.long().squeeze(), num_classes=classes) + one_hot_targets = F.one_hot(targets.long().squeeze(), + num_classes=classes) return one_hot_targets From 9e5d168b7d5e9ca5cb2117b8782e7f92e6453e6b Mon Sep 17 00:00:00 2001 From: Hu Di <476658825@qq.com> Date: Wed, 16 Feb 2022 17:55:43 +0800 Subject: [PATCH 3/7] fixed wrong output shape --- mmcls/models/losses/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcls/models/losses/utils.py b/mmcls/models/losses/utils.py index e3689c25ab1..54a6470f797 100644 --- a/mmcls/models/losses/utils.py +++ b/mmcls/models/losses/utils.py @@ -114,6 +114,6 @@ def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: """ assert (torch.max(targets).item() < classes), 'Class Index must be less than number of classes' - one_hot_targets = F.one_hot(targets.long().squeeze(), + one_hot_targets = F.one_hot(targets.long().squeeze(-1), num_classes=classes) return one_hot_targets From fc4496b24dfad0f99f19251557a1ae7efd75510a Mon Sep 17 00:00:00 2001 From: Hu Di <476658825@qq.com> Date: Thu, 17 Feb 2022 11:09:04 +0800 Subject: [PATCH 4/7] fixed lint PEP8 E128 --- mmcls/models/losses/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mmcls/models/losses/utils.py b/mmcls/models/losses/utils.py index 54a6470f797..a1090e6219c 100644 --- a/mmcls/models/losses/utils.py +++ b/mmcls/models/losses/utils.py @@ -114,6 +114,8 @@ def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: """ assert (torch.max(targets).item() < classes), 'Class Index must be less than number of classes' - one_hot_targets = F.one_hot(targets.long().squeeze(-1), - num_classes=classes) + one_hot_targets = F.one_hot( + targets.long().squeeze(-1), + num_classes=classes + ) return one_hot_targets From 7c9aa83569ed4e74f83e9c80890f022028a14878 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <1105212286@qq.com> Date: Fri, 18 Feb 2022 18:42:12 +0800 Subject: [PATCH 5/7] fix lint --- mmcls/models/losses/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mmcls/models/losses/utils.py b/mmcls/models/losses/utils.py index a1090e6219c..a65b68a6590 100644 --- a/mmcls/models/losses/utils.py +++ b/mmcls/models/losses/utils.py @@ -115,7 +115,5 @@ def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: assert (torch.max(targets).item() < classes), 'Class Index must be less than number of classes' one_hot_targets = F.one_hot( - targets.long().squeeze(-1), - num_classes=classes - ) + targets.long().squeeze(-1), num_classes=classes) return one_hot_targets From 0b6e9b132688929aa31d31303bd5369386c2c785 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <1105212286@qq.com> Date: Mon, 21 Feb 2022 14:53:23 +0800 Subject: [PATCH 6/7] fix lint --- tools/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/train.py b/tools/train.py index bceb26de2bf..b40a1df0eb5 100644 --- a/tools/train.py +++ b/tools/train.py @@ -4,6 +4,7 @@ import os import os.path as osp import time +import warnings import mmcv import torch From 68f1542068d3af4db932c97e6a728181432fff0c Mon Sep 17 00:00:00 2001 From: Ezra-Yu <1105212286@qq.com> Date: Thu, 24 Feb 2022 18:10:17 +0800 Subject: [PATCH 7/7] add unit tests --- tests/test_metrics/test_utils.py | 49 ++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tests/test_metrics/test_utils.py diff --git a/tests/test_metrics/test_utils.py b/tests/test_metrics/test_utils.py new file mode 100644 index 00000000000..962a1f8d764 --- /dev/null +++ b/tests/test_metrics/test_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmcls.models.losses.utils import convert_to_one_hot + + +def ori_convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: + assert (torch.max(targets).item() < + classes), 'Class Index must be less than number of classes' + one_hot_targets = torch.zeros((targets.shape[0], classes), + dtype=torch.long, + device=targets.device) + one_hot_targets.scatter_(1, targets.long(), 1) + return one_hot_targets + + +def test_convert_to_one_hot(): + # label should smaller than classes + targets = torch.tensor([1, 2, 3, 8, 5]) + classes = 5 + with pytest.raises(AssertionError): + _ = convert_to_one_hot(targets, classes) + + # test with original impl + classes = 10 + targets = torch.randint(high=classes, size=(10, 1)) + ori_one_hot_targets = torch.zeros((targets.shape[0], classes), + dtype=torch.long, + device=targets.device) + ori_one_hot_targets.scatter_(1, targets.long(), 1) + one_hot_targets = convert_to_one_hot(targets, classes) + assert torch.equal(ori_one_hot_targets, one_hot_targets) + + +# test cuda version +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_convert_to_one_hot_cuda(): + # test with original impl + classes = 10 + targets = torch.randint(high=classes, size=(10, 1)).cuda() + ori_one_hot_targets = torch.zeros((targets.shape[0], classes), + dtype=torch.long, + device=targets.device) + ori_one_hot_targets.scatter_(1, targets.long(), 1) + one_hot_targets = convert_to_one_hot(targets, classes) + assert torch.equal(ori_one_hot_targets, one_hot_targets) + assert ori_one_hot_targets.device == one_hot_targets.device