From 05df3113d40427db18490c68116aa3341360d746 Mon Sep 17 00:00:00 2001 From: Shaochang Tan <478710209@qq.com> Date: Wed, 4 Sep 2024 14:48:47 +0200 Subject: [PATCH 01/12] correct the padding related calculation errors in SSIM --- src/torchmetrics/functional/image/ssim.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index d89dd828a3e..5e510a97447 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -124,11 +124,15 @@ def _ssim_update( dtype = preds.dtype gauss_kernel_size = [int(3.5 * s + 0.5) * 2 + 1 for s in sigma] - pad_h = (gauss_kernel_size[0] - 1) // 2 - pad_w = (gauss_kernel_size[1] - 1) // 2 + if gaussian_kernel: + pad_h = (gauss_kernel_size[0] - 1) // 2 + pad_w = (gauss_kernel_size[1] - 1) // 2 + else: + pad_h = (kernel_size[0] - 1) // 2 + pad_w = (kernel_size[1] - 1) // 2 if is_3d: - pad_d = (gauss_kernel_size[2] - 1) // 2 + pad_d = (kernel_size[2] - 1) // 2 preds = _reflection_pad_3d(preds, pad_d, pad_w, pad_h) target = _reflection_pad_3d(target, pad_d, pad_w, pad_h) if gaussian_kernel: @@ -164,10 +168,6 @@ def _ssim_update( ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) - if is_3d: - ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d] - else: - ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w] if return_contrast_sensitivity: contrast_sensitivity = upper / lower @@ -175,14 +175,15 @@ def _ssim_update( contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d] else: contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w] - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), contrast_sensitivity.reshape( + + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), contrast_sensitivity.reshape( contrast_sensitivity.shape[0], -1 ).mean(-1) if return_full_image: - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), ssim_idx_full_image + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), ssim_idx_full_image - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1) + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1) def _ssim_compute( From bac2527324ee6ed139874606358b26b66a8041ca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:13:35 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/ssim.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 5e510a97447..86725b0b659 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -168,7 +168,6 @@ def _ssim_update( ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) - if return_contrast_sensitivity: contrast_sensitivity = upper / lower if is_3d: From dcfb19bf7b9b79896934ce8c4a621d74f59ffa27 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 5 Sep 2024 11:52:17 +0200 Subject: [PATCH 03/12] fix doctests --- src/torchmetrics/functional/image/_deprecated.py | 2 +- src/torchmetrics/functional/image/ssim.py | 2 +- src/torchmetrics/image/_deprecated.py | 2 +- src/torchmetrics/image/ssim.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index fafc09cabaa..892d07afaa6 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -168,7 +168,7 @@ def _multiscale_structural_similarity_index_measure( >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0) - tensor(0.9627) + tensor(0.9628) """ _deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image") diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 86725b0b659..c61ef833fe3 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -507,7 +507,7 @@ def multiscale_structural_similarity_index_measure( >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0) - tensor(0.9627) + tensor(0.9628) References: [1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 86fb04c7f1c..8b382b89cf7 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -44,7 +44,7 @@ class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarity >>> target = preds * 0.75 >>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> ms_ssim(preds, target) - tensor(0.9627) + tensor(0.9628) """ diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index f0c03a163d6..648f9c26029 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -274,7 +274,7 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): >>> target = preds * 0.75 >>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> ms_ssim(preds, target) - tensor(0.9627) + tensor(0.9628) """ From c17493507690086b3ea9fadf70bf4673054fa10c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 5 Sep 2024 11:53:26 +0200 Subject: [PATCH 04/12] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b0b0022476..1bd51af0290 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721)) ## [1.4.1] - 2024-08-02 From 7b0888d7099d7a1e5bc76d82593e1ae841ab7c89 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 5 Sep 2024 12:06:29 +0200 Subject: [PATCH 05/12] change test --- tests/unittests/image/test_ssim.py | 39 +++++++++++++++--------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 6b464f2a97b..3c336938449 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -278,7 +278,16 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma, match): structural_similarity_index_measure(pred, target, kernel_size=kernel, sigma=sigma) -def test_ssim_unequal_kernel_size(): +@pytest.mark.parametrize( + ("sigma", "kernel_size", "result"), + [ + ((0.25, 0.5), None, torch.tensor(0.08869550)), + ((0.5, 0.25), None, torch.tensor(0.08869550)), + (None, (3, 5), torch.tensor(0.05131844)), + (None, (5, 3), torch.tensor(0.05131844)), + ] +) +def test_ssim_unequal_kernel_size(sigma, kernel_size, result): """Test the case where kernel_size[0] != kernel_size[1].""" preds = torch.tensor([ [ @@ -306,24 +315,16 @@ def test_ssim_unequal_kernel_size(): ] ] ]) - # kernel order matters - assert torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.25, 0.5)), - torch.tensor(0.08869550), - ) - assert not torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.5, 0.25)), - torch.tensor(0.08869550), - ) - - assert torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(3, 5)), - torch.tensor(0.05131844), - ) - assert not torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(5, 3)), - torch.tensor(0.05131844), - ) + if sigma is not None: + assert torch.isclose( + structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=sigma), + result, + ) + else: + assert torch.isclose( + structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=kernel_size), + result, + ) @pytest.mark.parametrize( From 3b7846005a7d2e0dd55d702ef1f7b55c9db23344 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 10:06:55 +0000 Subject: [PATCH 06/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/image/test_ssim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 3c336938449..853dbde5991 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -285,7 +285,7 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma, match): ((0.5, 0.25), None, torch.tensor(0.08869550)), (None, (3, 5), torch.tensor(0.05131844)), (None, (5, 3), torch.tensor(0.05131844)), - ] + ], ) def test_ssim_unequal_kernel_size(sigma, kernel_size, result): """Test the case where kernel_size[0] != kernel_size[1].""" From 56797ee6a7cf08774b33cb71c06b29c3ca7fd50d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 5 Sep 2024 12:09:04 +0200 Subject: [PATCH 07/12] add test --- tests/unittests/image/test_ssim.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 853dbde5991..ee850988358 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -342,3 +342,18 @@ def test_full_image_output(preds, target): assert len(out) == 2 assert out[0].numel() == 1 assert out[1].shape == preds[0].shape + + +def test_ssim_for_correct_padding + """Check that padding is correctly added and removed for SSIM. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2718 + """ + preds = torch.rand([3, 3, 256, 256]) + # let the edge of the image be 0 + target = preds.clone() + target[:, :, 0, :] = 0 + target[:, :, -1, :] = 0 + target[:, :, :, 0] = 0 + target[:, :, :, -1] = 0 + assert structural_similarity_index_measure(preds, target) < 1.0 \ No newline at end of file From 23767de4f92cdeb5028ef26d0becfd407c744da9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 10:09:42 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/image/test_ssim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index ee850988358..b338de768f0 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -348,6 +348,7 @@ def test_ssim_for_correct_padding """Check that padding is correctly added and removed for SSIM. See issue: https://github.com/Lightning-AI/torchmetrics/issues/2718 + """ preds = torch.rand([3, 3, 256, 256]) # let the edge of the image be 0 @@ -356,4 +357,4 @@ def test_ssim_for_correct_padding target[:, :, -1, :] = 0 target[:, :, :, 0] = 0 target[:, :, :, -1] = 0 - assert structural_similarity_index_measure(preds, target) < 1.0 \ No newline at end of file + assert structural_similarity_index_measure(preds, target) < 1.0 From 2fa7d6d8e770e7b3098ec90864fc50f60a430df4 Mon Sep 17 00:00:00 2001 From: TanShaochang <30321432+petertheprocess@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:45:04 +0200 Subject: [PATCH 09/12] fix syntax in test_ssim.py --- tests/unittests/image/test_ssim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index b338de768f0..2141e347d69 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -344,7 +344,7 @@ def test_full_image_output(preds, target): assert out[1].shape == preds[0].shape -def test_ssim_for_correct_padding +def test_ssim_for_correct_padding: """Check that padding is correctly added and removed for SSIM. See issue: https://github.com/Lightning-AI/torchmetrics/issues/2718 From 23ede4870dd2edac6f4989d9cd047952664e0969 Mon Sep 17 00:00:00 2001 From: TanShaochang <30321432+petertheprocess@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:48:56 +0200 Subject: [PATCH 10/12] fix syntax in test_ssim.py --- tests/unittests/image/test_ssim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 2141e347d69..4a280d958e3 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -344,7 +344,7 @@ def test_full_image_output(preds, target): assert out[1].shape == preds[0].shape -def test_ssim_for_correct_padding: +def test_ssim_for_correct_padding(): """Check that padding is correctly added and removed for SSIM. See issue: https://github.com/Lightning-AI/torchmetrics/issues/2718 From 622203bd4b77f03e7be2bd303315b6d57dc0e738 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 7 Sep 2024 10:53:08 +0200 Subject: [PATCH 11/12] fix unittests --- tests/unittests/image/test_ssim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 4a280d958e3..7d16581bade 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -281,10 +281,10 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma, match): @pytest.mark.parametrize( ("sigma", "kernel_size", "result"), [ - ((0.25, 0.5), None, torch.tensor(0.08869550)), - ((0.5, 0.25), None, torch.tensor(0.08869550)), - (None, (3, 5), torch.tensor(0.05131844)), - (None, (5, 3), torch.tensor(0.05131844)), + ((0.25, 0.5), None, torch.tensor(0.20977394)), + ((0.5, 0.25), None, torch.tensor(0.13884821)), + (None, (3, 5), torch.tensor(0.05032664)), + (None, (5, 3), torch.tensor(0.03472072)), ], ) def test_ssim_unequal_kernel_size(sigma, kernel_size, result): From dae33e7ec0af90b65315467b04c866754b27e5ac Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 7 Sep 2024 11:22:31 +0200 Subject: [PATCH 12/12] fix tolerance --- tests/unittests/image/test_ssim.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 7d16581bade..49954f45cd7 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -319,11 +319,13 @@ def test_ssim_unequal_kernel_size(sigma, kernel_size, result): assert torch.isclose( structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=sigma), result, + atol=1e-04, ) else: assert torch.isclose( structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=kernel_size), result, + atol=1e-04, )