From 542b4d0973b05b941427ba4f745c5d363212b3b7 Mon Sep 17 00:00:00 2001 From: Eric Hartford Date: Thu, 8 Aug 2024 02:58:47 -0400 Subject: [PATCH 1/5] add grokadamw --- src/transformers/trainer.py | 12 ++++++++++++ src/transformers/training_args.py | 1 + 2 files changed, 13 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 59f0ed438bf7..50bab3f9285d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1445,6 +1445,18 @@ def optimizer_hook(param): optimizer_cls = Lomo optimizer_kwargs.update({"model": model}) + elif args.optim == OptimizerNames.GROKADAMW: + try: + from grokadamw import GrokAdamW + optimizer_cls = GrokAdamW + optimizer_kwargs.update({"alpha_init": float(optim_args.get("alpha_init", 0.98)), + "lamb": float(optim_args.get("lamb", 2.0)), + "gamma": float(optim_args.get("gamma", 0.1)), + "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), + "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), + }) + except ImportError: + raise ValueError("Please install grokadamw with `pip install grokadamw`") else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 7becf0dbf61d..48fd31dd9266 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -175,6 +175,7 @@ class OptimizerNames(ExplicitEnum): GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" LOMO = "lomo" ADALOMO = "adalomo" + GROKADAMW = "grokadamw" # Sometimes users will pass in a `str` repr of a dict in the CLI From 98eac85fe2441073fb91e1c2a1398f8580eb68ad Mon Sep 17 00:00:00 2001 From: Eric Hartford Date: Fri, 9 Aug 2024 00:14:10 -0400 Subject: [PATCH 2/5] reformat --- src/transformers/trainer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 52e3530aa35e..3da180cacfe9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1445,13 +1445,17 @@ def optimizer_hook(param): elif args.optim == OptimizerNames.GROKADAMW: try: from grokadamw import GrokAdamW + optimizer_cls = GrokAdamW - optimizer_kwargs.update({"alpha_init": float(optim_args.get("alpha_init", 0.98)), - "lamb": float(optim_args.get("lamb", 2.0)), - "gamma": float(optim_args.get("gamma", 0.1)), - "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), - "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), - }) + optimizer_kwargs.update( + { + "alpha_init": float(optim_args.get("alpha_init", 0.98)), + "lamb": float(optim_args.get("lamb", 2.0)), + "gamma": float(optim_args.get("gamma", 0.1)), + "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), + "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), + } + ) except ImportError: raise ValueError("Please install grokadamw with `pip install grokadamw`") else: From f43b17b46a47bd452e7d12a04536755f4bbaa2a0 Mon Sep 17 00:00:00 2001 From: Eric Hartford Date: Fri, 9 Aug 2024 04:24:58 -0400 Subject: [PATCH 3/5] code review feedback, unit test --- docs/source/en/trainer.md | 51 ++++++++++++++++++++++++++ src/transformers/testing_utils.py | 6 +++ src/transformers/trainer.py | 30 ++++++++------- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 4 ++ tests/trainer/test_trainer.py | 23 ++++++++++++ 6 files changed, 101 insertions(+), 14 deletions(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 916ae6428e87..37d8baf3d7ec 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -432,6 +432,57 @@ trainer = trl.SFTTrainer( trainer.train() ``` +## GrokAdamW optimizer + +The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`. + + + +GrokAdamW is particularly useful for models that require advanced optimization techniques to achieve better performance and stability. + + + +Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on the IMDB dataset using the GrokAdamW optimizer: + +```python +import torch +import datasets +from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM, Trainer + +# Load the IMDB dataset +train_dataset = datasets.load_dataset('imdb', split='train') + +# Define the training arguments +args = TrainingArguments( + output_dir="./test-grokadamw", + max_steps=1000, + per_device_train_batch_size=4, + optim="grokadamw", + logging_strategy="steps", + logging_steps=1, + learning_rate=2e-5, + save_strategy="no", + run_name="grokadamw-imdb", +) + +# Load the model and tokenizer +model_id = "google/gemma-2b" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0) + +# Initialize the Trainer +trainer = Trainer( + model=model, + args=args, + train_dataset=train_dataset, +) + +# Train the model +trainer.train() +``` + +This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training. + ## Accelerate and Trainer The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 608e278ecfe8..230907845c24 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -84,6 +84,7 @@ is_levenshtein_available, is_librosa_available, is_lomo_available, + is_grokadamw_available, is_natten_available, is_nltk_available, is_onnx_available, @@ -357,6 +358,11 @@ def require_lomo(test_case): """ return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case) +def require_grokadamw(test_case): + """ + Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed. + """ + return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case) def require_cv2(test_case): """ diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3da180cacfe9..06eef577c524 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -156,6 +156,7 @@ is_in_notebook, is_ipex_available, is_lomo_available, + is_grokadamw_available, is_peft_available, is_safetensors_available, is_sagemaker_dp_enabled, @@ -1443,21 +1444,22 @@ def optimizer_hook(param): optimizer_kwargs.update({"model": model}) elif args.optim == OptimizerNames.GROKADAMW: - try: - from grokadamw import GrokAdamW - - optimizer_cls = GrokAdamW - optimizer_kwargs.update( - { - "alpha_init": float(optim_args.get("alpha_init", 0.98)), - "lamb": float(optim_args.get("lamb", 2.0)), - "gamma": float(optim_args.get("gamma", 0.1)), - "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), - "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), - } - ) - except ImportError: + if not is_grokadamw_available(): raise ValueError("Please install grokadamw with `pip install grokadamw`") + + from grokadamw import GrokAdamW + + optimizer_cls = GrokAdamW + optimizer_kwargs.update( + { + "alpha_init": float(optim_args.get("alpha_init", 0.98)), + "lamb": float(optim_args.get("lamb", 2.0)), + "gamma": float(optim_args.get("gamma", 0.1)), + "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), + "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), + } + ) + else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index efe473a6cded..3f62a3702fc2 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -148,6 +148,7 @@ is_levenshtein_available, is_librosa_available, is_lomo_available, + is_grokadamw_available, is_mlx_available, is_natten_available, is_ninja_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 9b4be00ee8dd..b96489756cc2 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -101,6 +101,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _fbgemm_gpu_available = _is_package_available("fbgemm_gpu") _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") +_grokadamw_available = _is_package_available("grokadamw") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -352,6 +353,9 @@ def is_galore_torch_available(): def is_lomo_available(): return _lomo_available +def is_grokadamw_available(): + return _grokadamw_available + def is_pyctcdecode_available(): return _pyctcdecode_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7378a597c39c..be20d8a7acca 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -64,6 +64,7 @@ require_galore_torch, require_intel_extension_for_pytorch, require_lomo, + require_grokadamw, require_optuna, require_peft, require_ray, @@ -1365,6 +1366,28 @@ def test_adalomo(self): # Check this works _ = trainer.train() + + @require_grokadamw + @require_torch_gpu + def test_grokadamw(): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=2e-5, + logging_steps=5, + optim="grokadamw", + max_steps=20, + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() def test_galore_matched_modules(self): regex_patterns = [r".*.attn.*", r".*.mlp.*"] From 43eb440fe97bee93ed6d3789ee152c528709b589 Mon Sep 17 00:00:00 2001 From: Eric Hartford Date: Fri, 9 Aug 2024 04:29:08 -0400 Subject: [PATCH 4/5] reformat --- src/transformers/testing_utils.py | 2 ++ src/transformers/utils/import_utils.py | 1 + tests/trainer/test_trainer.py | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 230907845c24..f54f6fa6b9bf 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -358,12 +358,14 @@ def require_lomo(test_case): """ return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case) + def require_grokadamw(test_case): """ Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed. """ return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case) + def require_cv2(test_case): """ Decorator marking a test that requires OpenCV. diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index b96489756cc2..765df1996bf0 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -353,6 +353,7 @@ def is_galore_torch_available(): def is_lomo_available(): return _lomo_available + def is_grokadamw_available(): return _grokadamw_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index be20d8a7acca..62534d4764f4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1366,7 +1366,7 @@ def test_adalomo(self): # Check this works _ = trainer.train() - + @require_grokadamw @require_torch_gpu def test_grokadamw(): From 20958974ccfda169a00127c34e6877fbc0839088 Mon Sep 17 00:00:00 2001 From: Eric Hartford Date: Fri, 9 Aug 2024 04:32:58 -0400 Subject: [PATCH 5/5] reformat --- src/transformers/testing_utils.py | 2 +- src/transformers/trainer.py | 2 +- src/transformers/utils/__init__.py | 2 +- tests/trainer/test_trainer.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index f54f6fa6b9bf..6b04ed742606 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -76,6 +76,7 @@ is_g2p_en_available, is_galore_torch_available, is_gguf_available, + is_grokadamw_available, is_ipex_available, is_jieba_available, is_jinja_available, @@ -84,7 +85,6 @@ is_levenshtein_available, is_librosa_available, is_lomo_available, - is_grokadamw_available, is_natten_available, is_nltk_available, is_onnx_available, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 06eef577c524..cde8d0215b28 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -153,10 +153,10 @@ is_bitsandbytes_available, is_datasets_available, is_galore_torch_available, + is_grokadamw_available, is_in_notebook, is_ipex_available, is_lomo_available, - is_grokadamw_available, is_peft_available, is_safetensors_available, is_sagemaker_dp_enabled, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 3f62a3702fc2..a8aa670c07a7 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -137,6 +137,7 @@ is_g2p_en_available, is_galore_torch_available, is_gguf_available, + is_grokadamw_available, is_hqq_available, is_in_notebook, is_ipex_available, @@ -148,7 +149,6 @@ is_levenshtein_available, is_librosa_available, is_lomo_available, - is_grokadamw_available, is_mlx_available, is_natten_available, is_ninja_available, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 62534d4764f4..ca133a277c41 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -62,9 +62,9 @@ require_bitsandbytes, require_deepspeed, require_galore_torch, + require_grokadamw, require_intel_extension_for_pytorch, require_lomo, - require_grokadamw, require_optuna, require_peft, require_ray,