From 75d180889a5cea87b11651520aea580a217b9812 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 5 Apr 2021 16:01:40 +0200 Subject: [PATCH 01/13] ranking --- .../functional/regression/spearman.py | 53 +++++++++ torchmetrics/regression/spearman.py | 106 ++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 torchmetrics/functional/regression/spearman.py create mode 100644 torchmetrics/regression/spearman.py diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py new file mode 100644 index 00000000000..3a61ac17b69 --- /dev/null +++ b/torchmetrics/functional/regression/spearman.py @@ -0,0 +1,53 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch import Tensor + + +def _find_repeats(data: Tensor): + temp = data.detach().clone() + temp = temp.sort()[0] + + change = torch.cat([torch.tensor([True]), temp[1:] != temp[:-1]]) + unique = temp[change] + change_idx = torch.cat([torch.nonzero(change), torch.tensor([[n]])]).flatten() + freq = change_idx[1:] - change_idx[:-1] + atleast2 = freq > 1 + return unique[atleast2] + +def _rank_data(data: Tensor): + n = data.numel() + rank = torch.empty_like(data) + idx = data.argsort() + rank[idx[:n]] = torch.arange(1, n+1, dtype=torch.float) + + repeats = _find_repeats(data) + for r in repeats: + condition = (data == r).filled(False) + rank[condition] = rank[condition].mean() + return rank + +def _spearman_corrcoef_update(preds: Tensor, target: Tensor): + + +def _spearman_corrcoef_compute(preds: Tensor, target: Tensor): + rank_preds = _rank_data(preds) + rank_target = _rank_data(target) + + cov = ((rank_preds - rank_preds.mean()) * (rank_target - rank_target.mean())).sum() + return cov / (rank_preds.std() * rank_target.std()) + + +def spearman_corrcoef() + + \ No newline at end of file diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py new file mode 100644 index 00000000000..2b8f9de884c --- /dev/null +++ b/torchmetrics/regression/spearman.py @@ -0,0 +1,106 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence + +import torch +from torch import Tensor + +from torchmetrics.functional.regression.ssim import _ssim_compute, _ssim_update +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn + + +class SpearmanCorrCoef(Metric): + """ + Computes `Structual Similarity Index Measure + `_ (SSIM). + + Args: + kernel_size: size of the gaussian kernel (default: (11, 11)) + sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of SSIM. Default: 0.01 + k2: Parameter of SSIM. Default: 0.03 + + Return: + Tensor with SSIM score + + Example: + >>> from torchmetrics import SSIM + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> ssim = SSIM() + >>> ssim(preds, target) + tensor(0.9219) + """ + + def __init__( + self, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + k1: float = 0.01, + k2: float = 0.03, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + rank_zero_warn( + 'Metric `SSIM` will save all targets and' + ' predictions in buffer. For large datasets this may lead' + ' to large memory footprint.' + ) + + self.add_state("y", default=[], dist_reduce_fx=None) + self.add_state("y_pred", default=[], dist_reduce_fx=None) + self.kernel_size = kernel_size + self.sigma = sigma + self.data_range = data_range + self.k1 = k1 + self.k2 = k2 + self.reduction = reduction + + def update(self, preds: Tensor, target: Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target = _ssim_update(preds, target) + self.y_pred.append(preds) + self.y.append(target) + + def compute(self): + """ + Computes explained variance over state. + """ + preds = torch.cat(self.y_pred, dim=0) + target = torch.cat(self.y, dim=0) + return _ssim_compute( + preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2 + ) From 46cad73c7edec730bbc11de87d960fb70f402229 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 5 Apr 2021 19:59:25 +0200 Subject: [PATCH 02/13] init files --- tests/regression/test_spearman.py | 119 ++++++++++++++++++ torchmetrics/__init__.py | 1 + torchmetrics/functional/__init__.py | 1 + .../functional/regression/__init__.py | 1 + .../functional/regression/spearman.py | 19 ++- torchmetrics/regression/__init__.py | 1 + torchmetrics/regression/spearman.py | 83 ++++++------ 7 files changed, 175 insertions(+), 50 deletions(-) create mode 100644 tests/regression/test_spearman.py diff --git a/tests/regression/test_spearman.py b/tests/regression/test_spearman.py new file mode 100644 index 00000000000..0f820455f17 --- /dev/null +++ b/tests/regression/test_spearman.py @@ -0,0 +1,119 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import pytest +import torch +from scipy.stats import pearsonr + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.functional.regression.pearson import _update_cov, _update_mean, pearson_corrcoef +from torchmetrics.regression.pearson import PearsonCorrcoef + +seed_all(42) + + +def test_update_functions(tmpdir): + """ Test that updating the estimates are equal to estimating them on all data """ + data = torch.randn(100, 2) + batch1, batch2 = data.chunk(2) + + def _mean_cov(data): + mean = data.mean(0) + diff = data - mean + cov = diff.T @ diff + return mean, cov + + mean_update, cov_update, size_update = torch.zeros(2), torch.zeros(2, 2), torch.zeros(1) + for batch in [batch1, batch2]: + new_mean = _update_mean(mean_update, size_update, batch) + new_cov = _update_cov(cov_update, mean_update, new_mean, batch) + + assert not torch.allclose(new_mean, mean_update), "mean estimate did not update" + assert not torch.allclose(new_cov, cov_update), "covariance estimate did not update" + + size_update += batch.shape[0] + mean_update = new_mean + cov_update = new_cov + + mean, cov = _mean_cov(data) + + assert torch.allclose(mean, mean_update), "updated mean does not correspond to mean of all data" + assert torch.allclose(cov, cov_update), "updated covariance does not correspond to covariance of all data" + + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_single_target_inputs2 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE), + target=torch.randn(NUM_BATCHES, BATCH_SIZE), +) + + +def _sk_metric(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return pearsonr(sk_target, sk_preds)[0] + + +@pytest.mark.parametrize("preds, target", [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), +]) +class TestPearsonCorrcoef(MetricTester): + atol = 1e-4 + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_explained_variance(self, preds, target, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + PearsonCorrcoef, + _sk_metric, + dist_sync_on_step, + ) + + def test_pearson_corrcoef_functional(self, preds, target): + self.run_functional_metric_test( + preds, + target, + pearson_corrcoef, + _sk_metric + ) + + # Pearson half + cpu does not work due to missing support in torch.sqrt + @pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision") + def test_pearson_corrcoef_half_cpu(self, preds, target): + self.run_precision_test_cpu(preds, target, PearsonCorrcoef, pearson_corrcoef) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_pearson_corrcoef_half_gpu(self, preds, target): + self.run_precision_test_gpu(preds, target, PearsonCorrcoef, pearson_corrcoef) + + +def test_error_on_different_shape(): + metric = PearsonCorrcoef() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) + + with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'): + metric(torch.randn(100, 2), torch.randn(100, 2)) \ No newline at end of file diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 7e21b7886df..6fd8bd9568a 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -48,6 +48,7 @@ MeanSquaredError, MeanSquaredLogError, R2Score, + SpearmanCorrcoef ) from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision # noqa: F401 E402 from torchmetrics.wrappers import BootStrapper # noqa: F401 E402 diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 18d90b6d9ce..77c36135b89 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -36,6 +36,7 @@ from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401 from torchmetrics.functional.regression.psnr import psnr # noqa: F401 from torchmetrics.functional.regression.r2score import r2score # noqa: F401 +from torchmetrics.functional.regression.spearman import spearman_corrcoef # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401 diff --git a/torchmetrics/functional/regression/__init__.py b/torchmetrics/functional/regression/__init__.py index 63c2aabb1e2..db72cd45756 100644 --- a/torchmetrics/functional/regression/__init__.py +++ b/torchmetrics/functional/regression/__init__.py @@ -17,4 +17,5 @@ from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401 from torchmetrics.functional.regression.psnr import psnr # noqa: F401 from torchmetrics.functional.regression.r2score import r2score # noqa: F401 +from torchmetrics.functional.regression.spearman import spearman_corrcoef # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index 3a61ac17b69..0099cb2f9be 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch from torch import Tensor - +from torchmetrics.utilities.checks import _check_same_shape def _find_repeats(data: Tensor): temp = data.detach().clone() @@ -38,7 +39,17 @@ def _rank_data(data: Tensor): return rank def _spearman_corrcoef_update(preds: Tensor, target: Tensor): - + if preds.dtype != target.dtype: + raise TypeError( + "Expected `preds` and `target` to have the same data type." + f" Got pred: {preds.dtype} and target: {target.dtype}." + ) + _check_same_shape(preds, target) + + if preds.ndim > 1 or target.ndim > 1: + raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') + + return preds, target def _spearman_corrcoef_compute(preds: Tensor, target: Tensor): rank_preds = _rank_data(preds) @@ -48,6 +59,8 @@ def _spearman_corrcoef_compute(preds: Tensor, target: Tensor): return cov / (rank_preds.std() * rank_target.std()) -def spearman_corrcoef() +def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: + preds, target = _spearman_corrcoef_update(preds, target) + return _spearman_corrcoef_compute(preds, target) \ No newline at end of file diff --git a/torchmetrics/regression/__init__.py b/torchmetrics/regression/__init__.py index bf2da61095e..99ab467fc04 100644 --- a/torchmetrics/regression/__init__.py +++ b/torchmetrics/regression/__init__.py @@ -17,4 +17,5 @@ from torchmetrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401 from torchmetrics.regression.psnr import PSNR # noqa: F401 from torchmetrics.regression.r2score import R2Score # noqa: F401 +from torchmetrics.regression.spearman import SpearmanCorrcoef # noqa: F401 from torchmetrics.regression.ssim import SSIM # noqa: F401 diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index 2b8f9de884c..0b48edcb4fe 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -11,77 +11,68 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Callable import torch from torch import Tensor -from torchmetrics.functional.regression.ssim import _ssim_compute, _ssim_update +from torchmetrics.functional.regression.spearman import _spearman_corrcoef_compute, _spearman_corrcoef_update from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn -class SpearmanCorrCoef(Metric): +class SpearmanCorrcoef(Metric): """ - Computes `Structual Similarity Index Measure - `_ (SSIM). + Computes `spearmans rank correlation coefficient + `_. - Args: - kernel_size: size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 + .. math: + r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}} - Return: - Tensor with SSIM score + where rg_x and rg_y are the rank associated to the variables x and y. Spearmans correlations coefficient + corresponds to the standard pearsons correlation coefficient calculated on the rank variables. + Args: + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather Example: - >>> from torchmetrics import SSIM - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> ssim = SSIM() - >>> ssim(preds, target) - tensor(0.9219) + >>> from torchmetrics import SpearmanCorrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> spearman = SpearmanCorrcoef() + >>> spearman(preds, target) + tensor(0.9849) """ def __init__( self, - kernel_size: Sequence[int] = (11, 11), - sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", - data_range: Optional[float] = None, - k1: float = 0.01, - k2: float = 0.03, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) rank_zero_warn( - 'Metric `SSIM` will save all targets and' + 'Metric `SpearmanCorrcoef` will save all targets and' ' predictions in buffer. For large datasets this may lead' ' to large memory footprint.' ) - self.add_state("y", default=[], dist_reduce_fx=None) - self.add_state("y_pred", default=[], dist_reduce_fx=None) - self.kernel_size = kernel_size - self.sigma = sigma - self.data_range = data_range - self.k1 = k1 - self.k2 = k2 - self.reduction = reduction + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) def update(self, preds: Tensor, target: Tensor): """ @@ -91,16 +82,14 @@ def update(self, preds: Tensor, target: Tensor): preds: Predictions from model target: Ground truth values """ - preds, target = _ssim_update(preds, target) - self.y_pred.append(preds) - self.y.append(target) + preds, target = _spearman_corrcoef_update(preds, target) + self.preds.append(preds) + self.target.append(target) def compute(self): """ - Computes explained variance over state. + Computes spearmans correlation coefficient """ preds = torch.cat(self.y_pred, dim=0) target = torch.cat(self.y, dim=0) - return _ssim_compute( - preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2 - ) + return _spearman_corrcoef_compute(preds, target) From da6398390e45887158f8c089123d3e3fa4f8154b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 5 Apr 2021 20:08:15 +0200 Subject: [PATCH 03/13] update --- torchmetrics/functional/regression/spearman.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index 0099cb2f9be..0bedc8eb935 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -15,18 +15,28 @@ from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape + def _find_repeats(data: Tensor): + """ find and return values which have repeats i.e. the same value are more than once in the tensor """ temp = data.detach().clone() temp = temp.sort()[0] change = torch.cat([torch.tensor([True]), temp[1:] != temp[:-1]]) unique = temp[change] - change_idx = torch.cat([torch.nonzero(change), torch.tensor([[n]])]).flatten() + change_idx = torch.cat([torch.nonzero(change), torch.tensor([[temp.numel()]])]).flatten() freq = change_idx[1:] - change_idx[:-1] atleast2 = freq > 1 return unique[atleast2] + def _rank_data(data: Tensor): + """ Calculate the rank for each element of a tensor. The rank refers to the indices of an element in the + corresponding sorted tensor (starting from 1). Duplicates of the same value will be assigned the mean of + their rank + + Adopted from: + https://github.com/scipy/scipy/blob/v1.6.2/scipy/stats/stats.py#L4140-L4303 + """ n = data.numel() rank = torch.empty_like(data) idx = data.argsort() @@ -60,6 +70,9 @@ def _spearman_corrcoef_compute(preds: Tensor, target: Tensor): def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: + """ + + """ preds, target = _spearman_corrcoef_update(preds, target) return _spearman_corrcoef_compute(preds, target) From e1c3b3ad651de51b496e07ce058053fa074dc7bd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 7 Apr 2021 21:27:42 +0200 Subject: [PATCH 04/13] nearly working --- CHANGELOG.md | 3 + tests/regression/test_spearman.py | 81 ++++++------------- torchmetrics/__init__.py | 2 +- .../functional/regression/spearman.py | 62 +++++++++----- torchmetrics/regression/spearman.py | 10 +-- 5 files changed, 76 insertions(+), 82 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ab569b5666..5e71b3eb0cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) +- Added `SpearmanCorrcoef` metric ([#158](https://github.com/PyTorchLightning/metrics/pull/158)) + + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) diff --git a/tests/regression/test_spearman.py b/tests/regression/test_spearman.py index 0f820455f17..bd99c09c91a 100644 --- a/tests/regression/test_spearman.py +++ b/tests/regression/test_spearman.py @@ -15,45 +15,15 @@ import pytest import torch -from scipy.stats import pearsonr +from scipy.stats import spearmanr from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional.regression.pearson import _update_cov, _update_mean, pearson_corrcoef -from torchmetrics.regression.pearson import PearsonCorrcoef +from torchmetrics.functional.regression.spearman import spearman_corrcoef +from torchmetrics.regression.spearman import SpearmanCorrcoef seed_all(42) - -def test_update_functions(tmpdir): - """ Test that updating the estimates are equal to estimating them on all data """ - data = torch.randn(100, 2) - batch1, batch2 = data.chunk(2) - - def _mean_cov(data): - mean = data.mean(0) - diff = data - mean - cov = diff.T @ diff - return mean, cov - - mean_update, cov_update, size_update = torch.zeros(2), torch.zeros(2, 2), torch.zeros(1) - for batch in [batch1, batch2]: - new_mean = _update_mean(mean_update, size_update, batch) - new_cov = _update_cov(cov_update, mean_update, new_mean, batch) - - assert not torch.allclose(new_mean, mean_update), "mean estimate did not update" - assert not torch.allclose(new_cov, cov_update), "covariance estimate did not update" - - size_update += batch.shape[0] - mean_update = new_mean - cov_update = new_cov - - mean, cov = _mean_cov(data) - - assert torch.allclose(mean, mean_update), "updated mean does not correspond to mean of all data" - assert torch.allclose(cov, cov_update), "updated covariance does not correspond to covariance of all data" - - Input = namedtuple('Input', ["preds", "target"]) _single_target_inputs1 = Input( @@ -70,50 +40,47 @@ def _mean_cov(data): def _sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - return pearsonr(sk_target, sk_preds)[0] + return spearmanr(sk_target, sk_preds)[0] -@pytest.mark.parametrize("preds, target", [ - (_single_target_inputs1.preds, _single_target_inputs1.target), - (_single_target_inputs2.preds, _single_target_inputs2.target), -]) -class TestPearsonCorrcoef(MetricTester): - atol = 1e-4 +@pytest.mark.parametrize( + "preds, target", [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ] +) +class TestSpearmanCorrcoef(MetricTester): + atol = 1e-2 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_explained_variance(self, preds, target, ddp, dist_sync_on_step): + def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, target, - PearsonCorrcoef, + SpearmanCorrcoef, _sk_metric, dist_sync_on_step, ) - def test_pearson_corrcoef_functional(self, preds, target): - self.run_functional_metric_test( - preds, - target, - pearson_corrcoef, - _sk_metric - ) + def test_spearman_corrcoef_functional(self, preds, target): + self.run_functional_metric_test(preds, target, spearman_corrcoef, _sk_metric) - # Pearson half + cpu does not work due to missing support in torch.sqrt - @pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision") - def test_pearson_corrcoef_half_cpu(self, preds, target): - self.run_precision_test_cpu(preds, target, PearsonCorrcoef, pearson_corrcoef) + # Spearman half + cpu does not work due to missing support in torch.arange + @pytest.mark.xfail(reason="Spearman metric does not support cpu + half precision") + def test_spearman_corrcoef_half_cpu(self, preds, target): + self.run_precision_test_cpu(preds, target, SpearmanCorrcoef, spearman_corrcoef) @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') - def test_pearson_corrcoef_half_gpu(self, preds, target): - self.run_precision_test_gpu(preds, target, PearsonCorrcoef, pearson_corrcoef) + def test_spearman_corrcoef_half_gpu(self, preds, target): + self.run_precision_test_gpu(preds, target, SpearmanCorrcoef, spearman_corrcoef) def test_error_on_different_shape(): - metric = PearsonCorrcoef() + metric = SpearmanCorrcoef() with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): metric(torch.randn(100, ), torch.randn(50, )) with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'): - metric(torch.randn(100, 2), torch.randn(100, 2)) \ No newline at end of file + metric(torch.randn(100, 2), torch.randn(100, 2)) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 6fd8bd9568a..755f4bf543e 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -48,7 +48,7 @@ MeanSquaredError, MeanSquaredLogError, R2Score, - SpearmanCorrcoef + SpearmanCorrcoef, ) from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision # noqa: F401 E402 from torchmetrics.wrappers import BootStrapper # noqa: F401 E402 diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index 0bedc8eb935..c72758f1055 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -13,6 +13,7 @@ # limitations under the License. import torch from torch import Tensor + from torchmetrics.utilities.checks import _check_same_shape @@ -20,10 +21,13 @@ def _find_repeats(data: Tensor): """ find and return values which have repeats i.e. the same value are more than once in the tensor """ temp = data.detach().clone() temp = temp.sort()[0] - - change = torch.cat([torch.tensor([True]), temp[1:] != temp[:-1]]) + + change = torch.cat([torch.tensor([True], device=temp.device), temp[1:] != temp[:-1]]) unique = temp[change] - change_idx = torch.cat([torch.nonzero(change), torch.tensor([[temp.numel()]])]).flatten() + change_idx = torch.cat([ + torch.nonzero(change), + torch.tensor([[temp.numel()]], device=temp.device) + ]).flatten() freq = change_idx[1:] - change_idx[:-1] atleast2 = freq > 1 return unique[atleast2] @@ -32,22 +36,25 @@ def _find_repeats(data: Tensor): def _rank_data(data: Tensor): """ Calculate the rank for each element of a tensor. The rank refers to the indices of an element in the corresponding sorted tensor (starting from 1). Duplicates of the same value will be assigned the mean of - their rank - + their rank + Adopted from: https://github.com/scipy/scipy/blob/v1.6.2/scipy/stats/stats.py#L4140-L4303 """ n = data.numel() rank = torch.empty_like(data) idx = data.argsort() - rank[idx[:n]] = torch.arange(1, n+1, dtype=torch.float) - + rank[idx[:n]] = torch.arange(1, n + 1, dtype=data.dtype, device=data.device) + repeats = _find_repeats(data) for r in repeats: - condition = (data == r).filled(False) + import pdb + pdb.set_trace() + condition = rank == r rank[condition] = rank[condition].mean() return rank + def _spearman_corrcoef_update(preds: Tensor, target: Tensor): if preds.dtype != target.dtype: raise TypeError( @@ -55,25 +62,42 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor): f" Got pred: {preds.dtype} and target: {target.dtype}." ) _check_same_shape(preds, target) - + if preds.ndim > 1 or target.ndim > 1: raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') - + return preds, target - + + def _spearman_corrcoef_compute(preds: Tensor, target: Tensor): rank_preds = _rank_data(preds) rank_target = _rank_data(target) - - cov = ((rank_preds - rank_preds.mean()) * (rank_target - rank_target.mean())).sum() + cov = ((rank_preds - rank_preds.mean()) * (rank_target - rank_target.mean())).mean() return cov / (rank_preds.std() * rank_target.std()) - - + + def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: - """ - + r""" + Computes `spearmans rank correlation coefficient + `_. + + .. math: + r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}} + + where rg_x and rg_y are the rank associated to the variables x and y. Spearmans correlations coefficient + corresponds to the standard pearsons correlation coefficient calculated on the rank variables. + + Args: + preds: estimated scores + target: ground truth scores + + Example: + >>> from torchmetrics.functional import spearman_corrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> spearman_corrcoef(preds, target) + tensor(0.9849) + """ preds, target = _spearman_corrcoef_update(preds, target) return _spearman_corrcoef_compute(preds, target) - - \ No newline at end of file diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index 0b48edcb4fe..9a01ff2e7d8 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Callable +from typing import Any, Callable, Optional import torch from torch import Tensor @@ -22,7 +22,7 @@ class SpearmanCorrcoef(Metric): - """ + r""" Computes `spearmans rank correlation coefficient `_. @@ -57,7 +57,7 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, + dist_sync_fn: Optional[Callable] = None, ): super().__init__( compute_on_step=compute_on_step, @@ -90,6 +90,6 @@ def compute(self): """ Computes spearmans correlation coefficient """ - preds = torch.cat(self.y_pred, dim=0) - target = torch.cat(self.y, dim=0) + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) return _spearman_corrcoef_compute(preds, target) From b7573dc32e7f9b1ea66bd9ae8d800318af606941 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:07:05 +0200 Subject: [PATCH 05/13] fix tests --- tests/regression/test_spearman.py | 19 ++++++++++++-- .../functional/regression/spearman.py | 26 ++++++++++++------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/tests/regression/test_spearman.py b/tests/regression/test_spearman.py index bd99c09c91a..839b99a6bb7 100644 --- a/tests/regression/test_spearman.py +++ b/tests/regression/test_spearman.py @@ -15,11 +15,11 @@ import pytest import torch -from scipy.stats import spearmanr +from scipy.stats import spearmanr, rankdata from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional.regression.spearman import spearman_corrcoef +from torchmetrics.functional.regression.spearman import spearman_corrcoef, _rank_data from torchmetrics.regression.spearman import SpearmanCorrcoef seed_all(42) @@ -37,6 +37,21 @@ ) +@pytest.mark.parametrize( + "preds, target", [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ] +) +def test_ranking(preds, target): + """ test that ranking function works as expected """ + for p, t in zip(preds, target): + scipy_ranking = [rankdata(p.numpy()), rankdata(t.numpy())] + tm_ranking = [_rank_data(p), _rank_data(t)] + assert (torch.tensor(scipy_ranking[0]) == tm_ranking[0]).all() + assert (torch.tensor(scipy_ranking[1]) == tm_ranking[1]).all() + + def _sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index c72758f1055..345659e0435 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import torch from torch import Tensor @@ -25,7 +27,7 @@ def _find_repeats(data: Tensor): change = torch.cat([torch.tensor([True], device=temp.device), temp[1:] != temp[:-1]]) unique = temp[change] change_idx = torch.cat([ - torch.nonzero(change), + torch.nonzero(change), torch.tensor([[temp.numel()]], device=temp.device) ]).flatten() freq = change_idx[1:] - change_idx[:-1] @@ -48,14 +50,12 @@ def _rank_data(data: Tensor): repeats = _find_repeats(data) for r in repeats: - import pdb - pdb.set_trace() condition = rank == r rank[condition] = rank[condition].mean() return rank -def _spearman_corrcoef_update(preds: Tensor, target: Tensor): +def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: if preds.dtype != target.dtype: raise TypeError( "Expected `preds` and `target` to have the same data type." @@ -69,11 +69,19 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor): return preds, target -def _spearman_corrcoef_compute(preds: Tensor, target: Tensor): - rank_preds = _rank_data(preds) - rank_target = _rank_data(target) - cov = ((rank_preds - rank_preds.mean()) * (rank_target - rank_target.mean())).mean() - return cov / (rank_preds.std() * rank_target.std()) +def _spearman_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor: + preds = _rank_data(preds) + target = _rank_data(target) + + preds_diff = preds - preds.mean() + target_diff = target - target.mean() + + cov = (preds_diff * target_diff).mean() + preds_std = torch.sqrt((preds_diff * preds_diff).mean()) + target_std = torch.sqrt((target_diff * target_diff).mean()) + + corrcoef = cov / (preds_std * target_std + eps) + return torch.clamp(corrcoef, -1.0, 1.0) def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: From 93b2d4d0680183e757fe99d37f9f6b365456a6c5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:08:13 +0200 Subject: [PATCH 06/13] pep8 --- torchmetrics/functional/regression/spearman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index 345659e0435..b0445fca88e 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -72,7 +72,7 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Te def _spearman_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor: preds = _rank_data(preds) target = _rank_data(target) - + preds_diff = preds - preds.mean() target_diff = target - target.mean() From efd0dbd5c43a40574d2ddbbbf34776f4d220ef50 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:13:30 +0200 Subject: [PATCH 07/13] add docs --- docs/source/references/functional.rst | 20 ++++++++++++++------ docs/source/references/modules.rst | 20 ++++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index a69404eab35..de3d5a99809 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -203,19 +203,27 @@ psnr [func] :noindex: -ssim [func] -~~~~~~~~~~~ +r2score [func] +~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.ssim +.. autofunction:: torchmetrics.functional.r2score :noindex: -r2score [func] -~~~~~~~~~~~~~~ +spearman_corrcoef [func] +~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.r2score +.. autofunction:: torchmetrics.functional.spearman_corrcoef :noindex: + +ssim [func] +~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.ssim + :noindex: + + *** NLP *** diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 29d58252e16..866b9388e97 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -257,20 +257,28 @@ PSNR :noindex: -SSIM -~~~~ +R2Score +~~~~~~~ -.. autoclass:: torchmetrics.SSIM +.. autoclass:: torchmetrics.R2Score :noindex: -R2Score -~~~~~~~ +SpearmanCorrcoef +~~~~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.R2Score +.. autoclass:: torchmetrics.SpearmanCorrcoef :noindex: +SSIM +~~~~ + +.. autoclass:: torchmetrics.SSIM + :noindex: + + + ********* Retrieval ********* From 1da0cb3efe81a68e233122e42474ffe0ed374f5c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:22:23 +0200 Subject: [PATCH 08/13] fix doctests --- torchmetrics/functional/regression/spearman.py | 2 +- torchmetrics/regression/spearman.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index b0445fca88e..2687be21e51 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -104,7 +104,7 @@ def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> spearman_corrcoef(preds, target) - tensor(0.9849) + tensor(1.0000) """ preds, target = _spearman_corrcoef_update(preds, target) diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index 9a01ff2e7d8..ede0eb1aa99 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -49,7 +49,7 @@ class SpearmanCorrcoef(Metric): >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> spearman = SpearmanCorrcoef() >>> spearman(preds, target) - tensor(0.9849) + tensor(1.0000) """ def __init__( From 7a8bfe9975a9d00327347aa1805f5e6ff5288507 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:34:29 +0200 Subject: [PATCH 09/13] fix docs --- torchmetrics/functional/regression/spearman.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index 2687be21e51..b9fa123bcb8 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -87,13 +87,13 @@ def _spearman_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: r""" Computes `spearmans rank correlation coefficient - `_. + `_: .. math: r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}} - where rg_x and rg_y are the rank associated to the variables x and y. Spearmans correlations coefficient - corresponds to the standard pearsons correlation coefficient calculated on the rank variables. + where :math:`rg_x` and :math:`rg_y` are the rank associated to the variables x and y. Spearmans correlations + coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables. Args: preds: estimated scores From 1f7095272b271e7067efde62161f30162ad41567 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:36:49 +0200 Subject: [PATCH 10/13] pep8 --- torchmetrics/functional/regression/spearman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index b9fa123bcb8..a5f89c7ebc8 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -92,7 +92,7 @@ def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: .. math: r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}} - where :math:`rg_x` and :math:`rg_y` are the rank associated to the variables x and y. Spearmans correlations + where :math:`rg_x` and :math:`rg_y` are the rank associated to the variables x and y. Spearmans correlations coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables. Args: From 36c52e0a26b13ab6e1e0e9435c023cd36c6fdf9d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:40:02 +0200 Subject: [PATCH 11/13] isort --- tests/regression/test_spearman.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/regression/test_spearman.py b/tests/regression/test_spearman.py index 839b99a6bb7..45e64f65e8c 100644 --- a/tests/regression/test_spearman.py +++ b/tests/regression/test_spearman.py @@ -15,11 +15,11 @@ import pytest import torch -from scipy.stats import spearmanr, rankdata +from scipy.stats import rankdata, spearmanr from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional.regression.spearman import spearman_corrcoef, _rank_data +from torchmetrics.functional.regression.spearman import _rank_data, spearman_corrcoef from torchmetrics.regression.spearman import SpearmanCorrcoef seed_all(42) From 7206effc16037d9907b2cfcb34555c7ab7d9bcc1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 16:12:33 +0200 Subject: [PATCH 12/13] ghlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e417c43d0a..01385fd9768 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added other metrics: * Added `CohenKappa` ([#69](https://github.com/PyTorchLightning/metrics/pull/69)) * Added `MatthewsCorrcoef` ([#98](https://github.com/PyTorchLightning/metrics/pull/98)) + * Added `SpearmanCorrcoef` ([#158](https://github.com/PyTorchLightning/metrics/pull/158)) * Added `Hinge` ([#120](https://github.com/PyTorchLightning/metrics/pull/120)) - Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) @@ -28,7 +29,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ) - Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) -- Added `SpearmanCorrcoef` metric ([#158](https://github.com/PyTorchLightning/metrics/pull/158)) ### Changed From b5f4c5d53b079864bedac33b8a2661751893d6d3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 16:15:47 +0200 Subject: [PATCH 13/13] Apply suggestions from code review --- torchmetrics/regression/spearman.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index ede0eb1aa99..9dddcac9f97 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -43,6 +43,7 @@ class SpearmanCorrcoef(Metric): dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather + Example: >>> from torchmetrics import SpearmanCorrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) @@ -66,9 +67,8 @@ def __init__( dist_sync_fn=dist_sync_fn, ) rank_zero_warn( - 'Metric `SpearmanCorrcoef` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer.' + ' For large datasets, this may lead to large memory footprint.' ) self.add_state("preds", default=[], dist_reduce_fx=None)