Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Allow disabling input dtype casting for LoRA #2353

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/source/package_reference/helpers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@ A collection of helper functions for PEFT.
## Temporarily Rescaling Adapter Scale in LoraLayer Modules

[[autodoc]] helpers.rescale_adapter_scale
- all
- all

## Context manager to disable input dtype casting in the `forward` method of LoRA layers

[[autodoc]] helpers.disable_input_dtype_casting
- all
43 changes: 42 additions & 1 deletion src/peft/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from functools import update_wrapper
from types import MethodType

from torch import nn

from .peft_model import PeftConfig, PeftModel
from .tuners.lora.layer import LoraLayer
from .tuners.lora import LoraLayer


def update_forward_signature(model: PeftModel) -> None:
Expand Down Expand Up @@ -209,3 +211,42 @@ def rescale_adapter_scale(model, multiplier):
# restore original scaling values after exiting the context
for module, scaling in original_scaling.items():
module.scaling = scaling


@contextmanager
def disable_input_dtype_casting(model: nn.Module, active: bool = True):
"""
Context manager disables input dtype casting to the dtype of the weight.

Currently specifically works for LoRA.

Parameters:
model (nn.Module):
The model containing PEFT modules whose input dtype casting is to be adjusted.
active (bool):
Whether the context manager is active (default) or inactive.

"""
# Additional info: Normally, the dtype of the weight and input need to match, which is why the dtype is cast.
# However, in certain circumustances, this is handled by forward hooks, e.g. when using layerwise casting in
# diffusers. In that case, PEFT casting the dtype interferes with the layerwise casting, which is why the option to
# disable it is given.
if not active:
yield
return

original_values = {}
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
continue
original_values[name] = module.cast_input_dtype_enabled
module.cast_input_dtype_enabled = False

try:
yield
finally:
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
continue
if name in original_values:
module.cast_input_dtype_enabled = original_values[name]
4 changes: 1 addition & 3 deletions src/peft/tuners/adalora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.dtype)

output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
if requires_conversion:
Expand Down
3 changes: 1 addition & 2 deletions src/peft/tuners/adalora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
x = self._cast_input_dtype(x, torch.float32)

output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
# TODO: here, the dtype conversion is applied on the *whole expression*,
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5

x = x.to(lora_A.dtype)
x = self._cast_input_dtype(x, lora_A.dtype)
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum

return result
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
12 changes: 4 additions & 8 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -243,9 +241,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
Expand Down Expand Up @@ -470,7 +466,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -514,7 +510,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
8 changes: 2 additions & 6 deletions src/peft/tuners/lora/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -218,9 +216,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
19 changes: 17 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
self._caches: dict[str, Any] = {}
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled: bool = True
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -492,6 +494,19 @@ def _mixed_batch_forward(

return result

def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""
Whether to cast the dtype of the input to the forward method.

Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting
layer.cast_input_dtype=False, this can be disabled if necessary.

Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager.
"""
if (not self.cast_input_dtype_enabled) or (x.dtype == dtype):
return x
return x.to(dtype=dtype)


# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
Expand Down Expand Up @@ -703,7 +718,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down Expand Up @@ -1268,7 +1283,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/tp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
103 changes: 102 additions & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import pytest
import torch
from diffusers import StableDiffusionPipeline
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import LoraConfig, get_peft_model
from peft.helpers import check_if_peft_model, rescale_adapter_scale
from peft.helpers import check_if_peft_model, disable_input_dtype_casting, rescale_adapter_scale
from peft.tuners.lora.layer import LoraLayer
from peft.utils import infer_device


class TestCheckIsPeftModel:
Expand Down Expand Up @@ -369,3 +371,102 @@ def test_merging_adapter(self, tokenizer):
logits_merged_scaling = model(**inputs).logits

assert torch.allclose(logits_merged_scaling, logits_unmerged_scaling, atol=1e-4, rtol=1e-4)


class TestDisableInputDtypeCasting:
"""Test the context manager `disable_input_dtype_casting` that temporarily disables input dtype casting
in the model.

The test works as follows:

We create a simple MLP and convert it to a PeftModel. The model dtype is set to float16. Then a pre-foward hook is
added that casts the model parameters to float32. Moreover, a post-forward hook is added that casts the weights
back to float16. The input dtype is float32.

Without the disable_input_dtype_casting context, what would happen is that PEFT detects that the input dtype is
float32 but the weight dtype is float16, so it casts the input to float16. Then the pre-forward hook casts the
weight to float32, which results in a RuntimeError.

With the disable_input_dtype_casting context, the input dtype is left as float32 and there is no error. We also add
a hook to record the dtype of the result from the LoraLayer to ensure that it is indeed float32.

"""

device = infer_device()
dtype_record = []

@torch.no_grad()
def cast_params_to_fp32_pre_hook(self, module, input):
for param in module.parameters(recurse=False):
param.data = param.data.float()
return input

@torch.no_grad()
def cast_params_to_fp16_hook(self, module, input, output):
for param in module.parameters(recurse=False):
param.data = param.data.half()
return output

def record_dtype_hook(self, module, input, output):
self.dtype_record.append(output[0].dtype)

@pytest.fixture
def inputs(self):
return torch.randn(4, 10, device=self.device, dtype=torch.float32)

@pytest.fixture
def base_model(self):
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.lin1 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = self.lin0(X)
X = self.lin1(X)
X = self.sm(X)
return X

return MLP()

@pytest.fixture
def model(self, base_model):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(base_model, config).to(device=self.device, dtype=torch.float16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this casting is required, though.

We're loading the model in FP32 and keeping its params in FP16. While performing the computations, we're upcasting params to the requested dtype.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in line with how we do the tests:
https://github.com/huggingface/diffusers/blob/aad69ac2f323734a083d66fa89197bf7d88e5a57/tests/models/test_modeling_common.py#L1365

The compute_dtype is never lower than the storage dtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried what happens when I remove dtype=torch.float16. The model will be in float32, as is the input. Therefore, when PEFT casts the input dtype, it stays in float32. The pre-forward hook basically does nothing, as the model is already in float32. Therefore, there is no RuntimeError.

If my understanding is correct, the model has to start out in a dtype that is different from the one that is cast during the pre-forward hook to reflect the situation we have for layerwise upcasting. Is that correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay that makes sense.

# Register hooks on the submodule that holds parameters
for module in model.modules():
if sum(p.numel() for p in module.parameters()) > 0:
module.register_forward_pre_hook(self.cast_params_to_fp32_pre_hook)
module.register_forward_hook(self.cast_params_to_fp16_hook)
if isinstance(module, LoraLayer):
module.register_forward_hook(self.record_dtype_hook)
return model

def test_disable_input_dtype_casting_active(self, model, inputs):
self.dtype_record.clear()
with disable_input_dtype_casting(model, active=True):
model(inputs)
assert self.dtype_record == [torch.float32]

def test_no_disable_input_dtype_casting(self, model, inputs):
msg = "expected mat1 and mat2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
model(inputs)

def test_disable_input_dtype_casting_inactive(self, model, inputs):
msg = "expected mat1 and mat2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
with disable_input_dtype_casting(model, active=False):
model(inputs)

def test_disable_input_dtype_casting_inactive_after_existing_context(self, model, inputs):
# this is to ensure that when the context is left, we return to the previous behavior
with disable_input_dtype_casting(model, active=True):
model(inputs)

# after the context exited, we're back to the error
msg = "expected mat1 and mat2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
model(inputs)
Loading