We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
r2_score
target
With (near) constant targets, the r2_score can be larger than 1 which shouldn't be possible.
targets
See the code example below.
import torch from torchmetrics.functional import r2_score from sklearn.metrics import r2_score as sklearn_r2_score y_true = torch.tensor([-5.1608, -5.1609, -5.1608, -5.1608, -5.1608, -5.1608]) y_pred = torch.tensor([-3.9865, -5.4648, -5.0238, -4.3899, -5.6672, -4.7336]) score = r2_score(preds=y_pred, target=y_true) print(score) # prints 'tensor(82685.5312)' score = sklearn_r2_score(y_true=y_true.numpy(), y_pred=y_pred.numpy()) print(score) # prints '-301979050.37719727'
The r2_score should never be above 1. The expected behavior would be to return a negative value.
My environment was installed through mamba.
mamba
0.11.0
3.9.16
1.12.1
python 3.9.16 h2782a2a_0_cpython conda-forge pytorch 1.12.1 cuda112py39hb0b7ed5_201 conda-forge pytorch-lightning 1.8.1 pyhd8ed1ab_0 conda-forge pytorch_geometric 2.2.0 pyhd8ed1ab_0 conda-forge pytorch_scatter 2.1.0 cuda112py39h83a068c_0 conda-forge pytorch_sparse 0.6.15 py39h83a068c_0 conda-forge torchmetrics 0.11.0 pyhd8ed1ab_0 conda-forge torchvision 0.13.0 cuda112py39hd2c45b6_0 conda-forge
The text was updated successfully, but these errors were encountered:
R2Score
Successfully merging a pull request may close this issue.
🐛 Bug
With (near) constant
targets
, ther2_score
can be larger than 1 which shouldn't be possible.To Reproduce
See the code example below.
Code sample
Expected behavior
The
r2_score
should never be above 1. The expected behavior would be to return a negative value.Environment
My environment was installed through
mamba
.0.11.0
3.9.16
(Python) and1.12.1
(PyTorch)The text was updated successfully, but these errors were encountered: