From a1ab373350e80971fc0a5f5e482bdb129557243f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 25 Nov 2023 15:41:09 +0100 Subject: [PATCH] Fix device and dtype for `LPIPS` functional metric (#2234) (cherry picked from commit a57dfae0c42702a0a537a681fe18c0e2b8622166) --- CHANGELOG.md | 3 +++ src/torchmetrics/functional/image/lpips.py | 2 +- tests/unittests/image/test_lpips.py | 11 +++++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82d28b25dbe..21eb389299b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed numerical stability issue in `UniversalImageQualityIndex` metric ([#2222](https://github.com/Lightning-AI/torchmetrics/pull/2222)) +- Fix device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234)) + + ## [1.2.0] - 2023-09-22 ### Added diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 1c6e1b58906..63a708969c0 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -426,6 +426,6 @@ def learned_perceptual_image_patch_similarity( tensor(0.1008, grad_fn=) """ - net = _NoTrainLpips(net=net_type) + net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype) loss, total = _lpips_update(img1, img2, net, normalize) return _lpips_compute(loss.sum(), total, reduction) diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index 8c7170626f5..3ca19e7120d 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -18,6 +18,7 @@ import torch from lpips import LPIPS as LPIPS_reference # noqa: N811 from torch import Tensor +from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_GREATER_EQUAL_1_9 @@ -68,6 +69,16 @@ def test_lpips(self, net_type, ddp): metric_args={"net_type": net_type}, ) + def test_lpips_functional(self): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=_inputs.img1, + target=_inputs.img2, + metric_functional=learned_perceptual_image_patch_similarity, + reference_metric=partial(_compare_fn, net_type="alex"), + metric_args={"net_type": "alex"}, + ) + def test_lpips_differentiability(self): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" self.run_differentiability_test(