Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to make DoRA and QDoRA work with FSDP #1806

Merged

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented May 28, 2024

Supersedes #1797

Description

This is an alternative to the aforementioned PR to make DoRA and QDoRA work with FSDP.

Implementation

This PR is similar to #1797 but it moves all the DoRA functionality into a separate module class. Essentially, this is necessary because otherwise, the DoRA parameter lives on the lora.Linear layer as a parameter, not a separate module. Since FSDP auto wrap policy operates on the level of modules, not parameters, there is no way to modify the auto wrap policy to wrap the DoRA parameter, it must be its own module.

If not for this reason, #1797 would be preferable, since the number of code changes is smaller overall. In this PR, there are more numerous changes, but the majority only involves moving code around, not any actual code changes.

An additional required change was to make a defensive copy of the base layer before dequantizing its weight in order to calculate the weight norm for DoRA. Without this defensive copy, some side-effect is triggered in FSDP that results in

ValueError: Cannot flatten integer dtype tensors

even though the compute dtype of bnb is correctly set to float.

Compatibility

Since we introduce a new submodule, an extra steps are required to ensure that old DoRA state dicts can still be loaded correctly. This involves a fairly trivial extra remapping step in set_peft_model_state_dict. The test for this is performed via the new regression DoRA tests introduced in #1792. I ran it locally and it passed. I also performed a manual test to be extra sure.

Caveats

Some tests currently fail for me locally. This seems to be related to some strange side-effect that is triggered by creating a deepcopy of a bnb layer.

Two PRs have been merged to fix these issues:

As is, 8bit QDoRA will work once these fixes are released. For now, I exclude 8bit QDoRA, but I will update this PR, or create a follow up, once the bnb release is out and I tested it successfully.

Experiments

QDoRA (DoRA + bnb)

Experimental results are based on the QLoRA scripts in examples/sft/. The only difference is that DoRA is enabled and this change to the script:

1c1
< accelerate launch --config_file "configs/fsdp_config_qlora.yaml"  train.py \
---
> accelerate launch --config_file "configs/fsdp_config_qlora.yaml"  bb_train.py \
3c3
< --model_name_or_path "meta-llama/Llama-2-70b-hf" \
---
> --model_name_or_path "facebook/opt-125m" \
9c9
< --max_seq_len 2048 \
---
> --max_seq_len 256 \
16,19c16
< --push_to_hub \
< --hub_private_repo True \
< --hub_strategy "every_save" \
< --bf16 True \
---
> --bf16 False \
26c23
< --output_dir "llama-sft-qlora-fsdp" \
---
> --output_dir "/tmp/llama-sft-qlora-fsdp" \
33c30
< --use_flash_attn True \
---
> --use_flash_attn False \
41,42c38,39
< --bnb_4bit_compute_dtype "bfloat16" \
< --bnb_4bit_quant_storage_dtype "bfloat16"
\ No newline at end of file
---
> --bnb_4bit_compute_dtype "float32" \
> --bnb_4bit_quant_storage_dtype "float32"

Results for 4bit bnb + facebook/opt-125m

trainable params: 1,410,048 || all params: 126,649,344 || trainable%: 1.1133
***** Running training *****
  Num examples = 48,106
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 2
  Total optimization steps = 6,013
  Number of trainable parameters = 705,024
{'loss': 2.8957, 'grad_norm': 0.8884555697441101, 'learning_rate': 9.999982939289716e-05, 'epoch': 0.0}                                                  
{'loss': 2.7893, 'grad_norm': 0.9625115990638733, 'learning_rate': 9.999931757275294e-05, 'epoch': 0.0}                                                  
{'loss': 2.8415, 'grad_norm': 0.9538012146949768, 'learning_rate': 9.999846454306009e-05, 'epoch': 0.0}                                                  
{'loss': 2.6327, 'grad_norm': 0.9133555889129639, 'learning_rate': 9.999727030964001e-05, 'epoch': 0.0}                                                  
{'loss': 2.818, 'grad_norm': 1.0490074157714844, 'learning_rate': 9.999573488064242e-05, 'epoch': 0.0}                                                   
{'loss': 2.8878, 'grad_norm': 0.9281569719314575, 'learning_rate': 9.999385826654554e-05, 'epoch': 0.0}                                                  
{'loss': 2.8005, 'grad_norm': 0.976289689540863, 'learning_rate': 9.999164048015593e-05, 'epoch': 0.01}                                                  
{'loss': 2.8677, 'grad_norm': 1.1817463636398315, 'learning_rate': 9.998908153660838e-05, 'epoch': 0.01}                                                 
{'loss': 2.7082, 'grad_norm': 0.9303300380706787, 'learning_rate': 9.998618145336587e-05, 'epoch': 0.01}                                                 
{'loss': 2.7316, 'grad_norm': 1.15218985080719, 'learning_rate': 9.998294025021936e-05, 'epoch': 0.01}                                                   
{'loss': 2.6354, 'grad_norm': 4.059926986694336, 'learning_rate': 9.997935794928776e-05, 'epoch': 0.01}                                                  
{'loss': 2.716, 'grad_norm': 1.0557539463043213, 'learning_rate': 9.997543457501773e-05, 'epoch': 0.01}                                                  
{'loss': 2.6986, 'grad_norm': 1.0222344398498535, 'learning_rate': 9.997117015418345e-05, 'epoch': 0.01}                                                 
{'loss': 2.6453, 'grad_norm': 1.0596083402633667, 'learning_rate': 9.996656471588657e-05, 'epoch': 0.01}                                                 
{'loss': 2.6364, 'grad_norm': 0.9815739989280701, 'learning_rate': 9.996161829155588e-05, 'epoch': 0.01}                                                 
{'loss': 2.8204, 'grad_norm': 1.0197216272354126, 'learning_rate': 9.995633091494722e-05, 'epoch': 0.01}                                                 
{'loss': 2.7232, 'grad_norm': 1.0193029642105103, 'learning_rate': 9.995070262214313e-05, 'epoch': 0.01}                                                 
{'loss': 2.7582, 'grad_norm': 1.1623954772949219, 'learning_rate': 9.994473345155267e-05, 'epoch': 0.01}                                                 
{'loss': 2.5555, 'grad_norm': 0.973100483417511, 'learning_rate': 9.993842344391118e-05, 'epoch': 0.02}                                                  
{'loss': 2.7272, 'grad_norm': 1.0402048826217651, 'learning_rate': 9.993177264227992e-05, 'epoch': 0.02}                                                 
{'loss': 2.7338, 'grad_norm': 1.1844696998596191, 'learning_rate': 9.992478109204589e-05, 'epoch': 0.02}                                                 
{'loss': 2.7329, 'grad_norm': 1.2366889715194702, 'learning_rate': 9.991744884092137e-05, 'epoch': 0.02}                                                 
{'loss': 2.6435, 'grad_norm': 1.0378166437149048, 'learning_rate': 9.990977593894374e-05, 'epoch': 0.02}                                                 
{'loss': 2.5958, 'grad_norm': 1.067109227180481, 'learning_rate': 9.990176243847507e-05, 'epoch': 0.02}                                                  
{'loss': 2.6902, 'grad_norm': 1.067004680633545, 'learning_rate': 9.989340839420176e-05, 'epoch': 0.02}                                                  
{'loss': 2.6238, 'grad_norm': 1.0179526805877686, 'learning_rate': 9.988471386313418e-05, 'epoch': 0.02}                                                 
{'loss': 2.8078, 'grad_norm': 1.1208285093307495, 'learning_rate': 9.987567890460628e-05, 'epoch': 0.02}                                                 
{'loss': 2.6946, 'grad_norm': 1.125577688217163, 'learning_rate': 9.98663035802752e-05, 'epoch': 0.02}                                                   
{'loss': 2.6906, 'grad_norm': 1.0433801412582397, 'learning_rate': 9.985658795412079e-05, 'epoch': 0.02}                                                 
{'loss': 2.5874, 'grad_norm': 1.0063128471374512, 'learning_rate': 9.984653209244525e-05, 'epoch': 0.02}                                                 
{'loss': 2.6727, 'grad_norm': 1.2142236232757568, 'learning_rate': 9.983613606387265e-05, 'epoch': 0.03}                                                 
{'loss': 2.5666, 'grad_norm': 1.0894137620925903, 'learning_rate': 9.982539993934844e-05, 'epoch': 0.03}                                                 
{'loss': 2.6646, 'grad_norm': 1.0149219036102295, 'learning_rate': 9.981432379213898e-05, 'epoch': 0.03}                                                 
{'loss': 2.6702, 'grad_norm': 1.197405457496643, 'learning_rate': 9.980290769783103e-05, 'epoch': 0.03}                                                  
{'loss': 2.7494, 'grad_norm': 1.03690767288208, 'learning_rate': 9.979115173433128e-05, 'epoch': 0.03}                                                   
{'loss': 2.5861, 'grad_norm': 1.1195775270462036, 'learning_rate': 9.977905598186578e-05, 'epoch': 0.03}                                                 
{'loss': 2.7037, 'grad_norm': 1.0463085174560547, 'learning_rate': 9.976662052297935e-05, 'epoch': 0.03}                                                 
{'loss': 2.6419, 'grad_norm': 1.220333218574524, 'learning_rate': 9.97538454425351e-05, 'epoch': 0.03}                                                   
{'loss': 2.4702, 'grad_norm': 1.4125850200653076, 'learning_rate': 9.974073082771382e-05, 'epoch': 0.03}                                                 
{'loss': 2.6116, 'grad_norm': 1.117189884185791, 'learning_rate': 9.972727676801338e-05, 'epoch': 0.03}                                                  
{'loss': 2.7215, 'grad_norm': 1.1210697889328003, 'learning_rate': 9.971348335524808e-05, 'epoch': 0.03}                                                 
{'loss': 2.6411, 'grad_norm': 1.0325214862823486, 'learning_rate': 9.969935068354807e-05, 'epoch': 0.03}                                                 
  3%|███▉                                                                                                           | 210/6013 [07:44<3:45:27,  2.33s/it]

Results for 4bit bnb + meta-llama/Llama-2-7b-hf

trainable params: 21,348,352 || all params: 6,759,829,504 || trainable%: 0.3158
***** Running training *****
  Num examples = 54,783
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 2
  Total optimization steps = 6,848
  Number of trainable parameters = 10,674,176
{'loss': 1.6363, 'grad_norm': 0.4917793869972229, 'learning_rate': 9.999986846174718e-05, 'epoch': 0.0}                                                  
{'loss': 1.5694, 'grad_norm': 0.39926865696907043, 'learning_rate': 9.999947384768081e-05, 'epoch': 0.0}                                                 
{'loss': 1.4558, 'grad_norm': 0.3987913131713867, 'learning_rate': 9.999881615987716e-05, 'epoch': 0.0}                                                  
{'loss': 1.6005, 'grad_norm': 0.43162259459495544, 'learning_rate': 9.999789540179668e-05, 'epoch': 0.0}                                                 
{'loss': 1.5293, 'grad_norm': 0.43303075432777405, 'learning_rate': 9.999671157828396e-05, 'epoch': 0.0}                                                 
{'loss': 1.3762, 'grad_norm': 0.48176607489585876, 'learning_rate': 9.999526469556774e-05, 'epoch': 0.0}                                                 
  0%|▍                                                                                                              | 30/6848 [17:10<64:44:41, 34.19s/it]

DoRA (without quantization)

DoRA without quantization should work. If you use gradient checkpointing, remember to manually call model.gradient_checkpointing_enable(), as this is not automatically called without quantization.

Results for facebook/opt-125m

