Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DoRA #1474

Merged
merged 21 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ from peft import LoraConfig
config = LoraConfig(use_rslora=True, ...)
```

### Weight-Decomposed Low-Rank Adaptation (DoRA)

This technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference. For more information, see https://arxiv.org/abs/2402.09353.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we support conv layers for regular LoRA, would it be very non-trivial to support conv layers with DoRA?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also hyperlink "merge weights for inference" so that users know what we're talking about here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a link to merging.

would it be very non-trivial to support conv layers with DoRA?

Yes, I think it's not quite straightforward because DoRA is defined for 2d weights, but for conv weights that wouldn't apply. There is probably no theoretical limitation, so I hope we can add it in a future PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nbasyl what are your thoughts here? If we can support this (potentially in a future PR) it might open new avenues for the diffusion community.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think DoRA can definitely be applied on conv2d similar to how LoRA is applied to conv2d. And the normalization could be calculated per-input channel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We will loop you in when we add the support!

Cc: @BenjaminBossan

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Let's add Conv2d in a follow-up PR once this one is merged.


```py
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
from peft import LoraConfig

config = LoraConfig(use_dora=True, ...)
```

sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
### QLoRA-style training

The default LoRA settings in PEFT add trainable weights to the query and value layers of each attention block. But [QLoRA](https://hf.co/papers/2305.14314), which adds trainable weights to all the linear layers of a transformer model, can provide performance equal to a fully finetuned model. To apply LoRA to all the linear layers, like in QLoRA, set `target_modules="all-linear"` (easier than specifying individual modules by name which can vary depending on the architecture).
Expand Down
28 changes: 26 additions & 2 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,25 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down Expand Up @@ -216,13 +228,25 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down
23 changes: 23 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ class LoraConfig(PeftConfig):
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
quantized model in this case, as LoftQ will quantize the model itself.
use_dora (`bool`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice docstrings.

Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights
into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is
handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low
ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than
pure LoRA, so it is recommended to merge weights for inference. For more information, see
https://arxiv.org/abs/2402.09353.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -224,6 +231,19 @@ class LoraConfig(PeftConfig):
)
},
)
use_dora: bool = field(
default=False,
metadata={
"help": (
"Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the "
"weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the "
"magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, "
"especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces "
"a bigger overhead than pure LoRA, so it is recommended to merge weights for inference. For more "
"information, see https://arxiv.org/abs/2402.09353."
)
},
)

def __post_init__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we block some mis-intended usage ? e.g. if one passes use_dora & loft_config etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check for loftq and megatron.

self.peft_type = PeftType.LORA
Expand All @@ -238,6 +258,9 @@ def __post_init__(self):
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")

if self.use_dora and (self.megatron_config or self.init_lora_weights == "loftq"):
raise ValueError("DoRA does not support megatron_core or LoftQ. Please set `use_dora=False`.")

# handle init_lora_weights and loftq_config
if self.init_lora_weights == "loftq":
import importlib
Expand Down
14 changes: 13 additions & 1 deletion src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,28 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
Expand Down
147 changes: 137 additions & 10 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import gather_params_ctx
from peft.utils.other import transpose

from .config import LoraConfig
Expand All @@ -47,6 +48,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
self.use_dora: dict[str, bool] = {}
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA
self._caches: dict[str, Any] = {}
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -75,7 +79,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.in_features = in_features
self.out_features = out_features

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False
):
# This code works for linear layers, override for other layer types
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
Expand Down Expand Up @@ -111,6 +117,13 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
else:
self.to(weight.device)
break

if use_dora:
self.dora_init(adapter_name)
self.use_dora[adapter_name] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this technically means you can have multiple adapters, with some of them using DoRA and some not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so, although it wouldn't currently be possible to configure it as such.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even if you do add_adapter witha LoraConfig that has use_dora=True ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could have multiple adapters, some of which use DoRA and some don't. I believe this would be fine or can you think of an issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No i don't think this would be an issue at all, maybe we could emphasize it somewhere as it could be an interesting usecase

else:
self.use_dora[adapter_name] = False

self.set_adapter(self.active_adapters)

def reset_lora_parameters(self, adapter_name, init_lora_weights):
Expand Down Expand Up @@ -153,6 +166,47 @@ def loftq_init(self, adapter_name):
self.lora_embedding_B[adapter_name].weight.data = lora_B
self.get_base_layer().weight.data = qweight

def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1)
return weight_norm

def dora_init(self, adapter_name: str) -> None:
lora_A = self.lora_A[adapter_name]
lora_B = self.lora_B[adapter_name]
scaling = self.scaling[adapter_name]
with gather_params_ctx(self.get_base_layer()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think there is a way we could always avoid calling that context manager? I wonder if this might create some weird interactions if one does not use DeepSpeed Zero-3 with Trainer as is_deepspeed_zero_3_available() seems quite specific to Trainer: https://github.com/huggingface/transformers/blob/a3f9221a449e9b949e71d9b047c66186f023481f/src/transformers/integrations/deepspeed.py#L286

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand the question. We have to call the context manager or else this will break with DeepSpeed. If a user does not use DS but uses Trainer, should is_deepspeed_zero3_enabled() not return False? In this case the context manager just does nothing and should still work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok yes makes sense!

weight = self.get_base_layer().weight
lora_weight = lora_B.weight @ lora_A.weight
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
self.lora_magnitude_vector = nn.ParameterDict()
self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True)
# add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)

def _cache_store(self, key: str, value: Any) -> None:
self._caches[key] = value

def _cache_pop(self, key: str) -> Any:
value = self._caches.pop(key)
return value

def apply_dora(self, x, lora_weight, active_adapter):
scaling = self.scaling[active_adapter]
magnitude = self.lora_magnitude_vector[active_adapter]
weight = self.get_base_layer().weight
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be the following? There are a few other places that we need to do the same, otherwise we get mismatching dimensions.

weight_norm = self._get_weight_norm(transpose(weight, self.fan_in_fan_out), lora_weight, scaling)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. Indeed, models that use Conv1D like GPT2 wouldn't work right now. I created a PR to fix this: #1588.

# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it won’t receive any gradient
# during backpropagation"
weight_norm = weight_norm.detach()
dora_weight = transpose(weight + scaling * lora_weight, self.fan_in_fan_out)
return (magnitude / weight_norm - 1).view(1, -1) * F.linear(x, dora_weight)

def set_scale(self, adapter, scale):
if adapter not in self.scaling:
# Ignore the case where the adapter is not in the layer
Expand Down Expand Up @@ -203,14 +257,23 @@ def __init__(
is_target_conv_1d_layer: bool = False,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
Expand Down Expand Up @@ -238,7 +301,19 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter)
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
orig_weights += delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach()
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
orig_weights = dora_factor.view(-1, 1) * (orig_weights + delta_weight)

if not torch.isfinite(orig_weights).all():
raise ValueError(
Expand All @@ -247,7 +322,21 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N

base_layer.weight.data = orig_weights
else:
base_layer.weight.data += self.get_delta_weight(active_adapter)
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
base_layer.weight.data += delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach()
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
new_weight = dora_factor.view(-1, 1) * (base_layer.weight.data + delta_weight)
base_layer.weight.data = new_weight

self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand All @@ -260,7 +349,15 @@ def unmerge(self) -> None:
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
weight = self.get_base_layer().weight
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
weight.data -= delta_weight
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight
weight.data = weight_orig

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Expand Down Expand Up @@ -314,7 +411,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
lora_weight = lora_B.weight @ lora_A.weight
result = result + lora_B(lora_A(x)) * scaling + self.apply_dora(x, lora_weight, active_adapter)

result = result.to(torch_result_dtype)
return result
Expand All @@ -335,15 +438,27 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

Expand Down Expand Up @@ -511,15 +626,27 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

Expand Down
Loading
Loading