Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding optimizer_cls_and_kwargs to Trainer.__init__ #34358

Merged
merged 10 commits into from
Oct 29, 2024
106 changes: 67 additions & 39 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,70 @@ trainer = Trainer(..., args=training_args)

NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.

## GaLore
## Liger Kernel

[Liger-Kernel](https://github.com/linkedin/Liger-Kernel) Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.

<Tip>
Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)
</Tip>

First make sure to install Liger official repository:
```bash
pip install liger-kernel
```

You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:

```py
from transformers import TrainingArguments

training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
use_liger_kernel=True
)
```

The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.


## Optimizers

You can choose a built-in optimizer for training using:

```python
from transformers import TrainingArguments
training_args = TrainingArguments(..., optim="adamw_torch")
```

See [`OptimizerNames`](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py) for a full list of choices. We include advanced examples in the sections below.

You can also use an arbitrary PyTorch optimizer via:

```python
import torch

optimizer_cls = torch.optim.AdamW
optimizer_kwargs = {
"lr": 4e-3,
"betas": (0.9, 0.999),
"weight_decay": 0.05,
}

from transformers import Trainer
trainer = Trainer(..., optimizer_cls_and_kwargs=(optimizer_cls, optimizer_kwargs))
```

### GaLore

Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.

Expand Down Expand Up @@ -382,42 +445,7 @@ trainer.train()

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.

## Liger Kernel

[Liger-Kernel](https://github.com/linkedin/Liger-Kernel) Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.

<Tip>
Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)
</Tip>

First make sure to install Liger official repository:
```bash
pip install liger-kernel
```

You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:

```py
from transformers import TrainingArguments

training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
use_liger_kernel=True
)
```

The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.

## LOMO optimizer
### LOMO optimizer

The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
They both consist of an efficient full-parameter fine-tuning method. These optimizers fuse the gradient computation and the parameter update in one step to reduce memory usage. Supported optimizers for LOMO are `"lomo"` and `"adalomo"`. First either install LOMO from pypi `pip install lomo-optim` or install it from source with `pip install git+https://github.com/OpenLMLab/LOMO.git`.
Expand Down Expand Up @@ -467,7 +495,7 @@ trainer = trl.SFTTrainer(
trainer.train()
```

## GrokAdamW optimizer
### GrokAdamW optimizer

The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`.

Expand Down Expand Up @@ -518,7 +546,7 @@ trainer.train()

This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.

## Schedule Free Optimizer
### Schedule Free Optimizer

The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
Expand Down
18 changes: 15 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union


# Integrations must be imported before ML frameworks:
Expand Down Expand Up @@ -358,6 +358,11 @@ class Trainer:
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
A tuple containing the optimizer class and keyword arguments to use.
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.

Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
Expand Down Expand Up @@ -401,7 +406,8 @@ def __init__(
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
optimizer_cls_and_kwargs: Optional[Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] = None,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
if args is None:
Expand Down Expand Up @@ -603,6 +609,9 @@ def __init__(
self.compute_metrics = compute_metrics
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs
if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None:
raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.")
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
raise RuntimeError(
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Expand Down Expand Up @@ -1171,7 +1180,10 @@ def create_optimizer(self):
},
]

optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)

# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
Expand Down