From 9d2703541b4845048013b24c3419580249c947a8 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 4 Aug 2021 09:29:11 +0200 Subject: [PATCH 1/2] fix --- tests/regression/test_r2.py | 6 ++++++ torchmetrics/functional/regression/r2.py | 5 +++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/regression/test_r2.py b/tests/regression/test_r2.py index d723c9478cb..6882191b870 100644 --- a/tests/regression/test_r2.py +++ b/tests/regression/test_r2.py @@ -143,6 +143,12 @@ def test_error_on_too_few_samples(metric_class=R2Score): metric = metric_class() with pytest.raises(ValueError, match="Needs at least two samples to calculate r2 score."): metric(torch.randn(1), torch.randn(1)) + metric.reset() + + # calling update twice should still work + metric.update(torch.randn(1), torch.randn(1)) + metric.update(torch.randn(1), torch.randn(1)) + assert metric.compute() def test_warning_on_too_large_adjusted(metric_class=R2Score): diff --git a/torchmetrics/functional/regression/r2.py b/torchmetrics/functional/regression/r2.py index 23e3959aee8..7ee01bf2ff7 100644 --- a/torchmetrics/functional/regression/r2.py +++ b/torchmetrics/functional/regression/r2.py @@ -34,8 +34,6 @@ def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Ten "Expected both prediction and target to be 1D or 2D tensors," f" but received tensors with dimension {preds.shape}" ) - if len(preds) < 2: - raise ValueError("Needs at least two samples to calculate r2 score.") sum_obs = torch.sum(target, dim=0) sum_squared_obs = torch.sum(target * target, dim=0) @@ -77,6 +75,9 @@ def _r2_score_compute( >>> _r2_score_compute(sum_squared_obs, sum_obs, rss, n_obs, multioutput="raw_values") tensor([0.9654, 0.9082]) """ + if n_obs < 2: + raise ValueError("Needs at least two samples to calculate r2 score.") + mean_obs = sum_obs / n_obs tss = sum_squared_obs - sum_obs * mean_obs raw_scores = 1 - (rss / tss) From 25c3de909493db26dcda2190ec20198b5e33c70b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 4 Aug 2021 09:35:17 +0200 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f30ab1d83e..a6c5885b728 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `torch.argmax` instead of `torch.topk` when `k=1` for better performance ([#419](https://github.com/PyTorchLightning/metrics/pull/419)) +- Moved check for number of samples in R2 score to support single sample updating ([#426](https://github.com/PyTorchLightning/metrics/pull/426)) + + ### Deprecated - Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))