Skip to content

Commit

Permalink
Add option to change learning rate scheduler and made it easier to ad…
Browse files Browse the repository at this point in the history
…d a new one.
  • Loading branch information
Justin Sanders committed Sep 28, 2023
1 parent 86630e3 commit e402555
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 8 deletions.
1 change: 1 addition & 0 deletions casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Config:
residues=dict,
n_log=int,
tb_summarywriter=str,
lr_schedule=str,
warmup_iters=int,
max_iters=int,
learning_rate=float,
Expand Down
2 changes: 2 additions & 0 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ dim_intensity:
custom_encoder:
# Max decoded peptide length
max_length: 100
# Type of learning rate schedule to use. One of {constant, linear, cosine}.
lr_schedule: "cosine"
# Number of warmup iterations for learning rate scheduler
warmup_iters: 100_000
# Max number of iterations for learning rate scheduler
Expand Down
38 changes: 30 additions & 8 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
tb_summarywriter: Optional[
torch.utils.tensorboard.SummaryWriter
] = None,
lr_schedule = None,
warmup_iters: int = 100_000,
max_iters: int = 600_000,
out_writer: Optional[ms_io.MztabWriter] = None,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
self.softmax = torch.nn.Softmax(2)
self.celoss = torch.nn.CrossEntropyLoss(ignore_index=0)
# Optimizer settings.
self.lr_schedule = lr_schedule
self.warmup_iters = warmup_iters
self.max_iters = max_iters
self.opt_kwargs = kwargs
Expand Down Expand Up @@ -964,13 +966,35 @@ def configure_optimizers(
"""
optimizer = torch.optim.Adam(self.parameters(), **self.opt_kwargs)
# Apply learning rate scheduler per step.
lr_scheduler = CosineWarmupScheduler(
optimizer, warmup=self.warmup_iters, max_iters=self.max_iters
)
lr_schedulers = [
torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1e-10,
total_iters=self.warmup_iters)
]
if self.lr_schedule == 'cosine':
lr_schedulers.append(
CosineScheduler(
optimizer,
warmup=self.warmup_iters,
max_iters=self.max_iters
)
)
elif self.lr_schedule == 'linear':
lr_schedulers.append(
torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1,
end_factor=0,
total_iters=self.max_iters
)
)

lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler(lr_schedulers)
return [optimizer], {"scheduler": lr_scheduler, "interval": "step"}


class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
class CosineScheduler(torch.optim.lr_scheduler._LRScheduler):
"""
Learning rate scheduler with linear warm up followed by cosine shaped decay.
Expand All @@ -985,9 +1009,9 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
"""

def __init__(
self, optimizer: torch.optim.Optimizer, warmup: int, max_iters: int
self, optimizer: torch.optim.Optimizer, max_iters: int
):
self.warmup, self.max_iters = warmup, max_iters
self.max_iters = max_iters
super().__init__(optimizer)

def get_lr(self):
Expand All @@ -996,8 +1020,6 @@ def get_lr(self):

def get_lr_factor(self, epoch):
lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_iters))
if epoch <= self.warmup:
lr_factor *= epoch / self.warmup
return lr_factor


Expand Down
1 change: 1 addition & 0 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def initialize_model(self, train: bool) -> None:
top_match=self.config.top_match,
n_log=self.config.n_log,
tb_summarywriter=self.config.tb_summarywriter,
lr_schedule=self.config.tb_summarywriter,
warmup_iters=self.config.warmup_iters,
max_iters=self.config.max_iters,
lr=self.config.learning_rate,
Expand Down

0 comments on commit e402555

Please sign in to comment.