Skip to content

Commit

Permalink
Integrate total_batch_idx with progress tracking (Lightning-AI#8598)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and four4fish committed Aug 16, 2021
1 parent 76a47fb commit 7f7c521
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))


Expand Down
17 changes: 9 additions & 8 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def __init__(self, min_steps: int, max_steps: int):
self.min_steps: int = min_steps
self.max_steps: int = max_steps
self.global_step: int = 0
# the total batch index across all epochs
self.total_batch_idx: int = 0
# manually tracking which is the last batch is necessary for iterable dataset support
self.is_last_batch: Optional[bool] = None
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()
Expand All @@ -53,6 +52,13 @@ def __init__(self, min_steps: int, max_steps: int):
self._warning_cache: WarningCache = WarningCache()
self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None

@property
def total_batch_idx(self) -> int:
"""Returns the current batch index (across epochs)"""
# use `ready` instead of `completed` in case this is accessed after `completed` has been increased
# but before the next `ready` increase
return self.batch_progress.total.ready - 1

@property
def batch_idx(self) -> int:
"""Returns the current batch index (within this epoch)"""
Expand Down Expand Up @@ -176,14 +182,9 @@ def on_advance_end(self):
# update plateau LR scheduler after metrics are logged
self.update_lr_schedulers("step", update_plateau_schedulers=True)

self.total_batch_idx += 1

# progress global step according to grads progress
self._increment_accumulated_grad_global_step()

if self.done:
raise StopIteration

def on_run_end(self) -> List[List[STEP_OUTPUT]]:
"""Calls the on_epoch_end hook.
Expand Down Expand Up @@ -351,7 +352,7 @@ def _increment_accumulated_grad_global_step(self) -> None:
"""Increments global step according to grads progress"""
if not self._should_accumulate():
self.global_step = self.trainer.accelerator.update_global_step(
self.total_batch_idx, self.trainer.global_step
self.batch_progress.current.ready, self.trainer.global_step
)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def global_step(self, value: int) -> None:

@property
def total_batch_idx(self) -> int:
"""Returns the total number of batches already run (across all epochs)"""
"""Returns the current batch index (across epochs)"""
return self.epoch_loop.total_batch_idx

@property
def batch_idx(self) -> int:
"""Returns the number of batches already run within this epoch"""
"""Returns the current batch index (within this epoch)"""
return self.epoch_loop.batch_idx

@property
Expand Down
2 changes: 1 addition & 1 deletion tests/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(self):

assert lrfinder.suggestion() != 1e-3
assert len(lrfinder.results["lr"]) == 100
assert lrfinder._total_batch_idx == 200
assert lrfinder._total_batch_idx == 199


def test_suggestion_parameters_work(tmpdir):
Expand Down

0 comments on commit 7f7c521

Please sign in to comment.