Skip to content

Commit

Permalink
Fix: handle zero division error in binary IoU calculation (#2726)
Browse files Browse the repository at this point in the history
* Fix: Handle zero division error in binary IoU (Jaccard index) calculation
* chlog

---------

Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit 80929b5)
  • Loading branch information
rittik9 authored and Borda committed Sep 11, 2024
1 parent 4c77453 commit f056ad6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed handling zero division error in binary IoU (Jaccard index) calculation ([#2726](https://github.com/Lightning-AI/torchmetrics/pull/2726))


- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _jaccard_index_reduce(
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
confmat = confmat.float()
if average == "binary":
return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1])
return _safe_divide(confmat[1, 1], (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]), zero_division=zero_division)

ignore_index_cond = ignore_index is not None and 0 <= ignore_index < confmat.shape[0]
multilabel = confmat.ndim == 3
Expand Down
21 changes: 21 additions & 0 deletions tests/unittests/classification/test_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MultilabelJaccardIndex,
)
from torchmetrics.functional.classification.jaccard import (
_jaccard_index_reduce,
binary_jaccard_index,
multiclass_jaccard_index,
multilabel_jaccard_index,
Expand Down Expand Up @@ -403,6 +404,26 @@ def test_corner_case():
assert torch.allclose(res, out)


def test_jaccard_index_zero_division():
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2658."""
# Test case where all pixels are background (zeros)
confmat = torch.tensor([[4, 0], [0, 0]])

# Test with zero_division=0.0
result = _jaccard_index_reduce(confmat, average="binary", zero_division=0.0)
assert result == 0.0, f"Expected 0.0, but got {result}"

# Test with zero_division=1.0
result = _jaccard_index_reduce(confmat, average="binary", zero_division=1.0)
assert result == 1.0, f"Expected 1.0, but got {result}"

# Test case with some foreground pixels
confmat = torch.tensor([[2, 1], [1, 1]])
result = _jaccard_index_reduce(confmat, average="binary", zero_division=0.0)
expected = 1 / 3
assert torch.isclose(result, torch.tensor(expected)), f"Expected {expected}, but got {result}"


@pytest.mark.parametrize(
("metric", "kwargs"),
[
Expand Down

0 comments on commit f056ad6

Please sign in to comment.