-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Allow disabling input dtype casting for LoRA #2353
base: main
Are you sure you want to change the base?
ENH Allow disabling input dtype casting for LoRA #2353
Conversation
Thanks for helping us with this @BenjaminBossan and making peft techniques possible to run with FP8 as storage dtype! Here's the reproducer for the error (if we don't disallow peft layers by passing codeimport gc
from typing import Any
import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_layerwise_casting
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from peft.tuners.lora.layer import Linear as LoRALinear
set_verbosity_debug()
def main():
# Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
pipe.set_adapters(["cogvideox-lora"], [1.0])
apply_layerwise_casting(
pipe.transformer,
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
skip_modules_pattern=["patch_embed", "norm", "^proj_out$"]
)
for name, parameter in pipe.transformer.named_parameters():
if "lora" in name:
assert(parameter.dtype == torch.float8_e4m3fn)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
prompt = "walgro1. The scene begins with a close-up of Gromit's face, his expressive eyes filling the frame. His brow furrows slightly, ears perked forward in concentration. The soft lighting highlights the subtle details of his fur, every strand catching the warm sunlight filtering in from a nearby window. His dark, round nose twitches ever so slightly, sensing something in the air, and his gaze darts to the side, following an unseen movement. The camera lingers on Gromit’s face, capturing the subtleties of his expression—a quirked eyebrow and a knowing look that suggests he’s piecing together something clever. His silent, thoughtful demeanor speaks volumes as he watches the scene unfold with quiet intensity. The background remains out of focus, drawing all attention to the sharp intelligence in his eyes and the slight tilt of his head. In the claymation style of Wallace and Gromit."
with torch.no_grad():
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt="",
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
device="cuda",
dtype=torch.bfloat16,
)
video = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator().manual_seed(42)
).frames[0]
export_to_video(video, "output.mp4", fps=8)
print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
main() stack traceModel memory: 15.351 GB
The passed generator was created on 'cpu' even though a tensor on cuda:0 was expected. Tensors will be created on 'cpu' and then moved to cuda:0. Note that one can probably slighly speed up this function by passing a generator that was created on the cuda:0 device.
0%| | 0/50 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/aryan/work/diffusers/workflows/layerwise_upcasting/test_layerwise_upcasting_lora_3.py", line 79, in <module>
main()
File "/home/aryan/work/diffusers/workflows/layerwise_upcasting/test_layerwise_upcasting_lora_3.py", line 68, in main
video = pipe(
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 721, in __call__
noise_pred = self.transformer(
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py", line 498, in forward
hidden_states, encoder_hidden_states = block(
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py", line 135, in forward
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/models/attention_processor.py", line 594, in forward
return self.processor(
File "/home/aryan/work/diffusers/src/diffusers/models/attention_processor.py", line 2835, in __call__
query = attn.to_q(hidden_states)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 622, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/hooks/hooks.py", line 148, in new_forward
output = function_reference.forward(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != c10::BFloat16 One way to get it to work is the workaround I mentioned in our group chat - patching the call to The reproducer passes with this branch after the following changes in Diffusers: huggingface/diffusers#10685.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ben!
While I don't have an exact test case for you but I came up with the following that I think tests the correctness:
test
from peft import LoraConfig, get_peft_model
from peft.utils import infer_device
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.helpers import disable_lora_input_dtype_casting
import torch
import torch.nn as nn
TORCH_DEVICE = infer_device()
DTYPE = torch.float32
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.lin1 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.lin0(X)
X = self.lin1(X)
X = self.sm(X)
return X
def print_dtype_hook(module, input, output):
assert output[0].dtype == DTYPE
print(f"Module: {module.__class__.__name__}, Output dtype: {output[0].dtype}")
@torch.no_grad()
def cast_params_to_fp32(module, input):
for param in module.parameters(recurse=False):
param.data = param.data.float()
return input
@torch.no_grad()
def cast_params_back_to_fp16(module, input, output):
for param in module.parameters(recurse=False):
param.data = param.data.half()
return output
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = MLP()
# Register hooks on the submodule that holds parameters
for module in model.modules():
if sum(p.numel() for p in module.parameters()) > 0:
module.register_forward_pre_hook(cast_params_to_fp32)
module.register_forward_hook(cast_params_back_to_fp16)
model = get_peft_model(model, config0).to(TORCH_DEVICE)
inputs = torch.randn(4, 10, device=TORCH_DEVICE, dtype=DTYPE)
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.register_forward_hook(print_dtype_hook)
with torch.no_grad():
with disable_lora_input_dtype_casting(model, disable=True):
out = model(inputs)
print(f"{out.shape=}")
Does this work?
|
||
[[autodoc]] disable_lora_input_dtype_casting | ||
- all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be specific to LoRA layers or could we check against BaseTunerLayer
, too, to have a broader coverage? But okay to ship this only for LoRA and then propagate to other ones as needed.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks for providing the example code. With this, I could verify that the context manager works as expected. @a-r-r-o-w: As expected, your script will raise a @sayakpaul I used your example to craft a unit test. I had to make a change, namely to cast the whole model to float16 before calling
The thing is that each layer that wants to support this potentially needs to code changes, so I wanted to keep a small scope to get this PR ready quickly. I did check LoKr and LoHa and there is no casting there. Thus I think we have the typical diffusers cases covered. However, I did rename the function to remove the |
@pytest.fixture | ||
def model(self, base_model): | ||
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) | ||
model = get_peft_model(base_model, config).to(device=self.device, dtype=torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this casting is required, though.
We're loading the model in FP32 and keeping its params in FP16. While performing the computations, we're upcasting params to the requested dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also in line with how we do the tests:
https://github.com/huggingface/diffusers/blob/aad69ac2f323734a083d66fa89197bf7d88e5a57/tests/models/test_modeling_common.py#L1365
The compute_dtype
is never lower than the storage dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried what happens when I remove dtype=torch.float16
. The model will be in float32, as is the input. Therefore, when PEFT casts the input dtype, it stays in float32. The pre-forward hook basically does nothing, as the model is already in float32. Therefore, there is no RuntimeError
.
If my understanding is correct, the model has to start out in a dtype that is different from the one that is cast during the pre-forward hook to reflect the situation we have for layerwise upcasting. Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay that makes sense.
CUDA and CPU error messages use different wordings
@a-r-r-o-w @sayakpaul Is this PR is good to proceed. If you prefer another API to make integration for diffusers easier, now is the time to let me know. |
Thanks for the work, Ben! Will take a look and let you know. But this API
should be good
Sayak Paul | sayak.dev
…On Fri, 31 Jan 2025 at 8:09 PM, Benjamin Bossan ***@***.***> wrote:
@a-r-r-o-w <https://github.com/a-r-r-o-w> @sayakpaul
<https://github.com/sayakpaul> Is this PR is good to proceed. If you
prefer another API to make integration for diffusers easier, now is the
time to let me know.
—
Reply to this email directly, view it on GitHub
<#2353 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFPE2TEUTSN3BXPU6JMMUSD2NODKBAVCNFSM6AAAAABWCUT5VWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDMMRXGUYDSNJXGU>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
@BenjaminBossan PR is good to proceed IMO. Let's ship. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @BenjaminBossan, looks good to me for our usecase!
WIP