From baff9624dd6919e811f532ad965728277a42924b Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 3 Dec 2023 03:08:56 +0100 Subject: [PATCH] Merge regression with LR schedulers --- kraken/lib/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 29f0ef18f..80395ec47 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -665,7 +665,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + def lr_scheduler_step(self, scheduler, metric): if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR): @@ -1080,7 +1080,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + def lr_scheduler_step(self, scheduler, metric): if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR):