Skip to content

Commit

Permalink
FIX Decrease memory overhead of merging (#1944)
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 authored Jul 23, 2024
1 parent ebcd079 commit 2ce83e0
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
orig_weights = base_layer.weight.data.clone()
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
orig_weights = orig_weights + delta_weight
orig_weights += delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
Expand All @@ -452,7 +452,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
else:
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
base_layer.weight.data = base_layer.weight.data + delta_weight
base_layer.weight.data += delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
Expand Down Expand Up @@ -659,7 +659,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
orig_weights = orig_weights + self.get_delta_weight(active_adapter)
orig_weights += self.get_delta_weight(active_adapter)

if not torch.isfinite(orig_weights).all():
raise ValueError(
Expand All @@ -668,7 +668,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N

base_layer.weight.data = orig_weights
else:
base_layer.weight.data = base_layer.weight.data + self.get_delta_weight(active_adapter)
base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand Down Expand Up @@ -900,7 +900,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
delta_weight = self.get_delta_weight(active_adapter)

if not self.use_dora[active_adapter]:
orig_weights = orig_weights + delta_weight
orig_weights += delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
Expand All @@ -924,7 +924,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
else:
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
base_layer.weight.data = base_layer.weight.data + delta_weight
base_layer.weight.data += delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
Expand Down

0 comments on commit 2ce83e0

Please sign in to comment.