Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix backprop in LPIPS #2326

Merged
merged 14 commits into from
Feb 12, 2024
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed how backprop is handled in `LPIPS` metric ([#2326](https://github.com/Lightning-AI/torchmetrics/pull/2326))


- Fixed `MultitaskWrapper` not being able to be logged in lightning when using metric collections ([#2349](https://github.com/Lightning-AI/torchmetrics/pull/2349))


Expand Down
8 changes: 6 additions & 2 deletions src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
net: Indicate backbone to use, choose between ['alex','vgg','squeeze']
spatial: If input should be spatial averaged
pnet_rand: If backbone should be random or use imagenet pre-trained weights
pnet_tune: If backprop should be enabled
pnet_tune: If backprop should be enabled for both backbone and linear layers
use_dropout: If dropout layers should be added
model_path: Model path to load pretained models from
eval_mode: If network should be in evaluation mode
Expand Down Expand Up @@ -327,6 +327,10 @@ def __init__(
if eval_mode:
self.eval()

if not self.pnet_tune:
for param in self.parameters():
param.requires_grad = False

def forward(
self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False
) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]:
Expand Down Expand Up @@ -423,7 +427,7 @@ def learned_perceptual_image_patch_similarity(
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze')
tensor(0.1008, grad_fn=<DivBackward0>)
tensor(0.1008)

"""
net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> lpips(img1, img2)
tensor(0.1046, grad_fn=<SqueezeBackward0>)
tensor(0.1046)

"""

Expand Down
12 changes: 12 additions & 0 deletions tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,15 @@ def test_error_on_wrong_update(inp1, inp2):
metric = LearnedPerceptualImagePatchSimilarity()
with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"):
metric(inp1, inp2)


def test_check_for_backprop():
"""Check that by default the metric supports propagation of gradients, but does not update its parameters."""
metric = LearnedPerceptualImagePatchSimilarity()
assert not metric.net.lin0.model[1].weight.requires_grad
preds, target = _inputs.img1[0], _inputs.img2[0]
preds.requires_grad = True
loss = metric(preds, target)
assert loss.requires_grad
loss.backward()
assert metric.net.lin0.model[1].weight.grad is None
Loading