Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Allow disabling input dtype casting for LoRA #2353

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/source/package_reference/helpers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@ A collection of helper functions for PEFT.
## Temporarily Rescaling Adapter Scale in LoraLayer Modules

[[autodoc]] helpers.rescale_adapter_scale
- all
- all

## Context manager to disable input dtype casting in the `forward` method of LoRA layers

[[autodoc]] disable_lora_input_dtype_casting
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be specific to LoRA layers or could we check against BaseTunerLayer, too, to have a broader coverage? But okay to ship this only for LoRA and then propagate to other ones as needed.

- all
26 changes: 26 additions & 0 deletions src/peft/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from functools import update_wrapper
from types import MethodType

from torch import nn

from .peft_model import PeftConfig, PeftModel
from .tuners.lora.layer import LoraLayer

Expand Down Expand Up @@ -209,3 +211,27 @@ def rescale_adapter_scale(model, multiplier):
# restore original scaling values after exiting the context
for module, scaling in original_scaling.items():
module.scaling = scaling


@contextmanager
def disable_lora_input_dtype_casting(model: nn.Module, disable: bool = True):
"""TODO"""
if not disable:
yield
return

original_values = {}
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
continue
original_values[name] = module.cast_input_dtype_enabled
module.cast_input_dtype_enabled = False

try:
yield
finally:
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
continue
if name in original_values:
module.cast_input_dtype_enabled = original_values[name]
4 changes: 1 addition & 3 deletions src/peft/tuners/adalora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.dtype)

output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
if requires_conversion:
Expand Down
3 changes: 1 addition & 2 deletions src/peft/tuners/adalora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
x = self._cast_input_dtype(x, torch.float32)

output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
# TODO: here, the dtype conversion is applied on the *whole expression*,
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5

x = x.to(lora_A.dtype)
x = self._cast_input_dtype(x, lora_A.dtype)
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum

return result
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
12 changes: 4 additions & 8 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -243,9 +241,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
Expand Down Expand Up @@ -470,7 +466,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -514,7 +510,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
8 changes: 2 additions & 6 deletions src/peft/tuners/lora/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -218,9 +216,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
19 changes: 17 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
self._caches: dict[str, Any] = {}
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled: bool = True
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -492,6 +494,19 @@ def _mixed_batch_forward(

return result

def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""
Whether to cast the dtype of the input to the forward method.

Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting
layer.cast_input_dtype=False, this can be disabled if necessary.

Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager.
"""
if (not self.cast_input_dtype_enabled) or (x.dtype == dtype):
return x
return x.to(dtype=dtype)


# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
Expand Down Expand Up @@ -703,7 +718,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down Expand Up @@ -1268,7 +1283,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/tp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
Loading