diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 29f0ef18f..80395ec47 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -665,7 +665,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + def lr_scheduler_step(self, scheduler, metric): if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR): @@ -1080,7 +1080,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + def lr_scheduler_step(self, scheduler, metric): if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR):