From 3189ff10b43df975cab63790d13e16eb7f5a3ffc Mon Sep 17 00:00:00 2001 From: zspo Date: Thu, 30 May 2024 11:04:16 +0800 Subject: [PATCH] fix get_scheduler args --- src/transformers/optimization.py | 3 +++ tests/optimization/test_optimization.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 79a2c71c384f30..a462e3d8240099 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -540,6 +540,9 @@ def scheduler_hook(param): if name == SchedulerType.INVERSE_SQRT: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs) + # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") diff --git a/tests/optimization/test_optimization.py b/tests/optimization/test_optimization.py index 6d6707db5a4b67..5240b03779b735 100644 --- a/tests/optimization/test_optimization.py +++ b/tests/optimization/test_optimization.py @@ -36,6 +36,7 @@ get_inverse_sqrt_schedule, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, + get_scheduler, get_wsd_schedule, ) @@ -176,6 +177,27 @@ def test_schedulers(self): lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload") + def test_get_scheduler(self): + test_params = [ + { + "name": "warmup_stable_decay", + "optimizer": self.optimizer, + "num_warmup_steps": 2, + "scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3}, + }, + { + "name": "warmup_stable_decay", + "optimizer": self.optimizer, + "num_warmup_steps": 2, + "num_training_steps": 10, + "scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3}, + }, + {"name": "cosine", "optimizer": self.optimizer, "num_warmup_steps": 2, "num_training_steps": 10}, + ] + + for param in test_params: + self.assertTrue(get_scheduler(**param), msg=f"failed for {param['name']} in get_scheduler") + class LambdaScheduleWrapper: """See https://github.com/huggingface/transformers/issues/21689"""