-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BinnedPrecisionRecall
not working for multi-dimensional arrays
#663
Comments
Hi! thanks for your contribution!, great first issue! |
I think, since it's working with single value comparison we can just flatten n-dimensional array (N, C, ...) to (N, C, F) which F is flattened axis. |
@omerferhatt mind sending a PR so we can check your suggestion? :] |
@omerferhatt how is it going? would be nice to have it in the next bugfix release... 🐰 |
Issue will be fixed by classification refactor: see this issue #1001 and this PR #1195 for all changes Small recap: This issue describes that multi-dimensional tensors are not supported in import torch
# provided example is for multiclass problems
from torchmetrics.classification import MulticlassPrecisionRecallCurve
target = torch.randint(0, 5, (8, 224, 224))
pred = torch.randn(8, 5, 224, 224).softmax(dim=1)
# using the thresholds argument will choose to use a binning approach for calculating the metric
prcurve = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=3)
prcurve(pred, target)
# (tensor([[0.1996, 0.1996, 0.0000, 1.0000],
# [0.2009, 0.1984, 0.0000, 1.0000],
# [0.2005, 0.2021, 0.0000, 1.0000],
# [0.2002, 0.2000, 0.0000, 1.0000],
# [0.1988, 0.1961, 0.0000, 1.0000]]),
# tensor([[1.0000, 0.0663, 0.0000, 0.0000],
# [1.0000, 0.0652, 0.0000, 0.0000],
# [1.0000, 0.0667, 0.0000, 0.0000],
# [1.0000, 0.0655, 0.0000, 0.0000],
# [1.0000, 0.0659, 0.0000, 0.0000]]),
# tensor([0.0000, 0.5000, 1.0000])) Issue will be closed when #1195 is merged. |
🐛 Bug
Binned Precision Recall Curve not working expected with multi-class and multi-dimensional input/target
To Reproduce
Steps to reproduce the behavior:
Here's stack trace
Code sample
Expected behavior
It should be worked for multi-dimensional data
Environment
conda
,pip
, source): pipThe text was updated successfully, but these errors were encountered: