From 73d227c18cb949ad8876eb491369b266f417b05b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 28 Mar 2023 09:04:24 +0200 Subject: [PATCH 1/9] build(deps): update scikit-learn requirement from <1.2.2,>1.0 to >1.0,<1.2.3 in /requirements (#1657) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index 3248b0f7ff6..4b62db7b8f1 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -14,4 +14,4 @@ requests <=2.28.2 fire <=0.5.0 cloudpickle >1.3, <=2.2.1 -scikit-learn >1.0, <1.2.2 +scikit-learn >1.0, <1.2.3 From 685391d6c314981c536c54392783762ae545466c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 28 Mar 2023 08:25:59 +0000 Subject: [PATCH 2/9] build(deps): update netcal requirement from <=1.3.3,>1.0.0 to >1.0.0,<=1.3.4 in /requirements (#1658) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/classification_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/classification_test.txt b/requirements/classification_test.txt index 0a232c25173..08bbd5047ed 100644 --- a/requirements/classification_test.txt +++ b/requirements/classification_test.txt @@ -2,5 +2,5 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment pandas >1.0.0, <=1.5.3 -netcal >1.0.0, <=1.3.3 # calibration_error +netcal >1.0.0, <=1.3.4 # calibration_error fairlearn # group_fairness From 7015b947ced62f79b0ca0510a628ceec18085936 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Mar 2023 12:09:51 +0200 Subject: [PATCH 3/9] Allow FID with torch.float64 (#1628) * license + docstring + test * chlog --------- Co-authored-by: Jirka Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + src/torchmetrics/image/fid.py | 108 +++++++++++++++++++++++++++++- tests/unittests/image/test_fid.py | 12 ++++ 3 files changed, 122 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index afefa880f8e..05dfeda60f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,6 +96,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed `__iter__` method from raising `NotImplementedError` to `TypeError` by setting to `None` ([#1538](https://github.com/Lightning-AI/metrics/pull/1538)) +- Allowed FID with `torch.float64` ([#1628](https://github.com/Lightning-AI/metrics/pull/1628)) + + ### Deprecated - diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 52c75bc59b8..469c15a6e9a 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -19,6 +19,7 @@ from torch import Tensor from torch.autograd import Function from torch.nn import Module +from torch.nn.functional import adaptive_avg_pool2d from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_info @@ -30,11 +31,16 @@ if _TORCH_FIDELITY_AVAILABLE: from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3 + from torch_fidelity.helpers import vassert + from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x else: class _FeatureExtractorInceptionV3(Module): pass + vassert = None + interpolate_bilinear_2d_like_tensorflow1x = None + __doctest_skip__ = ["FrechetInceptionDistance", "FrechetInceptionDistance.plot"] @@ -59,9 +65,94 @@ def train(self, mode: bool) -> "NoTrainInceptionV3": """Force network to always be in evaluation mode.""" return super().train(False) + def _torch_fidelity_forward(self, x: Tensor) -> Tensor: + """Forward method of inception net. + + Copy of the forward method from this file: + https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/feature_extractor_inceptionv3.py + with a single line change regarding the casting of `x` in the beginning. + + Corresponding license file (Apache License, Version 2.0): + https://github.com/toshas/torch-fidelity/blob/master/LICENSE.md + """ + vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8") + features = {} + remaining_features = self.features_list.copy() + + x = x.to(self._dtype) if hasattr(self, "_dtype") else x.to(torch.float) + x = interpolate_bilinear_2d_like_tensorflow1x( + x, + size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), + align_corners=False, + ) + x = (x - 128) / 128 + + x = self.Conv2d_1a_3x3(x) + x = self.Conv2d_2a_3x3(x) + x = self.Conv2d_2b_3x3(x) + x = self.MaxPool_1(x) + + if "64" in remaining_features: + features["64"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) + remaining_features.remove("64") + if len(remaining_features) == 0: + return tuple(features[a] for a in self.features_list) + + x = self.Conv2d_3b_1x1(x) + x = self.Conv2d_4a_3x3(x) + x = self.MaxPool_2(x) + + if "192" in remaining_features: + features["192"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) + remaining_features.remove("192") + if len(remaining_features) == 0: + return tuple(features[a] for a in self.features_list) + + x = self.Mixed_5b(x) + x = self.Mixed_5c(x) + x = self.Mixed_5d(x) + x = self.Mixed_6a(x) + x = self.Mixed_6b(x) + x = self.Mixed_6c(x) + x = self.Mixed_6d(x) + x = self.Mixed_6e(x) + + if "768" in remaining_features: + features["768"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) + remaining_features.remove("768") + if len(remaining_features) == 0: + return tuple(features[a] for a in self.features_list) + + x = self.Mixed_7a(x) + x = self.Mixed_7b(x) + x = self.Mixed_7c(x) + x = self.AvgPool(x) + x = torch.flatten(x, 1) + + if "2048" in remaining_features: + features["2048"] = x + remaining_features.remove("2048") + if len(remaining_features) == 0: + return tuple(features[a] for a in self.features_list) + + if "logits_unbiased" in remaining_features: + x = x.mm(self.fc.weight.T) + # N x 1008 (num_classes) + features["logits_unbiased"] = x + remaining_features.remove("logits_unbiased") + if len(remaining_features) == 0: + return tuple(features[a] for a in self.features_list) + + x = x + self.fc.bias.unsqueeze(0) + else: + x = self.fc(x) + + features["logits"] = x + return tuple(features[a] for a in self.features_list) + def forward(self, x: Tensor) -> Tensor: """Forward pass of neural network with reshaping of output.""" - out = super().forward(x) + out = self._torch_fidelity_forward(x) return out[0].reshape(x.shape[0], -1) @@ -151,6 +242,10 @@ class FrechetInceptionDistance(Metric): flag ``real`` determines if the images should update the statistics of the real distribution or the fake distribution. + This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric + that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype` + method of the metric. + .. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install torchmetrics[image]`` or ``pip install scipy`` @@ -307,6 +402,17 @@ def reset(self) -> None: else: super().reset() + def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric": + """Transfer all metric state to specific dtype. Special version of standard `type` method. + + Arguments: + dst_type (type or string): the desired type. + """ + out = super().set_dtype(dst_type) + if isinstance(out.inception, NoTrainInceptionV3): + out.inception._dtype = dst_type + return out + def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 824c1e5c7eb..24b88f5c862 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -199,3 +199,15 @@ def test_normalize_arg_false(): metric = FrechetInceptionDistance(normalize=False) with pytest.raises(ValueError, match="Expecting image as torch.Tensor with dtype=torch.uint8"): metric.update(img, real=True) + + +def test_dtype_transfer_to_submodule(): + """Test that change in dtype also changes the default inception net.""" + imgs = torch.randn(1, 3, 256, 256) + imgs = ((imgs.clamp(-1, 1) / 2 + 0.5) * 255).to(torch.uint8) + + metric = FrechetInceptionDistance(feature=64) + metric.set_dtype(torch.float64) + + out = metric.inception(imgs) + assert out.dtype == torch.float64 From 6756b74664cee7c3301e56e3ca46709834927fef Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Mar 2023 12:10:45 +0200 Subject: [PATCH 4/9] Fix precision-recall curve based computations for float target (#1642) --- CHANGELOG.md | 3 +++ .../functional/classification/precision_recall_curve.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05dfeda60f5..2257d0f18ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -114,6 +114,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support in `MetricTracker` for `MultioutputWrapper` and nested structures ([#1608](https://github.com/Lightning-AI/metrics/pull/1608)) +- Fix precision-recall curve based computations for float target ([#1642](https://github.com/Lightning-AI/metrics/pull/1642)) + + ## [0.11.4] - 2023-03-10 ### Fixed diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 3369eeb3a2c..91d296ac83d 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -214,7 +214,7 @@ def _binary_precision_recall_curve_update_vectorized( """ len_t = len(thresholds) preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long() # num_samples x num_thresholds - unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device) + unique_mapping = preds_t + 2 * target.long().unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device) bins = _bincount(unique_mapping.flatten(), minlength=4 * len_t) return bins.reshape(len_t, 2, 2) @@ -469,7 +469,7 @@ def _multiclass_precision_recall_curve_update_vectorized( len_t = len(thresholds) preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long() target_t = torch.nn.functional.one_hot(target, num_classes=num_classes) - unique_mapping = preds_t + 2 * target_t.unsqueeze(-1) + unique_mapping = preds_t + 2 * target_t.long().unsqueeze(-1) unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1) unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device) bins = _bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t) @@ -714,7 +714,7 @@ def _multilabel_precision_recall_curve_update( len_t = len(thresholds) # num_samples x num_labels x num_thresholds preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long() - unique_mapping = preds_t + 2 * target.unsqueeze(-1) + unique_mapping = preds_t + 2 * target.long().unsqueeze(-1) unique_mapping += 4 * torch.arange(num_labels, device=preds.device).unsqueeze(0).unsqueeze(-1) unique_mapping += 4 * num_labels * torch.arange(len_t, device=preds.device) unique_mapping = unique_mapping[unique_mapping >= 0] From ae7a755674ca45dc9d13c9d404cfb724650f8486 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Mar 2023 12:11:53 +0200 Subject: [PATCH 5/9] Fix corner case of calibration error (#1648) --- CHANGELOG.md | 3 +++ .../functional/classification/calibration_error.py | 8 ++++---- tests/unittests/classification/test_calibration_error.py | 7 +++++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2257d0f18ed..fa6511d42ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -114,6 +114,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support in `MetricTracker` for `MultioutputWrapper` and nested structures ([#1608](https://github.com/Lightning-AI/metrics/pull/1608)) +- Fixed corner case in calibration error for zero confidence input ([#1648](https://github.com/Lightning-AI/metrics/pull/1648)) + + - Fix precision-recall curve based computations for float target ([#1642](https://github.com/Lightning-AI/metrics/pull/1642)) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index c81f4bf78cd..674df77cc03 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -40,11 +40,11 @@ def _binning_bucketize( tuple with binned accuracy, binned confidence and binned probabilities """ accuracies = accuracies.to(dtype=confidences.dtype) - acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) - conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) - count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) + acc_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype) + conf_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype) + count_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype) - indices = torch.bucketize(confidences, bin_boundaries) - 1 + indices = torch.bucketize(confidences, bin_boundaries, right=True) - 1 count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences)) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index dcd28b25d92..5ebe811f094 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -133,6 +133,13 @@ def test_binary_calibration_error_dtype_gpu(self, input, dtype): ) +def test_binary_with_zero_pred(): + """Test that metric works with edge case where confidence is zero for a bin.""" + preds = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.0]) + target = torch.tensor([0, 0, 1, 1, 1]) + assert binary_calibration_error(preds, target, n_bins=2, norm="l1") == torch.tensor(0.6) + + def _netcal_multiclass_calibration_error(preds, target, n_bins, norm, ignore_index): preds = preds.numpy() target = target.numpy().flatten() From bf4574e1daae0f4f8b302502301a637b1c46e74b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Mar 2023 12:16:21 +0200 Subject: [PATCH 6/9] Fid: error on too few samples (#1655) --- CHANGELOG.md | 3 +++ src/torchmetrics/image/fid.py | 2 ++ tests/unittests/image/test_fid.py | 12 ++++++++++++ 3 files changed, 17 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa6511d42ae..ba546fa4282 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,6 +96,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed `__iter__` method from raising `NotImplementedError` to `TypeError` by setting to `None` ([#1538](https://github.com/Lightning-AI/metrics/pull/1538)) +- `FID` metric will now raise an error if too few samples are provided ([#1655](https://github.com/Lightning-AI/metrics/pull/1655)) + + - Allowed FID with `torch.float64` ([#1628](https://github.com/Lightning-AI/metrics/pull/1628)) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 469c15a6e9a..99c5579d938 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -380,6 +380,8 @@ def update(self, imgs: Tensor, real: bool) -> None: def compute(self) -> Tensor: """Calculate FID score based on accumulated extracted features from the two distributions.""" + if self.real_features_num_samples < 2 or self.fake_features_num_samples < 2: + raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID") mean_real = (self.real_features_sum / self.real_features_num_samples).unsqueeze(0) mean_fake = (self.fake_features_sum / self.fake_features_num_samples).unsqueeze(0) diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 24b88f5c862..a80523001fe 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -201,6 +201,18 @@ def test_normalize_arg_false(): metric.update(img, real=True) +def test_not_enough_samples(): + """Test that an error is raised if not enough samples were provided.""" + img = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8) + metric = FrechetInceptionDistance() + metric.update(img, real=True) + metric.update(img, real=False) + with pytest.raises( + RuntimeError, match="More than one sample is required for both the real and fake distributed to compute FID" + ): + metric.compute() + + def test_dtype_transfer_to_submodule(): """Test that change in dtype also changes the default inception net.""" imgs = torch.randn(1, 3, 256, 256) From 82a5f6dfb0914cc72ec62197c6b15eec796e2e77 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Mar 2023 12:17:22 +0200 Subject: [PATCH 7/9] Enchance use of `data_range` in image metrics (#1606) --- CHANGELOG.md | 3 +++ src/torchmetrics/functional/image/psnr.py | 13 ++++++++++--- src/torchmetrics/functional/image/ssim.py | 22 +++++++++++++++------- src/torchmetrics/image/psnr.py | 14 ++++++++++++-- src/torchmetrics/image/ssim.py | 13 +++++++++---- tests/unittests/image/test_psnr.py | 5 +++++ tests/unittests/image/test_ssim.py | 11 ++++++++--- 7 files changed, 62 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba546fa4282..fde45018870 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,6 +76,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485)) +- Added support for auto clamping of input for metrics that uses the `data_range` ([#1606](argument https://github.com/Lightning-AI/metrics/pull/1606)) + + - Added `ModifiedPanopticQuality` metric to detection package ([#1627](https://github.com/Lightning-AI/metrics/pull/1627)) diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index a38c05525bb..35b733d0a30 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -88,7 +88,7 @@ def _psnr_update( def peak_signal_noise_ratio( preds: Tensor, target: Tensor, - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", dim: Optional[Union[int, Tuple[int, ...]]] = None, @@ -98,8 +98,10 @@ def peak_signal_noise_ratio( Args: preds: estimated signal target: groun truth signal - data_range: the range of the data. If None, it is determined from the data (max - min). - ``data_range`` must be given when ``dim`` is not None. + data_range: + the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then + the range is calculated as the difference and input is clamped between the values. + The ``data_range`` must be given when ``dim`` is not None. base: a base of a logarithm to use reduction: a method to reduce metric score over labels. @@ -138,7 +140,12 @@ def peak_signal_noise_ratio( raise ValueError("The `data_range` must be given when `dim` is not None.") data_range = target.max() - target.min() + elif isinstance(data_range, tuple): + preds = torch.clamp(preds, min=data_range[0], max=data_range[1]) + target = torch.clamp(target, min=data_range[0], max=data_range[1]) + data_range = tensor(data_range[1] - data_range[0]) else: data_range = tensor(float(data_range)) + sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim) return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 91f4d9fd385..7dfebf11f8b 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -47,7 +47,7 @@ def _ssim_update( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, @@ -109,6 +109,10 @@ def _ssim_update( if data_range is None: data_range = max(preds.max() - preds.min(), target.max() - target.min()) + elif isinstance(data_range, tuple): + preds = torch.clamp(preds, min=data_range[0], max=data_range[1]) + target = torch.clamp(target, min=data_range[0], max=data_range[1]) + data_range = data_range[1] - data_range[0] c1 = pow(k1 * data_range, 2) c2 = pow(k2 * data_range, 2) @@ -202,7 +206,7 @@ def structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, @@ -224,7 +228,9 @@ def structural_similarity_index_measure( - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied - data_range: Range of the image. If ``None``, it is determined from the image (max - min) + data_range: + the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then + the range is calculated as the difference and input is clamped between the values. k1: Parameter of SSIM. k2: Parameter of SSIM. return_full_image: If true, the full ``ssim`` image is returned as a second argument. @@ -283,7 +289,7 @@ def _get_normalized_sim_and_cs( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, normalize: Optional[Literal["relu", "simple"]] = None, @@ -311,7 +317,7 @@ def _multiscale_ssim_update( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, betas: Union[Tuple[float, float, float, float, float], Tuple[float, ...]] = ( @@ -436,7 +442,7 @@ def multiscale_structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), @@ -459,7 +465,9 @@ def multiscale_structural_similarity_index_measure( - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied - data_range: Range of the image. If ``None``, it is determined from the image (max - min) + data_range: + the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then + the range is calculated as the difference and input is clamped between the values. k1: Parameter of structural similarity index measure. k2: Parameter of structural similarity index measure. betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index b6e2e1aea87..d26b1b279a6 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Any, Optional, Sequence, Tuple, Union import torch @@ -46,7 +47,8 @@ class PeakSignalNoiseRatio(Metric): Args: data_range: - the range of the data. If None, it is determined from the data (max - min). + the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then + the range is calculated as the difference and input is clamped between the values. The ``data_range`` must be given when ``dim`` is not None. base: a base of a logarithm to use. reduction: a method to reduce metric score over labels. @@ -83,7 +85,7 @@ class PeakSignalNoiseRatio(Metric): def __init__( self, - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", dim: Optional[Union[int, Tuple[int, ...]]] = None, @@ -101,6 +103,7 @@ def __init__( self.add_state("sum_squared_error", default=[], dist_reduce_fx="cat") self.add_state("total", default=[], dist_reduce_fx="cat") + self.clamping_fn = None if data_range is None: if dim is not None: # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to @@ -110,6 +113,9 @@ def __init__( self.data_range = None self.add_state("min_target", default=tensor(0.0), dist_reduce_fx=torch.min) self.add_state("max_target", default=tensor(0.0), dist_reduce_fx=torch.max) + elif isinstance(data_range, tuple): + self.add_state("data_range", default=tensor(data_range[1] - data_range[0]), dist_reduce_fx="mean") + self.clamping_fn = partial(torch.clamp, min=data_range[0], max=data_range[1]) else: self.add_state("data_range", default=tensor(float(data_range)), dist_reduce_fx="mean") self.base = base @@ -118,6 +124,10 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" + if self.clamping_fn is not None: + preds = self.clamping_fn(preds) + target = self.clamping_fn(target) + sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim) if self.dim is None: if self.data_range is None: diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 54bd59bd967..f37f32cc7fb 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -54,7 +54,9 @@ class StructuralSimilarityIndexMeasure(Metric): - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied - data_range: Range of the image. If ``None``, it is determined from the image (max - min) + data_range: + the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then + the range is calculated as the difference and input is clamped between the values. k1: Parameter of SSIM. k2: Parameter of SSIM. return_full_image: If true, the full ``ssim`` image is returned as a second argument. @@ -89,7 +91,7 @@ def __init__( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, @@ -239,7 +241,10 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied - data_range: Range of the image. If ``None``, it is determined from the image (max - min) + data_range: + the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then + the range is calculated as the difference and input is clamped between the values. + The ``data_range`` must be given when ``dim`` is not None. k1: Parameter of structural similarity index measure. k2: Parameter of structural similarity index measure. betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image @@ -285,7 +290,7 @@ def __init__( kernel_size: Union[int, Sequence[int]] = 11, sigma: Union[float, Sequence[float]] = 1.5, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[float] = None, + data_range: Optional[Union[float, Tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), diff --git a/tests/unittests/image/test_psnr.py b/tests/unittests/image/test_psnr.py index 2ca5375e570..5c64a11a3f6 100644 --- a/tests/unittests/image/test_psnr.py +++ b/tests/unittests/image/test_psnr.py @@ -60,6 +60,10 @@ def _to_sk_peak_signal_noise_ratio_inputs(value, dim): def _skimage_psnr(preds, target, data_range, reduction, dim): + if isinstance(data_range, tuple): + preds = preds.clamp(min=data_range[0], max=data_range[1]) + target = target.clamp(min=data_range[0], max=data_range[1]) + data_range = data_range[1] - data_range[0] sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim) sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim) np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum} @@ -84,6 +88,7 @@ def _base_e_sk_psnr(preds, target, data_range, reduction, dim): (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1), (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)), (_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)), + (_inputs[0].preds, _inputs[0].target, (0.0, 1.0), "elementwise_mean", None), ], ) @pytest.mark.parametrize( diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 2170d10d043..7c1d5956ad1 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -65,6 +65,10 @@ def _skimage_ssim( gaussian_weights=True, reduction_arg="elementwise_mean", ): + if isinstance(data_range, tuple): + preds = preds.clamp(min=data_range[0], max=data_range[1]) + target = target.clamp(min=data_range[0], max=data_range[1]) + data_range = data_range[1] - data_range[0] if len(preds.shape) == 4: c, h, w = preds.shape[-3:] sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() @@ -137,17 +141,18 @@ class TestSSIM(MetricTester): atol = 6e-3 + @pytest.mark.parametrize("data_range", [1.0, (0.1, 1.0)]) @pytest.mark.parametrize("ddp", [True, False]) - def test_ssim_sk(self, preds, target, sigma, ddp): + def test_ssim_sk(self, preds, target, sigma, data_range, ddp): """Test class implementation of metricvs skimage.""" self.run_class_metric_test( ddp, preds, target, StructuralSimilarityIndexMeasure, - partial(_skimage_ssim, data_range=1.0, sigma=sigma, kernel_size=None), + partial(_skimage_ssim, data_range=data_range, sigma=sigma, kernel_size=None), metric_args={ - "data_range": 1.0, + "data_range": data_range, "sigma": sigma, }, ) From 1778808280e95189908a9ba88669ce4c08b205ae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 28 Mar 2023 11:53:01 +0000 Subject: [PATCH 8/9] build(deps): update netcal requirement from <=1.3.4,>1.0.0 to >1.0.0,<=1.3.5 in /requirements (#1661) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/classification_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/classification_test.txt b/requirements/classification_test.txt index 08bbd5047ed..2d87aec322b 100644 --- a/requirements/classification_test.txt +++ b/requirements/classification_test.txt @@ -2,5 +2,5 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment pandas >1.0.0, <=1.5.3 -netcal >1.0.0, <=1.3.4 # calibration_error +netcal >1.0.0, <=1.3.5 # calibration_error fairlearn # group_fairness From 4d5147a083a8972916fbf7f457347fdf10dac7ea Mon Sep 17 00:00:00 2001 From: Bas Veeling Date: Tue, 28 Mar 2023 14:48:15 +0200 Subject: [PATCH 9/9] Silently failing DDP syncing when initializing Metric with jsonargparse (#1651) * Broken ddp syncing when initializing metric with jsonargparse (LightningCLI) This happens due the following combination of factors. - jsonargparse has a docstring parsing function enabled when installed with `pip install jsonargparse[signatures]` - `torchmetrics.Metric` has a docstring that mentions an optional` distributed_available_fn` - If `distributed_available_fn` is not set in `**kwargs`, `Metric.__init__` sets a default pytorch function which enables ddp syncing of metrics\ - When a metric is initialized in a yaml config, jsonargparse recognizes the `distributed_available_fn` field in the docstring and passes a default value of `distributed_available_fn=None` - Hence, any subclass of `Metric` initialized by jsonargparse has `distributed_available_fn = None`. These metrics silently fail to sync across gpus. --------- Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 +++ src/torchmetrics/metric.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fde45018870..dcd2e714997 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support in `MetricTracker` for `MultioutputWrapper` and nested structures ([#1608](https://github.com/Lightning-AI/metrics/pull/1608)) +- Fixed integration with `jsonargparse` and `LightningCLI` ([#1651](https://github.com/Lightning-AI/metrics/pull/1651)) + + - Fixed corner case in calibration error for zero confidence input ([#1648](https://github.com/Lightning-AI/metrics/pull/1648)) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 9a392c8dd5e..0595575a647 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -124,7 +124,7 @@ def __init__( f"Expected keyword argument `dist_sync_fn` to be an callable function but got {self.dist_sync_fn}" ) - self.distributed_available_fn = kwargs.pop("distributed_available_fn", jit_distributed_available) + self.distributed_available_fn = kwargs.pop("distributed_available_fn", None) or jit_distributed_available self.sync_on_compute = kwargs.pop("sync_on_compute", True) if not isinstance(self.sync_on_compute, bool):