Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/pearson_1d_input
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 28, 2023
2 parents 12d9b12 + 4d5147a commit 6ab2aad
Show file tree
Hide file tree
Showing 15 changed files with 227 additions and 30 deletions.
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485))


- Added support for auto clamping of input for metrics that uses the `data_range` ([#1606](argument https://github.com/Lightning-AI/metrics/pull/1606))


- Added `ModifiedPanopticQuality` metric to detection package ([#1627](https://github.com/Lightning-AI/metrics/pull/1627))


Expand All @@ -96,6 +99,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed `__iter__` method from raising `NotImplementedError` to `TypeError` by setting to `None` ([#1538](https://github.com/Lightning-AI/metrics/pull/1538))


- `FID` metric will now raise an error if too few samples are provided ([#1655](https://github.com/Lightning-AI/metrics/pull/1655))


- Allowed FID with `torch.float64` ([#1628](https://github.com/Lightning-AI/metrics/pull/1628))


### Deprecated

-
Expand All @@ -114,6 +123,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed restrictive check in `PearsonCorrCoef` ([#1649](https://github.com/Lightning-AI/metrics/pull/1649))


- Fixed integration with `jsonargparse` and `LightningCLI` ([#1651](https://github.com/Lightning-AI/metrics/pull/1651))


- Fixed corner case in calibration error for zero confidence input ([#1648](https://github.com/Lightning-AI/metrics/pull/1648))


- Fix precision-recall curve based computations for float target ([#1642](https://github.com/Lightning-AI/metrics/pull/1642))


## [0.11.4] - 2023-03-10

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion requirements/classification_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

pandas >1.0.0, <=1.5.3
netcal >1.0.0, <=1.3.3 # calibration_error
netcal >1.0.0, <=1.3.5 # calibration_error
fairlearn # group_fairness
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ requests <=2.28.2
fire <=0.5.0

cloudpickle >1.3, <=2.2.1
scikit-learn >1.0, <1.2.2
scikit-learn >1.0, <1.2.3
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def _binning_bucketize(
tuple with binned accuracy, binned confidence and binned probabilities
"""
accuracies = accuracies.to(dtype=confidences.dtype)
acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype)
conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype)
count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype)
acc_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)
conf_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)
count_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)

indices = torch.bucketize(confidences, bin_boundaries) - 1
indices = torch.bucketize(confidences, bin_boundaries, right=True) - 1

count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _binary_precision_recall_curve_update_vectorized(
"""
len_t = len(thresholds)
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long() # num_samples x num_thresholds
unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device)
unique_mapping = preds_t + 2 * target.long().unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device)
bins = _bincount(unique_mapping.flatten(), minlength=4 * len_t)
return bins.reshape(len_t, 2, 2)

Expand Down Expand Up @@ -469,7 +469,7 @@ def _multiclass_precision_recall_curve_update_vectorized(
len_t = len(thresholds)
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long()
target_t = torch.nn.functional.one_hot(target, num_classes=num_classes)
unique_mapping = preds_t + 2 * target_t.unsqueeze(-1)
unique_mapping = preds_t + 2 * target_t.long().unsqueeze(-1)
unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1)
unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device)
bins = _bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t)
Expand Down Expand Up @@ -714,7 +714,7 @@ def _multilabel_precision_recall_curve_update(
len_t = len(thresholds)
# num_samples x num_labels x num_thresholds
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long()
unique_mapping = preds_t + 2 * target.unsqueeze(-1)
unique_mapping = preds_t + 2 * target.long().unsqueeze(-1)
unique_mapping += 4 * torch.arange(num_labels, device=preds.device).unsqueeze(0).unsqueeze(-1)
unique_mapping += 4 * num_labels * torch.arange(len_t, device=preds.device)
unique_mapping = unique_mapping[unique_mapping >= 0]
Expand Down
13 changes: 10 additions & 3 deletions src/torchmetrics/functional/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _psnr_update(
def peak_signal_noise_ratio(
preds: Tensor,
target: Tensor,
data_range: Optional[float] = None,
data_range: Optional[Union[float, Tuple[float, float]]] = None,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
dim: Optional[Union[int, Tuple[int, ...]]] = None,
Expand All @@ -98,8 +98,10 @@ def peak_signal_noise_ratio(
Args:
preds: estimated signal
target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min).
``data_range`` must be given when ``dim`` is not None.
data_range:
the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
the range is calculated as the difference and input is clamped between the values.
The ``data_range`` must be given when ``dim`` is not None.
base: a base of a logarithm to use
reduction: a method to reduce metric score over labels.
Expand Down Expand Up @@ -138,7 +140,12 @@ def peak_signal_noise_ratio(
raise ValueError("The `data_range` must be given when `dim` is not None.")

data_range = target.max() - target.min()
elif isinstance(data_range, tuple):
preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
target = torch.clamp(target, min=data_range[0], max=data_range[1])
data_range = tensor(data_range[1] - data_range[0])
else:
data_range = tensor(float(data_range))

sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)
22 changes: 15 additions & 7 deletions src/torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _ssim_update(
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
data_range: Optional[float] = None,
data_range: Optional[Union[float, Tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
return_full_image: bool = False,
Expand Down Expand Up @@ -109,6 +109,10 @@ def _ssim_update(

if data_range is None:
data_range = max(preds.max() - preds.min(), target.max() - target.min())
elif isinstance(data_range, tuple):
preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
target = torch.clamp(target, min=data_range[0], max=data_range[1])
data_range = data_range[1] - data_range[0]

c1 = pow(k1 * data_range, 2)
c2 = pow(k2 * data_range, 2)
Expand Down Expand Up @@ -202,7 +206,7 @@ def structural_similarity_index_measure(
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[float] = None,
data_range: Optional[Union[float, Tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
return_full_image: bool = False,
Expand All @@ -224,7 +228,9 @@ def structural_similarity_index_measure(
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
data_range:
the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
the range is calculated as the difference and input is clamped between the values.
k1: Parameter of SSIM.
k2: Parameter of SSIM.
return_full_image: If true, the full ``ssim`` image is returned as a second argument.
Expand Down Expand Up @@ -283,7 +289,7 @@ def _get_normalized_sim_and_cs(
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
data_range: Optional[float] = None,
data_range: Optional[Union[float, Tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
normalize: Optional[Literal["relu", "simple"]] = None,
Expand Down Expand Up @@ -311,7 +317,7 @@ def _multiscale_ssim_update(
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
data_range: Optional[float] = None,
data_range: Optional[Union[float, Tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
betas: Union[Tuple[float, float, float, float, float], Tuple[float, ...]] = (
Expand Down Expand Up @@ -436,7 +442,7 @@ def multiscale_structural_similarity_index_measure(
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[float] = None,
data_range: Optional[Union[float, Tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
Expand All @@ -459,7 +465,9 @@ def multiscale_structural_similarity_index_measure(
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
data_range:
the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
the range is calculated as the difference and input is clamped between the values.
k1: Parameter of structural similarity index measure.
k2: Parameter of structural similarity index measure.
betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image
Expand Down
110 changes: 109 additions & 1 deletion src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import Tensor
from torch.autograd import Function
from torch.nn import Module
from torch.nn.functional import adaptive_avg_pool2d

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_info
Expand All @@ -30,11 +31,16 @@

if _TORCH_FIDELITY_AVAILABLE:
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3
from torch_fidelity.helpers import vassert
from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
else:

class _FeatureExtractorInceptionV3(Module):
pass

vassert = None
interpolate_bilinear_2d_like_tensorflow1x = None

__doctest_skip__ = ["FrechetInceptionDistance", "FrechetInceptionDistance.plot"]


Expand All @@ -59,9 +65,94 @@ def train(self, mode: bool) -> "NoTrainInceptionV3":
"""Force network to always be in evaluation mode."""
return super().train(False)

def _torch_fidelity_forward(self, x: Tensor) -> Tensor:
"""Forward method of inception net.
Copy of the forward method from this file:
https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/feature_extractor_inceptionv3.py
with a single line change regarding the casting of `x` in the beginning.
Corresponding license file (Apache License, Version 2.0):
https://github.com/toshas/torch-fidelity/blob/master/LICENSE.md
"""
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8")
features = {}
remaining_features = self.features_list.copy()

x = x.to(self._dtype) if hasattr(self, "_dtype") else x.to(torch.float)
x = interpolate_bilinear_2d_like_tensorflow1x(
x,
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
align_corners=False,
)
x = (x - 128) / 128

x = self.Conv2d_1a_3x3(x)
x = self.Conv2d_2a_3x3(x)
x = self.Conv2d_2b_3x3(x)
x = self.MaxPool_1(x)

if "64" in remaining_features:
features["64"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
remaining_features.remove("64")
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)

x = self.Conv2d_3b_1x1(x)
x = self.Conv2d_4a_3x3(x)
x = self.MaxPool_2(x)

if "192" in remaining_features:
features["192"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
remaining_features.remove("192")
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)

x = self.Mixed_5b(x)
x = self.Mixed_5c(x)
x = self.Mixed_5d(x)
x = self.Mixed_6a(x)
x = self.Mixed_6b(x)
x = self.Mixed_6c(x)
x = self.Mixed_6d(x)
x = self.Mixed_6e(x)

if "768" in remaining_features:
features["768"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
remaining_features.remove("768")
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)

x = self.Mixed_7a(x)
x = self.Mixed_7b(x)
x = self.Mixed_7c(x)
x = self.AvgPool(x)
x = torch.flatten(x, 1)

if "2048" in remaining_features:
features["2048"] = x
remaining_features.remove("2048")
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)

if "logits_unbiased" in remaining_features:
x = x.mm(self.fc.weight.T)
# N x 1008 (num_classes)
features["logits_unbiased"] = x
remaining_features.remove("logits_unbiased")
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)

x = x + self.fc.bias.unsqueeze(0)
else:
x = self.fc(x)

features["logits"] = x
return tuple(features[a] for a in self.features_list)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass of neural network with reshaping of output."""
out = super().forward(x)
out = self._torch_fidelity_forward(x)
return out[0].reshape(x.shape[0], -1)


Expand Down Expand Up @@ -151,6 +242,10 @@ class FrechetInceptionDistance(Metric):
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.
This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric
that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype`
method of the metric.
.. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install
torchmetrics[image]`` or ``pip install scipy``
Expand Down Expand Up @@ -285,6 +380,8 @@ def update(self, imgs: Tensor, real: bool) -> None:

def compute(self) -> Tensor:
"""Calculate FID score based on accumulated extracted features from the two distributions."""
if self.real_features_num_samples < 2 or self.fake_features_num_samples < 2:
raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID")
mean_real = (self.real_features_sum / self.real_features_num_samples).unsqueeze(0)
mean_fake = (self.fake_features_sum / self.fake_features_num_samples).unsqueeze(0)

Expand All @@ -307,6 +404,17 @@ def reset(self) -> None:
else:
super().reset()

def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":
"""Transfer all metric state to specific dtype. Special version of standard `type` method.
Arguments:
dst_type (type or string): the desired type.
"""
out = super().set_dtype(dst_type)
if isinstance(out.inception, NoTrainInceptionV3):
out.inception._dtype = dst_type
return out

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
Expand Down
Loading

0 comments on commit 6ab2aad

Please sign in to comment.