From 27eb7f0941e000fcff4ee0f0d693e724951acb82 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 22 Dec 2023 13:45:26 +0100 Subject: [PATCH] =?UTF-8?q?[`Docs`]=C2=A0Add=20unsloth=20optimizations=20i?= =?UTF-8?q?n=20TRL's=20documentation=20(#1119)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add unsloth * Update sft_trainer.mdx (#1124) Co-authored-by: Daniel Han --------- Co-authored-by: Daniel Han --- docs/source/sft_trainer.mdx | 55 +++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index a6ca3e439c..4c0c1abeac 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -410,6 +410,61 @@ We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssis Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains. + +### Accelerate fine-tuning 2x using `unsloth` + +You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures. +First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows: + +```python +import torch + +from transformers import TrainingArguments +from trl import SFTTrainer +from unsloth import FastLlamaModel, FastMistralModel + +max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number. +dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ +load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. + +# Load Llama model +model, tokenizer = FastLlamaModel.from_pretrained( + model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf +) + +# Do model patching and add fast LoRA weights +model = FastLlamaModel.get_peft_model( + model, + r = 16, + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj",], + lora_alpha = 16, + lora_dropout = 0, # Currently only supports dropout = 0 + bias = "none", # Currently only supports bias = "none" + use_gradient_checkpointing = True, + random_state = 3407, + max_seq_length = max_seq_length, +) + +args = TrainingArguments(output_dir="./output") + +trainer = SFTTrainer( + model = model, + args = args, + train_dataset = dataset, + dataset_text_field = "text", + max_seq_length = max_seq_length, +) + +trainer.train() +``` + +The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). + ## Best practices Pay attention to the following best practices when training a model with that trainer: