Skip to content

Commit

Permalink
Fix backprop in LPIPS (#2326)
Browse files Browse the repository at this point in the history
* fixes + tests
* fix doctests

(cherry picked from commit 18b181d)
  • Loading branch information
SkafteNicki authored and Borda committed Feb 12, 2024
1 parent 5af8ecc commit 8a98e46
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed how backprop is handled in `LPIPS` metric ([#2326](https://github.com/Lightning-AI/torchmetrics/pull/2326))


- Fixed `MultitaskWrapper` not being able to be logged in lightning when using metric collections ([#2349](https://github.com/Lightning-AI/torchmetrics/pull/2349))


Expand Down
8 changes: 6 additions & 2 deletions src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
net: Indicate backbone to use, choose between ['alex','vgg','squeeze']
spatial: If input should be spatial averaged
pnet_rand: If backbone should be random or use imagenet pre-trained weights
pnet_tune: If backprop should be enabled
pnet_tune: If backprop should be enabled for both backbone and linear layers
use_dropout: If dropout layers should be added
model_path: Model path to load pretained models from
eval_mode: If network should be in evaluation mode
Expand Down Expand Up @@ -327,6 +327,10 @@ def __init__(
if eval_mode:
self.eval()

if not self.pnet_tune:
for param in self.parameters():
param.requires_grad = False

def forward(
self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False
) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]:
Expand Down Expand Up @@ -423,7 +427,7 @@ def learned_perceptual_image_patch_similarity(
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze')
tensor(0.1008, grad_fn=<DivBackward0>)
tensor(0.1008)
"""
net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> lpips(img1, img2)
tensor(0.1046, grad_fn=<SqueezeBackward0>)
tensor(0.1046)
"""

Expand Down
12 changes: 12 additions & 0 deletions tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,15 @@ def test_error_on_wrong_update(inp1, inp2):
metric = LearnedPerceptualImagePatchSimilarity()
with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"):
metric(inp1, inp2)


def test_check_for_backprop():
"""Check that by default the metric supports propagation of gradients, but does not update its parameters."""
metric = LearnedPerceptualImagePatchSimilarity()
assert not metric.net.lin0.model[1].weight.requires_grad
preds, target = _inputs.img1[0], _inputs.img2[0]
preds.requires_grad = True
loss = metric(preds, target)
assert loss.requires_grad
loss.backward()
assert metric.net.lin0.model[1].weight.grad is None

0 comments on commit 8a98e46

Please sign in to comment.