-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement DoRA #1474
Implement DoRA #1474
Changes from 16 commits
6242f4a
32ffc50
53a5e25
18fb476
e4a677f
c42bb0a
b915ea2
8bd22d7
ecf7160
07f3e43
5b22170
f15dbae
1e6d1d7
69a81a6
4a90843
951ae67
9c4edc1
4bf7346
ebdda07
6194726
51d4919
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,6 +101,13 @@ class LoraConfig(PeftConfig): | |
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights | ||
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a | ||
quantized model in this case, as LoftQ will quantize the model itself. | ||
use_dora (`bool`): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice docstrings. |
||
Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights | ||
into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is | ||
handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low | ||
ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than | ||
pure LoRA, so it is recommended to merge weights for inference. For more information, see | ||
https://arxiv.org/abs/2402.09353. | ||
""" | ||
|
||
r: int = field(default=8, metadata={"help": "Lora attention dimension"}) | ||
|
@@ -224,6 +231,19 @@ class LoraConfig(PeftConfig): | |
) | ||
}, | ||
) | ||
use_dora: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": ( | ||
"Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the " | ||
"weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " | ||
"magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " | ||
"especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces " | ||
"a bigger overhead than pure LoRA, so it is recommended to merge weights for inference. For more " | ||
"information, see https://arxiv.org/abs/2402.09353." | ||
) | ||
}, | ||
) | ||
|
||
def __post_init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we block some mis-intended usage ? e.g. if one passes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a check for loftq and megatron. |
||
self.peft_type = PeftType.LORA | ||
|
@@ -238,6 +258,9 @@ def __post_init__(self): | |
if isinstance(self.target_modules, str) and self.layers_pattern is not None: | ||
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") | ||
|
||
if self.use_dora and (self.megatron_config or self.init_lora_weights == "loftq"): | ||
raise ValueError("DoRA does not support megatron_core or LoftQ. Please set `use_dora=False`.") | ||
|
||
# handle init_lora_weights and loftq_config | ||
if self.init_lora_weights == "loftq": | ||
import importlib | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from transformers.pytorch_utils import Conv1D | ||
|
||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge | ||
from peft.utils.integrations import gather_params_ctx | ||
from peft.utils.other import transpose | ||
|
||
from .config import LoraConfig | ||
|
@@ -47,6 +48,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: | |
# Mark the weight as unmerged | ||
self._disable_adapters = False | ||
self.merged_adapters = [] | ||
self.use_dora: dict[str, bool] = {} | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA | ||
self._caches: dict[str, Any] = {} | ||
self.kwargs = kwargs | ||
|
||
base_layer = self.get_base_layer() | ||
|
@@ -75,7 +79,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: | |
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora): | ||
def update_layer( | ||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False | ||
): | ||
# This code works for linear layers, override for other layer types | ||
if r <= 0: | ||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | ||
|
@@ -111,6 +117,13 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig | |
else: | ||
self.to(weight.device) | ||
break | ||
|
||
if use_dora: | ||
self.dora_init(adapter_name) | ||
self.use_dora[adapter_name] = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this technically means you can have multiple adapters, with some of them using DoRA and some not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess so, although it wouldn't currently be possible to configure it as such. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. even if you do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we could have multiple adapters, some of which use DoRA and some don't. I believe this would be fine or can you think of an issue? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No i don't think this would be an issue at all, maybe we could emphasize it somewhere as it could be an interesting usecase |
||
else: | ||
self.use_dora[adapter_name] = False | ||
|
||
self.set_adapter(self.active_adapters) | ||
|
||
def reset_lora_parameters(self, adapter_name, init_lora_weights): | ||
|
@@ -153,6 +166,47 @@ def loftq_init(self, adapter_name): | |
self.lora_embedding_B[adapter_name].weight.data = lora_B | ||
self.get_base_layer().weight.data = qweight | ||
|
||
def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: | ||
# calculate L2 norm of weight matrix, column-wise | ||
weight = weight + scaling * lora_weight | ||
weight_norm = torch.linalg.norm(weight, dim=1) | ||
return weight_norm | ||
|
||
def dora_init(self, adapter_name: str) -> None: | ||
lora_A = self.lora_A[adapter_name] | ||
lora_B = self.lora_B[adapter_name] | ||
scaling = self.scaling[adapter_name] | ||
with gather_params_ctx(self.get_base_layer()): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think there is a way we could always avoid calling that context manager? I wonder if this might create some weird interactions if one does not use DeepSpeed Zero-3 with Trainer as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand the question. We have to call the context manager or else this will break with DeepSpeed. If a user does not use DS but uses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok yes makes sense! |
||
weight = self.get_base_layer().weight | ||
lora_weight = lora_B.weight @ lora_A.weight | ||
weight_norm = self._get_weight_norm(weight, lora_weight, scaling) | ||
self.lora_magnitude_vector = nn.ParameterDict() | ||
self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True) | ||
# add lora_magnitude_vector to the list of learnable parameters | ||
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) | ||
|
||
def _cache_store(self, key: str, value: Any) -> None: | ||
self._caches[key] = value | ||
|
||
def _cache_pop(self, key: str) -> Any: | ||
value = self._caches.pop(key) | ||
return value | ||
|
||
def apply_dora(self, x, lora_weight, active_adapter): | ||
scaling = self.scaling[active_adapter] | ||
magnitude = self.lora_magnitude_vector[active_adapter] | ||
weight = self.get_base_layer().weight | ||
weight_norm = self._get_weight_norm(weight, lora_weight, scaling) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be the following? There are a few other places that we need to do the same, otherwise we get mismatching dimensions.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out. Indeed, models that use Conv1D like GPT2 wouldn't work right now. I created a PR to fix this: #1588. |
||
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) | ||
# "[...] we suggest treating ||V +∆V ||_c in | ||
# Eq. (5) as a constant, thereby detaching it from the gradient | ||
# graph. This means that while ||V + ∆V ||_c dynamically | ||
# reflects the updates of ∆V , it won’t receive any gradient | ||
# during backpropagation" | ||
weight_norm = weight_norm.detach() | ||
dora_weight = transpose(weight + scaling * lora_weight, self.fan_in_fan_out) | ||
return (magnitude / weight_norm - 1).view(1, -1) * F.linear(x, dora_weight) | ||
|
||
def set_scale(self, adapter, scale): | ||
if adapter not in self.scaling: | ||
# Ignore the case where the adapter is not in the layer | ||
|
@@ -203,14 +257,23 @@ def __init__( | |
is_target_conv_1d_layer: bool = False, | ||
init_lora_weights: Union[bool, str] = True, | ||
use_rslora: bool = False, | ||
use_dora: bool = False, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
LoraLayer.__init__(self, base_layer, **kwargs) | ||
self.fan_in_fan_out = fan_in_fan_out | ||
|
||
self._active_adapter = adapter_name | ||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||
self.update_layer( | ||
adapter_name, | ||
r, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=lora_dropout, | ||
init_lora_weights=init_lora_weights, | ||
use_rslora=use_rslora, | ||
use_dora=use_dora, | ||
) | ||
self.is_target_conv_1d_layer = is_target_conv_1d_layer | ||
|
||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: | ||
|
@@ -238,7 +301,19 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N | |
# Note that safe_merge will be slower than the normal merge | ||
# because of the copy operation. | ||
orig_weights = base_layer.weight.data.clone() | ||
orig_weights += self.get_delta_weight(active_adapter) | ||
delta_weight = self.get_delta_weight(active_adapter) | ||
if not self.use_dora[active_adapter]: | ||
orig_weights += delta_weight | ||
else: | ||
# handle dora | ||
# since delta_weight already includes scaling, set it to 1 here | ||
weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach() | ||
# We need to cache weight_norm because it has to be based on the original weights. We | ||
# cannot calculate it on the fly based on the merged weights when unmerging because its a | ||
# different value | ||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm) | ||
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm | ||
orig_weights = dora_factor.view(-1, 1) * (orig_weights + delta_weight) | ||
|
||
if not torch.isfinite(orig_weights).all(): | ||
raise ValueError( | ||
|
@@ -247,7 +322,21 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N | |
|
||
base_layer.weight.data = orig_weights | ||
else: | ||
base_layer.weight.data += self.get_delta_weight(active_adapter) | ||
delta_weight = self.get_delta_weight(active_adapter) | ||
if not self.use_dora[active_adapter]: | ||
base_layer.weight.data += delta_weight | ||
else: | ||
# handle dora | ||
# since delta_weight already includes scaling, set it to 1 here | ||
weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach() | ||
# We need to cache weight_norm because it has to be based on the original weights. We | ||
# cannot calculate it on the fly based on the merged weights when unmerging because its a | ||
# different value | ||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm) | ||
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm | ||
new_weight = dora_factor.view(-1, 1) * (base_layer.weight.data + delta_weight) | ||
base_layer.weight.data = new_weight | ||
|
||
self.merged_adapters.append(active_adapter) | ||
|
||
def unmerge(self) -> None: | ||
|
@@ -260,7 +349,15 @@ def unmerge(self) -> None: | |
while len(self.merged_adapters) > 0: | ||
active_adapter = self.merged_adapters.pop() | ||
if active_adapter in self.lora_A.keys(): | ||
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) | ||
weight = self.get_base_layer().weight | ||
delta_weight = self.get_delta_weight(active_adapter) | ||
if not self.use_dora[active_adapter]: | ||
weight.data -= delta_weight | ||
else: | ||
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") | ||
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm | ||
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight | ||
weight.data = weight_orig | ||
|
||
def get_delta_weight(self, adapter) -> torch.Tensor: | ||
""" | ||
|
@@ -314,7 +411,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
dropout = self.lora_dropout[active_adapter] | ||
scaling = self.scaling[active_adapter] | ||
x = x.to(lora_A.weight.dtype) | ||
result += lora_B(lora_A(dropout(x))) * scaling | ||
|
||
if not self.use_dora[active_adapter]: | ||
result = result + lora_B(lora_A(dropout(x))) * scaling | ||
else: | ||
x = dropout(x) | ||
lora_weight = lora_B.weight @ lora_A.weight | ||
result = result + lora_B(lora_A(x)) * scaling + self.apply_dora(x, lora_weight, active_adapter) | ||
|
||
result = result.to(torch_result_dtype) | ||
return result | ||
|
@@ -335,15 +438,27 @@ def __init__( | |
lora_dropout: float = 0.0, | ||
init_lora_weights: Union[bool, str] = True, | ||
use_rslora: bool = False, | ||
use_dora: bool = False, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
LoraLayer.__init__(self, base_layer) | ||
|
||
if use_dora: | ||
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") | ||
|
||
self._active_adapter = adapter_name | ||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||
self.update_layer( | ||
adapter_name, | ||
r, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=lora_dropout, | ||
init_lora_weights=init_lora_weights, | ||
use_rslora=use_rslora, | ||
use_dora=use_dora, | ||
) | ||
|
||
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora): | ||
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): | ||
if r <= 0: | ||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | ||
|
||
|
@@ -511,15 +626,27 @@ def __init__( | |
lora_dropout: float = 0.0, | ||
init_lora_weights: Union[bool, str] = True, | ||
use_rslora: bool = False, | ||
use_dora: bool = False, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
LoraLayer.__init__(self, base_layer) | ||
|
||
if use_dora: | ||
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") | ||
|
||
self._active_adapter = adapter_name | ||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||
self.update_layer( | ||
adapter_name, | ||
r, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=lora_dropout, | ||
init_lora_weights=init_lora_weights, | ||
use_rslora=use_rslora, | ||
use_dora=use_dora, | ||
) | ||
|
||
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora): | ||
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): | ||
if r <= 0: | ||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we support conv layers for regular LoRA, would it be very non-trivial to support conv layers with DoRA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also hyperlink "merge weights for inference" so that users know what we're talking about here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a link to merging.
Yes, I think it's not quite straightforward because DoRA is defined for 2d weights, but for conv weights that wouldn't apply. There is probably no theoretical limitation, so I hope we can add it in a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nbasyl what are your thoughts here? If we can support this (potentially in a future PR) it might open new avenues for the diffusion community.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think DoRA can definitely be applied on conv2d similar to how LoRA is applied to conv2d. And the normalization could be calculated per-input channel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! We will loop you in when we add the support!
Cc: @BenjaminBossan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. Let's add Conv2d in a follow-up PR once this one is merged.