Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: Perceptual Path Length #1939

Merged
merged 40 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0d94f75
initial implementation
SkafteNicki Jul 25, 2023
aa16025
add more to module
SkafteNicki Jul 25, 2023
e126a88
changelog
SkafteNicki Jul 25, 2023
0e78bb2
add some docstrings
SkafteNicki Jul 25, 2023
38b35f5
add to doc pages
SkafteNicki Jul 25, 2023
1e0ecc0
more docs
SkafteNicki Jul 25, 2023
54be531
improve testing
SkafteNicki Jul 25, 2023
1d78d8c
Merge branch 'master' into newmetric/ppl
SkafteNicki Jul 31, 2023
ed8afaa
fix typing issues
SkafteNicki Jul 31, 2023
753aab6
fix docs build
SkafteNicki Jul 31, 2023
6b10f84
improve generator testing
SkafteNicki Jul 31, 2023
37c4de3
compatibility with older
SkafteNicki Jul 31, 2023
e12c9ea
improve testing
SkafteNicki Aug 2, 2023
c20f0d3
docstrings + doctests
SkafteNicki Aug 2, 2023
47b4ee6
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 2, 2023
1730ac6
fix
SkafteNicki Aug 3, 2023
d214170
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 3, 2023
8b6220e
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 3, 2023
f366b55
skip on missing import
SkafteNicki Aug 3, 2023
e38b846
Merge branch 'master' into newmetric/ppl
Borda Aug 3, 2023
8c4e2d2
Merge branch 'master' into newmetric/ppl
mergify[bot] Aug 4, 2023
426b4ae
Merge branch 'master' into newmetric/ppl
Borda Aug 4, 2023
e9b7a6d
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 7, 2023
62cde3e
Merge branch 'master' into newmetric/ppl
Borda Aug 7, 2023
650165d
Merge branch 'master' into newmetric/ppl
Borda Aug 7, 2023
49081a8
Merge branch 'master' into newmetric/ppl
Borda Aug 8, 2023
4037cc1
move requirement to tests
SkafteNicki Aug 8, 2023
1fe9475
fix link
SkafteNicki Aug 8, 2023
0d0f771
add resize functionality
SkafteNicki Aug 8, 2023
65a4a00
reformat to use own implementation of lpips
SkafteNicki Aug 8, 2023
a29d1d5
add tests
SkafteNicki Aug 8, 2023
d07a27d
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 8, 2023
6912d0c
Merge branch 'master' into newmetric/ppl
Borda Aug 8, 2023
fecf65c
req.
Borda Aug 8, 2023
0139533
Merge branch 'master' into newmetric/ppl
mergify[bot] Aug 8, 2023
dbcbf04
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 9, 2023
e2808e4
fix mypy
SkafteNicki Aug 9, 2023
715a360
skip on random
SkafteNicki Aug 9, 2023
b2d79a7
device placement
SkafteNicki Aug 9, 2023
e07de45
seed
SkafteNicki Aug 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for evaluating `"segm"` and `"bbox"` detection in `MeanAveragePrecision` at the same time ([#1928](https://github.com/Lightning-AI/torchmetrics/pull/1928))


- Added `PerceptualPathLength` to image package ([#1939](https://github.com/Lightning-AI/torchmetrics/pull/1939))


- Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ ________________
.. autoclass:: torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.learned_perceptual_image_patch_similarity
:noindex:
23 changes: 23 additions & 0 deletions docs/source/image/perceptual_path_length.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Perceptual Path Length
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

############################
Perceptual Path Length (PPL)
############################

Module Interface
________________

.. autoclass:: torchmetrics.image.perceptual_path_length.PerceptualPathLength
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.perceptual_path_length.perceptual_path_length
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220
.. _Fleiss kappa: https://en.wikipedia.org/wiki/Fleiss%27_kappa
.. _VIF: https://ieeexplore.ieee.org/abstract/document/1576816
.. _PPL : https://arxiv.org/pdf/1812.04948
.. _CIOU: https://arxiv.org/abs/2005.03572
.. _DIOU: https://arxiv.org/abs/1911.08287v1
.. _GIOU: https://arxiv.org/abs/1902.09630
Expand Down
2 changes: 1 addition & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

scipy >1.0.0, <1.11.0
torchvision >=0.8, <=0.15.2
torch-fidelity <=0.3.0
torch-fidelity <=0.4.0 # bumping to alow install version from master, now used in testing
lpips <=0.1.4
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ kornia >=0.6.7, <0.7.1
pytorch-msssim ==1.0.0
sewar >=0.4.4, <=0.4.6
numpy <1.25.0
torch-fidelity @ git+https://github.com/toshas/torch-fidelity@master
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
from torchmetrics.functional.image.perceptual_path_length import perceptual_path_length
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio_with_blocked_effect
from torchmetrics.functional.image.rase import relative_average_spectral_error
Expand Down Expand Up @@ -41,4 +43,6 @@
"total_variation",
"universal_image_quality_index",
"visual_information_fidelity",
"learned_perceptual_image_patch_similarity",
"perceptual_path_length",
]
39 changes: 26 additions & 13 deletions src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def spatial_average(in_tens: Tensor, keepdim: bool = True) -> Tensor:
return in_tens.mean([2, 3], keepdim=keepdim)


def upsam(in_tens: Tensor, out_hw: Tuple[int, int] = (64, 64)) -> Tensor:
def upsam(in_tens: Tensor, out_hw: Tuple[int, ...] = (64, 64)) -> Tensor:
"""Upsample input with bilinear interpolation."""
return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens)

Expand All @@ -197,6 +197,13 @@ def normalize_tensor(in_feat: Tensor, eps: float = 1e-10) -> Tensor:
return in_feat / (norm_factor + eps)


def resize_tensor(x: Tensor, size: int = 64) -> Tensor:
"""https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/sample_similarity_lpips.py#L127C22-L132."""
if x.shape[-1] > size and x.shape[-2] > size:
return torch.nn.functional.interpolate(x, (size, size), mode="area")
return torch.nn.functional.interpolate(x, (size, size), mode="bilinear", align_corners=False)


class ScalingLayer(nn.Module):
"""Scaling layer."""

Expand Down Expand Up @@ -238,6 +245,7 @@ def __init__(
use_dropout: bool = True,
model_path: Optional[str] = None,
eval_mode: bool = True,
resize: Optional[int] = None,
) -> None:
"""Initializes a perceptual loss torch.nn.Module.

Expand All @@ -250,6 +258,7 @@ def __init__(
use_dropout: If dropout layers should be added
model_path: Model path to load pretained models from
eval_mode: If network should be in evaluation mode
resize: If input should be resized to this size

"""
super().__init__()
Expand All @@ -258,6 +267,7 @@ def __init__(
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.resize = resize
self.scaling_layer = ScalingLayer()

if self.pnet_type in ["vgg", "vgg16"]:
Expand Down Expand Up @@ -298,31 +308,34 @@ def __init__(

def forward(
self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False
) -> Union[int, Tuple[int, List[Tensor]]]:
) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]:
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1

# v0.0 - original release had a bug, where input was not scaled
# normalize input
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)

# resize input if needed
if self.resize is not None:
in0_input = resize_tensor(in0_input, size=self.resize)
in1_input = resize_tensor(in1_input, size=self.resize)

outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}

for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

if self.spatial:
res = [
upsam(self.lins[kk](diffs[kk]), out_hw=in0.shape[2:]) for kk in range(self.L) # type: ignore[arg-type]
]
else:
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]

val = 0
for layer in range(self.L):
val += res[layer] # type: ignore[assignment]
res = []
for kk in range(self.L):
if self.spatial:
res.append(upsam(self.lins[kk](diffs[kk]), out_hw=tuple(in0.shape[2:])))
else:
res.append(spatial_average(self.lins[kk](diffs[kk]), keepdim=True))

val: Tensor = sum(res) # type: ignore[assignment]
if retperlayer:
return (val, res)
return val
Expand Down
Loading