Skip to content

Commit

Permalink
Remove special handling of loss in progress bar (#16192)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
2 people authored and lantiga committed Jan 19, 2023
1 parent 83f6d93 commit cf0952b
Show file tree
Hide file tree
Showing 12 changed files with 9 additions and 192 deletions.
1 change: 1 addition & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `LoggerConnector.on_train_split_start` method

- Removed the `LightningModule.precision` attribute ([#16203](https://github.com/Lightning-AI/lightning/pull/16203))
- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))


### Fixed
Expand Down
20 changes: 5 additions & 15 deletions src/pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_metrics(self, trainer, model):
Return:
Dictionary with the items to be displayed in the progress bar.
"""
standard_metrics = get_standard_metrics(trainer, pl_module)
standard_metrics = get_standard_metrics(trainer)
pbar_metrics = trainer.progress_bar_metrics
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
Expand All @@ -255,30 +255,20 @@ def get_metrics(self, trainer, model):
return {**standard_metrics, **pbar_metrics}


def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]:
r"""
Returns several standard metrics displayed in the progress bar, including the average loss value,
split index of BPTT (if used) and the version of the experiment when using a logger.
Returns the standard metrics displayed in the progress bar.
Currently, it only includes the version of the experiment when using a logger.
.. code-block::
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, v_num=10]
Return:
Dictionary with the standard metrics to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_train_loss = trainer.fit_loop.running_loss.mean()
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
elif pl_module.automatic_optimization:
avg_training_loss = float("NaN")

items_dict: Dict[str, Union[int, str]] = {}
if avg_training_loss is not None:
items_dict["loss"] = f"{avg_training_loss:.3g}"

if trainer.loggers:
version = _version(trainer.loggers)
if version is not None:
Expand Down
25 changes: 1 addition & 24 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning import loops # import as loops to avoid circular imports
Expand All @@ -28,7 +27,7 @@
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
Expand Down Expand Up @@ -60,8 +59,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.accumulated_loss = TensorRunningAccum(window_length=20)
self.running_loss = TensorRunningAccum(window_length=20)
self.optimizer_loop = OptimizerLoop()
self.manual_loop = ManualOptimization()

Expand Down Expand Up @@ -294,11 +291,6 @@ def teardown(self) -> None:
self._results.cpu()
self.optimizer_loop.teardown()
self.manual_loop.teardown()
# release memory
if self.accumulated_loss.memory is not None:
self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
if self.running_loss.memory is not None:
self.running_loss.memory = self.running_loss.memory.cpu()
self.val_loop.teardown()

def on_save_checkpoint(self) -> Dict:
Expand Down Expand Up @@ -554,21 +546,6 @@ def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> Orde
kwargs["batch_idx"] = batch_idx
return kwargs

def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value."""
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(current_loss)

accumulated_loss = self.accumulated_loss.mean()

if accumulated_loss is not None:
# calculate running loss for display
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)

# reset for next set of accumulated grads
self.accumulated_loss.reset()


def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]:
"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element.
Expand Down
10 changes: 1 addition & 9 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
Expand Down Expand Up @@ -104,11 +104,6 @@ def max_steps(self, value: int) -> None:
)
self.epoch_loop.max_steps = value

@property
def running_loss(self) -> TensorRunningAccum:
"""Returns the running loss."""
return self.epoch_loop.running_loss

@Loop.restarting.setter
def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting
Expand Down Expand Up @@ -233,9 +228,6 @@ def on_advance_start(self) -> None:
# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

# stores accumulated grad fractions per batch
self.epoch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches)

self.epoch_progress.increment_ready()

self.trainer._logger_connector.on_epoch_start()
Expand Down
6 changes: 0 additions & 6 deletions src/pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,6 @@ def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimize

result = closure.consume_result()

if result.loss is not None:
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
# TODO: find proper way to handle updating running loss
self.trainer.fit_loop.epoch_loop._update_running_loss(result.loss)

# untoggle model params
self._run_optimization_end(opt_idx)
return result
Expand Down
78 changes: 0 additions & 78 deletions src/pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset
Expand All @@ -33,83 +32,6 @@
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class TensorRunningAccum:
"""Tracks a running accumulation values (min, max, mean) without graph references.
Examples:
>>> accum = TensorRunningAccum(5)
>>> accum.last(), accum.mean()
(None, None)
>>> accum.append(torch.tensor(1.5))
>>> accum.last(), accum.mean()
(tensor(1.5000), tensor(1.5000))
>>> accum.append(torch.tensor(2.5))
>>> accum.last(), accum.mean()
(tensor(2.5000), tensor(2.))
>>> accum.reset()
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
"""

def __init__(self, window_length: int):
self.window_length = window_length
self.reset(window_length)

def reset(self, window_length: Optional[int] = None) -> None:
"""Empty the accumulator."""
if window_length is not None:
self.window_length = window_length
self.memory: Optional[Tensor] = None
self.current_idx: int = 0
self.last_idx: Optional[int] = None
self.rotated: bool = False

def last(self) -> Optional[Tensor]:
"""Get the last added element."""
if self.last_idx is not None:
assert isinstance(self.memory, Tensor)
return self.memory[self.last_idx].float()

def append(self, x: Tensor) -> None:
"""Add an element to the accumulator."""
if self.memory is None:
# tradeoff memory for speed by keeping the memory on device
self.memory = torch.zeros(self.window_length, *x.shape, device=x.device, dtype=x.dtype)

# store without grads
with torch.no_grad():
self.memory[self.current_idx] = x
self.last_idx = self.current_idx

# increase index
self.current_idx += 1

# reset index when hit limit of tensor
self.current_idx = self.current_idx % self.window_length
if self.current_idx == 0:
self.rotated = True

def mean(self) -> Optional[Tensor]:
"""Get mean value from stored elements."""
return self._agg_memory("mean")

def max(self) -> Optional[Tensor]:
"""Get maximal value from stored elements."""
return self._agg_memory("max")

def min(self) -> Optional[Tensor]:
"""Get minimal value from stored elements."""
return self._agg_memory("min")

def _agg_memory(self, how: str) -> Optional[Tensor]:
if self.last_idx is not None:
assert isinstance(self.memory, Tensor)
if self.rotated:
return getattr(self.memory.float(), how)()
return getattr(self.memory[: self.current_idx].float(), how)()


@dataclass
class SharedCycleIteratorState:
"""A state shared between all CycleIterators in a CombinedLoader.
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def on_train_batch_end(
if self.progress_bar:
self.progress_bar.update()

loss_tensor = trainer.fit_loop.running_loss.last()
loss_tensor = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"]
assert loss_tensor is not None
current_loss = loss_tensor.item()
current_step = trainer.global_step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,6 @@ class MockedProgressBar(RichProgressBar):
def get_metrics(self, trainer, pl_module):
items = super().get_metrics(trainer, model)
del items["v_num"]
del items["loss"]
# this is equivalent to mocking `set_postfix` as this method gets called every time
self.calls[trainer.state.fn].append(
(trainer.state.stage, trainer.current_epoch, trainer.global_step, items)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,6 @@ def get_metrics(self, trainer: Trainer, model: LightningModule):
model = BoringModel()
trainer.fit(model)
standard_metrics = progress_bar.get_metrics(trainer, model)
assert "loss" in standard_metrics.keys()
assert "v_num" not in standard_metrics.keys()


Expand All @@ -673,7 +672,6 @@ class MockedProgressBar(TQDMProgressBar):
def get_metrics(self, trainer, pl_module):
items = super().get_metrics(trainer, model)
del items["v_num"]
del items["loss"]
# this is equivalent to mocking `set_postfix` as this method gets called every time
self.calls[trainer.state.fn].append(
(trainer.state.stage, trainer.current_epoch, trainer.global_step, items)
Expand Down
15 changes: 0 additions & 15 deletions tests/tests_pytorch/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,21 +352,6 @@ def test_epoch_end(self, outputs):
trainer.test(model)


def test_logging_to_progress_bar_with_reserved_key(tmpdir):
"""Test that logging a metric with a reserved name to the progress bar raises a warning."""

class TestModel(BoringModel):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
self.log("loss", output["loss"], prog_bar=True)
return output

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
trainer.fit(model)


@pytest.mark.parametrize("add_dataloader_idx", [False, True])
def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx):
"""test that auto_add_dataloader_idx argument works."""
Expand Down
19 changes: 0 additions & 19 deletions tests/tests_pytorch/trainer/logging_/test_progress_bar_logging.py

This file was deleted.

22 changes: 0 additions & 22 deletions tests/tests_pytorch/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,13 @@
CombinedLoader,
CombinedLoaderIterator,
CycleIterator,
TensorRunningAccum,
)
from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf


def test_tensor_running_accum_reset():
"""Test that reset would set all attributes to the initialization state."""

window_length = 10

accum = TensorRunningAccum(window_length=window_length)
assert accum.last() is None
assert accum.mean() is None

accum.append(torch.tensor(1.5))
assert accum.last() == torch.tensor(1.5)
assert accum.mean() == torch.tensor(1.5)

accum.reset()
assert accum.window_length == window_length
assert accum.memory is None
assert accum.current_idx == 0
assert accum.last_idx is None
assert not accum.rotated


def test_cycle_iterator():
"""Test the cycling function of `CycleIterator`"""

Expand Down

0 comments on commit cf0952b

Please sign in to comment.