Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Progressbar showed different value vs tensorboard and myself output #10653

Closed
Luciennnnnnn opened this issue Nov 20, 2021 · 4 comments · Fixed by #11069 or #11689
Closed

Progressbar showed different value vs tensorboard and myself output #10653

Luciennnnnnn opened this issue Nov 20, 2021 · 4 comments · Fixed by #11069 or #11689
Assignees
Labels
bug Something isn't working logging Related to the `LoggerConnector` and `log()` priority: 0 High priority task progress bar: rich
Milestone

Comments

@Luciennnnnnn
Copy link

Luciennnnnnn commented Nov 20, 2021

🐛 Bug

code:
image

result:
image

We can see that the value on the progress bar is different from the output that I print. And the value on tensorboard also is the same as my output, so I guess there is something happening in the progress bar.

I only use one GPU.

Environment

  • PyTorch Lightning Version: 1.4.6
  • PyTorch Version (e.g., 1.8): 1.8.1

cc @tchaton @rohitgr7 @akihironitta @carmocca @edward-io @ananthsub @kamil-kaczmarek @Raalsky @Blaizzy @SeanNaren @kaushikb11

@Luciennnnnnn Luciennnnnnn added the bug Something isn't working label Nov 20, 2021
@ananthsub
Copy link
Contributor

Here is a related issue: #9372

The value shown in the progress bar is a running average

@rohitgr7
Copy link
Contributor

Here is a related issue: #9372

The value shown in the progress bar is a running average

I believe we do a running average just for the loss and not for other metrics. Might be something else happening here.

@tchaton
Copy link
Contributor

tchaton commented Nov 22, 2021

Yes,

I have always been a bit confused about this feature, but it is quite similar to Keras running mean.

@awaelchli
Copy link
Contributor

It is not related to the running average. This applies only to the loss value.

For some reason, the epoch-level metrics displayed in the progress bar are delayed by two epochs instead of one.
Initially, you can see that in the printout when you compare the value manually logged one epoch earlier with the one in the progress bar (and take into consideration the rounding approximation).

This seems to be a bug coming from the loop or the LoggerConnector. If you observe the value of the progress_bar_metrics property, it is not updated by the time the progress bar accesses it, and therefore displays an old value.

Note: The expected behavior would be that the value shown in the progress bar is the epoch-metric from the previous epoch.

Here is the repro (issuer did not provide one):

import os
import time

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        preds = self(batch)
        loss = preds.sum()
        self.log("train_loss", loss)
        time.sleep(0.1)
        self.log("x", self.global_step, prog_bar=True, on_step=False, on_epoch=True)
        print(self.global_step)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        max_epochs=6,
        enable_model_summary=False,
        log_every_n_steps=1,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

@tchaton tchaton added the priority: 0 High priority task label Nov 23, 2021
@awaelchli awaelchli added the logging Related to the `LoggerConnector` and `log()` label Dec 6, 2021
@awaelchli awaelchli added this to the 1.5.x milestone Dec 6, 2021
@carmocca carmocca self-assigned this Dec 15, 2021
@carmocca carmocca assigned kaushikb11 and unassigned carmocca Dec 16, 2021
@carmocca carmocca reopened this Dec 16, 2021
@carmocca carmocca assigned rohitgr7 and unassigned kaushikb11 Jan 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working logging Related to the `LoggerConnector` and `log()` priority: 0 High priority task progress bar: rich
Projects
None yet
7 participants