diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 2e38b7ce..bae0e686 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -955,7 +955,7 @@ def configure_optimizers( self, ) -> Tuple[torch.optim.Optimizer, Dict[str, Any]]: """ - Initialize the optimizer. + Initialize the optimizer and lr_scheduler. This is used by pytorch-lightning when preparing the model for training. @@ -965,7 +965,7 @@ def configure_optimizers( The initialized Adam optimizer and its learning rate scheduler. """ optimizer = torch.optim.Adam(self.parameters(), **self.opt_kwargs) - # Apply learning rate scheduler per step. + # Add linear learning rate scheduler for warmup lr_schedulers = [ torch.optim.lr_scheduler.LinearLR( optimizer, @@ -976,7 +976,6 @@ def configure_optimizers( lr_schedulers.append( CosineScheduler( optimizer, - warmup=self.warmup_iters, max_iters=self.max_iters ) ) @@ -989,14 +988,15 @@ def configure_optimizers( total_iters=self.max_iters ) ) - + #Combine learning rate schedulers lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler(lr_schedulers) + # Apply learning rate scheduler per step. return [optimizer], {"scheduler": lr_scheduler, "interval": "step"} class CosineScheduler(torch.optim.lr_scheduler._LRScheduler): """ - Learning rate scheduler with linear warm up followed by cosine shaped decay. + Learning rate scheduler with cosine shaped decay. Parameters ---------- diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 013cc8e8..6950e749 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -222,7 +222,7 @@ def initialize_model(self, train: bool) -> None: top_match=self.config.top_match, n_log=self.config.n_log, tb_summarywriter=self.config.tb_summarywriter, - lr_schedule=self.config.tb_summarywriter, + lr_schedule=self.config.lr_schedule, warmup_iters=self.config.warmup_iters, max_iters=self.config.max_iters, lr=self.config.learning_rate,