Skip to content

Commit

Permalink
FIX low_cpu_mem_usage consolidates devices (#2113)
Browse files Browse the repository at this point in the history
See: huggingface/diffusers#9510 (comment)

Right now, the low_cpu_mem_usage=True option does not consolidate the
devices. E.g. when the model is on GPU and the state_dict on CPU, the
adapter weight will be on CPU after loading, when it should be GPU. This
fix ensures that the devices are consolidated.
  • Loading branch information
BenjaminBossan committed Oct 8, 2024
1 parent f0b066e commit 61c57f4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ def renamed_dora_weights(k):
)
if low_cpu_mem_usage:
load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
# ensure that the correct device is set
for module in model.modules():
if hasattr(module, "_move_adapter_to_device_of_base_layer"):
module._move_adapter_to_device_of_base_layer(adapter_name)
else:
load_result = model.load_state_dict(peft_model_state_dict, strict=False)

Expand Down
51 changes: 51 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@
PromptEncoderConfig,
TaskType,
get_peft_model,
get_peft_model_state_dict,
inject_adapter_in_model,
prepare_model_for_kbit_training,
replace_lora_weights_loftq,
set_peft_model_state_dict,
)
from peft.tuners import boft
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
Expand Down Expand Up @@ -3226,3 +3229,51 @@ def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):

torch.testing.assert_close(output_loaded, output_peft)
torch.testing.assert_close(gen_loaded, gen_peft)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestLowCpuMemUsageDifferentDevices:
"""Test for the low CPU memory usage option for loading PEFT models.
There are already tests for this in test_initialization.py but here we want to specifically test diverging devices
for the model and state_dict.
"""

model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"

@pytest.mark.parametrize("device_model, device_sd", [("cpu", "cuda"), ("cuda", "cpu")])
def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd):
inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)}
inputs = {k: v.to(device_model) for k, v in inputs.items()}

model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")
model = get_peft_model(model, lora_config)
model.eval()
logits_not_low_cpu_mem = model(**inputs).logits

state_dict = get_peft_model_state_dict(model)
peft_model_state_dict = {}
# remap the state dict so that it can be correctly loaded, and move weights to the other device
prefix = "base_model.model."
for k, v in state_dict.items():
k = k[len(prefix) :]
peft_model_state_dict[k] = v.to(device_sd)

del model

model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
model.eval()
inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True)
load_result = set_peft_model_state_dict(model, peft_model_state_dict, low_cpu_mem_usage=True)

# sanity check: all lora keys are matched
assert not any("lora" in k for k in load_result.missing_keys)
assert not any("lora" in k for k in load_result.unexpected_keys)

logits_low_cpu_mem = model(**inputs).logits

assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
assert {p.device.type for p in model.parameters()} == {device_model}

0 comments on commit 61c57f4

Please sign in to comment.