Skip to content

Commit

Permalink
Added asserts for scheduler arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Apr 30, 2024
1 parent 6118d99 commit 6047a3b
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions utils/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,23 @@


def init_scheduler(args, optimizer: Optimizer, train_loader: DataLoader):
if args.scheduler == 'IncreaseBSOnPlateau':
# "{'mode':'min', 'factor':2.0, 'max_batch_size': 1000}"
if args.scheduler == 'IncreaseBSOnPlateau': # "{'mode':'min', 'factor':2.0, 'max_batch_size': 1000}"

assert 'factor' in args.scheduler_params
assert 'max_batch_size' in args.scheduler_params
scheduler = IncreaseBSOnPlateau(train_loader, **args.scheduler_params)
elif args.scheduler == 'ReduceLROnPlateau':
# "{'mode':'min', 'factor':0.5}"
elif args.scheduler == 'ReduceLROnPlateau': # "{'mode':'min', 'factor':0.5}"

assert 'factor' in args.scheduler_params
scheduler = ReduceLROnPlateau(optimizer, **args.scheduler_params)
elif args.scheduler == 'StepBS':
# "{'step_size':30, 'gamma': 2.0, 'max_batch_size': 1000}"
elif args.scheduler == 'StepBS': # "{'step_size':30, 'gamma': 2.0, 'max_batch_size': 1000}"

assert 'step_size' in args.scheduler_params
assert 'max_batch_size' in args.scheduler_params
scheduler = StepBS(train_loader, **args.scheduler_params)
elif args.scheduler == 'StepLR':
# "{'step_size':30, 'gamma': 2.0}"
elif args.scheduler == 'StepLR': # "{'step_size':30, 'gamma': 2.0}"

assert 'step_size' in args.scheduler_params
scheduler = StepLR(optimizer, **args.scheduler_params)
else:
raise NotImplementedError(f'Scheduler {args.scheduler} not implemented')
Expand Down

0 comments on commit 6047a3b

Please sign in to comment.