Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

correct the padding related calculation errors in SSIM #2721

Merged
merged 15 commits into from
Sep 9, 2024
Merged
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))


## [1.4.1] - 2024-08-02
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _multiscale_structural_similarity_index_measure(
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9627)
tensor(0.9628)

"""
_deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image")
Expand Down
24 changes: 12 additions & 12 deletions src/torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,15 @@ def _ssim_update(
dtype = preds.dtype
gauss_kernel_size = [int(3.5 * s + 0.5) * 2 + 1 for s in sigma]

pad_h = (gauss_kernel_size[0] - 1) // 2
pad_w = (gauss_kernel_size[1] - 1) // 2
if gaussian_kernel:
pad_h = (gauss_kernel_size[0] - 1) // 2
pad_w = (gauss_kernel_size[1] - 1) // 2
else:
pad_h = (kernel_size[0] - 1) // 2
pad_w = (kernel_size[1] - 1) // 2

if is_3d:
pad_d = (gauss_kernel_size[2] - 1) // 2
pad_d = (kernel_size[2] - 1) // 2
preds = _reflection_pad_3d(preds, pad_d, pad_w, pad_h)
target = _reflection_pad_3d(target, pad_d, pad_w, pad_h)
if gaussian_kernel:
Expand Down Expand Up @@ -164,25 +168,21 @@ def _ssim_update(

ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)

if is_3d:
ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d]
else:
ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w]

if return_contrast_sensitivity:
contrast_sensitivity = upper / lower
if is_3d:
contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d]
else:
contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w]
return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), contrast_sensitivity.reshape(

return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), contrast_sensitivity.reshape(
contrast_sensitivity.shape[0], -1
).mean(-1)

if return_full_image:
return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), ssim_idx_full_image
return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), ssim_idx_full_image

return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1)
return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1)


def _ssim_compute(
Expand Down Expand Up @@ -507,7 +507,7 @@ def multiscale_structural_similarity_index_measure(
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9627)
tensor(0.9628)

References:
[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarity
>>> target = preds * 0.75
>>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
tensor(0.9627)
tensor(0.9628)

"""

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric):
>>> target = preds * 0.75
>>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
tensor(0.9627)
tensor(0.9628)

"""

Expand Down
55 changes: 36 additions & 19 deletions tests/unittests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,16 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma, match):
structural_similarity_index_measure(pred, target, kernel_size=kernel, sigma=sigma)


def test_ssim_unequal_kernel_size():
@pytest.mark.parametrize(
("sigma", "kernel_size", "result"),
[
((0.25, 0.5), None, torch.tensor(0.20977394)),
((0.5, 0.25), None, torch.tensor(0.13884821)),
(None, (3, 5), torch.tensor(0.05032664)),
(None, (5, 3), torch.tensor(0.03472072)),
],
)
def test_ssim_unequal_kernel_size(sigma, kernel_size, result):
"""Test the case where kernel_size[0] != kernel_size[1]."""
preds = torch.tensor([
[
Expand Down Expand Up @@ -306,24 +315,16 @@ def test_ssim_unequal_kernel_size():
]
]
])
# kernel order matters
assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.25, 0.5)),
torch.tensor(0.08869550),
)
assert not torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.5, 0.25)),
torch.tensor(0.08869550),
)

assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(3, 5)),
torch.tensor(0.05131844),
)
assert not torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(5, 3)),
torch.tensor(0.05131844),
)
if sigma is not None:
assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=sigma),
result,
)
else:
assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=kernel_size),
result,
)


@pytest.mark.parametrize(
Expand All @@ -341,3 +342,19 @@ def test_full_image_output(preds, target):
assert len(out) == 2
assert out[0].numel() == 1
assert out[1].shape == preds[0].shape


def test_ssim_for_correct_padding():
"""Check that padding is correctly added and removed for SSIM.

See issue: https://github.com/Lightning-AI/torchmetrics/issues/2718

"""
preds = torch.rand([3, 3, 256, 256])
# let the edge of the image be 0
target = preds.clone()
target[:, :, 0, :] = 0
target[:, :, -1, :] = 0
target[:, :, :, 0] = 0
target[:, :, :, -1] = 0
assert structural_similarity_index_measure(preds, target) < 1.0
Loading