diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index b5a89c55c96a..f51833bfb6df 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -518,6 +518,51 @@ trainer.train() This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training. +## Schedule Free Optimizer + +The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682). +Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule. +Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`. + +Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision: + +```python +import torch +import datasets +from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM +import trl + +train_dataset = datasets.load_dataset('imdb', split='train') + +args = TrainingArguments( + output_dir="./test-schedulefree", + max_steps=1000, + per_device_train_batch_size=4, + optim="schedule_free_adamw", + gradient_checkpointing=True, + logging_strategy="steps", + logging_steps=1, + learning_rate=2e-6, + save_strategy="no", + run_name="sfo-imdb", +) + +model_id = "google/gemma-2b" + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0) + +trainer = trl.SFTTrainer( + model=model, + args=args, + train_dataset=train_dataset, + dataset_text_field='text', + max_seq_length=1024, +) + +trainer.train() +``` + ## Accelerate and Trainer The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). diff --git a/setup.py b/setup.py index 43d051df8b83..14a80d3321be 100644 --- a/setup.py +++ b/setup.py @@ -163,6 +163,7 @@ "sacremoses", "safetensors>=0.4.1", "sagemaker>=2.31.0", + "schedulefree>=1.2.6", "scikit-learn", "scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`) "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 23d686efd51d..c199884a1960 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -69,6 +69,7 @@ "sacremoses": "sacremoses", "safetensors": "safetensors>=0.4.1", "sagemaker": "sagemaker>=2.31.0", + "schedulefree": "schedulefree>=1.2.6", "scikit-learn": "scikit-learn", "scipy": "scipy<1.13.0", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 46f0c2f3560a..3306f76249fe 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -103,6 +103,7 @@ is_rjieba_available, is_sacremoses_available, is_safetensors_available, + is_schedulefree_available, is_scipy_available, is_sentencepiece_available, is_seqio_available, @@ -370,6 +371,14 @@ def require_grokadamw(test_case): return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case) +def require_schedulefree(test_case): + """ + Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed. + https://github.com/facebookresearch/schedule_free + """ + return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case) + + def require_cv2(test_case): """ Decorator marking a test that requires OpenCV. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c241fd4eb83c..525708645c2c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -161,6 +161,7 @@ is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, + is_schedulefree_available, is_torch_compile_available, is_torch_mlu_available, is_torch_mps_available, @@ -1488,6 +1489,36 @@ def optimizer_hook(param): optimizer_cls = AdamW4bit optimizer_kwargs.update(adam_kwargs) + elif args.optim in [ + OptimizerNames.SCHEDULE_FREE_ADAMW, + OptimizerNames.SCHEDULE_FREE_SGD, + ]: + if not is_schedulefree_available(): + raise ImportError( + "You need to install `schedulefree` in order to use schedulefree optimizers" + " install it with `pip install schedulefree`" + ) + if not is_accelerate_available("0.30.0"): + raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers") + from schedulefree import AdamWScheduleFree, SGDScheduleFree + + additional_optim_kwargs = {} + if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW: + optimizer_cls = AdamWScheduleFree + additional_optim_kwargs = adam_kwargs + elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD: + optimizer_cls = SGDScheduleFree + else: + raise ValueError("Invalid schedulefree optimizer") + additional_optim_kwargs["weight_decay"] = args.weight_decay + additional_optim_kwargs["warmup_steps"] = args.warmup_steps + additional_optim_kwargs.update( + { + "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)), + "r": float(optim_args.get("r", 0.0)), + } + ) + optimizer_kwargs.update(additional_optim_kwargs) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs @@ -3410,6 +3441,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, `torch.Tensor`: The tensor with training loss on this batch. """ model.train() + if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + self.optimizer.train() + inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) @@ -3960,6 +3994,8 @@ def evaluation_loop( logger.info(f" Batch size = {batch_size}") model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() self.callback_handler.eval_dataloader = dataloader # Do this before wrapping. @@ -4573,6 +4609,8 @@ def prediction_loop( inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() if args.past_index >= 0: self._past = None diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 6b587bdd65ae..02413c285832 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -178,6 +178,8 @@ class OptimizerNames(ExplicitEnum): LOMO = "lomo" ADALOMO = "adalomo" GROKADAMW = "grokadamw" + SCHEDULE_FREE_ADAMW = "schedule_free_adamw" + SCHEDULE_FREE_SGD = "schedule_free_sgd" # Sometimes users will pass in a `str` repr of a dict in the CLI diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index dc8e8c88f25f..eee350349f55 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -175,6 +175,7 @@ is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, + is_schedulefree_available, is_scipy_available, is_sentencepiece_available, is_seqio_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 3d03c1589477..c9123299b61d 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -103,6 +103,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") _grokadamw_available = _is_package_available("grokadamw") +_schedulefree_available = _is_package_available("schedulefree") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -364,6 +365,10 @@ def is_grokadamw_available(): return _grokadamw_available +def is_schedulefree_available(): + return _schedulefree_available + + def is_pyctcdecode_available(): return _pyctcdecode_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 49213f19187f..1837d9890352 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -70,6 +70,7 @@ require_peft, require_ray, require_safetensors, + require_schedulefree, require_sentencepiece, require_sigopt, require_tensorboard, @@ -1442,6 +1443,27 @@ def test_grokadamw(): # Check this works _ = trainer.train() + @require_schedulefree + @require_torch_gpu + def test_schedulefree_adam(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="schedule_free_adamw", + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + def test_galore_matched_modules(self): regex_patterns = [r".*.attn.*", r".*.mlp.*"]