diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 98ec7556bed49..4a1cedba6d003 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -50,6 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Removed the `LoggerConnector.on_train_split_start` method - Removed the `LightningModule.precision` attribute ([#16203](https://github.com/Lightning-AI/lightning/pull/16203)) +- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192)) ## [1.9.0] - 2023-01-17 diff --git a/src/pytorch_lightning/callbacks/progress/base.py b/src/pytorch_lightning/callbacks/progress/base.py index b757325143c62..749a3c9e83211 100644 --- a/src/pytorch_lightning/callbacks/progress/base.py +++ b/src/pytorch_lightning/callbacks/progress/base.py @@ -242,7 +242,7 @@ def get_metrics(self, trainer, model): Return: Dictionary with the items to be displayed in the progress bar. """ - standard_metrics = get_standard_metrics(trainer, pl_module) + standard_metrics = get_standard_metrics(trainer) pbar_metrics = trainer.progress_bar_metrics duplicates = list(standard_metrics.keys() & pbar_metrics.keys()) if duplicates: @@ -255,30 +255,20 @@ def get_metrics(self, trainer, model): return {**standard_metrics, **pbar_metrics} -def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: +def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: r""" - Returns several standard metrics displayed in the progress bar, including the average loss value, - split index of BPTT (if used) and the version of the experiment when using a logger. + Returns the standard metrics displayed in the progress bar. + Currently, it only includes the version of the experiment when using a logger. .. code-block:: - Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10] + Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, v_num=10] Return: Dictionary with the standard metrics to be displayed in the progress bar. """ - # call .item() only once but store elements without graphs - running_train_loss = trainer.fit_loop.running_loss.mean() - avg_training_loss = None - if running_train_loss is not None: - avg_training_loss = running_train_loss.cpu().item() - elif pl_module.automatic_optimization: - avg_training_loss = float("NaN") items_dict: Dict[str, Union[int, str]] = {} - if avg_training_loss is not None: - items_dict["loss"] = f"{avg_training_loss:.3g}" - if trainer.loggers: version = _version(trainer.loggers) if version is not None: diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 980d6aa133c3d..e63569f67a960 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -18,7 +18,6 @@ import numpy as np import torch from lightning_utilities.core.apply_func import apply_to_collection -from torch import Tensor import pytorch_lightning as pl from pytorch_lightning import loops # import as loops to avoid circular imports @@ -28,7 +27,7 @@ from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress -from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher @@ -60,8 +59,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() - self.accumulated_loss = TensorRunningAccum(window_length=20) - self.running_loss = TensorRunningAccum(window_length=20) self.optimizer_loop = OptimizerLoop() self.manual_loop = ManualOptimization() @@ -294,11 +291,6 @@ def teardown(self) -> None: self._results.cpu() self.optimizer_loop.teardown() self.manual_loop.teardown() - # release memory - if self.accumulated_loss.memory is not None: - self.accumulated_loss.memory = self.accumulated_loss.memory.cpu() - if self.running_loss.memory is not None: - self.running_loss.memory = self.running_loss.memory.cpu() self.val_loop.teardown() def on_save_checkpoint(self) -> Dict: @@ -554,21 +546,6 @@ def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> Orde kwargs["batch_idx"] = batch_idx return kwargs - def _update_running_loss(self, current_loss: Tensor) -> None: - """Updates the running loss value with the current value.""" - if self.trainer.lightning_module.automatic_optimization: - # track total loss for logging (avoid mem leaks) - self.accumulated_loss.append(current_loss) - - accumulated_loss = self.accumulated_loss.mean() - - if accumulated_loss is not None: - # calculate running loss for display - self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) - - # reset for next set of accumulated grads - self.accumulated_loss.reset() - def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]: """Converts an optimizer dict to a list in which the key of the dict determines the position of the element. diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index c0ac03cc98a95..341be3932deb3 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -23,7 +23,7 @@ from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import Progress -from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( AbstractDataFetcher, @@ -104,11 +104,6 @@ def max_steps(self, value: int) -> None: ) self.epoch_loop.max_steps = value - @property - def running_loss(self) -> TensorRunningAccum: - """Returns the running loss.""" - return self.epoch_loop.running_loss - @Loop.restarting.setter def restarting(self, restarting: bool) -> None: # if the last epoch completely finished, we are not actually restarting @@ -233,9 +228,6 @@ def on_advance_start(self) -> None: # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) - # stores accumulated grad fractions per batch - self.epoch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches) - self.epoch_progress.increment_ready() self.trainer._logger_connector.on_epoch_start() diff --git a/src/pytorch_lightning/loops/optimization/optimizer_loop.py b/src/pytorch_lightning/loops/optimization/optimizer_loop.py index 0158f97f686cf..e45fe2be49aa4 100644 --- a/src/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/src/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -242,12 +242,6 @@ def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimize result = closure.consume_result() - if result.loss is not None: - # if no result, user decided to skip optimization - # otherwise update running loss + reset accumulated loss - # TODO: find proper way to handle updating running loss - self.trainer.fit_loop.epoch_loop._update_running_loss(result.loss) - # untoggle model params self._run_optimization_end(opt_idx) return result diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py index a1fe23ea02e06..af84c27c40a6f 100644 --- a/src/pytorch_lightning/trainer/supporters.py +++ b/src/pytorch_lightning/trainer/supporters.py @@ -18,7 +18,6 @@ import torch from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections -from torch import Tensor from torch.utils.data import Dataset from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset @@ -33,83 +32,6 @@ from pytorch_lightning.utilities.imports import _fault_tolerant_training -class TensorRunningAccum: - """Tracks a running accumulation values (min, max, mean) without graph references. - - Examples: - >>> accum = TensorRunningAccum(5) - >>> accum.last(), accum.mean() - (None, None) - >>> accum.append(torch.tensor(1.5)) - >>> accum.last(), accum.mean() - (tensor(1.5000), tensor(1.5000)) - >>> accum.append(torch.tensor(2.5)) - >>> accum.last(), accum.mean() - (tensor(2.5000), tensor(2.)) - >>> accum.reset() - >>> _= [accum.append(torch.tensor(i)) for i in range(13)] - >>> accum.last(), accum.mean(), accum.min(), accum.max() - (tensor(12.), tensor(10.), tensor(8.), tensor(12.)) - """ - - def __init__(self, window_length: int): - self.window_length = window_length - self.reset(window_length) - - def reset(self, window_length: Optional[int] = None) -> None: - """Empty the accumulator.""" - if window_length is not None: - self.window_length = window_length - self.memory: Optional[Tensor] = None - self.current_idx: int = 0 - self.last_idx: Optional[int] = None - self.rotated: bool = False - - def last(self) -> Optional[Tensor]: - """Get the last added element.""" - if self.last_idx is not None: - assert isinstance(self.memory, Tensor) - return self.memory[self.last_idx].float() - - def append(self, x: Tensor) -> None: - """Add an element to the accumulator.""" - if self.memory is None: - # tradeoff memory for speed by keeping the memory on device - self.memory = torch.zeros(self.window_length, *x.shape, device=x.device, dtype=x.dtype) - - # store without grads - with torch.no_grad(): - self.memory[self.current_idx] = x - self.last_idx = self.current_idx - - # increase index - self.current_idx += 1 - - # reset index when hit limit of tensor - self.current_idx = self.current_idx % self.window_length - if self.current_idx == 0: - self.rotated = True - - def mean(self) -> Optional[Tensor]: - """Get mean value from stored elements.""" - return self._agg_memory("mean") - - def max(self) -> Optional[Tensor]: - """Get maximal value from stored elements.""" - return self._agg_memory("max") - - def min(self) -> Optional[Tensor]: - """Get minimal value from stored elements.""" - return self._agg_memory("min") - - def _agg_memory(self, how: str) -> Optional[Tensor]: - if self.last_idx is not None: - assert isinstance(self.memory, Tensor) - if self.rotated: - return getattr(self.memory.float(), how)() - return getattr(self.memory[: self.current_idx].float(), how)() - - @dataclass class SharedCycleIteratorState: """A state shared between all CycleIterators in a CombinedLoader. diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 563e65653f27f..e83ef6eb89511 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -379,7 +379,7 @@ def on_train_batch_end( if self.progress_bar: self.progress_bar.update() - loss_tensor = trainer.fit_loop.running_loss.last() + loss_tensor = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"] assert loss_tensor is not None current_loss = loss_tensor.item() current_step = trainer.global_step diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index b1a7082ef6448..a400a22628100 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -353,7 +353,6 @@ class MockedProgressBar(RichProgressBar): def get_metrics(self, trainer, pl_module): items = super().get_metrics(trainer, model) del items["v_num"] - del items["loss"] # this is equivalent to mocking `set_postfix` as this method gets called every time self.calls[trainer.state.fn].append( (trainer.state.stage, trainer.current_epoch, trainer.global_step, items) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index c0f93d092459a..865c65cefd4b3 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -660,7 +660,6 @@ def get_metrics(self, trainer: Trainer, model: LightningModule): model = BoringModel() trainer.fit(model) standard_metrics = progress_bar.get_metrics(trainer, model) - assert "loss" in standard_metrics.keys() assert "v_num" not in standard_metrics.keys() @@ -673,7 +672,6 @@ class MockedProgressBar(TQDMProgressBar): def get_metrics(self, trainer, pl_module): items = super().get_metrics(trainer, model) del items["v_num"] - del items["loss"] # this is equivalent to mocking `set_postfix` as this method gets called every time self.calls[trainer.state.fn].append( (trainer.state.stage, trainer.current_epoch, trainer.global_step, items) diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index e2308ea18389c..93bac56264e14 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -352,21 +352,6 @@ def test_epoch_end(self, outputs): trainer.test(model) -def test_logging_to_progress_bar_with_reserved_key(tmpdir): - """Test that logging a metric with a reserved name to the progress bar raises a warning.""" - - class TestModel(BoringModel): - def training_step(self, *args, **kwargs): - output = super().training_step(*args, **kwargs) - self.log("loss", output["loss"], prog_bar=True) - return output - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): - trainer.fit(model) - - @pytest.mark.parametrize("add_dataloader_idx", [False, True]) def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx): """test that auto_add_dataloader_idx argument works.""" diff --git a/tests/tests_pytorch/trainer/logging_/test_progress_bar_logging.py b/tests/tests_pytorch/trainer/logging_/test_progress_bar_logging.py deleted file mode 100644 index 1524c117588c5..0000000000000 --- a/tests/tests_pytorch/trainer/logging_/test_progress_bar_logging.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - -from pytorch_lightning import Trainer -from pytorch_lightning.demos.boring_classes import BoringModel - - -def test_logging_to_progress_bar_with_reserved_key(tmpdir): - """Test that logging a metric with a reserved name to the progress bar raises a warning.""" - - class TestModel(BoringModel): - def training_step(self, *args, **kwargs): - output = super().training_step(*args, **kwargs) - self.log("loss", output["loss"], prog_bar=True) - return output - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, max_steps=2) - with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): - trainer.fit(model) diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index 8533ad5fdb467..b5f1d9bdf776c 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -32,7 +32,6 @@ CombinedLoader, CombinedLoaderIterator, CycleIterator, - TensorRunningAccum, ) from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler from pytorch_lightning.utilities.data import get_len @@ -40,27 +39,6 @@ from tests_pytorch.helpers.runif import RunIf -def test_tensor_running_accum_reset(): - """Test that reset would set all attributes to the initialization state.""" - - window_length = 10 - - accum = TensorRunningAccum(window_length=window_length) - assert accum.last() is None - assert accum.mean() is None - - accum.append(torch.tensor(1.5)) - assert accum.last() == torch.tensor(1.5) - assert accum.mean() == torch.tensor(1.5) - - accum.reset() - assert accum.window_length == window_length - assert accum.memory is None - assert accum.current_idx == 0 - assert accum.last_idx is None - assert not accum.rotated - - def test_cycle_iterator(): """Test the cycling function of `CycleIterator`"""