diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d73cba87cd..87a70b4012f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `CohenKappa` ([#69](https://github.com/PyTorchLightning/metrics/pull/69)) * Added `MatthewsCorrcoef` ([#98](https://github.com/PyTorchLightning/metrics/pull/98)) * Added `PearsonCorrcoef` ([#157](https://github.com/PyTorchLightning/metrics/pull/157)) + * Added `SpearmanCorrcoef` ([#158](https://github.com/PyTorchLightning/metrics/pull/158)) * Added `Hinge` ([#120](https://github.com/PyTorchLightning/metrics/pull/120)) - Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 23811ab0816..3fc08d56de3 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -210,19 +210,27 @@ psnr [func] :noindex: -ssim [func] -~~~~~~~~~~~ +r2score [func] +~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.ssim +.. autofunction:: torchmetrics.functional.r2score :noindex: -r2score [func] -~~~~~~~~~~~~~~ +spearman_corrcoef [func] +~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.r2score +.. autofunction:: torchmetrics.functional.spearman_corrcoef :noindex: + +ssim [func] +~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.ssim + :noindex: + + *** NLP *** diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 52e3736e4c9..979846785ec 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -264,20 +264,28 @@ PSNR :noindex: -SSIM -~~~~ +R2Score +~~~~~~~ -.. autoclass:: torchmetrics.SSIM +.. autoclass:: torchmetrics.R2Score :noindex: -R2Score -~~~~~~~ +SpearmanCorrcoef +~~~~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.R2Score +.. autoclass:: torchmetrics.SpearmanCorrcoef :noindex: +SSIM +~~~~ + +.. autoclass:: torchmetrics.SSIM + :noindex: + + + ********* Retrieval ********* diff --git a/tests/regression/test_spearman.py b/tests/regression/test_spearman.py new file mode 100644 index 00000000000..45e64f65e8c --- /dev/null +++ b/tests/regression/test_spearman.py @@ -0,0 +1,101 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import pytest +import torch +from scipy.stats import rankdata, spearmanr + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.functional.regression.spearman import _rank_data, spearman_corrcoef +from torchmetrics.regression.spearman import SpearmanCorrcoef + +seed_all(42) + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_single_target_inputs2 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE), + target=torch.randn(NUM_BATCHES, BATCH_SIZE), +) + + +@pytest.mark.parametrize( + "preds, target", [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ] +) +def test_ranking(preds, target): + """ test that ranking function works as expected """ + for p, t in zip(preds, target): + scipy_ranking = [rankdata(p.numpy()), rankdata(t.numpy())] + tm_ranking = [_rank_data(p), _rank_data(t)] + assert (torch.tensor(scipy_ranking[0]) == tm_ranking[0]).all() + assert (torch.tensor(scipy_ranking[1]) == tm_ranking[1]).all() + + +def _sk_metric(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return spearmanr(sk_target, sk_preds)[0] + + +@pytest.mark.parametrize( + "preds, target", [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ] +) +class TestSpearmanCorrcoef(MetricTester): + atol = 1e-2 + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + SpearmanCorrcoef, + _sk_metric, + dist_sync_on_step, + ) + + def test_spearman_corrcoef_functional(self, preds, target): + self.run_functional_metric_test(preds, target, spearman_corrcoef, _sk_metric) + + # Spearman half + cpu does not work due to missing support in torch.arange + @pytest.mark.xfail(reason="Spearman metric does not support cpu + half precision") + def test_spearman_corrcoef_half_cpu(self, preds, target): + self.run_precision_test_cpu(preds, target, SpearmanCorrcoef, spearman_corrcoef) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_spearman_corrcoef_half_gpu(self, preds, target): + self.run_precision_test_gpu(preds, target, SpearmanCorrcoef, spearman_corrcoef) + + +def test_error_on_different_shape(): + metric = SpearmanCorrcoef() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) + + with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'): + metric(torch.randn(100, 2), torch.randn(100, 2)) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index e115f00af18..62098cdb04a 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -49,6 +49,7 @@ MeanSquaredLogError, PearsonCorrcoef, R2Score, + SpearmanCorrcoef, ) from torchmetrics.retrieval import ( # noqa: F401 E402 RetrievalFallOut, diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 1cce8d36349..153894791f6 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -37,6 +37,7 @@ from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401 from torchmetrics.functional.regression.psnr import psnr # noqa: F401 from torchmetrics.functional.regression.r2score import r2score # noqa: F401 +from torchmetrics.functional.regression.spearman import spearman_corrcoef # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401 diff --git a/torchmetrics/functional/regression/__init__.py b/torchmetrics/functional/regression/__init__.py index 9684b5642b4..28aabb6ec62 100644 --- a/torchmetrics/functional/regression/__init__.py +++ b/torchmetrics/functional/regression/__init__.py @@ -18,4 +18,5 @@ from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401 from torchmetrics.functional.regression.psnr import psnr # noqa: F401 from torchmetrics.functional.regression.r2score import r2score # noqa: F401 +from torchmetrics.functional.regression.spearman import spearman_corrcoef # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py new file mode 100644 index 00000000000..a5f89c7ebc8 --- /dev/null +++ b/torchmetrics/functional/regression/spearman.py @@ -0,0 +1,111 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +from torch import Tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _find_repeats(data: Tensor): + """ find and return values which have repeats i.e. the same value are more than once in the tensor """ + temp = data.detach().clone() + temp = temp.sort()[0] + + change = torch.cat([torch.tensor([True], device=temp.device), temp[1:] != temp[:-1]]) + unique = temp[change] + change_idx = torch.cat([ + torch.nonzero(change), + torch.tensor([[temp.numel()]], device=temp.device) + ]).flatten() + freq = change_idx[1:] - change_idx[:-1] + atleast2 = freq > 1 + return unique[atleast2] + + +def _rank_data(data: Tensor): + """ Calculate the rank for each element of a tensor. The rank refers to the indices of an element in the + corresponding sorted tensor (starting from 1). Duplicates of the same value will be assigned the mean of + their rank + + Adopted from: + https://github.com/scipy/scipy/blob/v1.6.2/scipy/stats/stats.py#L4140-L4303 + """ + n = data.numel() + rank = torch.empty_like(data) + idx = data.argsort() + rank[idx[:n]] = torch.arange(1, n + 1, dtype=data.dtype, device=data.device) + + repeats = _find_repeats(data) + for r in repeats: + condition = rank == r + rank[condition] = rank[condition].mean() + return rank + + +def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + if preds.dtype != target.dtype: + raise TypeError( + "Expected `preds` and `target` to have the same data type." + f" Got pred: {preds.dtype} and target: {target.dtype}." + ) + _check_same_shape(preds, target) + + if preds.ndim > 1 or target.ndim > 1: + raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') + + return preds, target + + +def _spearman_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor: + preds = _rank_data(preds) + target = _rank_data(target) + + preds_diff = preds - preds.mean() + target_diff = target - target.mean() + + cov = (preds_diff * target_diff).mean() + preds_std = torch.sqrt((preds_diff * preds_diff).mean()) + target_std = torch.sqrt((target_diff * target_diff).mean()) + + corrcoef = cov / (preds_std * target_std + eps) + return torch.clamp(corrcoef, -1.0, 1.0) + + +def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: + r""" + Computes `spearmans rank correlation coefficient + `_: + + .. math: + r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}} + + where :math:`rg_x` and :math:`rg_y` are the rank associated to the variables x and y. Spearmans correlations + coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables. + + Args: + preds: estimated scores + target: ground truth scores + + Example: + >>> from torchmetrics.functional import spearman_corrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> spearman_corrcoef(preds, target) + tensor(1.0000) + + """ + preds, target = _spearman_corrcoef_update(preds, target) + return _spearman_corrcoef_compute(preds, target) diff --git a/torchmetrics/regression/__init__.py b/torchmetrics/regression/__init__.py index 5c405e10ff3..d10c35ae864 100644 --- a/torchmetrics/regression/__init__.py +++ b/torchmetrics/regression/__init__.py @@ -18,4 +18,5 @@ from torchmetrics.regression.pearson import PearsonCorrcoef # noqa: F401 from torchmetrics.regression.psnr import PSNR # noqa: F401 from torchmetrics.regression.r2score import R2Score # noqa: F401 +from torchmetrics.regression.spearman import SpearmanCorrcoef # noqa: F401 from torchmetrics.regression.ssim import SSIM # noqa: F401 diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py new file mode 100644 index 00000000000..9dddcac9f97 --- /dev/null +++ b/torchmetrics/regression/spearman.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch +from torch import Tensor + +from torchmetrics.functional.regression.spearman import _spearman_corrcoef_compute, _spearman_corrcoef_update +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn + + +class SpearmanCorrcoef(Metric): + r""" + Computes `spearmans rank correlation coefficient + `_. + + .. math: + r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}} + + where rg_x and rg_y are the rank associated to the variables x and y. Spearmans correlations coefficient + corresponds to the standard pearsons correlation coefficient calculated on the rank variables. + + Args: + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Example: + >>> from torchmetrics import SpearmanCorrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> spearman = SpearmanCorrcoef() + >>> spearman(preds, target) + tensor(1.0000) + """ + + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + rank_zero_warn( + 'Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer.' + ' For large datasets, this may lead to large memory footprint.' + ) + + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + def update(self, preds: Tensor, target: Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target = _spearman_corrcoef_update(preds, target) + self.preds.append(preds) + self.target.append(target) + + def compute(self): + """ + Computes spearmans correlation coefficient + """ + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _spearman_corrcoef_compute(preds, target)