Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AUC score legend not plotted when score=True #1941

Closed
yztxwd opened this issue Jul 26, 2023 · 1 comment · Fixed by #1948
Closed

AUC score legend not plotted when score=True #1941

yztxwd opened this issue Jul 26, 2023 · 1 comment · Fixed by #1948
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@yztxwd
Copy link

yztxwd commented Jul 26, 2023

🐛 Bug

When calling the .plot(score=True) on curve metrics like BinaryPrecisionRecall, BinaryROC.... The AUC score is not shown.

To Reproduce

Steps to reproduce the behavior...

from torch import randn, randint
from torchmetrics.classification import BinaryPrecisionRecallCurve
import torch.nn.functional as F
preds = F.softmax(randn(20, 2), dim=1)
target = randint(2, (20,))
metric = BinaryPrecisionRecallCurve()
metric.update(preds[:, 1], target)
fig_, ax_ = metric.plot(score=True)
Code sample

Expected behavior

Expect a legend showing the AUC score.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 1.01, from pip
  • Python & PyTorch Version (e.g., 1.0): torch 1.13, python 3.10.11
  • Any other relevant information such as OS (e.g., Linux): Linux

Additional context

The bug seems be due to the score is computed when curve is None and score is True, however curve is computed before this expression, so curve is always not None and score will never be computed, here is an example of BinaryROC:

# line 156-164 in src/torchmetrics/classification/roc.py
        curve = curve or self.compute()
        score = _auc_compute_without_check(curve[0], curve[1], 1.0) if not curve and score is True else None
        return plot_curve(
            curve,
            score=score,
            ax=ax,
            label_names=("False positive rate", "True positive rate"),
            name=self.__class__.__name__,
        )

it works as expected after changing it to:

        curve_computed = curve or self.compute()
        score = _auc_compute_without_check(curve_computed[0], curve_computed[1], 1.0) if not curve and score is True else None
        return plot_curve(
            curve_computed,
            score=score,
            ax=ax,
            label_names=("False positive rate", "True positive rate"),
            name=self.__class__.__name__,
        )
@yztxwd yztxwd added bug / fix Something isn't working help wanted Extra attention is needed labels Jul 26, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant