diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 340c1711b40f7..dbe163aa33cae 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -476,7 +476,7 @@ def _update_learning_rates( if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]: continue - current_idx = self.trainer.fit_loop.batch_idx if interval == "step" else self.trainer.current_epoch + current_idx = self.batch_idx if interval == "step" else self.trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero @@ -502,7 +502,7 @@ def _update_learning_rates( ) continue - self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() + self.scheduler_progress.increment_ready() # update LR if lr_scheduler["reduce_on_plateau"]: @@ -510,7 +510,7 @@ def _update_learning_rates( else: lr_scheduler["scheduler"].step() - self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed() + self.scheduler_progress.increment_completed() def _get_monitor_value(self, key: str) -> Any: # this is a separate method to aid in testing