Skip to content

Commit

Permalink
Fix overflow for PSNR metric when used with uint8 input (#2788)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Oct 18, 2024
1 parent 6377aa5 commit 0fd1f96
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed corner case in `Iou` metric for single empty prediction tensors ([#2780](https://github.com/Lightning-AI/torchmetrics/pull/2780))


- Fixed `PSNR` calculation for integer type input images ([#2788](https://github.com/Lightning-AI/torchmetrics/pull/2788))


## [1.4.3] - 2024-10-10

### Fixed
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def _psnr_update(
Default is None meaning scores will be reduced across all dimensions.
"""
if not preds.is_floating_point():
preds = preds.to(torch.float32)
if not target.is_floating_point():
target = target.to(torch.float32)

if dim is None:
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
num_obs = tensor(target.numel(), device=target.device)
Expand Down
13 changes: 13 additions & 0 deletions tests/unittests/image/test_psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,16 @@ def test_missing_data_range():

with pytest.raises(ValueError, match="The `data_range` must be given when `dim` is not None."):
peak_signal_noise_ratio(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)


def test_psnr_uint_dtype():
"""Check that automatic casting to float is done for uint dtype.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2787
"""
preds = torch.randint(0, 255, _input_size, dtype=torch.uint8)
target = torch.randint(0, 255, _input_size, dtype=torch.uint8)
psnr = peak_signal_noise_ratio(preds, target)
prnr2 = peak_signal_noise_ratio(preds.float(), target.float())
assert torch.allclose(psnr, prnr2)

0 comments on commit 0fd1f96

Please sign in to comment.