-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Progressbar showed different value vs tensorboard and myself output #10653
Comments
Here is a related issue: #9372 The value shown in the progress bar is a running average |
I believe we do a running average just for the loss and not for other metrics. Might be something else happening here. |
Yes, I have always been a bit confused about this feature, but it is quite similar to Keras running mean. |
It is not related to the running average. This applies only to the loss value. For some reason, the epoch-level metrics displayed in the progress bar are delayed by two epochs instead of one. This seems to be a bug coming from the loop or the LoggerConnector. If you observe the value of the Note: The expected behavior would be that the value shown in the progress bar is the epoch-metric from the previous epoch. Here is the repro (issuer did not provide one): import os
import time
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
preds = self(batch)
loss = preds.sum()
self.log("train_loss", loss)
time.sleep(0.1)
self.log("x", self.global_step, prog_bar=True, on_step=False, on_epoch=True)
print(self.global_step)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
max_epochs=6,
enable_model_summary=False,
log_every_n_steps=1,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run() |
🐛 Bug
code:
result:
We can see that the value on the progress bar is different from the output that I print. And the value on tensorboard also is the same as my output, so I guess there is something happening in the progress bar.
I only use one GPU.
Environment
cc @tchaton @rohitgr7 @akihironitta @carmocca @edward-io @ananthsub @kamil-kaczmarek @Raalsky @Blaizzy @SeanNaren @kaushikb11
The text was updated successfully, but these errors were encountered: