Skip to content

Commit

Permalink
[Docs] Add unsloth optimizations in TRL's documentation (#1119)
Browse files Browse the repository at this point in the history
* add unsloth

* Update sft_trainer.mdx (#1124)

Co-authored-by: Daniel Han <[email protected]>

---------

Co-authored-by: Daniel Han <[email protected]>
  • Loading branch information
younesbelkada and danielhanchen authored Dec 22, 2023
1 parent 814fe39 commit f11e213
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,61 @@ We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssis
</div>

Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.

### Accelerate fine-tuning 2x using `unsloth`

You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:

```python
import torch

from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLlamaModel, FastMistralModel

max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# Load Llama model
model, tokenizer = FastLlamaModel.from_pretrained(
model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

# Do model patching and add fast LoRA weights
model = FastLlamaModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Currently only supports dropout = 0
bias = "none", # Currently only supports bias = "none"
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = max_seq_length,
)

args = TrainingArguments(output_dir="./output")

trainer = SFTTrainer(
model = model,
args = args,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
)

trainer.train()
```

The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).

## Best practices

Pay attention to the following best practices when training a model with that trainer:
Expand Down

0 comments on commit f11e213

Please sign in to comment.