diff --git a/kraken/lib/pretrain/model.py b/kraken/lib/pretrain/model.py index eee832b6d..f89efc918 100644 --- a/kraken/lib/pretrain/model.py +++ b/kraken/lib/pretrain/model.py @@ -438,7 +438,10 @@ def lr_scheduler_step(self, scheduler, metric): scheduler.step() # step every other scheduler epoch-wise elif self.trainer.is_last_batch: - scheduler.step() + if metric is None: + scheduler.step() + else: + scheduler.step(metric) def setup(self, stage: Optional[str] = None): # finalize models in case of appending/loading diff --git a/kraken/lib/train.py b/kraken/lib/train.py index dca4b6def..92081bec6 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -1114,7 +1114,10 @@ def lr_scheduler_step(self, scheduler, metric): scheduler.step() # step every other scheduler epoch-wise elif self.trainer.is_last_batch: - scheduler.step() + if metric is None: + scheduler.step() + else: + scheduler.step(metric) def _configure_optimizer_and_lr_scheduler(hparams, params, len_train_set=None, loss_tracking_mode='max'):