From cfe5e87797e07cf6c429e99b07648db1ddfb3e4c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 19 Jan 2022 22:34:39 +0100 Subject: [PATCH] Fix Matthews correlation coefficient when the denominator is 0 (#781) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 +++ tests/classification/test_matthews_corrcoef.py | 7 +++++++ .../functional/classification/matthews_corrcoef.py | 10 +++++++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a35f752030..9df5c265111 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed check for available modules ([#772](https://github.com/PyTorchLightning/metrics/pull/772)) +- Fixed Matthews correlation coefficient when the denominator is 0 ([#781](https://github.com/PyTorchLightning/metrics/pull/781)) + + ## [0.7.0] - 2022-01-17 ### Added diff --git a/tests/classification/test_matthews_corrcoef.py b/tests/classification/test_matthews_corrcoef.py index 692c11a8419..b5cb8aae8ea 100644 --- a/tests/classification/test_matthews_corrcoef.py +++ b/tests/classification/test_matthews_corrcoef.py @@ -140,3 +140,10 @@ def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num "threshold": THRESHOLD, }, ) + + +def test_zero_case(): + """Cases where the denominator in the matthews corrcoef is 0, the score should return 0.""" + # Example where neither 1 or 2 is present in the target tensor + out = matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) + assert out == 0.0 diff --git a/torchmetrics/functional/classification/matthews_corrcoef.py b/torchmetrics/functional/classification/matthews_corrcoef.py index 00e396d97db..a0807699806 100644 --- a/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/torchmetrics/functional/classification/matthews_corrcoef.py @@ -37,7 +37,15 @@ def _matthews_corrcoef_compute(confmat: Tensor) -> Tensor: pk = confmat.sum(dim=0).float() c = torch.trace(confmat).float() s = confmat.sum().float() - return (c * s - sum(tk * pk)) / (torch.sqrt(s ** 2 - sum(pk * pk)) * torch.sqrt(s ** 2 - sum(tk * tk))) + + cov_ytyp = c * s - sum(tk * pk) + cov_ypyp = s ** 2 - sum(pk * pk) + cov_ytyt = s ** 2 - sum(tk * tk) + + if cov_ypyp * cov_ytyt == 0: + return torch.tensor(0, dtype=confmat.dtype, device=confmat.device) + else: + return cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp) def matthews_corrcoef(