Skip to content

Commit

Permalink
Remove {running,accumulated} loss
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Sep 8, 2021
1 parent 91ce0d0 commit a979944
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 126 deletions.
11 changes: 0 additions & 11 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,18 +1723,7 @@ def get_progress_bar_dict(self):
Return:
Dictionary with the items to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_train_loss = self.trainer.fit_loop.running_loss.mean()
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
elif self.automatic_optimization:
avg_training_loss = float("NaN")

tqdm_dict = {}
if avg_training_loss is not None:
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"

if self.truncated_bptt_steps > 0:
tqdm_dict["split_idx"] = self.trainer.fit_loop.split_idx

Expand Down
18 changes: 0 additions & 18 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.batch.manual import ManualOptimization
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
Expand All @@ -33,9 +32,7 @@ class TrainingBatchLoop(Loop):

def __init__(self) -> None:
super().__init__()
self.accumulated_loss: Optional[Tensor] = None
self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20)
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: Optional[int] = None
self.optimizer_loop = OptimizerLoop()
Expand Down Expand Up @@ -160,21 +157,6 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]:
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
return splits

def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value."""
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(current_loss)

accumulated_loss = self.accumulated_loss.mean()

if accumulated_loss is not None:
# calculate running loss for display
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)

# reset for next set of accumulated grads
self.accumulated_loss.reset()

def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]:
"""Returns the currently active optimizers. When multiple optimizers are used with different frequencies,
only one of the optimizers is active at a time.
Expand Down
11 changes: 0 additions & 11 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -108,11 +107,6 @@ def max_steps(self, value: int) -> None:
raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {value}.")
self.epoch_loop.max_steps = value

@property
def running_loss(self) -> TensorRunningAccum:
"""Returns the running loss."""
return self.epoch_loop.batch_loop.running_loss

@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
Expand Down Expand Up @@ -213,11 +207,6 @@ def on_advance_start(self) -> None:
# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

# stores accumulated grad fractions per batch
self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum(
window_length=self.trainer.accumulate_grad_batches
)

self.epoch_progress.increment_ready()

def advance(self) -> None:
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/loops/optimizer/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,6 @@ def _run_optimization(

result = closure.get_result()

if result:
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
# TODO: find proper way to handle updating running loss
assert self.trainer.fit_loop is not None
assert self.trainer.fit_loop.epoch_loop is not None
assert self.trainer.fit_loop.epoch_loop.batch_loop is not None
assert result.loss is not None
self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss)

# untoggle model params
self._run_optimization_end(opt_idx)
return result
Expand Down
76 changes: 0 additions & 76 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,82 +36,6 @@
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class TensorRunningAccum:
"""Tracks a running accumulation values (min, max, mean) without graph references.
Examples:
>>> accum = TensorRunningAccum(5)
>>> accum.last(), accum.mean()
(None, None)
>>> accum.append(torch.tensor(1.5))
>>> accum.last(), accum.mean()
(tensor(1.5000), tensor(1.5000))
>>> accum.append(torch.tensor(2.5))
>>> accum.last(), accum.mean()
(tensor(2.5000), tensor(2.))
>>> accum.reset()
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
"""

def __init__(self, window_length: int):
self.window_length = window_length
self.memory = None
self.current_idx: int = 0
self.last_idx: Optional[int] = None
self.rotated: bool = False

def reset(self) -> None:
"""Empty the accumulator."""
self.__init__(self.window_length)

def last(self):
"""Get the last added element."""
if self.last_idx is not None:
return self.memory[self.last_idx]

def append(self, x):
"""Add an element to the accumulator."""
if self.memory is None:
self.memory = torch.zeros(self.window_length, *x.shape)

# ensure same device and type
if self.memory.device != x.device or self.memory.type() != x.type():
x = x.to(self.memory)

# store without grads
with torch.no_grad():
self.memory[self.current_idx] = x
self.last_idx = self.current_idx

# increase index
self.current_idx += 1

# reset index when hit limit of tensor
self.current_idx = self.current_idx % self.window_length
if self.current_idx == 0:
self.rotated = True

def mean(self):
"""Get mean value from stored elements."""
return self._agg_memory("mean")

def max(self):
"""Get maximal value from stored elements."""
return self._agg_memory("max")

def min(self):
"""Get minimal value from stored elements."""
return self._agg_memory("min")

def _agg_memory(self, how: str):
if self.last_idx is not None:
if self.rotated:
return getattr(self.memory, how)()
return getattr(self.memory[: self.current_idx], how)()


@dataclass
class SharedCycleIteratorState:
"""A state shared between all CylceIterators in a CombinedLoader.
Expand Down

0 comments on commit a979944

Please sign in to comment.