Skip to content

Commit

Permalink
Added schedulers
Browse files Browse the repository at this point in the history
Exponential, Polynomial, CosineAnnealing, CosineAnnealingWarmRestarts
  • Loading branch information
ancestor-mithril committed Apr 30, 2024
1 parent 6047a3b commit 4b8ec87
Showing 1 changed file with 50 additions and 2 deletions.
52 changes: 50 additions & 2 deletions utils/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from bs_scheduler import IncreaseBSOnPlateau, StepBS
from bs_scheduler import IncreaseBSOnPlateau, StepBS, ExponentialBS, PolynomialBS, CosineAnnealingBS, \
CosineAnnealingBSWithWarmRestarts
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, ExponentialLR, PolynomialLR, CosineAnnealingLR, \
CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader


Expand All @@ -23,6 +25,52 @@ def init_scheduler(args, optimizer: Optimizer, train_loader: DataLoader):

assert 'step_size' in args.scheduler_params
scheduler = StepLR(optimizer, **args.scheduler_params)

elif args.scheduler == 'ExponentialBS': # "{'gamma':1.1, 'max_batch_size': 1000}"

assert 'gamma' in args.scheduler_params
assert 'max_batch_size' in args.scheduler_params
scheduler = ExponentialBS(train_loader, **args.scheduler_params)

elif args.scheduler == 'ExponentialLR': # "{'gamma':0.9}"

assert 'gamma' in args.scheduler_params
scheduler = ExponentialLR(optimizer, **args.scheduler_params)

elif args.scheduler == 'PolynomialBS ': # "{'total_iters':200, 'max_batch_size': 1000}"

assert 'total_iters' in args.scheduler_params
assert 'max_batch_size' in args.scheduler_params
scheduler = PolynomialBS(train_loader, **args.scheduler_params)

elif args.scheduler == 'PolynomialLR': # "{'total_iters':200}"

assert 'total_iters' in args.scheduler_params
scheduler = PolynomialLR(optimizer, **args.scheduler_params)

elif args.scheduler == 'CosineAnnealingBS ': # "{'total_iters':200, 'max_batch_size': 1000}"

assert 'total_iters' in args.scheduler_params
assert 'max_batch_size' in args.scheduler_params
scheduler = CosineAnnealingBS(train_loader, **args.scheduler_params)

elif args.scheduler == 'CosineAnnealingLR': # "{'T_max':200}"

assert 'T_max' in args.scheduler_params
scheduler = CosineAnnealingLR(optimizer, **args.scheduler_params)

elif args.scheduler == 'CosineAnnealingBSWithWarmRestarts ': # "{'t_0':100, 'factor':1, 'max_batch_size': 1000}"

assert 't_0' in args.scheduler_params
assert 'factor' in args.scheduler_params
assert 'max_batch_size' in args.scheduler_params
scheduler = CosineAnnealingBSWithWarmRestarts(train_loader, **args.scheduler_params)

elif args.scheduler == 'CosineAnnealingWarmRestarts ': # "{'T_0':100, 'T_mult': 1}"

assert 'T_0' in args.scheduler_params
assert 'T_mult' in args.scheduler_params
scheduler = CosineAnnealingWarmRestarts(optimizer, **args.scheduler_params)
else:
raise NotImplementedError(f'Scheduler {args.scheduler} not implemented')
return scheduler

0 comments on commit 4b8ec87

Please sign in to comment.