Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Trainer: LOMO optimizer support #30178

Merged
merged 23 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e7d7bbe
add V1 - adalomo not working yet
younesbelkada Apr 11, 2024
8cdc21e
add todo docs + refactor from comments
younesbelkada Apr 11, 2024
029a9c9
adjust LR
younesbelkada Apr 11, 2024
62b5e0e
add docs
younesbelkada Apr 11, 2024
629413c
Merge remote-tracking branch 'upstream/main' into add-lomo
younesbelkada Apr 17, 2024
4907531
add more elaborated test
younesbelkada Apr 17, 2024
51c8e9e
Apply suggestions from code review
younesbelkada Apr 22, 2024
d9499c5
fix
younesbelkada Apr 22, 2024
afaabfc
push
younesbelkada Apr 22, 2024
a57dd5e
Merge remote-tracking branch 'upstream/main' into add-lomo
younesbelkada May 3, 2024
beb7edc
add accelerate check
younesbelkada May 3, 2024
5184057
Merge remote-tracking branch 'upstream/main' into add-lomo
younesbelkada May 7, 2024
ac007ee
fix DDP case
younesbelkada May 7, 2024
741a1a4
Merge remote-tracking branch 'origin/main' into add-lomo
younesbelkada May 16, 2024
80105e1
Apply suggestions from code review
younesbelkada May 16, 2024
49ce45e
fix
younesbelkada May 16, 2024
8d008a5
Merge branch 'add-lomo' of https://github.com/younesbelkada/transform…
younesbelkada May 16, 2024
40db2fa
init kwargs
younesbelkada May 16, 2024
5a536bf
safely add attribute
younesbelkada May 16, 2024
c1ac8bf
Merge remote-tracking branch 'origin/main' into add-lomo
younesbelkada May 16, 2024
9d547be
revert to enum logic
younesbelkada May 17, 2024
efe04a5
Update src/transformers/trainer.py
younesbelkada May 17, 2024
6cadb75
Merge remote-tracking branch 'origin/main' into add-lomo
younesbelkada May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,56 @@ trainer.train()

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.

## LOMO optimizer

The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
They both consist of an efficient full-parameter fine-tuning method. These optimizers fuse the gradient computation and the parameter update in one step to reduce memory usage. Supported optimizers for LOMO are `"lomo"` and `"adalomo"`. First either install LOMO from pypi `pip install lomo-optim` or install it from source with `pip install git+https://github.com/OpenLMLab/LOMO.git`.

<Tip>

According to the authors, it is recommended to use `AdaLomo` without `grad_norm` to get better performance and higher throughput.

</Tip>

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:

```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-lomo",
max_steps=1000,
per_device_train_batch_size=4,
optim="adalomo",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-6,
save_strategy="no",
run_name="lomo-imdb",
)

model_id = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
)

trainer.train()
```

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_lomo_available,
is_natten_available,
is_nltk_available,
is_onnx_available,
Expand Down Expand Up @@ -338,6 +339,14 @@ def require_galore_torch(test_case):
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


def require_lomo(test_case):
"""
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
https://github.com/OpenLMLab/LOMO
"""
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
42 changes: 38 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_lomo_available,
is_peft_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -1059,12 +1060,18 @@ def create_optimizer(self):
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")

# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")

# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")

self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

Expand Down Expand Up @@ -1382,6 +1389,26 @@ def optimizer_hook(param):

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
if not is_lomo_available():
raise ImportError(
"You need to install `lomo_optim` in order to use LOMO optimizers"
" install it with `pip install lomo-optim`"
)
if not is_accelerate_available("0.30.0"):
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers")

if model is None:
raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")

from lomo_optim import AdaLomo, Lomo

if "ada" in args.optim:
optimizer_cls = AdaLomo
else:
optimizer_cls = Lomo

optimizer_kwargs.update({"model": model})
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down Expand Up @@ -2045,6 +2072,9 @@ def _inner_training_loop(
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# In this case we are in DDP + LOMO, which should be supported
self.optimizer = self.accelerator.prepare(self.optimizer)

if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
Expand Down Expand Up @@ -2143,7 +2173,6 @@ def _inner_training_loop(
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

total_batched_samples = 0
Expand Down Expand Up @@ -2275,8 +2304,8 @@ def _inner_training_loop(
else:
grad_norm = _grad_norm

# Optimizer step
self.optimizer.step()

optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
Expand Down Expand Up @@ -3229,7 +3258,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
"""
model.train()
inputs = self._prepare_inputs(inputs)

if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
Expand All @@ -3240,14 +3268,20 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
del inputs
torch.cuda.empty_cache()

kwargs = {}

# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
self.accelerator.backward(loss, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we pass the learning rate through when lomo isn't being used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will break .. 😢 but we:
1- raise an error if users do not have the correct accelerate version with init-ing the trainer with lomo
2- pass learning_rate only if the optimizer is a lomo optimizer
3- removed kwargs in training step
So hopefully this should be safe enough 🙏


return loss.detach() / self.args.gradient_accumulation_steps

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_lomo_available,
is_mlx_available,
is_natten_available,
is_ninja_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq")
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -328,6 +329,10 @@ def is_galore_torch_available():
return _galore_torch_available


def is_lomo_available():
return _lomo_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down
44 changes: 44 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch,
require_lomo,
require_optuna,
require_peft,
require_ray,
Expand Down Expand Up @@ -1229,6 +1230,49 @@ def test_dataloader_without_dataset(self):
trainer.train()
trainer.evaluate()

@require_lomo
@require_torch_gpu
def test_lomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)

previous_params = {n: p.clone() for n, p in tiny_llama.named_parameters()}

x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, optim="lomo", max_steps=20)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

for name, param in tiny_llama.named_parameters():
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tolerance is super small, do we expect optimizers to make changes on this order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ok to put it higher, I decided to put it low so that even small changes would be captured by the test (sometimes higher tolerances would fail even though the weights are properly updated + with a high learning rate, so just to be on the safe zone)


@require_lomo
@require_torch_gpu
def test_adalomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="adalomo",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]

Expand Down
Loading