***** Running training *****
  Num examples = 48,106
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 3,007
  Number of trainable parameters = 705,024
  0%|                                                                                                                           | 0/3007 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
{'loss': 2.7924, 'grad_norm': 0.6611325144767761, 'learning_rate': 9.999931779967976e-05, 'epoch': 0.0}                                                  
{'loss': 2.714, 'grad_norm': 0.5837051868438721, 'learning_rate': 9.999727121733491e-05, 'epoch': 0.0}                                                   
{'loss': 2.8399, 'grad_norm': 0.6707738637924194, 'learning_rate': 9.999386030881264e-05, 'epoch': 0.0}                                                  
{'loss': 2.8308, 'grad_norm': 0.6908975839614868, 'learning_rate': 9.998908516718984e-05, 'epoch': 0.01}                                                 
{'loss': 2.7241, 'grad_norm': 0.7178784608840942, 'learning_rate': 9.998294592277063e-05, 'epoch': 0.01}                                                 
{'loss': 2.6833, 'grad_norm': 1.1266907453536987, 'learning_rate': 9.997544274308281e-05, 'epoch': 0.01}                                                 
{'loss': 2.6706, 'grad_norm': 0.6946617364883423, 'learning_rate': 9.996657583287326e-05, 'epoch': 0.01}                                                 
{'loss': 2.7254, 'grad_norm': 0.6686468124389648, 'learning_rate': 9.99563454341023e-05, 'epoch': 0.01}                                                  
{'loss': 2.7375, 'grad_norm': 0.7301179766654968, 'learning_rate': 9.994475182593722e-05, 'epoch': 0.01}                                                 
{'loss': 2.6461, 'grad_norm': 0.6995381712913513, 'learning_rate': 9.99317953247445e-05, 'epoch': 0.02}                                                  
{'loss': 2.7338, 'grad_norm': 0.7221236228942871, 'learning_rate': 9.991747628408137e-05, 'epoch': 0.02}                                                 
{'loss': 2.6135, 'grad_norm': 0.7124046683311462, 'learning_rate': 9.990179509468596e-05, 'epoch': 0.02}                                                 
{'loss': 2.6589, 'grad_norm': 0.7173681259155273, 'learning_rate': 9.988475218446675e-05, 'epoch': 0.02}                                                 
{'loss': 2.7504, 'grad_norm': 0.7079339623451233, 'learning_rate': 9.986634801849093e-05, 'epoch': 0.02}                                                 
{'loss': 2.6336, 'grad_norm': 0.6653233766555786, 'learning_rate': 9.984658309897162e-05, 'epoch': 0.02}                                                 
{'loss': 2.6204, 'grad_norm': 0.6643339991569519, 'learning_rate': 9.982545796525415e-05, 'epoch': 0.03}                                                 
{'loss': 2.6584, 'grad_norm': 0.7454470992088318, 'learning_rate': 9.980297319380147e-05, 'epoch': 0.03}                                                 
{'loss': 2.6579, 'grad_norm': 0.7237271070480347, 'learning_rate': 9.977912939817832e-05, 'epoch': 0.03}                                                 
{'loss': 2.6649, 'grad_norm': 0.7269802689552307, 'learning_rate': 9.97539272290345e-05, 'epoch': 0.03}                                                  
{'loss': 2.5364, 'grad_norm': 0.7742022275924683, 'learning_rate': 9.97273673740871e-05, 'epoch': 0.03}                                                  
  3%|███▊                                                                                                           | 104/3007 [02:00<1:13:12,  1.51s/it]

Results for meta-llama/Llama-2-7b-hf

Without quantization, I run OOM.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

8bit bnb does not work because after making a deepcopy of a Int8Params
instance, it cannot be successfully dequantized in PEFT as it loses some
attributes.
@BenjaminBossan BenjaminBossan marked this pull request as ready for review May 31, 2024 10:16
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! I don't have any major comment, one comment you raised offline is about forward compatibility (being able to load DoRA models with old PEFT versions), it would makes sense to remap the state dict when loading as you suggested !

@@ -233,7 +237,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
else:
output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
x = dropout(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this always apply dropout , even during eval mode?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I wanted to mention this in the PR description but forgot. I think this was actually a bug in QDoRA previously, as dropout was not applied at all. For reference, see how we do it in normal LoRA:

x = dropout(x)
result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)

Since dropout is an instance of nn.Dropout, when we call .eval(), PyTorch will automatically deactivate dropout.

>>> drop = nn.Dropout(0.5)
>>> x = torch.ones(10)
>>> drop(x)
tensor([2., 0., 0., 0., 2., 0., 0., 2., 2., 2.])
>>> drop.eval()
>>> drop(x)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thank you for explaining!


import torch
from torch import nn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import torch.nn.functional as F

@@ -348,6 +348,18 @@ def set_peft_model_state_dict(
" PRNG initialisation to restore these projections using `config.projection_prng_key`, which may"
" not be accurate on all system configurations."
)
elif config.peft_type == PeftType.LORA:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for taking care of that!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I guess you didn't like F = nn.functional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah hahah

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's less typing though ;-)

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot Benjamin !

@BenjaminBossan BenjaminBossan merged commit a0788a3 into huggingface:main May 31, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants