Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong numbers for multiclass acc/prec/f1 if some class is not listed in true/pred tensors #295

Closed
notonlyvandalzzz opened this issue Jun 15, 2021 · 7 comments · Fixed by #303
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed Priority Critical task/issue
Milestone

Comments

@notonlyvandalzzz
Copy link

🐛 Bug

TM's output for precision, accuracy, f1 differs from sklearn' version of same metrics with average='macro' setting used and labels in pred/true tensors for all classes except one

To Reproduce

Steps to reproduce the behavior:

  1. Get one tensor with classlabels as y_pred, copy it to y_true, and pull thru torchmetrics.functional.f1/accuracy/precision with average="macro", num_classes=number of your classes used for that.
  2. Convert that tensors to lists and pass into sklearn metrics, of course with average=macro where its possible
  3. Compare numbers

Code sample

y_pred = y_true = torch.tensor([1, 6, 6, 6, 3, 6, 3, 6, 6, 3, 6, 3, 6, 6, 3, 6, 3, 3, 6, 6, 6, 6, 6, 6,
        6, 3, 5, 6, 6, 3, 6, 6, 6, 6, 3, 6, 6, 3, 3, 1, 3, 1, 6, 3, 3, 1, 3, 6,
        4, 6, 6, 6, 6, 6, 6, 3, 3, 6, 3, 6, 1, 0, 5, 3, 6, 6, 6, 6, 3, 0, 6, 3,
        3, 3, 6, 3, 4]) # no class with id=2 !
print(f'Metrics:')
print(f'F1: {f1_score(y_true.numpy().astype(int).tolist(), y_pred.numpy().astype(int).tolist(), average="macro")}')
print(f'ACC: {accuracy_score(y_true.numpy().astype(int).tolist(), y_pred.numpy().astype(int).tolist())}')
print(f'PREC: {precision_score(y_true.numpy().astype(int).tolist(), y_pred.numpy().astype(int).tolist(), average="macro")}')
print(f'TM F1: {torchmetrics.functional.f1(y_pred, y_true, average="macro", num_classes=7)}')
print(f'TM acc: {torchmetrics.functional.accuracy(y_pred, y_true, average="macro", num_classes=7)}')
print(f'TM prec: {torchmetrics.functional.precision(y_pred, y_true, average="macro", num_classes=7)}')
Metrics:
F1: 1.0
ACC: 1.0
PREC: 1.0
TM F1: 0.8571428656578064
TM acc: 0.8571428656578064
TM prec: 0.8571428656578064

Expected behavior

They're should be 1.0, similar to sklearn version
But adding even single item with id=2 makes numbers right, so its definitely a bug when label is exist but no examples for it in pred and true tensors

Environment

  • CUDA:
    - GPU:
    - NVIDIA Quadro RTX 5000
    - available: True
    - version: 11.0
  • Packages:
    - numpy: 1.20.3
    - pyTorch_debug: True
    - pyTorch_version: 1.7.0+cu110
    - pytorch-lightning: 1.3.4
    - torchmetrics: 0.3.2
    - tqdm: 4.51.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    -
    - processor: x86_64
    - python: 3.7.10
    - version: Add metric code from lightning #1 SMP Sun Feb 14 18:10:38 EST 2021
@notonlyvandalzzz notonlyvandalzzz added bug / fix Something isn't working help wanted Extra attention is needed labels Jun 15, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@notonlyvandalzzz notonlyvandalzzz changed the title Wrong numbers for multiclass acc/prec/f1 if some class is missing in labels Wrong numbers for multiclass acc/prec/f1 if some class is not listed in true/pred tensors Jun 15, 2021
@Borda Borda added the Priority Critical task/issue label Jun 15, 2021
@vatch123
Copy link
Contributor

Hi @Borda I can have a look at this if someone is already not involved here.

@Borda
Copy link
Member

Borda commented Jun 17, 2021

Hi @Borda I can have a look at this if someone is already not involved here.

That would be great, thank you!

@Borda Borda assigned vatch123 and unassigned SkafteNicki Jun 17, 2021
@SkafteNicki
Copy link
Member

Hi @vatch123
Thank you for wanting to contribute. Just wanted to let you know that I debugged this a bit I think it can be solved by changing this line
https://github.com/PyTorchLightning/metrics/blob/1841cad3839f5d1907a1bb8bb6a266de5c5333f9/torchmetrics/functional/classification/stat_scores.py#L186
to also account for the zero_div_mask so

weights = torch.where(zero_div_mask | ignore_mask, tensor(0.0, device=weights.device), weights)

@vatch123
Copy link
Contributor

Hi @vatch123
Thank you for wanting to contribute. Just wanted to let you know that I debugged this a bit I think it can be solved by changing this line
https://github.com/PyTorchLightning/metrics/blob/1841cad3839f5d1907a1bb8bb6a266de5c5333f9/torchmetrics/functional/classification/stat_scores.py#L186

to also account for the zero_div_mask so

weights = torch.where(zero_div_mask | ignore_mask, tensor(0.0, device=weights.device), weights)

Hi. Thanks for this. Let me have a look at this.

@celsofranssa
Copy link

Any update on this?

@SkafteNicki
Copy link
Member

@celsofranssa we have a open PR #303 that should fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed Priority Critical task/issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants