Skip to content

Commit

Permalink
New metric: Perceptual Path Length (#1939)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 9, 2023
1 parent 5d977e9 commit 628ee1c
Show file tree
Hide file tree
Showing 13 changed files with 726 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for evaluating `"segm"` and `"bbox"` detection in `MeanAveragePrecision` at the same time ([#1928](https://github.com/Lightning-AI/torchmetrics/pull/1928))


- Added `PerceptualPathLength` to image package ([#1939](https://github.com/Lightning-AI/torchmetrics/pull/1939))


- Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ ________________
.. autoclass:: torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.learned_perceptual_image_patch_similarity
:noindex:
23 changes: 23 additions & 0 deletions docs/source/image/perceptual_path_length.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Perceptual Path Length
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

############################
Perceptual Path Length (PPL)
############################

Module Interface
________________

.. autoclass:: torchmetrics.image.perceptual_path_length.PerceptualPathLength
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.perceptual_path_length.perceptual_path_length
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220
.. _Fleiss kappa: https://en.wikipedia.org/wiki/Fleiss%27_kappa
.. _VIF: https://ieeexplore.ieee.org/abstract/document/1576816
.. _PPL : https://arxiv.org/pdf/1812.04948
.. _CIOU: https://arxiv.org/abs/2005.03572
.. _DIOU: https://arxiv.org/abs/1911.08287v1
.. _GIOU: https://arxiv.org/abs/1902.09630
Expand Down
2 changes: 1 addition & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

scipy >1.0.0, <1.11.0
torchvision >=0.8, <=0.15.2
torch-fidelity <=0.3.0
torch-fidelity <=0.4.0 # bumping to alow install version from master, now used in testing
lpips <=0.1.4
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ kornia >=0.6.7, <0.7.1
pytorch-msssim ==1.0.0
sewar >=0.4.4, <=0.4.6
numpy <1.25.0
torch-fidelity @ git+https://github.com/toshas/torch-fidelity@master
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
from torchmetrics.functional.image.perceptual_path_length import perceptual_path_length
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio_with_blocked_effect
from torchmetrics.functional.image.rase import relative_average_spectral_error
Expand Down Expand Up @@ -41,4 +43,6 @@
"total_variation",
"universal_image_quality_index",
"visual_information_fidelity",
"learned_perceptual_image_patch_similarity",
"perceptual_path_length",
]
39 changes: 26 additions & 13 deletions src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def spatial_average(in_tens: Tensor, keepdim: bool = True) -> Tensor:
return in_tens.mean([2, 3], keepdim=keepdim)


def upsam(in_tens: Tensor, out_hw: Tuple[int, int] = (64, 64)) -> Tensor:
def upsam(in_tens: Tensor, out_hw: Tuple[int, ...] = (64, 64)) -> Tensor:
"""Upsample input with bilinear interpolation."""
return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens)

Expand All @@ -197,6 +197,13 @@ def normalize_tensor(in_feat: Tensor, eps: float = 1e-10) -> Tensor:
return in_feat / (norm_factor + eps)


def resize_tensor(x: Tensor, size: int = 64) -> Tensor:
"""https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/sample_similarity_lpips.py#L127C22-L132."""
if x.shape[-1] > size and x.shape[-2] > size:
return torch.nn.functional.interpolate(x, (size, size), mode="area")
return torch.nn.functional.interpolate(x, (size, size), mode="bilinear", align_corners=False)


class ScalingLayer(nn.Module):
"""Scaling layer."""

Expand Down Expand Up @@ -238,6 +245,7 @@ def __init__(
use_dropout: bool = True,
model_path: Optional[str] = None,
eval_mode: bool = True,
resize: Optional[int] = None,
) -> None:
"""Initializes a perceptual loss torch.nn.Module.
Expand All @@ -250,6 +258,7 @@ def __init__(
use_dropout: If dropout layers should be added
model_path: Model path to load pretained models from
eval_mode: If network should be in evaluation mode
resize: If input should be resized to this size
"""
super().__init__()
Expand All @@ -258,6 +267,7 @@ def __init__(
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.resize = resize
self.scaling_layer = ScalingLayer()

if self.pnet_type in ["vgg", "vgg16"]:
Expand Down Expand Up @@ -298,31 +308,34 @@ def __init__(

def forward(
self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False
) -> Union[int, Tuple[int, List[Tensor]]]:
) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]:
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1

# v0.0 - original release had a bug, where input was not scaled
# normalize input
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)

# resize input if needed
if self.resize is not None:
in0_input = resize_tensor(in0_input, size=self.resize)
in1_input = resize_tensor(in1_input, size=self.resize)

outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}

for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

if self.spatial:
res = [
upsam(self.lins[kk](diffs[kk]), out_hw=in0.shape[2:]) for kk in range(self.L) # type: ignore[arg-type]
]
else:
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]

val = 0
for layer in range(self.L):
val += res[layer] # type: ignore[assignment]
res = []
for kk in range(self.L):
if self.spatial:
res.append(upsam(self.lins[kk](diffs[kk]), out_hw=tuple(in0.shape[2:])))
else:
res.append(spatial_average(self.lins[kk](diffs[kk]), keepdim=True))

val: Tensor = sum(res) # type: ignore[assignment]
if retperlayer:
return (val, res)
return val
Expand Down
Loading

0 comments on commit 628ee1c

Please sign in to comment.