Skip to content

Commit

Permalink
ENH: Allow empty initialization of adapter weight (#1961)
Browse files Browse the repository at this point in the history
This PR allows to initialize the adpater weights as empty, i.e. on meta
device, by passing low_cpu_mem_usage=True.

Why would this be useful? For PEFT training, it is indeed not useful, as
we need the real weights in order to train the model. However, when
loading a trained PEFT adapter, it is unnecessary to initialize the
adapters for real, as we override them with the loaded weights later.

In the grand scheme of things, loading the base model will typically be
much slower, but if the user loads, say, dozens of adapters, the
overhead could add up. Of course, besides loading the model, this has no
performance impact and is thus not a high priority feature.

For the time being, this is completely opt in. However, it should be safe to
make this default for loading adapters. Therefore, in the future we may change
the default there.

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
BenjaminBossan and sayakpaul authored Sep 23, 2024
1 parent 9bc670e commit af275d2
Show file tree
Hide file tree
Showing 24 changed files with 547 additions and 46 deletions.
29 changes: 29 additions & 0 deletions docs/source/developer_guides/low_level_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Check the table below to see when you should inject adapters.
| the model is modified inplace, keeping all the original attributes and methods | manually write the `from_pretrained` and `save_pretrained` utility functions from Hugging Face to save and load adapters |
| works for any `torch` module and modality | doesn't work with any of the utility methods provided by `PeftModel` such as disabling and merging adapters |

## Creating a new PEFT model

To perform the adapter injection, use the [`inject_adapter_in_model`] method. This method takes 3 arguments, the PEFT config, the model, and an optional adapter name. You can also attach multiple adapters to the model if you call [`inject_adapter_in_model`] multiple times with different adapter names.

For example, to inject LoRA adapters into the `linear` submodule of the `DummyModel` module:
Expand Down Expand Up @@ -85,6 +87,8 @@ DummyModel(
)
```

## Saving the model

To only save the adapter, use the [`get_peft_model_state_dict`] function:

```python
Expand All @@ -95,3 +99,28 @@ print(peft_state_dict)
```

Otherwise, `model.state_dict()` returns the full state dict of the model.

## Loading the model

After loading the saved `state_dict`, it can be applied using the [`set_peft_model_state_dict`] function:

```python
from peft import set_peft_model_state_dict

model = DummyModel()
model = inject_adapter_in_model(lora_config, model)
outcome = set_peft_model_state_dict(model, peft_state_dict)
# check that there were no wrong keys
print(outcome.unexpected_keys)
```

If injecting the adapter is slow or you need to load a large number of adapters, you may use an optimization that allows to create an "empty" adapter on meta device and only fills the weights with real weights when the [`set_peft_model_state_dict`] is called. To do this, pass `low_cpu_mem_usage=True` to both [`inject_adapter_in_model`] and [`set_peft_model_state_dict`].

```python
model = DummyModel()
model = inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True)

print(model.linear.lora_A["default"].weight.device.type == "meta") # should be True
set_peft_model_state_dict(model, peft_state_dict, low_cpu_mem_usage=True)
print(model.linear.lora_A["default"].weight.device.type == "cpu") # should be True
```
13 changes: 13 additions & 0 deletions docs/source/developer_guides/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,19 @@ TunerModelStatus(
)
```

## Speed

### Loading adapter weights is slow

Loading adapters like LoRA weights should generally be fast compared to loading the base model. However, there can be use cases where the adapter weights are quite large or where users need to load a large number of adapters -- the loading time can add up in this case. The reason for this is that the adapter weights are first initialized and then overridden by the loaded weights, which is wasteful. To speed up the loading time, you can pass the `low_cpu_mem_usage=True` argument to [`~PeftModel.from_pretrained`] and [`~PeftModel.load_adapter`].

<Tip>

If this option works well across different use casese, it may become the default for adapter loading in the future.

</Tip>


## Reproducibility

### Models using batch norm
Expand Down
6 changes: 4 additions & 2 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def get_peft_model(


def inject_adapter_in_model(
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default"
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False
) -> torch.nn.Module:
r"""
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
Expand All @@ -210,6 +210,8 @@ def inject_adapter_in_model(
The input model where the adapter will be injected.
adapter_name (`str`, `optional`, defaults to `"default"`):
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.
"""
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")
Expand All @@ -222,6 +224,6 @@ def inject_adapter_in_model(
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]

# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name)
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

return peft_model.model
67 changes: 65 additions & 2 deletions src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module):
The config of the model to be tuned. The adapter type must be compatible.
adapter_name (`str`, `optional`, defaults to `"default"`):
The name of the first adapter.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.
"""

def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
Expand Down Expand Up @@ -219,12 +221,38 @@ def disable_adapter(self):
finally:
self.base_model.enable_adapter_layers()

def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""
Add an adapter to the model based on the passed configuration.
This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].
The name for the new adapter should be unique.
The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
adapter.
Args:
adapter_name (`str`):
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the process when loading saved
adapters.
<Tip>
Don't use `low_cpu_mem_usage=True` when creating a new PEFT adapter for training (training is untested
and discouraged for PeftMixedModel in general).
</Tip>
"""
_check_config_compatible(peft_config)

