diff --git a/CHANGELOG.md b/CHANGELOG.md index f432c7afa26..7b4033b60d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,7 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed corner case in `Iou` metric for single empty prediction tensors ([#2780](https://github.com/Lightning-AI/torchmetrics/pull/2780)) ## [1.4.3] - 2024-10-10 diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index 7b4a60200ca..e809b27ce6a 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -211,7 +211,8 @@ def compute(self) -> dict: """Computes IoU based on inputs passed in to ``update`` previously.""" score = torch.cat([mat[mat != self._invalid_val] for mat in self.iou_matrix], 0).mean() results: Dict[str, Tensor] = {f"{self._iou_type}": score} - + if torch.isnan(score): # if no valid boxes are found + results[f"{self._iou_type}"] = torch.tensor(0.0, device=score.device) if self.class_metrics: gt_labels = dim_zero_cat(self.groundtruth_labels) classes = gt_labels.unique().tolist() if len(gt_labels) > 0 else [] diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index c42a6763ba9..a028a014ea1 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -63,6 +63,8 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla base_name = {tv_ciou: "ciou", tv_diou: "diou", tv_giou: "giou", tv_iou: "iou"}[base_fn] result = {f"{base_name}": score.cpu()} + if torch.isnan(score): + result.update({f"{base_name}": torch.tensor(0.0)}) if class_metrics: for cl in torch.cat(classes).unique().tolist(): class_score, numel = 0, 0 @@ -71,7 +73,6 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla class_score += masked_s[masked_s != -1].sum() numel += masked_s[masked_s != -1].numel() result.update({f"{base_name}/cl_{cl}": class_score.cpu() / numel}) - return result @@ -328,6 +329,32 @@ def test_functional_error_on_wrong_input_shape(self, class_metric, functional_me with pytest.raises(ValueError, match="Expected target to be of shape.*"): functional_metric(torch.randn(25, 4), torch.randn(25, 25)) + def test_corner_case_only_one_empty_prediction(self, class_metric, functional_metric, reference_metric): + """Test that the metric does not crash when there is only one empty prediction.""" + target = [ + { + "boxes": torch.tensor([ + [8.0000, 70.0000, 76.0000, 110.0000], + [247.0000, 131.0000, 315.0000, 175.0000], + [361.0000, 177.0000, 395.0000, 203.0000], + ]), + "labels": torch.tensor([0, 0, 0]), + } + ] + preds = [ + { + "boxes": torch.empty(size=(0, 4)), + "labels": torch.tensor([], dtype=torch.int64), + "scores": torch.tensor([]), + } + ] + + metric = class_metric() + metric.update(preds, target) + res = metric.compute() + for val in res.values(): + assert val == torch.tensor(0.0) + def test_corner_case(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921."""