-
Notifications
You must be signed in to change notification settings - Fork 411
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'bugfix/broken_cloning' of https://github.com/PyTorchLig…
…htning/metrics into bugfix/broken_cloning
- Loading branch information
Showing
12 changed files
with
601 additions
and
258 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
.. customcarditem:: | ||
:header: Total Variation (TV) | ||
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg | ||
:tags: Image | ||
|
||
.. include:: ../links.rst | ||
|
||
#################### | ||
Total Variation (TV) | ||
#################### | ||
|
||
Module Interface | ||
________________ | ||
|
||
.. autoclass:: torchmetrics.TotalVariation | ||
:noindex: | ||
|
||
Functional Interface | ||
____________________ | ||
|
||
.. autofunction:: torchmetrics.functional.total_variation | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
scikit-image>0.17.1 | ||
kornia | ||
pytorch-msssim==0.2.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = "0.10.0" | ||
__version__ = "0.11.0dev" | ||
__author__ = "Lightning-AI et al." | ||
__author_email__ = "[email protected]" | ||
__license__ = "Apache-2.0" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Tuple | ||
|
||
from torch import Tensor | ||
from typing_extensions import Literal | ||
|
||
|
||
def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]: | ||
"""Computes total variation statistics on current batch.""" | ||
if img.ndim != 4: | ||
raise RuntimeError(f"Expected input `img` to be an 4D tensor, but got {img.shape}") | ||
diff1 = img[..., 1:, :] - img[..., :-1, :] | ||
diff2 = img[..., :, 1:] - img[..., :, :-1] | ||
|
||
res1 = diff1.abs().sum([1, 2, 3]) | ||
res2 = diff2.abs().sum([1, 2, 3]) | ||
score = res1 + res2 | ||
return score, img.shape[0] | ||
|
||
|
||
def _total_variation_compute( | ||
score: Tensor, num_elements: int, reduction: Literal["mean", "sum", "none", None] | ||
) -> Tensor: | ||
"""Compute final total variation score.""" | ||
if reduction == "mean": | ||
return score.sum() / num_elements | ||
elif reduction == "sum": | ||
return score.sum() | ||
elif reduction is None or reduction == "none": | ||
return score | ||
else: | ||
raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None") | ||
|
||
|
||
def total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor: | ||
"""Computes total variation loss. | ||
Args: | ||
img: A `Tensor` of shape `(N, C, H, W)` consisting of images | ||
reduction: a method to reduce metric score over samples. | ||
- ``'mean'``: takes the mean over samples | ||
- ``'sum'``: takes the sum over samples | ||
- ``None`` or ``'none'``: return the score per sample | ||
Returns: | ||
A loss scalar value containing the total variation | ||
Raises: | ||
ValueError: | ||
If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None`` | ||
RuntimeError: | ||
If ``img`` is not 4D tensor | ||
Example: | ||
>>> import torch | ||
>>> from torchmetrics.functional import total_variation | ||
>>> _ = torch.manual_seed(42) | ||
>>> img = torch.rand(5, 3, 28, 28) | ||
>>> total_variation(img) | ||
tensor(7546.8018) | ||
""" | ||
# code adapted from: | ||
# from kornia.losses import total_variation as kornia_total_variation | ||
score, num_elements = _total_variation_update(img) | ||
return _total_variation_compute(score, num_elements, reduction) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any | ||
|
||
import torch | ||
from torch import Tensor, tensor | ||
from typing_extensions import Literal | ||
|
||
from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update | ||
from torchmetrics.metric import Metric | ||
from torchmetrics.utilities.data import dim_zero_cat | ||
|
||
|
||
class TotalVariation(Metric): | ||
"""Computes Total Variation loss (`TV`_). | ||
Args: | ||
reduction: a method to reduce metric score over samples | ||
- ``'mean'``: takes the mean over samples | ||
- ``'sum'``: takes the sum over samples | ||
- ``None`` or ``'none'``: return the score per sample | ||
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. | ||
Raises: | ||
ValueError: | ||
If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None`` | ||
Example: | ||
>>> import torch | ||
>>> from torchmetrics import TotalVariation | ||
>>> _ = torch.manual_seed(42) | ||
>>> tv = TotalVariation() | ||
>>> img = torch.rand(5, 3, 28, 28) | ||
>>> tv(img) | ||
tensor(7546.8018) | ||
""" | ||
|
||
full_state_update: bool = False | ||
is_differentiable: bool = True | ||
higher_is_better: bool = False | ||
|
||
def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None: | ||
super().__init__(**kwargs) | ||
if reduction is not None and reduction not in ("sum", "mean", "none"): | ||
raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None") | ||
self.reduction = reduction | ||
|
||
if self.reduction is None or self.reduction == "none": | ||
self.add_state("score", default=[], dist_reduce_fx="cat") | ||
else: | ||
self.add_state("score", default=tensor(0, dtype=torch.float), dist_reduce_fx="sum") | ||
self.add_state("num_elements", default=tensor(0, dtype=torch.int), dist_reduce_fx="sum") | ||
|
||
def update(self, img: Tensor) -> None: # type: ignore | ||
"""Update current score with batch of input images. | ||
Args: | ||
img: A `Tensor` of shape `(N, C, H, W)` consisting of images | ||
""" | ||
score, num_elements = _total_variation_update(img) | ||
if self.reduction is None or self.reduction == "none": | ||
self.score.append(score) | ||
else: | ||
self.score += score.sum() | ||
self.num_elements += num_elements | ||
|
||
def compute(self) -> Tensor: | ||
"""Compute final total variation.""" | ||
if self.reduction is None or self.reduction == "none": | ||
score = dim_zero_cat(self.score) | ||
else: | ||
score = self.score | ||
return _total_variation_compute(score, self.num_elements, self.reduction) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
from kornia.losses import total_variation as kornia_total_variation | ||
|
||
from torchmetrics.functional.image.tv import total_variation | ||
from torchmetrics.image.tv import TotalVariation | ||
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8 | ||
from unittests.helpers import seed_all | ||
from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester | ||
|
||
seed_all(42) | ||
|
||
|
||
# add extra argument to make the metric and reference fit into the testing framework | ||
class TotalVariationTester(TotalVariation): | ||
def update(self, img, *args): | ||
super().update(img=img) | ||
|
||
|
||
def total_variaion_tester(preds, target, reduction="mean"): | ||
return total_variation(preds, reduction) | ||
|
||
|
||
def total_variation_kornia_tester(preds, target, reduction): | ||
score = kornia_total_variation(preds).sum(-1) | ||
if reduction == "sum": | ||
return score.sum() | ||
elif reduction == "mean": | ||
return score.mean() | ||
return score | ||
|
||
|
||
# define inputs | ||
Input = namedtuple("Input", ["preds", "target"]) | ||
|
||
_inputs = [] | ||
for size, channel, dtype in [ | ||
(12, 3, torch.float), | ||
(13, 3, torch.float32), | ||
(14, 3, torch.double), | ||
(15, 3, torch.float64), | ||
]: | ||
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) | ||
target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) | ||
_inputs.append(Input(preds=preds, target=target)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds, target", | ||
[(i.preds, i.target) for i in _inputs], | ||
) | ||
@pytest.mark.parametrize("reduction", ["sum", "mean", None]) | ||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="Kornia used as reference requires min PT version") | ||
class TestTotalVariation(MetricTester): | ||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
def test_total_variation(self, preds, target, reduction, ddp, dist_sync_on_step): | ||
"""Test modular implementation.""" | ||
if reduction is None and ddp: | ||
pytest.skip("reduction=None and ddp=True runs out of memory on CI hardware, but it does work") | ||
self.run_class_metric_test( | ||
ddp, | ||
preds, | ||
target, | ||
TotalVariationTester, | ||
partial(total_variation_kornia_tester, reduction=reduction), | ||
dist_sync_on_step, | ||
metric_args={"reduction": reduction}, | ||
) | ||
|
||
def test_total_variation_functional(self, preds, target, reduction): | ||
"""Test for functional implementation.""" | ||
self.run_functional_metric_test( | ||
preds, | ||
target, | ||
total_variaion_tester, | ||
partial(total_variation_kornia_tester, reduction=reduction), | ||
metric_args={"reduction": reduction}, | ||
) | ||
|
||
@pytest.mark.skipif( | ||
not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6" | ||
) | ||
def test_sam_half_cpu(self, preds, target, reduction): | ||
"""Test for half precision on CPU.""" | ||
self.run_precision_test_cpu( | ||
preds, | ||
target, | ||
TotalVariationTester, | ||
total_variaion_tester, | ||
) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") | ||
def test_sam_half_gpu(self, preds, target, reduction): | ||
"""Test for half precision on GPU.""" | ||
self.run_precision_test_gpu(preds, target, TotalVariationTester, total_variaion_tester) | ||
|
||
|
||
def test_correct_args(): | ||
"""that that arguments have the right type and sizes.""" | ||
with pytest.raises(ValueError, match="Expected argument `reduction`.*"): | ||
_ = TotalVariation(reduction="diff") | ||
|
||
with pytest.raises(RuntimeError, match="Expected input `img` to.*"): | ||
_ = total_variation(torch.randn(1, 2, 3)) |