Skip to content

Commit

Permalink
FEAT / Trainer: LOMO optimizer support (#30178)
Browse files Browse the repository at this point in the history
* add V1 - adalomo not working yet

* add todo docs + refactor from comments

* adjust LR

* add docs

* add more elaborated test

* Apply suggestions from code review

Co-authored-by: Zach Mueller <[email protected]>

* fix

* push

* add accelerate check

* fix DDP case

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* fix

* init kwargs

* safely add attribute

* revert to enum logic

* Update src/transformers/trainer.py

---------

Co-authored-by: Zach Mueller <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
3 people authored May 21, 2024
1 parent c876d12 commit 8871b26
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 4 deletions.
50 changes: 50 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,56 @@ trainer.train()

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.

## LOMO optimizer

The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
They both consist of an efficient full-parameter fine-tuning method. These optimizers fuse the gradient computation and the parameter update in one step to reduce memory usage. Supported optimizers for LOMO are `"lomo"` and `"adalomo"`. First either install LOMO from pypi `pip install lomo-optim` or install it from source with `pip install git+https://github.com/OpenLMLab/LOMO.git`.

<Tip>

According to the authors, it is recommended to use `AdaLomo` without `grad_norm` to get better performance and higher throughput.

</Tip>

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:

```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-lomo",
max_steps=1000,
per_device_train_batch_size=4,
optim="adalomo",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-6,
save_strategy="no",
run_name="lomo-imdb",
)

model_id = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
)

trainer.train()
```

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_lomo_available,
is_natten_available,
is_nltk_available,
is_onnx_available,
Expand Down Expand Up @@ -338,6 +339,14 @@ def require_galore_torch(test_case):
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


def require_lomo(test_case):
"""
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
https://github.com/OpenLMLab/LOMO
"""
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
42 changes: 38 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_lomo_available,
is_peft_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -1059,12 +1060,18 @@ def create_optimizer(self):
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")

# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")

# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")

self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

Expand Down Expand Up @@ -1382,6 +1389,26 @@ def optimizer_hook(param):

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
if not is_lomo_available():
raise ImportError(
"You need to install `lomo_optim` in order to use LOMO optimizers"
" install it with `pip install lomo-optim`"
)
if not is_accelerate_available("0.30.0"):
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers")

if model is None:
raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")

from lomo_optim import AdaLomo, Lomo

if "ada" in args.optim:
optimizer_cls = AdaLomo
else:
optimizer_cls = Lomo

optimizer_kwargs.update({"model": model})
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down Expand Up @@ -2045,6 +2072,9 @@ def _inner_training_loop(
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# In this case we are in DDP + LOMO, which should be supported
self.optimizer = self.accelerator.prepare(self.optimizer)

if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
Expand Down Expand Up @@ -2143,7 +2173,6 @@ def _inner_training_loop(
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

total_batched_samples = 0
Expand Down Expand Up @@ -2275,8 +2304,8 @@ def _inner_training_loop(
else:
grad_norm = _grad_norm

# Optimizer step
self.optimizer.step()

optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
Expand Down Expand Up @@ -3229,7 +3258,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
"""
model.train()
inputs = self._prepare_inputs(inputs)

if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
Expand All @@ -3240,14 +3268,20 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
del inputs
torch.cuda.empty_cache()

kwargs = {}

# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_lomo_available,
is_mlx_available,
is_natten_available,
is_ninja_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq")
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -328,6 +329,10 @@ def is_galore_torch_available():
return _galore_torch_available


def is_lomo_available():
return _lomo_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down
44 changes: 44 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch,
require_lomo,
require_optuna,
require_peft,
require_ray,
Expand Down Expand Up @@ -1229,6 +1230,49 @@ def test_dataloader_without_dataset(self):
trainer.train()
trainer.evaluate()

@require_lomo
@require_torch_gpu
def test_lomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)

previous_params = {n: p.clone() for n, p in tiny_llama.named_parameters()}

x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, optim="lomo", max_steps=20)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

for name, param in tiny_llama.named_parameters():
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))

@require_lomo
@require_torch_gpu
def test_adalomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="adalomo",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]

Expand Down

0 comments on commit 8871b26

Please sign in to comment.