Skip to content

Commit

Permalink
PiSSA, OLoRA: Delete initial adapter after conversion instead of the …
Browse files Browse the repository at this point in the history
…active adapter (#1933)

Resolves #1860

As discussed in that issue, it's not user friendly to delete the default
adapter of a PiSSA/OLoRA model after calling save_pretrained with weight
conversion. Instead, it is much more intuitive to delete the initial
adapter instead, since it is loaded inside the method and not by the
user, so it's really an implementation detail.

Apart from this, I made the following related changes:

- Put everything in a try ... finally to ensure that the initial adapter
  does not hang around if there is an error (thus not hogging memory).
- Renamed initial_adapter to initial_adapter_name, to make it clear that
  this is the name and not the adapter itself.
  • Loading branch information
BenjaminBossan authored Jul 24, 2024
1 parent 2ce83e0 commit 05f57e9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
37 changes: 21 additions & 16 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,24 +271,29 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion,
str(peft_config.init_lora_weights).lower().startswith(prefix) for prefix in ["pissa", "olora", "true"]
):
warnings.warn(
"`path_initial_model_for_weight_conversion` only works for converting a PiSSA or OLoRA adapter to a LoRA adapter"
"`path_initial_model_for_weight_conversion` only works for converting a PiSSA or OLoRA adapter to "
"a LoRA adapter"
)
initial_adapter = os.path.basename(path_initial_model_for_weight_conversion)
self.load_adapter(
os.path.dirname(path_initial_model_for_weight_conversion),
subfolder=initial_adapter,
adapter_name=initial_adapter,
)
if any(
str(self.peft_config[initial_adapter].init_lora_weights).lower().startswith(prefix)
for prefix in ["pissa", "olora"]
):
raise ValueError(
"The `init_lora_weights` parameter of the initial adapter should be set to `True`. "
"Otherwise, `self.load_adapter` will subtract the decomposed values again based on the residual model."
initial_adapter_name = os.path.basename(path_initial_model_for_weight_conversion)
try:
self.load_adapter(
os.path.dirname(path_initial_model_for_weight_conversion),
subfolder=initial_adapter_name,
adapter_name=initial_adapter_name,
)
is_pissa = str(self.peft_config[initial_adapter_name].init_lora_weights).lower().startswith("pissa")
is_olora = str(self.peft_config[initial_adapter_name].init_lora_weights).lower() == "olora"
if is_pissa or is_olora:
raise ValueError(
"The `init_lora_weights` parameter of the initial adapter should be set to `True`. "
"Otherwise, `self.load_adapter` will subtract the decomposed values again based on the "
"residual model."
)
output_state_dict = self.base_model.subtract_mutated_init(
output_state_dict, initial_adapter_name, kwargs
)
output_state_dict = self.base_model.subtract_mutated_init(output_state_dict, initial_adapter, kwargs)
self.delete_adapter(adapter_name)
finally:
self.delete_adapter(initial_adapter_name)
return output_state_dict

if is_main_process:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,13 @@ def test_lora_pissa_conversion_same_output_after_loading(self, data, tmp_path):
)

# save the model with conversion
peft_config_keys_before = list(peft_model.peft_config.keys())
peft_model.save_pretrained(
tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
)
peft_config_keys_after = list(peft_model.peft_config.keys())
assert peft_config_keys_before == peft_config_keys_after

model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted")
output_converted = model_converted(data)[0]

Expand Down Expand Up @@ -597,9 +601,13 @@ def test_olora_conversion_same_output_after_loading(self, data, tmp_path):
)

# save the model with conversion
peft_config_keys_before = list(peft_model.peft_config.keys())
peft_model.save_pretrained(
tmp_path / "olora-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
)
peft_config_keys_after = list(peft_model.peft_config.keys())
assert peft_config_keys_before == peft_config_keys_after

model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model-converted")
output_converted = model_converted(data)[0]

Expand Down

0 comments on commit 05f57e9

Please sign in to comment.