Skip to content

Commit

Permalink
[PEFT] Support low_cpu_mem_usage option for PEFT loading adapters (#3…
Browse files Browse the repository at this point in the history
…3725)

* [PEFT] Support low_cpu_mem_usage for PEFT loading

PEFT added support for low_cpu_mem_usage=True when loading adapters in
huggingface/peft#1961. This feature is now
available when installing PEFT v0.13.0. With this PR, this option is
also supported when loading PEFT adapters directly into transformers
models.

Additionally, with this PR,
huggingface/diffusers#9510 will be unblocked,
which implements this option in diffusers.

* Fix typo
  • Loading branch information
BenjaminBossan authored Oct 3, 2024
1 parent bf0ffe3 commit 6500f78
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
25 changes: 23 additions & 2 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import warnings
from typing import Any, Dict, List, Optional, Union

from packaging import version

from ..utils import (
check_peft_version,
find_adapter_config_file,
Expand Down Expand Up @@ -77,6 +80,7 @@ def load_adapter(
offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
low_cpu_mem_usage: bool = False,
adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Expand Down Expand Up @@ -129,12 +133,27 @@ def load_adapter(
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
Requires PEFT version 0.13.0 or higher.
adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

# peft only supports low_cpu_mem_usage starting from v0.13.0
peft_load_kwargs = {}
if low_cpu_mem_usage:
min_version_lcmu = "0.13.0"
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
else:
raise ValueError(
"The version of PEFT you are using does not support `low_cpu_mem_usage` yet, "
f"please install PEFT >= {min_version_lcmu}."
)

adapter_name = adapter_name if adapter_name is not None else "default"
if adapter_kwargs is None:
adapter_kwargs = {}
Expand Down Expand Up @@ -192,7 +211,7 @@ def load_adapter(
)

# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name)
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)

if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
Expand All @@ -211,7 +230,9 @@ def load_adapter(
processed_adapter_state_dict[new_key] = value

# Load state dict
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)
incompatible_keys = set_peft_model_state_dict(
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
)

if incompatible_keys is not None:
# check only for unexpected keys
Expand Down
44 changes: 44 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import tempfile
import unittest

from huggingface_hub import hf_hub_download
from packaging import version

from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers.testing_utils import (
Expand Down Expand Up @@ -478,6 +480,48 @@ def test_peft_add_adapter_with_state_dict(self):
# dummy generation
_ = model.generate(input_ids=dummy_input)

def test_peft_add_adapter_with_state_dict_low_cpu_mem_usage(self):
"""
Check the usage of low_cpu_mem_usage, which is supported in PEFT >= 0.13.0
"""
from peft import LoraConfig

min_version_lcmu = "0.13.0"
is_lcmu_supported = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu)

for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)

# this should always work
model.load_adapter(
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
)

if is_lcmu_supported:
# if supported, this should not raise an error
model.load_adapter(
adapter_state_dict=dummy_state_dict,
adapter_name="other",
peft_config=peft_config,
low_cpu_mem_usage=True,
)
# after loading, no meta device should be remaining
self.assertFalse(any((p.device.type == "meta") for p in model.parameters()))
else:
err_msg = r"The version of PEFT you are using does not support `low_cpu_mem_usage` yet"
with self.assertRaisesRegex(ValueError, err_msg):
model.load_adapter(
adapter_state_dict=dummy_state_dict,
adapter_name="other",
peft_config=peft_config,
low_cpu_mem_usage=True,
)

def test_peft_from_pretrained_hub_kwargs(self):
"""
Tests different combinations of PEFT model + from_pretrained + hub kwargs
Expand Down

0 comments on commit 6500f78

Please sign in to comment.