-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor to make DoRA and QDoRA work with FSDP #1806
Refactor to make DoRA and QDoRA work with FSDP #1806
Conversation
This causes trouble with loading trained DoRA models
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
8bit bnb does not work because after making a deepcopy of a Int8Params instance, it cannot be successfully dequantized in PEFT as it loses some attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot ! I don't have any major comment, one comment you raised offline is about forward compatibility (being able to load DoRA models with old PEFT versions), it would makes sense to remap the state dict when loading as you suggested !
@@ -233,7 +237,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: | |||
if not self.use_dora[active_adapter]: | |||
output = lora_B(lora_A(dropout(x))) * scaling | |||
else: | |||
output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) | |||
x = dropout(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this always apply dropout , even during eval mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, I wanted to mention this in the PR description but forgot. I think this was actually a bug in QDoRA previously, as dropout was not applied at all. For reference, see how we do it in normal LoRA:
peft/src/peft/tuners/lora/layer.py
Lines 564 to 565 in cb0bf07
x = dropout(x) | |
result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) |
Since dropout
is an instance of nn.Dropout
, when we call .eval()
, PyTorch will automatically deactivate dropout.
>>> drop = nn.Dropout(0.5)
>>> x = torch.ones(10)
>>> drop(x)
tensor([2., 0., 0., 0., 2., 0., 0., 2., 2., 2.])
>>> drop.eval()
>>> drop(x)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thank you for explaining!
|
||
import torch | ||
from torch import nn | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch.nn.functional as F |
@@ -348,6 +348,18 @@ def set_peft_model_state_dict( | |||
" PRNG initialisation to restore these projections using `config.projection_prng_key`, which may" | |||
" not be accurate on all system configurations." | |||
) | |||
elif config.peft_type == PeftType.LORA: | |||
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for taking care of that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I guess you didn't like F = nn.functional
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah hahah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's less typing though ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot Benjamin !
Supersedes #1797
Description
This is an alternative to the aforementioned PR to make DoRA and QDoRA work with FSDP.
Implementation
This PR is similar to #1797 but it moves all the DoRA functionality into a separate module class. Essentially, this is necessary because otherwise, the DoRA parameter lives on the
lora.Linear
layer as a parameter, not a separate module. Since FSDP auto wrap policy operates on the level of modules, not parameters, there is no way to modify the auto wrap policy to wrap the DoRA parameter, it must be its own module.If not for this reason, #1797 would be preferable, since the number of code changes is smaller overall. In this PR, there are more numerous changes, but the majority only involves moving code around, not any actual code changes.
An additional required change was to make a defensive copy of the base layer before dequantizing its weight in order to calculate the weight norm for DoRA. Without this defensive copy, some side-effect is triggered in FSDP that results in
even though the compute dtype of bnb is correctly set to float.
Compatibility
Since we introduce a new submodule, an extra steps are required to ensure that old DoRA state dicts can still be loaded correctly. This involves a fairly trivial extra remapping step in
set_peft_model_state_dict
. The test for this is performed via the new regression DoRA tests introduced in #1792. I ran it locally and it passed. I also performed a manual test to be extra sure.Caveats
Some tests currently fail for me locally. This seems to be related to some strange side-effect that is triggered by creating adeepcopy
of a bnb layer.Two PRs have been merged to fix these issues:
As is, 8bit QDoRA will work once these fixes are released. For now, I exclude 8bit QDoRA, but I will update this PR, or create a follow up, once the bnb release is out and I tested it successfully.
Experiments
QDoRA (DoRA + bnb)
Experimental results are based on the QLoRA scripts in
examples/sft/
. The only difference is that DoRA is enabled and this change to the script:Results for 4bit bnb +
facebook/opt-125m
Results for 4bit bnb +
meta-llama/Llama-2-7b-hf
DoRA (without quantization)
DoRA without quantization should work. If you use gradient checkpointing, remember to manually call
model.gradient_checkpointing_enable()
, as this is not automatically called without quantization.Results for
facebook/opt-125m
Results for
meta-llama/Llama-2-7b-hf
Without quantization, I run OOM.