Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new metrics: SAM #885

Merged
merged 36 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f0ea2e3
Added new image metric - SAM
vumichien Mar 11, 2022
ab82404
Added new image metric - SAM
vumichien Mar 11, 2022
925caa1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2022
74a3d31
revise commit indent
vumichien Mar 11, 2022
440632c
Merge remote-tracking branch 'origin/feature/sam' into feature/sam
vumichien Mar 11, 2022
02700c2
revise library
vumichien Mar 11, 2022
d4fb716
revise docs string
vumichien Mar 11, 2022
9a6d02a
Merge branch 'master' into feature/sam
Borda Mar 11, 2022
0c95581
update docs ang changelog
vumichien Mar 11, 2022
c066b2e
revise conflict with master
vumichien Mar 12, 2022
9c63be6
revise compute correctly spectral angle, change functional name
vumichien Mar 13, 2022
d5cdf3c
Apply suggestions from code review
Borda Mar 18, 2022
24c7d77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2022
5167ac4
Merge branch 'master' into feature/sam
mergify[bot] Mar 19, 2022
4c2d55a
wording
Borda Mar 19, 2022
9f1e17a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2022
cad085a
Merge branch 'master' into feature/sam
mergify[bot] Mar 19, 2022
ff6026d
add citation
vumichien Mar 20, 2022
197c362
Merge branch 'master' into feature/sam
mergify[bot] Mar 20, 2022
73096f2
fix move dim with lower version of pytorch
vumichien Mar 20, 2022
ac63809
fix old PT
vumichien Mar 20, 2022
b18abb4
fix old PT torch clip
vumichien Mar 20, 2022
608f056
Update torchmetrics/image/sam.py
vumichien Mar 21, 2022
5e5e455
Update torchmetrics/functional/image/sam.py
vumichien Mar 21, 2022
8ee687d
Update torchmetrics/image/sam.py
vumichien Mar 21, 2022
0e6e49b
Update torchmetrics/image/sam.py
vumichien Mar 21, 2022
bc9dc03
move reference test function, change docs
vumichien Mar 21, 2022
743725c
resolve conflict with master
vumichien Mar 21, 2022
60aa750
Merge branch 'master' into feature/sam
Borda Mar 21, 2022
0f4b419
fix
vumichien Mar 21, 2022
768fd04
Merge remote-tracking branch 'origin/feature/sam' into feature/sam
vumichien Mar 21, 2022
d9cc65a
fix changelog
SkafteNicki Mar 21, 2022
bfe12fe
fix docs
vumichien Mar 21, 2022
ac1b9f8
fix mypy
SkafteNicki Mar 21, 2022
9aafd37
Merge branch 'feature/sam' of https://github.com/vumichien/metrics in…
SkafteNicki Mar 21, 2022
bb917de
fix mypy
SkafteNicki Mar 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added new image metric `SpectralAngleMapper` ([#885](https://github.com/PyTorchLightning/metrics/pull/885))


- Added `CoverageError` to classification metrics ([#787](https://github.com/PyTorchLightning/metrics/pull/787))

Expand Down
8 changes: 8 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ peak_signal_noise_ratio [func]
.. autofunction:: torchmetrics.functional.peak_signal_noise_ratio
:noindex:


spectral_angle_mapper [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.spectral_angle_mapper
:noindex:


universal_image_quality_index [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,12 @@ PeakSignalNoiseRatio
.. autoclass:: torchmetrics.PeakSignalNoiseRatio
:noindex:

SpectralAngleMapper
~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SpectralAngleMapper
:noindex:

StructuralSimilarityIndexMeasure
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
15 changes: 15 additions & 0 deletions tests/helpers/reference_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics._regression import _check_reg_targets
from sklearn.utils import assert_all_finite, check_consistent_length, column_or_1d

Expand Down Expand Up @@ -198,3 +200,16 @@ def _calibration_error(
loss += np.sum(np.nan_to_num(debias))
loss = np.sqrt(max(loss, 0.0))
return loss


def _sk_sam(preds, target, reduction):
similarity = F.cosine_similarity(preds, target)
sam_score = torch.clamp(similarity, -1, 1).acos()
# reduction
if reduction == "sum":
to_return = torch.sum(sam_score)
elif reduction == "elementwise_mean":
to_return = torch.mean(sam_score)
else:
to_return = sam_score
return to_return
106 changes: 106 additions & 0 deletions tests/image/test_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch

from tests.helpers import seed_all
from tests.helpers.reference_metrics import _sk_sam
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.image.sam import SpectralAngleMapper

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

_inputs = []
for size, channel, dtype in [
(12, 3, torch.float),
(13, 3, torch.float32),
(14, 3, torch.double),
(15, 3, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(Input(preds=preds, target=target))


@pytest.mark.parametrize("reduction", ["sum", "elementwise_mean"])
@pytest.mark.parametrize(
"preds, target",
[(i.preds, i.target) for i in _inputs],
)
class TestSpectralAngleMapper(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_sam(self, reduction, preds, target, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
SpectralAngleMapper,
partial(_sk_sam, reduction=reduction),
dist_sync_on_step,
metric_args=dict(reduction=reduction),
)

def test_sam_functional(self, reduction, preds, target):
self.run_functional_metric_test(
preds,
target,
spectral_angle_mapper,
partial(_sk_sam, reduction=reduction),
metric_args=dict(reduction=reduction),
)

# SAM half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="SAM metric does not support cpu + half precision")
def test_sam_half_cpu(self, reduction, preds, target):
self.run_precision_test_cpu(
preds,
target,
SpectralAngleMapper,
spectral_angle_mapper,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_sam_half_gpu(self, reduction, preds, target):
self.run_precision_test_gpu(preds, target, SpectralAngleMapper, spectral_angle_mapper)


def test_error_on_different_shape(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(RuntimeError):
metric(torch.randn([1, 3, 16, 16]), torch.randn([1, 1, 16, 16]))


def test_error_on_invalid_shape(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(ValueError):
metric(torch.randn([3, 16, 16]), torch.randn([3, 16, 16]))


def test_error_on_invalid_type(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(TypeError):
metric(torch.randn([3, 16, 16]), torch.randn([3, 16, 16], dtype=torch.float64))


def test_error_on_grayscale_image(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(ValueError):
metric(torch.randn([16, 1, 16, 16]), torch.randn([16, 1, 16, 16]))
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from torchmetrics.image import ( # noqa: E402
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
SpectralAngleMapper,
StructuralSimilarityIndexMeasure,
UniversalImageQualityIndex,
)
Expand Down Expand Up @@ -167,6 +168,7 @@
"SignalNoiseRatio",
"SpearmanCorrCoef",
"Specificity",
"SpectralAngleMapper",
"SQuAD",
"StructuralSimilarityIndexMeasure",
"StatScores",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torchmetrics.functional.classification.stat_scores import stat_scores
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand Down Expand Up @@ -152,6 +153,7 @@
"symmetric_mean_absolute_percentage_error",
"translation_edit_rate",
"universal_image_quality_index",
"spectral_angle_mapper",
"word_error_rate",
"char_error_rate",
"match_error_rate",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401
from torchmetrics.functional.image.ssim import ( # noqa: F401
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand Down
119 changes: 119 additions & 0 deletions torchmetrics/functional/image/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce


def _sam_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Spectral Angle Mapper. Checks for same shape and type of
the input tensors.

Args:
preds: Predicted tensor
target: Ground truth tensor
"""

if preds.dtype != target.dtype:
raise TypeError(
"Expected `preds` and `target` to have the same data type."
f" Got preds: {preds.dtype} and target: {target.dtype}."
)
_check_same_shape(preds, target)
if len(preds.shape) != 4:
raise ValueError(
"Expected `preds` and `target` to have BxCxHxW shape."
f" Got preds: {preds.shape} and target: {target.shape}."
)
if (preds.shape[1] <= 1) or (target.shape[1] <= 1):
raise ValueError(
"Expected channel dimension of `preds` and `target` to be larger than 1."
f" Got preds: {preds.shape[1]} and target: {target.shape[1]}."
)
return preds, target


def _sam_compute(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Computes Spectral Angle Mapper.

Args:
preds: estimated image
target: ground truth image
reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied

Example:
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> preds, target = _sam_update(preds, target)
>>> _sam_compute(preds, target)
tensor(0.5943)
"""
dot_product = (preds * target).sum(dim=1)
preds_norm = preds.norm(dim=1)
target_norm = target.norm(dim=1)
sam_score = torch.clamp(dot_product / (preds_norm * target_norm), -1, 1).acos()
return reduce(sam_score, reduction)


def spectral_angle_mapper(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Universal Spectral Angle Mapper.

Args:
preds: estimated image
target: ground truth image
reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied

Return:
Tensor with Spectral Angle Mapper score

Raises:
TypeError:
If ``preds`` and ``target`` don't have the same data type.
ValueError:
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.

Example:
>>> from torchmetrics.functional import spectral_angle_mapper
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> spectral_angle_mapper(preds, target)
tensor(0.5943)

References: Roberta H. Yuhas, Alexander F. H. Goetz and Joe W. Boardman, "Discrimination among semi-arid
landscape endmembers using the Spectral Angle Mapper (SAM) algorithm" in PL, Summaries of the Third Annual JPL
Airborne Geoscience Workshop, vol. 1, June 1, 1992.
"""
preds, target = _sam_update(preds, target)
return _sam_compute(preds, target, reduction)
1 change: 1 addition & 0 deletions torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.psnr import PeakSignalNoiseRatio # noqa: F401
from torchmetrics.image.sam import SpectralAngleMapper # noqa: F401
from torchmetrics.image.ssim import ( # noqa: F401
MultiScaleStructuralSimilarityIndexMeasure,
StructuralSimilarityIndexMeasure,
Expand Down
Loading