Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Precision/Recall/F1 score compared to sklearn #3035

Closed
junwen-austin opened this issue Aug 18, 2020 · 14 comments · Fixed by #3322
Closed

Incorrect Precision/Recall/F1 score compared to sklearn #3035

junwen-austin opened this issue Aug 18, 2020 · 14 comments · Fixed by #3322
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@junwen-austin
Copy link

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Copy the code
  2. Run the code from top to bottom
  3. Compare print results
  4. See Difference between sklearn and Lightning

Code

import torch
import numpy as np
import pytorch_lightning as pl
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

print(pl.__version__)


#### Generate binary data
pl.seed_everything(2020)
n = 10000  # number of samples
y = np.random.choice([0, 1], n)
y_pred = np.random.choice([0, 1], n, p=[0.1, 0.9])
y_tensor = torch.tensor(y)
y_pred_tensor = torch.tensor(y_pred)


# Accuracy appears alright
print('accuracy from sklearn', accuracy_score(y, y_pred))
print('accuracy from lightning functional', pl.metrics.functional.accuracy(y_pred_tensor, y_tensor, num_classes=2))
print('accuracy from lightning tensor', pl.metrics.Accuracy(num_classes=2)(y_pred_tensor, y_tensor))

## results
## accuracy from sklearn 0.4986
## accuracy from lightning functional tensor(0.4986)
## accuracy from lightning tensor tensor(0.4986)

# Precision appears to be off, compared to sklearn
print('precision from sklearn', precision_score(y, y_pred))
print('precision from lightning functional', pl.metrics.functional.precision(y_pred_tensor, y_tensor, num_classes=2))
print('precision from lightning tensor', pl.metrics.Precision(num_classes=2)(y_pred_tensor, y_tensor))

## precision from sklearn 0.5005544466622311
## precision from lightning functional tensor(0.4906)
## precision from lightning tensor tensor(0.4906)

#Recall appears to be off, compared to sklearn
print('recall from sklearn', recall_score(y, y_pred))
print('recall from lightning functional', pl.metrics.functional.recall(y_pred_tensor, y_tensor, num_classes=2))
print('recall from lightning tensor', pl.metrics.Recall(num_classes=2)(y_pred_tensor, y_tensor))

## recall from sklearn 0.8984872611464968
## recall from lightning functional tensor(0.4967)
## recall from lightning tensor tensor(0.4967)

#F1 appears to be off, compared to sklearn
print('F1 from sklearn', f1_score(y, y_pred))
print('F1 from lightning functional', pl.metrics.functional.f1_score(y_pred_tensor, y_tensor, num_classes=2))
print('F1 from lightning tensor', pl.metrics.F1(num_classes=2)(y_pred_tensor, y_tensor))

## F1 from sklearn 0.6429283577837915
## F1 from lightning functional tensor(0.4007)
## F1 from lightning tensor tensor(0.4007)

Expected behavior

Precision/Recall/F1 results are expected to be consistent with those from sklearn.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
  • PyTorch Version : 1.5.1
  • OS (e.g., Linux): MacOS
  • How you installed PyTorch (conda, pip, source): Pip
  • Build command you used (if compiling from source):
  • Python version: 3.7
  • CUDA/cuDNN version: None
  • GPU models and configuration: @@None
  • Any other relevant information:

Additional context

@junwen-austin junwen-austin added bug Something isn't working help wanted Open to be worked on labels Aug 18, 2020
@junwen-austin
Copy link
Author

By the way, Precision/Recall/F1 scores are also off in Pytorch-lightning 0.8.5

@ananyahjha93 ananyahjha93 self-assigned this Aug 18, 2020
@ananyahjha93 ananyahjha93 added the priority: 0 High priority task label Aug 18, 2020
@williamFalcon
Copy link
Contributor

i thought we tested against sklearn?

@Borda
Copy link
Member

Borda commented Aug 18, 2020

@justusschock @SkafteNicki mind have look, pls 🐰

@SkafteNicki
Copy link
Member

Its because we calculate the macro average instead of the micro average which is the default in sklearn

@SkafteNicki
Copy link
Member

At some point we should probably support the different averaging methods that sklearn also have as one averaging method may be more meaningful in some cases (like very unbalanced datasets)

@ananyahjha93 ananyahjha93 removed their assignment Aug 18, 2020
@junwen-austin
Copy link
Author

I figured out the reason why this is a discrepancy:

for binary classification, to recover sklearn, precision/recall/F1 should be done something like below:

pl.metrics.functional.precision(y_pred_tensor, y_tensor, num_classes=2, reduction='none')[1])

where reduction by default is elementwise_mean instead of none, the [1] returns the score for class 1

We can close the issue for now, but it would be really good to update the document to reflect these subtle differences.

