From 6fa8f79857a6c63766acc414ed232ddeb4f5abb8 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 16 Aug 2024 15:18:30 +0200 Subject: [PATCH] Regression in LR schedulers with metric tracking Fixes #635 --- kraken/lib/pretrain/model.py | 5 ++++- kraken/lib/train.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/kraken/lib/pretrain/model.py b/kraken/lib/pretrain/model.py index eee832b6..f89efc91 100644 --- a/kraken/lib/pretrain/model.py +++ b/kraken/lib/pretrain/model.py @@ -438,7 +438,10 @@ def lr_scheduler_step(self, scheduler, metric): scheduler.step() # step every other scheduler epoch-wise elif self.trainer.is_last_batch: - scheduler.step() + if metric is None: + scheduler.step() + else: + scheduler.step(metric) def setup(self, stage: Optional[str] = None): # finalize models in case of appending/loading diff --git a/kraken/lib/train.py b/kraken/lib/train.py index dca4b6de..92081bec 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -1114,7 +1114,10 @@ def lr_scheduler_step(self, scheduler, metric): scheduler.step() # step every other scheduler epoch-wise elif self.trainer.is_last_batch: - scheduler.step() + if metric is None: + scheduler.step() + else: + scheduler.step(metric) def _configure_optimizer_and_lr_scheduler(hparams, params, len_train_set=None, loss_tracking_mode='max'):