Skip to content

Commit

Permalink
Add normalize arg to LPIPS metric (#1216)
Browse files Browse the repository at this point in the history
* add normalize arg
* changelog
  • Loading branch information
SkafteNicki authored Sep 13, 2022
1 parent ebdaea7 commit 8086097
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `Perplexity` metric ([#922](https://github.com/PyTorchLightning/metrics/pull/922))


- Added argument `normalize` to `LPIPS` metric ([#1216](https://github.com/Lightning-AI/metrics/pull/1216))


### Changed

- Classification refactor (
Expand Down
29 changes: 18 additions & 11 deletions src/torchmetrics/image/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def train(self, mode: bool) -> "NoTrainLpips":
return super().train(False)


def _valid_img(img: Tensor) -> bool:
def _valid_img(img: Tensor, normalize: bool) -> bool:
"""check that input is a valid image to the network."""
return img.ndim == 4 and img.shape[1] == 3 and img.min() >= -1.0 and img.max() <= 1.0
value_check = img.max() <= 1.0 and img.min() >= 0.0 if normalize else img.min() >= -1
return img.ndim == 4 and img.shape[1] == 3 and value_check


class LearnedPerceptualImagePatchSimilarity(Metric):
Expand All @@ -48,8 +49,8 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
pre-defined network. This measure has been shown to match human perseption well. A low LPIPS score means that
image patches are perceptual similar.
Both input image patches are expected to have shape `[N, 3, H, W]` and be normalized to the [-1,1]
range. The minimum size of `H, W` depends on the chosen backbone (see `net_type` arg).
Both input image patches are expected to have shape `[N, 3, H, W]`.
The minimum size of `H, W` depends on the chosen backbone (see `net_type` arg).
.. note:: using this metrics requires you to have ``lpips`` package installed. Either install
as ``pip install torchmetrics[image]`` or ``pip install lpips``
Expand All @@ -60,6 +61,8 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
Args:
net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'`
reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'` or `'mean'`.
normalize: by default this is ``False`` meaning that the input is expected to be in the [-1,1] range. If set
to ``True`` will instead expect input to be in the ``[0,1]`` range.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand Down Expand Up @@ -95,6 +98,7 @@ def __init__(
self,
net_type: str = "alex",
reduction: Literal["sum", "mean"] = "mean",
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -115,6 +119,10 @@ def __init__(
raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
self.reduction = reduction

if not isinstance(normalize, bool):
raise ValueError(f"Argument `normalize` should be an bool but got {normalize}")
self.normalize = normalize

self.add_state("sum_scores", torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum")

Expand All @@ -125,15 +133,14 @@ def update(self, img1: Tensor, img2: Tensor) -> None: # type: ignore
img1: tensor with images of shape ``[N, 3, H, W]``
img2: tensor with images of shape ``[N, 3, H, W]``
"""
if not (_valid_img(img1) and _valid_img(img2)):
if not (_valid_img(img1, self.normalize) and _valid_img(img2, self.normalize)):
raise ValueError(
"Expected both input arguments to be normalized tensors (all values in range [-1,1])"
f" and to have shape [N, 3, H, W] but `img1` have shape {img1.shape} with values in"
f" range {[img1.min(), img1.max()]} and `img2` have shape {img2.shape} with value"
f" in range {[img2.min(), img2.max()]}"
"Expected both input arguments to be normalized tensors with shape [N, 3, H, W]."
f" Got input with shape {img1.shape} and {img2.shape} and values in range"
f" {[img1.min(), img1.max()]} and {[img2.min(), img2.max()]} when all values are"
f"expected to be in the {[0,1] if self.normalize else [-1,1]} range."
)

loss = self.net(img1, img2).squeeze()
loss = self.net(img1, img2, normalize=self.normalize).squeeze()
self.sum_scores += loss.sum()
self.total += img1.shape[0]

Expand Down
19 changes: 10 additions & 9 deletions tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,49 +34,50 @@
)


def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, reduction: str = "mean") -> Tensor:
def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, normalize: bool, reduction: str = "mean") -> Tensor:
"""comparison function for tm implementation."""
ref = LPIPS_reference(net=net_type)
res = ref(img1, img2).detach().cpu().numpy()
res = ref(img1, img2, normalize=normalize).detach().cpu().numpy()
if reduction == "mean":
return res.mean()
return res.sum()


@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
@pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"])
class TestLPIPS(MetricTester):
atol: float = 1e-6

@pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"])
@pytest.mark.parametrize("normalize", [False, True])
@pytest.mark.parametrize("ddp", [True, False])
def test_lpips(self, net_type, ddp):
def test_lpips(self, net_type, normalize, ddp):
"""test modular implementation for correctness."""
self.run_class_metric_test(
ddp=ddp,
preds=_inputs.img1,
target=_inputs.img2,
metric_class=LearnedPerceptualImagePatchSimilarity,
sk_metric=partial(_compare_fn, net_type=net_type),
sk_metric=partial(_compare_fn, net_type=net_type, normalize=normalize),
dist_sync_on_step=False,
check_scriptable=False,
check_state_dict=False,
metric_args={"net_type": net_type},
metric_args={"net_type": net_type, "normalize": normalize},
)

def test_lpips_differentiability(self, net_type):
def test_lpips_differentiability(self):
"""test for differentiability of LPIPS metric."""
self.run_differentiability_test(
preds=_inputs.img1, target=_inputs.img2, metric_module=LearnedPerceptualImagePatchSimilarity
)

# LPIPS half + cpu does not work due to missing support in torch.min
@pytest.mark.xfail(reason="LPIPS metric does not support cpu + half precision")
def test_lpips_half_cpu(self, net_type):
def test_lpips_half_cpu(self):
"""test for half + cpu support."""
self.run_precision_test_cpu(_inputs.img1, _inputs.img2, LearnedPerceptualImagePatchSimilarity)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_lpips_half_gpu(self, net_type):
def test_lpips_half_gpu(self):
"""test for half + gpu support."""
self.run_precision_test_gpu(_inputs.img1, _inputs.img2, LearnedPerceptualImagePatchSimilarity)

Expand Down

0 comments on commit 8086097

Please sign in to comment.