From ff4c717f1aeb2237ea0b3b784556023e0d7dfe93 Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Wed, 28 Feb 2024 07:34:29 +0000 Subject: [PATCH 1/2] Add cosine_with_min_lr scheduler --- src/transformers/optimization.py | 64 +++++++++++++++++++++++++++++++ src/transformers/trainer_utils.py | 1 + tests/trainer/test_trainer.py | 23 +++++++++++ 3 files changed, 88 insertions(+) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index b3861b371a2..6ad36dc7619 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -323,6 +323,69 @@ def get_inverse_sqrt_schedule( return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_with_min_lr_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + min_lr: float = None, + min_lr_rate: float = None, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + min_lr (`float`, *optional*): + The minimum learning rate to reach after the cosine schedule. + min_lr_rate (`float`, *optional*): + The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + if min_lr is not None and min_lr_rate is not None: + raise ValueError("Only one of min_lr or min_lr_rate should be set") + elif min_lr is not None: + min_lr_rate = min_lr / optimizer.defaults["lr"] + elif min_lr_rate is None: + raise ValueError("One of min_lr or min_lr_rate should be set") + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + min_lr_rate=min_lr_rate, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.LINEAR: get_linear_schedule_with_warmup, SchedulerType.COSINE: get_cosine_schedule_with_warmup, @@ -332,6 +395,7 @@ def get_inverse_sqrt_schedule( SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, + SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup, } diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 803f6fe840e..7838c63dfc9 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -401,6 +401,7 @@ class SchedulerType(ExplicitEnum): CONSTANT_WITH_WARMUP = "constant_with_warmup" INVERSE_SQRT = "inverse_sqrt" REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" + COSINE_WITH_MIN_LR = "cosine_with_min_lr" class TrainerMemoryTracker: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 65eeb6d6238..33e3e070985 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -679,6 +679,29 @@ def test_lr_scheduler_kwargs(self): self.assertEqual(sched1.lr_lambdas[0].args, sched2.lr_lambdas[0].args) self.assertEqual(sched1.lr_lambdas[0].keywords, sched2.lr_lambdas[0].keywords) + def test_cosine_with_min_lr_scheduler(self): + train_dataset = RegressionDataset() + model = RegressionModel() + num_steps, num_warmup_steps = 10, 2 + extra_kwargs = {"min_lr": 1e-5} # Non-default arguments + args = TrainingArguments( + "./regression", + lr_scheduler_type="cosine_with_min_lr", + lr_scheduler_kwargs=extra_kwargs, + learning_rate=0.2, + warmup_steps=num_warmup_steps, + ) + trainer = Trainer(model, args, train_dataset=train_dataset) + trainer.create_optimizer_and_scheduler(num_training_steps=num_steps) + + # Checking that the scheduler was created + self.assertIsNotNone(trainer.lr_scheduler) + + # Check the last learning rate + for _ in range(num_steps): + trainer.lr_scheduler.step() + self.assertEqual(trainer.lr_scheduler.get_last_lr()[0], 1e-5) + def test_reduce_lr_on_plateau_args(self): # test passed arguments for a custom ReduceLROnPlateau scheduler train_dataset = RegressionDataset(length=64) From 24c21d49abbe1dfd6ffbfa72902906278800f1f1 Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Tue, 5 Mar 2024 01:18:07 +0000 Subject: [PATCH 2/2] Update error message for missing min_lr or min_lr_rate --- src/transformers/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 6ad36dc7619..0ae9d735888 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -374,7 +374,7 @@ def get_cosine_with_min_lr_schedule_with_warmup( elif min_lr is not None: min_lr_rate = min_lr / optimizer.defaults["lr"] elif min_lr_rate is None: - raise ValueError("One of min_lr or min_lr_rate should be set") + raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`") lr_lambda = partial( _get_cosine_schedule_with_warmup_lr_lambda,