Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move checking in r2 for number of samples to compute #426

Merged
merged 4 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use `torch.argmax` instead of `torch.topk` when `k=1` for better performance ([#419](https://github.com/PyTorchLightning/metrics/pull/419))


- Moved check for number of samples in R2 score to support single sample updating ([#426](https://github.com/PyTorchLightning/metrics/pull/426))


### Deprecated

- Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))
Expand Down
6 changes: 6 additions & 0 deletions tests/regression/test_r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def test_error_on_too_few_samples(metric_class=R2Score):
metric = metric_class()
with pytest.raises(ValueError, match="Needs at least two samples to calculate r2 score."):
metric(torch.randn(1), torch.randn(1))
metric.reset()

# calling update twice should still work
metric.update(torch.randn(1), torch.randn(1))
metric.update(torch.randn(1), torch.randn(1))
assert metric.compute()


def test_warning_on_too_large_adjusted(metric_class=R2Score):
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/functional/regression/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Ten
"Expected both prediction and target to be 1D or 2D tensors,"
f" but received tensors with dimension {preds.shape}"
)
if len(preds) < 2:
raise ValueError("Needs at least two samples to calculate r2 score.")

sum_obs = torch.sum(target, dim=0)
sum_squared_obs = torch.sum(target * target, dim=0)
Expand Down Expand Up @@ -77,6 +75,9 @@ def _r2_score_compute(
>>> _r2_score_compute(sum_squared_obs, sum_obs, rss, n_obs, multioutput="raw_values")
tensor([0.9654, 0.9082])
"""
if n_obs < 2:
raise ValueError("Needs at least two samples to calculate r2 score.")

mean_obs = sum_obs / n_obs
tss = sum_squared_obs - sum_obs * mean_obs
raw_scores = 1 - (rss / tss)
Expand Down