Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove special handling of loss in progress bar #16192

Merged
merged 16 commits into from
Jan 16, 2023
7 changes: 7 additions & 0 deletions src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [2.0.0] - 202Y-MM-DD

### Removed

- Removed moving average of training step loss value for the progress bar ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


## [1.9.0] - 202Y-MM-DD

### Added
Expand Down
19 changes: 11 additions & 8 deletions src/pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def get_metrics(self, trainer, model):

def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
r"""
Returns several standard metrics displayed in the progress bar, including the average loss value,
Returns several standard metrics displayed in the progress bar, including the latest loss value,
split index of BPTT (if used) and the version of the experiment when using a logger.

.. code-block::
Expand All @@ -267,17 +267,20 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
Return:
Dictionary with the standard metrics to be displayed in the progress bar.
"""
loss_value = None
loss_metric = None

# call .item() only once but store elements without graphs
running_train_loss = trainer.fit_loop.running_loss.mean()
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
if trainer.training:
loss_metric = trainer.fit_loop._results.get("training_step.train_loss")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if loss_metric is not None:
loss_value = loss_metric.value.cpu().item()
elif pl_module.automatic_optimization:
avg_training_loss = float("NaN")
loss_value = float("NaN")

items_dict: Dict[str, Union[int, str]] = {}
if avg_training_loss is not None:
items_dict["loss"] = f"{avg_training_loss:.3g}"
if loss_value is not None:
items_dict["loss"] = f"{loss_value:.3g}"

if trainer.loggers:
version = _version(trainer.loggers)
Expand Down
25 changes: 1 addition & 24 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning import loops # import as loops to avoid circular imports
Expand All @@ -28,7 +27,7 @@
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
Expand Down Expand Up @@ -60,8 +59,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.accumulated_loss = TensorRunningAccum(window_length=20)
self.running_loss = TensorRunningAccum(window_length=20)
self.optimizer_loop = OptimizerLoop()
self.manual_loop = ManualOptimization()

Expand Down Expand Up @@ -294,11 +291,6 @@ def teardown(self) -> None:
self._results.cpu()
self.optimizer_loop.teardown()
self.manual_loop.teardown()
# release memory
if self.accumulated_loss.memory is not None:
self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
if self.running_loss.memory is not None:
self.running_loss.memory = self.running_loss.memory.cpu()
self.val_loop.teardown()

def on_save_checkpoint(self) -> Dict:
Expand Down Expand Up @@ -554,21 +546,6 @@ def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> Orde
kwargs["batch_idx"] = batch_idx
return kwargs

def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value."""
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(current_loss)

accumulated_loss = self.accumulated_loss.mean()

if accumulated_loss is not None:
# calculate running loss for display
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)

# reset for next set of accumulated grads
self.accumulated_loss.reset()


def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]:
"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element.
Expand Down
10 changes: 1 addition & 9 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
Expand Down Expand Up @@ -104,11 +104,6 @@ def max_steps(self, value: int) -> None:
)
self.epoch_loop.max_steps = value

@property
def running_loss(self) -> TensorRunningAccum:
"""Returns the running loss."""
return self.epoch_loop.running_loss

@Loop.restarting.setter
def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting
Expand Down Expand Up @@ -233,9 +228,6 @@ def on_advance_start(self) -> None:
# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

# stores accumulated grad fractions per batch
self.epoch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches)

self.epoch_progress.increment_ready()

self.trainer._logger_connector.on_epoch_start()
Expand Down
6 changes: 0 additions & 6 deletions src/pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,6 @@ def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimize

result = closure.consume_result()

if result.loss is not None:
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
# TODO: find proper way to handle updating running loss
self.trainer.fit_loop.epoch_loop._update_running_loss(result.loss)

# untoggle model params
self._run_optimization_end(opt_idx)
return result
Expand Down
78 changes: 0 additions & 78 deletions src/pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset
Expand All @@ -33,83 +32,6 @@
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class TensorRunningAccum:
"""Tracks a running accumulation values (min, max, mean) without graph references.

Examples:
>>> accum = TensorRunningAccum(5)
>>> accum.last(), accum.mean()
(None, None)
>>> accum.append(torch.tensor(1.5))
>>> accum.last(), accum.mean()
(tensor(1.5000), tensor(1.5000))
>>> accum.append(torch.tensor(2.5))
>>> accum.last(), accum.mean()
(tensor(2.5000), tensor(2.))
>>> accum.reset()
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
"""

def __init__(self, window_length: int):
self.window_length = window_length
self.reset(window_length)

def reset(self, window_length: Optional[int] = None) -> None:
"""Empty the accumulator."""
if window_length is not None:
self.window_length = window_length
self.memory: Optional[Tensor] = None
self.current_idx: int = 0
self.last_idx: Optional[int] = None
self.rotated: bool = False

def last(self) -> Optional[Tensor]:
"""Get the last added element."""
if self.last_idx is not None:
assert isinstance(self.memory, Tensor)
return self.memory[self.last_idx].float()

def append(self, x: Tensor) -> None:
"""Add an element to the accumulator."""
if self.memory is None:
# tradeoff memory for speed by keeping the memory on device
self.memory = torch.zeros(self.window_length, *x.shape, device=x.device, dtype=x.dtype)

# store without grads
with torch.no_grad():
self.memory[self.current_idx] = x
self.last_idx = self.current_idx

# increase index
self.current_idx += 1

# reset index when hit limit of tensor
self.current_idx = self.current_idx % self.window_length
if self.current_idx == 0:
self.rotated = True

def mean(self) -> Optional[Tensor]:
"""Get mean value from stored elements."""
return self._agg_memory("mean")

def max(self) -> Optional[Tensor]:
"""Get maximal value from stored elements."""
return self._agg_memory("max")

def min(self) -> Optional[Tensor]:
"""Get minimal value from stored elements."""
return self._agg_memory("min")

def _agg_memory(self, how: str) -> Optional[Tensor]:
if self.last_idx is not None:
assert isinstance(self.memory, Tensor)
if self.rotated:
return getattr(self.memory.float(), how)()
return getattr(self.memory[: self.current_idx].float(), how)()


@dataclass
class SharedCycleIteratorState:
"""A state shared between all CycleIterators in a CombinedLoader.
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ def on_train_batch_end(
if self.progress_bar:
self.progress_bar.update()

loss_tensor = trainer.fit_loop.running_loss.last()
# TODO: should we read it from the local variable "outputs"?
loss_tensor = trainer.fit_loop._results["training_step.train_loss"].value
assert loss_tensor is not None
current_loss = loss_tensor.item()
current_step = trainer.global_step
Expand Down
22 changes: 0 additions & 22 deletions tests/tests_pytorch/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,13 @@
CombinedLoader,
CombinedLoaderIterator,
CycleIterator,
TensorRunningAccum,
)
from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf


def test_tensor_running_accum_reset():
"""Test that reset would set all attributes to the initialization state."""

window_length = 10

accum = TensorRunningAccum(window_length=window_length)
assert accum.last() is None
assert accum.mean() is None

accum.append(torch.tensor(1.5))
assert accum.last() == torch.tensor(1.5)
assert accum.mean() == torch.tensor(1.5)

accum.reset()
assert accum.window_length == window_length
assert accum.memory is None
assert accum.current_idx == 0
assert accum.last_idx is None
assert not accum.rotated


def test_cycle_iterator():
"""Test the cycling function of `CycleIterator`"""

Expand Down