Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spearman correlation coefficient #158

Merged
merged 16 commits into from
Apr 13, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added `CohenKappa` ([#69](https://github.com/PyTorchLightning/metrics/pull/69))
* Added `MatthewsCorrcoef` ([#98](https://github.com/PyTorchLightning/metrics/pull/98))
* Added `PearsonCorrcoef` ([#157](https://github.com/PyTorchLightning/metrics/pull/157))
* Added `SpearmanCorrcoef` ([#158](https://github.com/PyTorchLightning/metrics/pull/158))
* Added `Hinge` ([#120](https://github.com/PyTorchLightning/metrics/pull/120))
- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))
Expand Down
20 changes: 14 additions & 6 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,27 @@ psnr [func]
:noindex:


ssim [func]
~~~~~~~~~~~
r2score [func]
~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.ssim
.. autofunction:: torchmetrics.functional.r2score
:noindex:


r2score [func]
~~~~~~~~~~~~~~
spearman_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.r2score
.. autofunction:: torchmetrics.functional.spearman_corrcoef
:noindex:


ssim [func]
~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.ssim
:noindex:


***
NLP
***
Expand Down
20 changes: 14 additions & 6 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,28 @@ PSNR
:noindex:


SSIM
~~~~
R2Score
~~~~~~~

.. autoclass:: torchmetrics.SSIM
.. autoclass:: torchmetrics.R2Score
:noindex:


R2Score
~~~~~~~
SpearmanCorrcoef
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.R2Score
.. autoclass:: torchmetrics.SpearmanCorrcoef
:noindex:


SSIM
~~~~

.. autoclass:: torchmetrics.SSIM
:noindex:



*********
Retrieval
*********
Expand Down
101 changes: 101 additions & 0 deletions tests/regression/test_spearman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple

import pytest
import torch
from scipy.stats import rankdata, spearmanr

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.regression.spearman import _rank_data, spearman_corrcoef
from torchmetrics.regression.spearman import SpearmanCorrcoef

seed_all(42)

Input = namedtuple('Input', ["preds", "target"])

_single_target_inputs1 = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)

_single_target_inputs2 = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE),
target=torch.randn(NUM_BATCHES, BATCH_SIZE),
)


@pytest.mark.parametrize(
"preds, target", [
(_single_target_inputs1.preds, _single_target_inputs1.target),
(_single_target_inputs2.preds, _single_target_inputs2.target),
]
)
def test_ranking(preds, target):
""" test that ranking function works as expected """
for p, t in zip(preds, target):
scipy_ranking = [rankdata(p.numpy()), rankdata(t.numpy())]
tm_ranking = [_rank_data(p), _rank_data(t)]
assert (torch.tensor(scipy_ranking[0]) == tm_ranking[0]).all()
assert (torch.tensor(scipy_ranking[1]) == tm_ranking[1]).all()


def _sk_metric(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return spearmanr(sk_target, sk_preds)[0]


@pytest.mark.parametrize(
"preds, target", [
(_single_target_inputs1.preds, _single_target_inputs1.target),
(_single_target_inputs2.preds, _single_target_inputs2.target),
]
)
class TestSpearmanCorrcoef(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
SpearmanCorrcoef,
_sk_metric,
dist_sync_on_step,
)

def test_spearman_corrcoef_functional(self, preds, target):
self.run_functional_metric_test(preds, target, spearman_corrcoef, _sk_metric)

# Spearman half + cpu does not work due to missing support in torch.arange
@pytest.mark.xfail(reason="Spearman metric does not support cpu + half precision")
def test_spearman_corrcoef_half_cpu(self, preds, target):
self.run_precision_test_cpu(preds, target, SpearmanCorrcoef, spearman_corrcoef)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_spearman_corrcoef_half_gpu(self, preds, target):
self.run_precision_test_gpu(preds, target, SpearmanCorrcoef, spearman_corrcoef)


def test_error_on_different_shape():
metric = SpearmanCorrcoef()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100, ), torch.randn(50, ))

with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'):
metric(torch.randn(100, 2), torch.randn(100, 2))
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
MeanSquaredLogError,
PearsonCorrcoef,
R2Score,
SpearmanCorrcoef,
)
from torchmetrics.retrieval import ( # noqa: F401 E402
RetrievalFallOut,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
from torchmetrics.functional.regression.psnr import psnr # noqa: F401
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.spearman import spearman_corrcoef # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
from torchmetrics.functional.regression.psnr import psnr # noqa: F401
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.spearman import spearman_corrcoef # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
111 changes: 111 additions & 0 deletions torchmetrics/functional/regression/spearman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def _find_repeats(data: Tensor):
""" find and return values which have repeats i.e. the same value are more than once in the tensor """
temp = data.detach().clone()
temp = temp.sort()[0]

change = torch.cat([torch.tensor([True], device=temp.device), temp[1:] != temp[:-1]])
unique = temp[change]
change_idx = torch.cat([
torch.nonzero(change),
torch.tensor([[temp.numel()]], device=temp.device)
]).flatten()
freq = change_idx[1:] - change_idx[:-1]
atleast2 = freq > 1
return unique[atleast2]


def _rank_data(data: Tensor):
""" Calculate the rank for each element of a tensor. The rank refers to the indices of an element in the
corresponding sorted tensor (starting from 1). Duplicates of the same value will be assigned the mean of
their rank

Adopted from:
https://github.com/scipy/scipy/blob/v1.6.2/scipy/stats/stats.py#L4140-L4303
"""
n = data.numel()
rank = torch.empty_like(data)
idx = data.argsort()
rank[idx[:n]] = torch.arange(1, n + 1, dtype=data.dtype, device=data.device)

repeats = _find_repeats(data)
for r in repeats:
condition = rank == r
rank[condition] = rank[condition].mean()
return rank


def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
if preds.dtype != target.dtype:
raise TypeError(
"Expected `preds` and `target` to have the same data type."
f" Got pred: {preds.dtype} and target: {target.dtype}."
)
_check_same_shape(preds, target)

if preds.ndim > 1 or target.ndim > 1:
raise ValueError('Expected both predictions and target to be 1 dimensional tensors.')

return preds, target


def _spearman_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor:
preds = _rank_data(preds)
target = _rank_data(target)

preds_diff = preds - preds.mean()
target_diff = target - target.mean()

cov = (preds_diff * target_diff).mean()
preds_std = torch.sqrt((preds_diff * preds_diff).mean())
target_std = torch.sqrt((target_diff * target_diff).mean())

corrcoef = cov / (preds_std * target_std + eps)
return torch.clamp(corrcoef, -1.0, 1.0)


def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
r"""
Computes `spearmans rank correlation coefficient
<https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_:

.. math:
r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}}

where :math:`rg_x` and :math:`rg_y` are the rank associated to the variables x and y. Spearmans correlations
coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables.

Args:
preds: estimated scores
target: ground truth scores

Example:
>>> from torchmetrics.functional import spearman_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> spearman_corrcoef(preds, target)
tensor(1.0000)

"""
preds, target = _spearman_corrcoef_update(preds, target)
return _spearman_corrcoef_compute(preds, target)
1 change: 1 addition & 0 deletions torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
from torchmetrics.regression.pearson import PearsonCorrcoef # noqa: F401
from torchmetrics.regression.psnr import PSNR # noqa: F401
from torchmetrics.regression.r2score import R2Score # noqa: F401
from torchmetrics.regression.spearman import SpearmanCorrcoef # noqa: F401
from torchmetrics.regression.ssim import SSIM # noqa: F401
Loading