For multi-classes, I assume there will be more nuances between Lightning and Sklearn, given different ways of doing average (macro,
micro and so on

@Borda
Copy link
Member

Borda commented Aug 19, 2020

@junwen-austin mind update it docs so we avoid similar questions in future...

@junwen-austin
Copy link
Author

@Borda Yes I plan to do more testing on metrics if you do not mind and then update the docs so that we have more examples. Does this sound good to you?

@Borda
Copy link
Member

Borda commented Aug 21, 2020

@Borda Yes I plan to do more testing on metrics if you do not mind and then update the docs so that we have more examples. Does this sound good to you?

that would be perfect!

@Borda Borda reopened this Aug 21, 2020
@raynardj
Copy link

raynardj commented Sep 1, 2020

🐛 Bug

We can not produce sklearn's micro f1 with PL, right?

  • For some scenario, like classifying 200 classes, with most of the predicted class index is right, micro f1 makes a lot more sense than macro f1
  • Macro f1 for multi-classes problem suffers great fluctuation from batch size, as many classes neither appeared in prediction or label, as illustrated below the tiny batch f1 score.

Steps to reproduce the behavior:

  • Copy the code
  • Run the code from top to bottom
  • Compare print results
  • See Difference between sklearn and Lightning
from sklearn.metrics import f1_score as sklearn_f1
from pytorch_lightning.metrics import F1
import torch

# create sample label
y = torch.randint(high = 199,size = (210,))

print("dummy label/prediction")
print(y)

sk_macro_f1 = sklearn_f1(y.numpy(),y.numpy(),labels=list(range(200)),average = 'macro')
sk_macro_f1_tiny_batch = sklearn_f1(y[:10].numpy(),y[:10].numpy(),
                                    labels=list(range(200)),average = 'macro')
sk_micro_f1 = sklearn_f1(y.numpy(),y.numpy(),labels=list(range(200)),average = 'micro')

pl_f1 = F1(200,reduction = "elementwise_mean")
pl_ele_f1 = pl_f1(y,y)

print(f"""sklearn macro f1:\t{sk_macro_f1}
sklearn macro f1 (tiny batch):\t{sk_macro_f1_tiny_batch}
skelarn micro f1:\t{sk_micro_f1}
pl_elementwise f1:\t{pl_ele_f1}
""")

will output the following, while PL produce the macro f1 0.625, the tiny batch macro f1 is much worse, but the model predicted perfectly

dummy label/prediction
tensor([  4,  61, 120,  64,  60,  18, 182, 123,  65, 149, 145,   2, 182, 154,
         46, 125,  39, 142, 144,  93, 164,  45,  70,  60, 102, 121,  39, 150,
         54, 109,  61, 120, 180,  52, 184, 189,   4,  89,  56,   5,  24, 100,
        194, 148, 152, 133,  75, 141,   6,  76,  93, 160, 173, 164,  13, 134,
        186, 176, 103,  30, 179, 172, 110, 164,  45, 157, 188, 187,  80,  54,
         77,   3,  80, 146,  42,  65,  84, 195, 132,  15,  35, 167, 110,  61,
         38, 197, 151, 102, 193,  78,  77, 169,  93, 129, 162, 168,  97, 190,
        129, 117,  38, 118, 145,  95, 173, 148,  70,  69, 147, 121, 138,  95,
         47,  41, 160, 131, 167, 116, 188, 171,  68, 196,  29,  22, 183,  29,
         90, 157, 179,  13,  26,  89, 148, 166, 193, 125, 100,  74, 130, 187,
         79, 166, 166, 131, 147, 191,  11, 147, 101, 139,  94,  20,  22, 187,
        149,  61,  55, 141, 176, 120, 152, 187, 146, 197, 192, 180, 180,  68,
          1, 115, 142,   5, 161,  77,  54, 115, 175,  39, 110,  68, 151,  98,
        102, 147,  37,  42, 154,  53, 105, 170, 114, 109,  53,  16,  62,  57,
         75,  79,  33,  42,  74,  92, 130, 151,  50, 112, 174, 113,  69,  34])
sklearn macro f1:	0.65
sklearn macro f1 (tiny batch):	0.05
skelarn micro f1:	1.0
pl_elementwise f1:	0.6499999761581421

@Borda @SkafteNicki

@justusschock
Copy link
Member

@raynardj We are already tracking it in this issue and it will be part of our new aggregation system. However this may take a while to lay out.

@raynardj
Copy link

raynardj commented Sep 2, 2020

@raynardj We are already tracking it in this issue and it will be part of our new aggregation system. However this may take a while to lay out.

I'm also in the slack by the same user name, anything I can contribute to the matter?

@SkafteNicki
Copy link
Member

@raynardj if you want to help, please write to me on slack (username Nicki Skafte), as I already have some code ready that you could help finish :]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants