diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index aa0b181540c7ed..bd0ca16f865f4c 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect import warnings from typing import Any, Dict, List, Optional, Union +from packaging import version + from ..utils import ( check_peft_version, find_adapter_config_file, @@ -77,6 +80,7 @@ def load_adapter( offload_index: Optional[int] = None, peft_config: Dict[str, Any] = None, adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None, + low_cpu_mem_usage: bool = False, adapter_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ @@ -129,12 +133,27 @@ def load_adapter( adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*): The state dict of the adapter to load. This argument is used in case users directly pass PEFT state dicts + low_cpu_mem_usage (`bool`, *optional*, defaults to `False`): + Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process. + Requires PEFT version 0.13.0 or higher. adapter_kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and `find_adapter_config_file` method. """ check_peft_version(min_version=MIN_PEFT_VERSION) + # peft only supports low_cpu_mem_usage starting from v0.13.0 + peft_load_kwargs = {} + if low_cpu_mem_usage: + min_version_lcmu = "0.13.0" + if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu): + peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + else: + raise ValueError( + "The version of PEFT you are using does not support `low_cpu_mem_usage` yet, " + f"please install PEFT >= {min_version_lcmu}." + ) + adapter_name = adapter_name if adapter_name is not None else "default" if adapter_kwargs is None: adapter_kwargs = {} @@ -192,7 +211,7 @@ def load_adapter( ) # Create and add fresh new adapters into the model. - inject_adapter_in_model(peft_config, self, adapter_name) + inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) if not self._hf_peft_config_loaded: self._hf_peft_config_loaded = True @@ -211,7 +230,9 @@ def load_adapter( processed_adapter_state_dict[new_key] = value # Load state dict - incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name) + incompatible_keys = set_peft_model_state_dict( + self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs + ) if incompatible_keys is not None: # check only for unexpected keys diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 602ed04d9c6271..e7b336623c1a43 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -12,11 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os import tempfile import unittest from huggingface_hub import hf_hub_download +from packaging import version from transformers import AutoModelForCausalLM, OPTForCausalLM from transformers.testing_utils import ( @@ -478,6 +480,48 @@ def test_peft_add_adapter_with_state_dict(self): # dummy generation _ = model.generate(input_ids=dummy_input) + def test_peft_add_adapter_with_state_dict_low_cpu_mem_usage(self): + """ + Check the usage of low_cpu_mem_usage, which is supported in PEFT >= 0.13.0 + """ + from peft import LoraConfig + + min_version_lcmu = "0.13.0" + is_lcmu_supported = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu) + + for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids): + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + + peft_config = LoraConfig() + state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") + dummy_state_dict = torch.load(state_dict_path) + + # this should always work + model.load_adapter( + adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False + ) + + if is_lcmu_supported: + # if supported, this should not raise an error + model.load_adapter( + adapter_state_dict=dummy_state_dict, + adapter_name="other", + peft_config=peft_config, + low_cpu_mem_usage=True, + ) + # after lodaing, no meta device should be remaining + self.assertFalse(any((p.device.type == "meta") for p in model.parameters())) + else: + err_msg = r"The version of PEFT you are using does not support `low_cpu_mem_usage` yet" + with self.assertRaisesRegex(ValueError, err_msg): + model.load_adapter( + adapter_state_dict=dummy_state_dict, + adapter_name="other", + peft_config=peft_config, + low_cpu_mem_usage=True, + ) + def test_peft_from_pretrained_hub_kwargs(self): """ Tests different combinations of PEFT model + from_pretrained + hub kwargs