try:
self.peft_config[adapter_name] = peft_config
self.base_model.inject_adapter(self, adapter_name)
self.base_model.inject_adapter(self, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
except Exception: # something went wrong, roll back
if adapter_name in self.peft_config:
del self.peft_config[adapter_name]
Expand Down Expand Up @@ -323,6 +351,37 @@ def _split_kwargs(cls, kwargs: dict[str, Any]):
return PeftModel._split_kwargs(kwargs)

def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any):
"""
Load a trained adapter into the model.
The name for the new adapter should be unique.
The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
adapter.
Args:
adapter_name (`str`):
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
is_trainable (`bool`, *optional*, defaults to `False`):
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
used for inference.
torch_device (`str`, *optional*, defaults to None):
The device to load the adapter on. If `None`, the device will be inferred.
autocast_adapter_dtype (`bool`, *optional*, defaults to `True`):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter
weights using float16 and bfloat16 to float32, as this is typically required for stable training, and
only affect select PEFT tuners.
ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`):
Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the
process.
kwargs: (`optional`):
Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub.
"""
# the low_cpu_mem_usage option is handled through kwargs
output = PeftModel.load_adapter(self, model_id, adapter_name, *args, **kwargs)
# TODO: not quite clear why this is necessary but tests fail without it
self.set_adapter(self.active_adapters)
Expand Down Expand Up @@ -373,6 +432,9 @@ def from_pretrained(
The configuration object to use instead of an automatically loaded configuration. This configuration
object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already
loaded before calling `from_pretrained`.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the
process.
kwargs: (`optional`):
Additional keyword arguments passed along to the specific PEFT configuration class.
"""
Expand Down Expand Up @@ -412,5 +474,6 @@ def from_pretrained(

# note: this is different from PeftModel.from_pretrained, we always return a PeftMixedModel
model = cls(model, config, adapter_name)
# the low_cpu_mem_usage option is handled through kwargs
model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
return model
66 changes: 55 additions & 11 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import inspect
import os
import warnings
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union

import packaging.version
import torch
import transformers
from accelerate import dispatch_model, infer_auto_device_map
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory, named_module_tensors
from huggingface_hub import HfFileSystem, ModelCard, ModelCardData, hf_hub_download
Expand Down Expand Up @@ -118,6 +118,14 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading loading process.
<Tip>
Don't use `low_cpu_mem_usage=True` when creating a new PEFT adapter for training.
</Tip>
**Attributes**:
- **base_model** ([`torch.nn.Module`]) -- The base transformer model used for Peft.
Expand All @@ -140,6 +148,7 @@ def __init__(
peft_config: PeftConfig,
adapter_name: str = "default",
autocast_adapter_dtype: bool = True,
low_cpu_mem_usage: bool = False,
) -> None:
super().__init__()
self.modules_to_save = None
Expand All @@ -153,11 +162,13 @@ def __init__(
if self._is_prompt_learning:
self._peft_config = {adapter_name: peft_config}
self.base_model = model
self.add_adapter(adapter_name, peft_config)
self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage)
else:
self._peft_config = None
cls = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type]
self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
self.set_additional_trainable_modules(peft_config, adapter_name)

if hasattr(self.base_model, "_cast_adapter_dtype"):
Expand Down Expand Up @@ -422,6 +433,7 @@ def from_pretrained(
config: Optional[PeftConfig] = None,
autocast_adapter_dtype: bool = True,
ephemeral_gpu_offload: bool = False,
low_cpu_mem_usage: bool = False,
**kwargs: Any,
) -> PeftModel:
r"""
Expand Down Expand Up @@ -456,6 +468,9 @@ def from_pretrained(
are needed. Rather than perform expensive operations on small data, the data is transferred to the GPU
on-demand, the operation(s) performed, and the results moved back to CPU memory. This brings a slight
momentary VRAM overhead but gives orders of magnitude speedup in certain cases.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the
process.
torch_device (`str`, *optional*, defaults to None):
The device to load the adapter on. If `None`, the device will be inferred.
kwargs: (`optional`):
Expand Down Expand Up @@ -552,14 +567,29 @@ def from_pretrained(
raise ValueError("If model_id is a local path, then `adapters` must be passed in kwargs.")

if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
model = cls(model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype)
model = cls(
model,
config,
adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](
model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
model,
config,
adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)

model.load_adapter(
model_id, adapter_name, is_trainable=is_trainable, autocast_adapter_dtype=autocast_adapter_dtype, **kwargs
model_id,
adapter_name,
is_trainable=is_trainable,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
**kwargs,
)

return model
Expand Down Expand Up @@ -852,7 +882,7 @@ def get_base_model(self) -> torch.nn.Module:
else self.base_model.model
)

def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""
Add an adapter to the model based on the passed configuration.
Expand All @@ -868,6 +898,10 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
The name of the adapter to be added.
peft_config ([`PeftConfig`]):
The configuration of the adapter to be added.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the process when loading saved
adapters. Don't use this option when creating a new PEFT adapter for training.
"""
if peft_config.peft_type != self.peft_type:
raise ValueError(
Expand All @@ -889,7 +923,9 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
self.base_model.add_adapter(adapter_name, peft_config)
else:
self.peft_config[adapter_name] = peft_config
self.base_model.inject_adapter(self.base_model.model, adapter_name)
self.base_model.inject_adapter(
self.base_model.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage
)
except Exception: # something went wrong, roll back
if adapter_name in self.peft_config:
del self.peft_config[adapter_name]
Expand Down Expand Up @@ -1077,6 +1113,7 @@ def load_adapter(
torch_device: Optional[str] = None,
autocast_adapter_dtype: bool = True,
ephemeral_gpu_offload: bool = False,
low_cpu_mem_usage: bool = False,
**kwargs: Any,
):
"""
Expand All @@ -1103,6 +1140,9 @@ def load_adapter(
only affect select PEFT tuners.
ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`):
Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the
process.
kwargs: (`optional`):
Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub.
"""
Expand All @@ -1128,14 +1168,18 @@ def load_adapter(
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
else:
peft_config.inference_mode = not is_trainable
self.add_adapter(adapter_name, peft_config)
self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage)

adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs)

# load the weights into the model
ignore_mismatched_sizes = kwargs.get("ignore_mismatched_sizes", False)
load_result = set_peft_model_state_dict(
self, adapters_weights, adapter_name=adapter_name, ignore_mismatched_sizes=ignore_mismatched_sizes
self,
adapters_weights,
adapter_name=adapter_name,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
(getattr(self, "hf_device_map", None) is not None)
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class AdaLoraModel(LoraModel):
model ([`transformers.PreTrainedModel`]): The model to be adapted.
config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.
Returns:
`torch.nn.Module`: The AdaLora model.
Expand Down
Loading

0 comments on commit af275d2

Please sign in to comment.