Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for FSDP+Qlora. #572

Merged
merged 21 commits into from
Jul 11, 2024
Merged

Adding support for FSDP+Qlora. #572

merged 21 commits into from
Jul 11, 2024

Conversation

HamidShojanazeri
Copy link
Contributor

What does this PR do?

This adds support FSDP + Qlora that enables fine-tuning Llama 70B that lower the computer resource requirements significantly.

  • Logs
(fsdp-qlora) bash-5.1$ FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 torchrun --nnodes 1 --nproc_per_node 4 recipes/finetuning/finetuning.py --enable_fsdp --model_name met
a-llama/Meta-Llama-3-70B --use_peft --peft_method lora  --gradient_accumulation_steps 2 --output_dir peft-output --low_cpu_fsdp  --quantization int4  --mixed_precision False --max_
train_step 5 --max_eval_step 5
W0624 10:21:48.393000 139928613528064 torch/distributed/run.py:757] 
W0624 10:21:48.393000 139928613528064 torch/distributed/run.py:757] *****************************************
W0624 10:21:48.393000 139928613528064 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0624 10:21:48.393000 139928613528064 torch/distributed/run.py:757] *****************************************
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:28<00:00,  1.06it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:28<00:00,  1.06it/s]
Loading checkpoint shards:  77%|████████████████████████████████████████████████████████████████████████████████████████▉                           | 23/30 [00:28<00:08,  1.27s/it]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
Mixed precision: None
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
Mixed precision: None
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:33<00:00,  1.12s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
--> Model meta-llama/Meta-Llama-3-70B

--> meta-llama/Meta-Llama-3-70B has 2102.665216 Million params

trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
Mixed precision: None
Loading checkpoint shards:  90%|████████████████████████████████████████████████████████████████████████████████████████████████████████▍           | 27/30 [00:33<00:03,  1.17s/it]NCCL version 2.20.5+cuda12.4
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:36<00:00,  1.21s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
Mixed precision: None
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/datasets/load.py:1491: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/datasets/load.py:1491: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/datasets/load.py:1491: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/datasets/load.py:1491: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
--> Training Set Length = 14732
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/datasets/load.py:1491: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/datasets/load.py:1491: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/datasets/load.py:1491: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/datasets/load.py:1491: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
--> Validation Set Length = 818
Preprocessing dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:02<00:00, 7093.02it/s]
Preprocessing dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:02<00:00, 7206.71it/s]
Preprocessing dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:02<00:00, 7194.31it/s]
Preprocessing dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 7319.68it/s]
--> Num of Validation Set Batches loaded = 8
Preprocessing dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 7424.00it/s]
--> Num of Validation Set Batches loaded = 8
Preprocessing dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:02<00:00, 7024.57it/s]
Preprocessing dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 7230.16it/s]
--> Num of Validation Set Batches loaded = 8
Preprocessing dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 7439.98it/s]
--> Num of Validation Set Batches loaded = 8
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                     | 0/19 [00:00<?, ?it/s]/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                     | 0/19 [00:00<?, ?it/s]/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                     | 0/19 [00:00<?, ?it/s]/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1/3, step 4/39 completed (loss: 0.8952157497406006):  11%|███████▉                                                                   | 2/19 [02:04<14:05, 49.76s/it]max training steps reached, stopping training, total train steps finished:  5
Training Epoch: 1/3, step 4/39 completed (loss: 0.7581993341445923):  11%|███████▉                                                                   | 2/19 [02:04<17:40, 62.37s/it]
Training Epoch: 1/3, step 4/39 completed (loss: 0.8486832976341248):  11%|███████▉                                                                   | 2/19 [02:04<17:39, 62.34s/it]
Training Epoch: 1/3, step 4/39 completed (loss: 0.9907992482185364):  11%|███████▉                                                                   | 2/19 [02:04<17:40, 62.40s/it]
Training Epoch: 1/3, step 4/39 completed (loss: 0.8952157497406006):  11%|███████▉                                                                   | 2/19 [02:04<17:38, 62.26s/it]
Max CUDA memory allocated was 63 GB
Max CUDA memory reserved was 78 GB
Peak active CUDA memory was 63 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 9 GB
evaluating Epoch:  62%|███████████████████████████████████████████████████████████████████████████████▍                                               | 5/8 [00:09<00:05,  1.81s/it]max eval steps reached, stopping evaluation, total_eval_steps:  5
evaluating Epoch:  62%|███████████████████████████████████████████████████████████████████████████████▍                                               | 5/8 [00:09<00:05,  1.85s/it]
evaluating Epoch:  62%|███████████████████████████████████████████████████████████████████████████████▍                                               | 5/8 [00:09<00:05,  1.85s/it]
evaluating Epoch:  62%|███████████████████████████████████████████████████████████████████████████████▍                                               | 5/8 [00:09<00:05,  1.84s/it]
evaluating Epoch:  62%|███████████████████████████████████████████████████████████████████████████████▍                                               | 5/8 [00:09<00:05,  1.83s/it]
 eval_ppl=tensor(2.8264, device='cuda:0') eval_epoch_loss=tensor(1.0390, device='cuda:0')
we are about to save the PEFT modules
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/home/hamidnazeri/miniconda3/envs/fsdp-qlora/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
PEFT modules are saved in peft-output directory
best eval loss on epoch 1 is 1.0389947891235352
Epoch 1: train_perplexity=1.1290, train_epoch_loss=0.1214, epoch time 125.38527124101529s
Key: avg_train_prep, Value: 1.129026174545288
Key: avg_train_loss, Value: 0.12135542184114456
Key: avg_eval_prep, Value: 2.8263745307922363
Key: avg_eval_loss, Value: 1.0389947891235352
Key: avg_epoch_time, Value: 125.38527124101529
Key: avg_checkpoint_time, Value: 1.9115908490202855
  • Test B
    Logs for Test B

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Thanks for contributing 🎉!

@HamidShojanazeri HamidShojanazeri requested a review from mreso June 24, 2024 17:55
Copy link
Contributor

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, added some suggestions which should be added before merging and we should enable users to overwrite setting with --quantization.quant_type as well

src/llama_recipes/configs/__init__.py Outdated Show resolved Hide resolved
src/llama_recipes/configs/quantization.py Outdated Show resolved Hide resolved
src/llama_recipes/finetuning.py Outdated Show resolved Hide resolved
src/llama_recipes/finetuning.py Show resolved Hide resolved
src/llama_recipes/finetuning.py Outdated Show resolved Hide resolved
src/llama_recipes/finetuning.py Outdated Show resolved Hide resolved
src/llama_recipes/finetuning.py Outdated Show resolved Hide resolved
src/llama_recipes/finetuning.py Outdated Show resolved Hide resolved
@mreso mreso merged commit 808a3f7 into main Jul 11, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants