From 8a98e463282ca1fa8b5fccf7a77cf5c4edd395af Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 12 Feb 2024 13:40:36 +0100 Subject: [PATCH] Fix backprop in LPIPS (#2326) * fixes + tests * fix doctests (cherry picked from commit 18b181d522a53beaa3a1ae39ca6e2c8323958a68) --- CHANGELOG.md | 3 +++ src/torchmetrics/functional/image/lpips.py | 8 ++++++-- src/torchmetrics/image/lpip.py | 2 +- tests/unittests/image/test_lpips.py | 12 ++++++++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1377394d14f..3c1c8cbedd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed how backprop is handled in `LPIPS` metric ([#2326](https://github.com/Lightning-AI/torchmetrics/pull/2326)) + + - Fixed `MultitaskWrapper` not being able to be logged in lightning when using metric collections ([#2349](https://github.com/Lightning-AI/torchmetrics/pull/2349)) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 6f8d4b4a450..2ce1de9886d 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -275,7 +275,7 @@ def __init__( net: Indicate backbone to use, choose between ['alex','vgg','squeeze'] spatial: If input should be spatial averaged pnet_rand: If backbone should be random or use imagenet pre-trained weights - pnet_tune: If backprop should be enabled + pnet_tune: If backprop should be enabled for both backbone and linear layers use_dropout: If dropout layers should be added model_path: Model path to load pretained models from eval_mode: If network should be in evaluation mode @@ -327,6 +327,10 @@ def __init__( if eval_mode: self.eval() + if not self.pnet_tune: + for param in self.parameters(): + param.requires_grad = False + def forward( self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: @@ -423,7 +427,7 @@ def learned_perceptual_image_patch_similarity( >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze') - tensor(0.1008, grad_fn=) + tensor(0.1008) """ net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype) diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index 77b87f61344..c094003fa4f 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -86,7 +86,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric): >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> lpips(img1, img2) - tensor(0.1046, grad_fn=) + tensor(0.1046) """ diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index 3db11ba5d78..75182338960 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -133,3 +133,15 @@ def test_error_on_wrong_update(inp1, inp2): metric = LearnedPerceptualImagePatchSimilarity() with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"): metric(inp1, inp2) + + +def test_check_for_backprop(): + """Check that by default the metric supports propagation of gradients, but does not update its parameters.""" + metric = LearnedPerceptualImagePatchSimilarity() + assert not metric.net.lin0.model[1].weight.requires_grad + preds, target = _inputs.img1[0], _inputs.img2[0] + preds.requires_grad = True + loss = metric(preds, target) + assert loss.requires_grad + loss.backward() + assert metric.net.lin0.model[1].weight.grad is None