diff --git a/CHANGELOG.md b/CHANGELOG.md index d03efe15340..361117a4f82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for evaluating `"segm"` and `"bbox"` detection in `MeanAveragePrecision` at the same time ([#1928](https://github.com/Lightning-AI/torchmetrics/pull/1928)) +- Added `PerceptualPathLength` to image package ([#1939](https://github.com/Lightning-AI/torchmetrics/pull/1939)) + + - Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937)) diff --git a/docs/source/image/learned_perceptual_image_patch_similarity.rst b/docs/source/image/learned_perceptual_image_patch_similarity.rst index 347caa47aaa..1b7cf955da4 100644 --- a/docs/source/image/learned_perceptual_image_patch_similarity.rst +++ b/docs/source/image/learned_perceptual_image_patch_similarity.rst @@ -15,3 +15,9 @@ ________________ .. autoclass:: torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity :noindex: :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.image.learned_perceptual_image_patch_similarity + :noindex: diff --git a/docs/source/image/perceptual_path_length.rst b/docs/source/image/perceptual_path_length.rst new file mode 100644 index 00000000000..a22fa214179 --- /dev/null +++ b/docs/source/image/perceptual_path_length.rst @@ -0,0 +1,23 @@ +.. customcarditem:: + :header: Perceptual Path Length + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Image + +.. include:: ../links.rst + +############################ +Perceptual Path Length (PPL) +############################ + +Module Interface +________________ + +.. autoclass:: torchmetrics.image.perceptual_path_length.PerceptualPathLength + :noindex: + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.image.perceptual_path_length.perceptual_path_length + :noindex: diff --git a/docs/source/links.rst b/docs/source/links.rst index d24e965509d..7f7d44e5555 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -144,6 +144,7 @@ .. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220 .. _Fleiss kappa: https://en.wikipedia.org/wiki/Fleiss%27_kappa .. _VIF: https://ieeexplore.ieee.org/abstract/document/1576816 +.. _PPL : https://arxiv.org/pdf/1812.04948 .. _CIOU: https://arxiv.org/abs/2005.03572 .. _DIOU: https://arxiv.org/abs/1911.08287v1 .. _GIOU: https://arxiv.org/abs/1902.09630 diff --git a/requirements/image.txt b/requirements/image.txt index 8d3f2f14638..7e02ee2ccb5 100644 --- a/requirements/image.txt +++ b/requirements/image.txt @@ -3,5 +3,5 @@ scipy >1.0.0, <1.11.0 torchvision >=0.8, <=0.15.2 -torch-fidelity <=0.3.0 +torch-fidelity <=0.4.0 # bumping to alow install version from master, now used in testing lpips <=0.1.4 diff --git a/requirements/image_test.txt b/requirements/image_test.txt index 9e54ea0fa69..4a911c154b1 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -6,3 +6,4 @@ kornia >=0.6.7, <0.7.1 pytorch-msssim ==1.0.0 sewar >=0.4.4, <=0.4.6 numpy <1.25.0 +torch-fidelity @ git+https://github.com/toshas/torch-fidelity@master diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index f16ec72ec6e..329b33b66fe 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -14,6 +14,8 @@ from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis from torchmetrics.functional.image.gradients import image_gradients +from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity +from torchmetrics.functional.image.perceptual_path_length import perceptual_path_length from torchmetrics.functional.image.psnr import peak_signal_noise_ratio from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio_with_blocked_effect from torchmetrics.functional.image.rase import relative_average_spectral_error @@ -41,4 +43,6 @@ "total_variation", "universal_image_quality_index", "visual_information_fidelity", + "learned_perceptual_image_patch_similarity", + "perceptual_path_length", ] diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 708cec05bac..594556d014a 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -186,7 +186,7 @@ def spatial_average(in_tens: Tensor, keepdim: bool = True) -> Tensor: return in_tens.mean([2, 3], keepdim=keepdim) -def upsam(in_tens: Tensor, out_hw: Tuple[int, int] = (64, 64)) -> Tensor: +def upsam(in_tens: Tensor, out_hw: Tuple[int, ...] = (64, 64)) -> Tensor: """Upsample input with bilinear interpolation.""" return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens) @@ -197,6 +197,13 @@ def normalize_tensor(in_feat: Tensor, eps: float = 1e-10) -> Tensor: return in_feat / (norm_factor + eps) +def resize_tensor(x: Tensor, size: int = 64) -> Tensor: + """https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/sample_similarity_lpips.py#L127C22-L132.""" + if x.shape[-1] > size and x.shape[-2] > size: + return torch.nn.functional.interpolate(x, (size, size), mode="area") + return torch.nn.functional.interpolate(x, (size, size), mode="bilinear", align_corners=False) + + class ScalingLayer(nn.Module): """Scaling layer.""" @@ -238,6 +245,7 @@ def __init__( use_dropout: bool = True, model_path: Optional[str] = None, eval_mode: bool = True, + resize: Optional[int] = None, ) -> None: """Initializes a perceptual loss torch.nn.Module. @@ -250,6 +258,7 @@ def __init__( use_dropout: If dropout layers should be added model_path: Model path to load pretained models from eval_mode: If network should be in evaluation mode + resize: If input should be resized to this size """ super().__init__() @@ -258,6 +267,7 @@ def __init__( self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial + self.resize = resize self.scaling_layer = ScalingLayer() if self.pnet_type in ["vgg", "vgg16"]: @@ -298,13 +308,19 @@ def __init__( def forward( self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False - ) -> Union[int, Tuple[int, List[Tensor]]]: + ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 - # v0.0 - original release had a bug, where input was not scaled + # normalize input in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) + + # resize input if needed + if self.resize is not None: + in0_input = resize_tensor(in0_input, size=self.resize) + in1_input = resize_tensor(in1_input, size=self.resize) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} @@ -312,17 +328,14 @@ def forward( feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - if self.spatial: - res = [ - upsam(self.lins[kk](diffs[kk]), out_hw=in0.shape[2:]) for kk in range(self.L) # type: ignore[arg-type] - ] - else: - res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] - - val = 0 - for layer in range(self.L): - val += res[layer] # type: ignore[assignment] + res = [] + for kk in range(self.L): + if self.spatial: + res.append(upsam(self.lins[kk](diffs[kk]), out_hw=tuple(in0.shape[2:]))) + else: + res.append(spatial_average(self.lins[kk](diffs[kk]), keepdim=True)) + val: Tensor = sum(res) # type: ignore[assignment] if retperlayer: return (val, res) return val diff --git a/src/torchmetrics/functional/image/perceptual_path_length.py b/src/torchmetrics/functional/image/perceptual_path_length.py new file mode 100644 index 00000000000..31626ca066d --- /dev/null +++ b/src/torchmetrics/functional/image/perceptual_path_length.py @@ -0,0 +1,271 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Literal, Optional, Tuple, Union + +import torch +from torch import Tensor, nn + +from torchmetrics.functional.image.lpips import _LPIPS +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TORCHVISION_AVAILABLE + +if not _TORCHVISION_AVAILABLE: + __doctest_skip__ = ["perceptual_path_length"] + + +class _GeneratorType(nn.Module): + @property + def num_classes(self) -> int: + raise NotImplementedError + + def sample(self, num_samples: int) -> Tensor: + raise NotImplementedError + + +def _validate_generator_model(generator: _GeneratorType, conditional: bool = False) -> None: + """Validate that the user provided generator has the right methods and attributes. + + Args: + generator: Generator model + conditional: Whether the generator is conditional or not (i.e. whether it takes labels as input). + + """ + if not hasattr(generator, "sample"): + raise NotImplementedError( + "The generator must have a `sample` method with signature `sample(num_samples: int) -> Tensor` where the" + " returned tensor has shape `(num_samples, z_size)`." + ) + if not callable(generator.sample): + raise ValueError("The generator's `sample` method must be callable.") + if conditional and not hasattr(generator, "num_classes"): + raise AttributeError("The generator must have a `num_classes` attribute when `conditional=True`.") + if conditional and not isinstance(generator.num_classes, int): + raise ValueError("The generator's `num_classes` attribute must be an integer when `conditional=True`.") + + +def _perceptual_path_length_validate_arguments( + num_samples: int = 10_000, + conditional: bool = False, + batch_size: int = 128, + interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp", + epsilon: float = 1e-4, + resize: Optional[int] = 64, + lower_discard: Optional[float] = 0.01, + upper_discard: Optional[float] = 0.99, +) -> None: + """Validate arguments for perceptual path length.""" + if not (isinstance(num_samples, int) and num_samples > 0): + raise ValueError(f"Argument `num_samples` must be a positive integer, but got {num_samples}.") + if not isinstance(conditional, bool): + raise ValueError(f"Argument `conditional` must be a boolean, but got {conditional}.") + if not (isinstance(batch_size, int) and batch_size > 0): + raise ValueError(f"Argument `batch_size` must be a positive integer, but got {batch_size}.") + if interpolation_method not in ["lerp", "slerp_any", "slerp_unit"]: + raise ValueError( + f"Argument `interpolation_method` must be one of 'lerp', 'slerp_any', 'slerp_unit'," + f"got {interpolation_method}." + ) + if not (isinstance(epsilon, float) and epsilon > 0): + raise ValueError(f"Argument `epsilon` must be a positive float, but got {epsilon}.") + if resize is not None and not (isinstance(resize, int) and resize > 0): + raise ValueError(f"Argument `resize` must be a positive integer or `None`, but got {resize}.") + if lower_discard is not None and not (isinstance(lower_discard, float) and 0 <= lower_discard <= 1): + raise ValueError( + f"Argument `lower_discard` must be a float between 0 and 1 or `None`, but got {lower_discard}." + ) + if upper_discard is not None and not (isinstance(upper_discard, float) and 0 <= upper_discard <= 1): + raise ValueError( + f"Argument `upper_discard` must be a float between 0 and 1 or `None`, but got {upper_discard}." + ) + + +def _interpolate( + latents1: Tensor, + latents2: Tensor, + epsilon: float = 1e-4, + interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp", +) -> Tensor: + """Interpolate between two sets of latents. + + Inspired by: https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/noise.py + + Args: + latents1: First set of latents. + latents2: Second set of latents. + epsilon: Spacing between the points on the path between latent points. + interpolation_method: Interpolation method to use. Choose from 'lerp', 'slerp_any', 'slerp_unit'. + + """ + eps = 1e-7 + if latents1.shape != latents2.shape: + raise ValueError("Latents must have the same shape.") + if interpolation_method == "lerp": + return latents1 + (latents2 - latents1) * epsilon + if interpolation_method == "slerp_any": + ndims = latents1.dim() - 1 + z_size = latents1.shape[-1] + latents1_norm = latents1 / (latents1**2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps) + latents2_norm = latents2 / (latents2**2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps) + d = (latents1_norm * latents2_norm).sum(dim=-1, keepdim=True) + mask_zero = (latents1_norm.norm(dim=-1, keepdim=True) < eps) | (latents2_norm.norm(dim=-1, keepdim=True) < eps) + mask_collinear = (d > 1 - eps) | (d < -1 + eps) + mask_lerp = (mask_zero | mask_collinear).repeat([1 for _ in range(ndims)] + [z_size]) + omega = d.acos() + denom = omega.sin().clamp_min(eps) + coef_latents1 = ((1 - epsilon) * omega).sin() / denom + coef_latents2 = (epsilon * omega).sin() / denom + out = coef_latents1 * latents1 + coef_latents2 * latents2 + out[mask_lerp] = _interpolate(latents1, latents2, epsilon, interpolation_method="lerp")[mask_lerp] + return out + if interpolation_method == "slerp_unit": + out = _interpolate(latents1=latents1, latents2=latents2, epsilon=epsilon, interpolation_method="slerp_any") + return out / (out**2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps) + raise ValueError( + f"Interpolation method {interpolation_method} not supported. Choose from 'lerp', 'slerp_any', 'slerp_unit'." + ) + + +def perceptual_path_length( + generator: _GeneratorType, + num_samples: int = 10_000, + conditional: bool = False, + batch_size: int = 64, + interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp", + epsilon: float = 1e-4, + resize: Optional[int] = 64, + lower_discard: Optional[float] = 0.01, + upper_discard: Optional[float] = 0.99, + sim_net: Union[nn.Module, Literal["alex", "vgg", "squeeze"]] = "vgg", + device: Union[str, torch.device] = "cpu", +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Computes the perceptual path length (`PPL`_) of a generator model. + + The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is + defined as + + .. math:: + PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right] + + where :math:`G` is the generator, :math:`I` is the interpolation function, :math:`D` is a similarity metric, + :math:`z_1` and :math:`z_2` are two sets of latent points, and :math:`t` is a parameter between 0 and 1. The metric + thus works by interpolating between two sets of latent points, and measuring the similarity between the generated + images. The expectation is approximated by sampling :math:`z_1` and :math:`z_2` from the generator, and averaging + the calculated distanced. The similarity metric :math:`D` is by default the `LPIPS`_ metric, but can be changed by + setting the `sim_net` argument. + + The provided generator model must have a `sample` method with signature `sample(num_samples: int) -> Tensor` where + the returned tensor has shape `(num_samples, z_size)`. If the generator is conditional, it must also have a + `num_classes` attribute. The `forward` method of the generator must have signature `forward(z: Tensor) -> Tensor` + if `conditional=False`, and `forward(z: Tensor, labels: Tensor) -> Tensor` if `conditional=True`. The returned + tensor should have shape `(num_samples, C, H, W)` and be scaled to the range [0, 255]. + + Args: + generator: Generator model, with specific requirements. See above. + num_samples: Number of samples to use for the PPL computation. + conditional: Whether the generator is conditional or not (i.e. whether it takes labels as input). + batch_size: Batch size to use for the PPL computation. + interpolation_method: Interpolation method to use. Choose from 'lerp', 'slerp_any', 'slerp_unit'. + epsilon: Spacing between the points on the path between latent points. + resize: Resize images to this size before computing the similarity between generated images. + lower_discard: Lower quantile to discard from the distances, before computing the mean and standard deviation. + upper_discard: Upper quantile to discard from the distances, before computing the mean and standard deviation. + sim_net: Similarity network to use. Can be a `nn.Module` or one of 'alex', 'vgg', 'squeeze', where the three + latter options correspond to the pretrained networks from the `LPIPS`_ paper. + device: Device to use for the computation. + + Returns: + A tuple containing the mean, standard deviation and all distances. + + Example:: + >>> from torchmetrics.functional.image import perceptual_path_length + >>> import torch + >>> _ = torch.manual_seed(42) + >>> class DummyGenerator(torch.nn.Module): + ... def __init__(self, z_size) -> None: + ... super().__init__() + ... self.z_size = z_size + ... self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid()) + ... def forward(self, z): + ... return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1) + ... def sample(self, num_samples): + ... return torch.randn(num_samples, self.z_size) + >>> generator = DummyGenerator(2) + >>> perceptual_path_length(generator, num_samples=10) # doctest: +SKIP + (tensor(0.1945), + tensor(0.1222), + tensor([0.0990, 0.4173, 0.1628, 0.3573, 0.1875, 0.0335, 0.1095, 0.1887, 0.1953])) + + """ + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError( + "Metric `perceptual_path_length` requires torchvision which is not installed." + "Install with `pip install torchvision` or `pip install torchmetrics[image]`" + ) + _perceptual_path_length_validate_arguments( + num_samples, conditional, batch_size, interpolation_method, epsilon, resize, lower_discard, upper_discard + ) + _validate_generator_model(generator, conditional) + generator = generator.to(device) + + latent1 = generator.sample(num_samples).to(device) + latent2 = generator.sample(num_samples).to(device) + latent2 = _interpolate(latent1, latent2, epsilon, interpolation_method=interpolation_method) + + if conditional: + labels = torch.randint(0, generator.num_classes, (num_samples,)).to(device) + + if isinstance(sim_net, nn.Module): + net = sim_net.to(device) + elif sim_net in ["alex", "vgg", "squeeze"]: + net = _LPIPS(pretrained=True, net=sim_net, resize=resize).to(device) + else: + raise ValueError(f"sim_net must be a nn.Module or one of 'alex', 'vgg', 'squeeze', got {sim_net}") + + decorator = torch.inference_mode if _TORCH_GREATER_EQUAL_1_10 else torch.no_grad + with decorator(): + distances = [] + num_batches = math.ceil(num_samples / batch_size) + for batch_idx in range(num_batches): + batch_latent1 = latent1[batch_idx * batch_size : (batch_idx + 1) * batch_size].to(device) + batch_latent2 = latent2[batch_idx * batch_size : (batch_idx + 1) * batch_size].to(device) + + if conditional: + batch_labels = labels[batch_idx * batch_size : (batch_idx + 1) * batch_size].to(device) + outputs = generator( + torch.cat((batch_latent1, batch_latent2), dim=0), torch.cat((batch_labels, batch_labels), dim=0) + ) + else: + outputs = generator(torch.cat((batch_latent1, batch_latent2), dim=0)) + + out1, out2 = outputs.chunk(2, dim=0) + + # rescale to lpips expected domain: [0, 255] -> [0, 1] -> [-1, 1] + out1_rescale = 2 * (out1 / 255) - 1 + out2_rescale = 2 * (out2 / 255) - 1 + + similarity = net(out1_rescale, out2_rescale) + dist = similarity / epsilon**2 + distances.append(dist.detach()) + + distances = torch.cat(distances) + + lower = torch.quantile(distances, lower_discard, interpolation="lower") if lower_discard is not None else 0.0 + upper = ( + torch.quantile(distances, upper_discard, interpolation="lower") + if upper_discard is not None + else max(distances) + ) + distances = distances[(distances >= lower) & (distances <= upper)] + + return distances.mean(), distances.std(), distances diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index 3e6a09cc591..ace4a50de22 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -23,7 +23,7 @@ from torchmetrics.image.tv import TotalVariation from torchmetrics.image.uqi import UniversalImageQualityIndex from torchmetrics.image.vif import VisualInformationFidelity -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE, _TORCHVISION_AVAILABLE __all__ = [ "SpectralDistortionIndex", @@ -52,7 +52,8 @@ "KernelInceptionDistance", ] -if _LPIPS_AVAILABLE: - from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity # noqa: F401 +if _TORCHVISION_AVAILABLE: + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + from torchmetrics.image.perceptual_path_length import PerceptualPathLength - __all__.append("LearnedPerceptualImagePatchSimilarity") + __all__ += ["LearnedPerceptualImagePatchSimilarity", "PerceptualPathLength"] diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index d0bf22ae2ff..f209892e2e2 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -20,7 +20,7 @@ from torchmetrics.functional.image.lpips import _LPIPS, _lpips_compute, _lpips_update, _NoTrainLpips from torchmetrics.metric import Metric from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: @@ -111,10 +111,10 @@ def __init__( ) -> None: super().__init__(**kwargs) - if not _LPIPS_AVAILABLE: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - "LPIPS metric requires that lpips is installed." - " Either install as `pip install torchmetrics[image]` or `pip install lpips`." + "LPIPS metric requires that torchvision is installed." + " Either install as `pip install torchmetrics[image]` or `pip install torchvision`." ) valid_net_type = ("vgg", "alex", "squeeze") diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py new file mode 100644 index 00000000000..440b9e9abf2 --- /dev/null +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -0,0 +1,178 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Literal, Optional, Tuple, Union + +from torch import Tensor, nn + +from torchmetrics.functional.image.lpips import _LPIPS +from torchmetrics.functional.image.perceptual_path_length import ( + _GeneratorType, + _perceptual_path_length_validate_arguments, + _validate_generator_model, + perceptual_path_length, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE + +if not _TORCHVISION_AVAILABLE: + __doctest_skip__ = ["PerceptualPathLength"] + + +class PerceptualPathLength(Metric): + r"""Computes the perceptual path length (`PPL`_) of a generator model. + + The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is + defined as + + .. math:: + PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right] + + where :math:`G` is the generator, :math:`I` is the interpolation function, :math:`D` is a similarity metric, + :math:`z_1` and :math:`z_2` are two sets of latent points, and :math:`t` is a parameter between 0 and 1. The metric + thus works by interpolating between two sets of latent points, and measuring the similarity between the generated + images. The expectation is approximated by sampling :math:`z_1` and :math:`z_2` from the generator, and averaging + the calculated distanced. The similarity metric :math:`D` is by default the `LPIPS`_ metric, but can be changed by + setting the `sim_net` argument. + + The provided generator model must have a `sample` method with signature `sample(num_samples: int) -> Tensor` where + the returned tensor has shape `(num_samples, z_size)`. If the generator is conditional, it must also have a + `num_classes` attribute. The `forward` method of the generator must have signature `forward(z: Tensor) -> Tensor` + if `conditional=False`, and `forward(z: Tensor, labels: Tensor) -> Tensor` if `conditional=True`. The returned + tensor should have shape `(num_samples, C, H, W)` and be scaled to the range [0, 255]. + + .. note:: using this metric with the default feature extractor requires that ``torchvision`` is installed. + Either install as ``pip install torchmetrics[image]`` or ``pip install torchvision`` + + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``generator`` (:class:`~torch.nn.Module`): Generator model, with specific requirements. See above. + + As output of `forward` and `compute` the metric returns the following output + + - ``ppl_mean`` (:class:`~torch.Tensor`): float scalar tensor with mean PPL value over distances + - ``ppl_std`` (:class:`~torch.Tensor`): float scalar tensor with std PPL value over distances + - ``ppl_raw`` (:class:`~torch.Tensor`): float scalar tensor with raw PPL distances + + Args: + num_samples: Number of samples to use for the PPL computation. + conditional: Whether the generator is conditional or not (i.e. whether it takes labels as input). + batch_size: Batch size to use for the PPL computation. + interpolation_method: Interpolation method to use. Choose from 'lerp', 'slerp_any', 'slerp_unit'. + epsilon: Spacing between the points on the path between latent points. + resize: Resize images to this size before computing the similarity between generated images. + lower_discard: Lower quantile to discard from the distances, before computing the mean and standard deviation. + upper_discard: Upper quantile to discard from the distances, before computing the mean and standard deviation. + sim_net: Similarity network to use. Can be a `nn.Module` or one of 'alex', 'vgg', 'squeeze', where the three + latter options correspond to the pretrained networks from the `LPIPS`_ paper. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ModuleNotFoundError: + If ``torch-fidelity`` is not installed. + ValueError: + If ``num_samples`` is not a positive integer. + ValueError: + If `conditional` is not a boolean. + ValueError: + If ``batch_size`` is not a positive integer. + ValueError: + If ``interpolation_method`` is not one of 'lerp', 'slerp_any', 'slerp_unit'. + ValueError: + If ``epsilon`` is not a positive float. + ValueError: + If ``resize`` is not a positive integer. + ValueError: + If ``lower_discard`` is not a float between 0 and 1 or None. + ValueError: + If ``upper_discard`` is not a float between 0 and 1 or None. + + Example:: + >>> from torchmetrics.image import PerceptualPathLength + >>> import torch + >>> _ = torch.manual_seed(42) + >>> class DummyGenerator(torch.nn.Module): + ... def __init__(self, z_size) -> None: + ... super().__init__() + ... self.z_size = z_size + ... self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid()) + ... def forward(self, z): + ... return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1) + ... def sample(self, num_samples): + ... return torch.randn(num_samples, self.z_size) + >>> generator = DummyGenerator(2) + >>> ppl = PerceptualPathLength(num_samples=10) + >>> ppl(generator) # doctest: +SKIP + (tensor(0.2371), + tensor(0.1763), + tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921])) + + """ + + def __init__( + self, + num_samples: int = 10_000, + conditional: bool = False, + batch_size: int = 128, + interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp", + epsilon: float = 1e-4, + resize: Optional[int] = 64, + lower_discard: Optional[float] = 0.01, + upper_discard: Optional[float] = 0.99, + sim_net: Union[nn.Module, Literal["alex", "vgg", "squeeze"]] = "vgg", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError( + "Metric `PerceptualPathLength` requires torchvision which is not installed." + "Install with `pip install torchvision` or `pip install torchmetrics[image]`" + ) + _perceptual_path_length_validate_arguments( + num_samples, conditional, batch_size, interpolation_method, epsilon, resize, lower_discard, upper_discard + ) + self.num_samples = num_samples + self.conditional = conditional + self.batch_size = batch_size + self.interpolation_method = interpolation_method + self.epsilon = epsilon + self.resize = resize + self.lower_discard = lower_discard + self.upper_discard = upper_discard + + if isinstance(sim_net, nn.Module): + self.net = sim_net + elif sim_net in ["alex", "vgg", "squeeze"]: + self.net = _LPIPS(pretrained=True, net=sim_net, resize=resize) + else: + raise ValueError(f"sim_net must be a nn.Module or one of 'alex', 'vgg', 'squeeze', got {sim_net}") + + def update(self, generator: _GeneratorType) -> None: + """Update the generator model.""" + _validate_generator_model(generator, self.conditional) + self.generator = generator + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + """Compute the perceptual path length.""" + return perceptual_path_length( + generator=self.generator, + num_samples=self.num_samples, + conditional=self.conditional, + interpolation_method=self.interpolation_method, + epsilon=self.epsilon, + resize=self.resize, + lower_discard=self.lower_discard, + upper_discard=self.upper_discard, + sim_net=self.net, + device=self.device, + ) diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py new file mode 100644 index 00000000000..8535f74a524 --- /dev/null +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -0,0 +1,203 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from operator import attrgetter + +import pytest +import torch +import torch_fidelity +from torch import nn +from torch_fidelity.sample_similarity_lpips import SampleSimilarityLPIPS +from torch_fidelity.utils import batch_interp +from torchmetrics.functional.image.lpips import _LPIPS +from torchmetrics.functional.image.perceptual_path_length import _interpolate, perceptual_path_length +from torchmetrics.image.perceptual_path_length import PerceptualPathLength +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE + +from unittests.helpers import seed_all + +seed_all(42) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.parametrize("interpolation_method", ["lerp", "slerp_any", "slerp_unit"]) +def test_interpolation_methods(interpolation_method): + """Test that interpolation method works as expected.""" + latent1 = torch.randn(100, 25) + latent2 = torch.randn(100, 25) + + res1 = _interpolate(latent1, latent2, 1e-4, interpolation_method) + res2 = batch_interp(latent1, latent2, 1e-4, interpolation_method) + assert torch.allclose(res1, res2) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +def test_sim_net(): + """Check that the similiarity network is the same as the one used in torch_fidelity.""" + compare = SampleSimilarityLPIPS("sample_similarity", resize=64) + simnet = _LPIPS(net="vgg", resize=64) + + # check that the weights are the same + for name, weight in compare.named_parameters(): + getter = attrgetter(name) + weight2 = getter(simnet) + assert torch.allclose(weight, weight2) + + img1 = torch.rand(1, 3, 64, 64) + img2 = torch.rand(1, 3, 64, 64) + + # note that by default the two networks expect different scaling of the images + out = compare(255 * img1, 255 * img2) + out2 = simnet(2 * img1 - 1, 2 * img2 - 1) + + assert torch.allclose(out, out2) + + +class DummyGenerator(torch.nn.Module): + """From https://github.com/toshas/torch-fidelity/blob/master/examples/sngan_cifar10.py.""" + + def __init__(self, z_size) -> None: + super().__init__() + self.z_size = z_size + self.model = torch.nn.Sequential( + torch.nn.ConvTranspose2d(z_size, 512, 4, stride=1), + torch.nn.BatchNorm2d(512), + torch.nn.ReLU(), + torch.nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1, 1)), + torch.nn.BatchNorm2d(256), + torch.nn.ReLU(), + torch.nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1, 1)), + torch.nn.BatchNorm2d(128), + torch.nn.ReLU(), + torch.nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1, 1)), + torch.nn.BatchNorm2d(64), + torch.nn.ReLU(), + torch.nn.ConvTranspose2d(64, 3, 3, stride=1, padding=(1, 1)), + torch.nn.Tanh(), + ) + + def forward(self, z): + """Generate images from latent vectors.""" + fake = self.model(z.view(-1, self.z_size, 1, 1)) + if not self.training: + fake = 255 * (fake.clamp(-1, 1) * 0.5 + 0.5) + fake = fake.to(torch.uint8) + return fake + + def sample(self, num_samples): + """Sample latent vectors.""" + return torch.randn(num_samples, self.z_size) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.parametrize( + ("argument", "match"), + [ + ({"num_samples": 0}, "Argument `num_samples` must be a positive integer, but got 0."), + ({"conditional": 2}, "Argument `conditional` must be a boolean, but got 2."), + ({"batch_size": 0}, "Argument `batch_size` must be a positive integer, but got 0."), + ({"interpolation_method": "wrong"}, "Argument `interpolation_method` must be one of.*"), + ({"epsilon": 0}, "Argument `epsilon` must be a positive float, but got 0."), + ({"resize": 0}, "Argument `resize` must be a positive integer or `None`, but got 0."), + ({"lower_discard": -1}, "Argument `lower_discard` must be a float between 0 and 1 or `None`, but got -1"), + ({"upper_discard": 2}, "Argument `upper_discard` must be a float between 0 and 1 or `None`, but got 2"), + ], +) +def test_raises_error_on_wrong_arguments(argument, match): + """Test that appropriate errors are raised on wrong arguments.""" + with pytest.raises(ValueError, match=match): + perceptual_path_length(DummyGenerator(128), **argument) + + with pytest.raises(ValueError, match=match): + PerceptualPathLength(**argument) + + +class _WrongGenerator1(nn.Module): + pass + + +class _WrongGenerator2(nn.Module): + sample = 1 + + +class _WrongGenerator3(nn.Module): + def sample(self, n): + return torch.randn(n, 2) + + +class _WrongGenerator4(nn.Module): + def sample(self, n): + return torch.randn(n, 2) + + @property + def num_classes(self): + return [10, 10] + + +@pytest.mark.parametrize( + ("generator", "errortype", "match"), + [ + (_WrongGenerator1(), NotImplementedError, "The generator must have a `sample` method.*"), + (_WrongGenerator2(), ValueError, "The generator's `sample` method must be callable."), + ( + _WrongGenerator3(), + AttributeError, + "The generator must have a `num_classes` attribute when `conditional=True`.", + ), + ( + _WrongGenerator4(), + ValueError, + "The generator's `num_classes` attribute must be an integer when `conditional=True`.", + ), + ], +) +def test_raises_error_on_wrong_generator(generator, errortype, match): + """Test that appropriate errors are raised on wrong generator.""" + with pytest.raises(errortype, match=match): + perceptual_path_length(generator, conditional=True) + + ppl = PerceptualPathLength(conditional=True) + with pytest.raises(errortype, match=match): + ppl.update(generator=generator) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_compare(): + """Test against torch_fidelity. + + Because it is a sample based metric, the results are not deterministic. Thus we need a large amount of samples to + even get close to the reference value. Even then we are going to allow a 6% deviation on the mean and 6% deviation + on the standard deviation. + + """ + generator = DummyGenerator(128) + + compare = torch_fidelity.calculate_metrics( + input1=torch_fidelity.GenerativeModelModuleWrapper(generator, 128, "normal", 10), + input1_model_num_samples=50000, + ppl=True, + ppl_reduction="none", + input_model_num_classes=0, + ppl_discard_percentile_lower=None, + ppl_discard_percentile_higher=None, + ) + compare = torch.tensor(compare["perceptual_path_length_raw"]) + + result = perceptual_path_length( + generator, num_samples=50000, conditional=False, lower_discard=None, upper_discard=None, device="cuda" + ) + result = result[-1].cpu() + + assert 0.94 * result.mean() <= compare.mean() <= 1.06 * result.mean() + assert 0.94 * result.std() <= compare.std() <= 1.06 * result.std()