From bf9a7a6dd62dae1352de10660829ef55118bfcf5 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Sun, 24 Apr 2022 10:36:30 +0530 Subject: [PATCH 1/6] Explicit cast to torch.uint8 for bool types on CUDA --- tests/detection/test_map.py | 27 ++++++++++++++++++++++++--- torchmetrics/detection/mean_ap.py | 7 ++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 01524eae167..035e7340082 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -58,7 +58,7 @@ scores=torch.Tensor([0.699]), labels=torch.IntTensor([5]), ), # coco image id 133 - ], + ], ], target=[ [ @@ -95,7 +95,7 @@ boxes=torch.Tensor([[13.99, 2.87, 640.00, 421.52]]), labels=torch.IntTensor([5]), ), # coco image id 133 - ], + ], ], ) @@ -133,6 +133,27 @@ ], ) +_inputs3 = Input( + preds = [ + [ + dict( + boxes = torch.tensor([]), + scores = torch.tensor([]), + labels = torch.tensor([]) + ), + ], + ], + target=[ + [ + dict( + boxes=torch.tensor([[1., 2., 3., 4.]]), + scores=torch.tensor([0.8]), + labels=torch.tensor([1]), + ), + ], + ], +) + def _compare_fn(preds, target) -> dict: """Comparison function for map implementation. @@ -283,7 +304,7 @@ def _move_to_gpu(input): @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") @pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability") -@pytest.mark.parametrize("inputs", [_inputs, _inputs2]) +@pytest.mark.parametrize("inputs", [_inputs, _inputs2, _inputs3]) def test_map_gpu(inputs): """Test predictions on single gpu.""" metric = MeanAveragePrecision() diff --git a/torchmetrics/detection/mean_ap.py b/torchmetrics/detection/mean_ap.py index c4ebec228d8..96b5bad79af 100644 --- a/torchmetrics/detection/mean_ap.py +++ b/torchmetrics/detection/mean_ap.py @@ -694,7 +694,12 @@ def __calculate_recall_precision_scores( # different sorting method generates slightly different results. # mergesort is used to be consistent as Matlab implementation. - inds = torch.argsort(det_scores, descending=True) + # Sort in PyTorch does not support bool types on CUDA (yet, 1.11.0) + if det_scores.is_cuda and det_scores.dtype is torch.bool: + # Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort + inds = torch.argsort(det_scores.to(torch.uint8), descending=True) + else: + inds = torch.argsort(det_scores, descending=True) det_scores_sorted = det_scores[inds] det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] From 2f202a5a0fbe5b17ffd1697bff0def50b268b078 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Sun, 24 Apr 2022 10:37:07 +0530 Subject: [PATCH 2/6] Fix whitespaces --- tests/detection/test_map.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 035e7340082..6fa1b56f1da 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -58,7 +58,7 @@ scores=torch.Tensor([0.699]), labels=torch.IntTensor([5]), ), # coco image id 133 - ], + ], ], target=[ [ @@ -95,7 +95,7 @@ boxes=torch.Tensor([[13.99, 2.87, 640.00, 421.52]]), labels=torch.IntTensor([5]), ), # coco image id 133 - ], + ], ], ) From 235ad2b5658aa51cfa2983b83be4c914f0dd393f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 24 Apr 2022 05:09:09 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/detection/test_map.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 6fa1b56f1da..9fc3ebfcd7f 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -134,19 +134,15 @@ ) _inputs3 = Input( - preds = [ + preds=[ [ - dict( - boxes = torch.tensor([]), - scores = torch.tensor([]), - labels = torch.tensor([]) - ), + dict(boxes=torch.tensor([]), scores=torch.tensor([]), labels=torch.tensor([])), ], ], target=[ [ dict( - boxes=torch.tensor([[1., 2., 3., 4.]]), + boxes=torch.tensor([[1.0, 2.0, 3.0, 4.0]]), scores=torch.tensor([0.8]), labels=torch.tensor([1]), ), From ab034df96480086cea3edd11f67417d3e2a4b3ea Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Sun, 24 Apr 2022 10:39:52 +0530 Subject: [PATCH 4/6] Update tests/detection/test_map.py --- tests/detection/test_map.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 9fc3ebfcd7f..388d77f5a46 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -133,6 +133,8 @@ ], ) +# Test empty preds case, to ensure bool inputs are properly casted to uint8 +# From https://github.com/PyTorchLightning/metrics/issues/981 _inputs3 = Input( preds=[ [ From fed0835624887602f876d052ca2c2f90a6ef99b9 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Sun, 24 Apr 2022 10:41:33 +0530 Subject: [PATCH 5/6] Add changelog entry --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ab7a66e997..c1162c3155d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed "Sort currently does not support bool dtype on CUDA" error in MAP for empty preds ([#983](https://github.com/PyTorchLightning/metrics/pull/983)) + + - Fixed `BinnedPrecisionRecallCurve` when `thresholds` argument is not provided ([#968](https://github.com/PyTorchLightning/metrics/pull/968)) From 98c517a95fe511c799bacf43a68aceeb076c1c64 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 24 Apr 2022 15:59:17 +0200 Subject: [PATCH 6/6] simple --- torchmetrics/detection/mean_ap.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchmetrics/detection/mean_ap.py b/torchmetrics/detection/mean_ap.py index 96b5bad79af..1ef6f5ae69c 100644 --- a/torchmetrics/detection/mean_ap.py +++ b/torchmetrics/detection/mean_ap.py @@ -695,11 +695,9 @@ def __calculate_recall_precision_scores( # different sorting method generates slightly different results. # mergesort is used to be consistent as Matlab implementation. # Sort in PyTorch does not support bool types on CUDA (yet, 1.11.0) - if det_scores.is_cuda and det_scores.dtype is torch.bool: - # Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort - inds = torch.argsort(det_scores.to(torch.uint8), descending=True) - else: - inds = torch.argsort(det_scores, descending=True) + dtype = torch.uint8 if det_scores.is_cuda and det_scores.dtype is torch.bool else det_scores.dtype + # Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort + inds = torch.argsort(det_scores.to(dtype), descending=True) det_scores_sorted = det_scores[inds] det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]