diff --git a/utils/trainer.py b/utils/trainer.py index 0710926..8c9731f 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -243,13 +243,7 @@ def scheduler_step(self, metrics): "ReduceLROnPlateau", "IncreaseBSOnPlateau", ): - scheduler_metric = metrics[self.scheduler_metric] - - if type(self.scheduler).__name__ == "IncreaseBSOnPlateau": - self.scheduler.step(metric=scheduler_metric) - # TODO: keep the same name - else: - self.scheduler.step(scheduler_metric) + self.scheduler.step(metrics=metrics[self.scheduler_metric]) else: self.scheduler.step()