Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Allow disabling input dtype casting for LoRA #2353

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

BenjaminBossan
Copy link
Member

WIP

@a-r-r-o-w
Copy link
Member

Thanks for helping us with this @BenjaminBossan and making peft techniques possible to run with FP8 as storage dtype!

Here's the reproducer for the error (if we don't disallow peft layers by passing skip_modules_pattern in Diffusers layerwise casting huggingface/diffusers#10347):

code
import gc
from typing import Any

import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_layerwise_casting
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from peft.tuners.lora.layer import Linear as LoRALinear

set_verbosity_debug()


def main():
    # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
    pipe.set_adapters(["cogvideox-lora"], [1.0])
    
    apply_layerwise_casting(
        pipe.transformer,
        storage_dtype=torch.float8_e4m3fn,
        compute_dtype=torch.bfloat16,
        skip_modules_pattern=["patch_embed", "norm", "^proj_out$"]
    )
    
    for name, parameter in pipe.transformer.named_parameters():
        if "lora" in name:
            assert(parameter.dtype == torch.float8_e4m3fn)

    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")

    prompt = "walgro1. The scene begins with a close-up of Gromit's face, his expressive eyes filling the frame. His brow furrows slightly, ears perked forward in concentration. The soft lighting highlights the subtle details of his fur, every strand catching the warm sunlight filtering in from a nearby window. His dark, round nose twitches ever so slightly, sensing something in the air, and his gaze darts to the side, following an unseen movement. The camera lingers on Gromit’s face, capturing the subtleties of his expression—a quirked eyebrow and a knowing look that suggests he’s piecing together something clever. His silent, thoughtful demeanor speaks volumes as he watches the scene unfold with quiet intensity. The background remains out of focus, drawing all attention to the sharp intelligence in his eyes and the slight tilt of his head. In the claymation style of Wallace and Gromit."

    with torch.no_grad():
        prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
            prompt=prompt,
            negative_prompt="",
            do_classifier_free_guidance=True,
            num_videos_per_prompt=1,
            device="cuda",
            dtype=torch.bfloat16,
        )

    video = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        guidance_scale=6,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(42)
    ).frames[0]
    export_to_video(video, "output.mp4", fps=8)

    print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")

main()
stack trace
Model memory: 15.351 GB
The passed generator was created on 'cpu' even though a tensor on cuda:0 was expected. Tensors will be created on 'cpu' and then moved to cuda:0. Note that one can probably slighly speed up this function by passing a generator that was created on the cuda:0 device.
  0%|                                                                                                                                                                                                                                                                                | 0/50 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/workflows/layerwise_upcasting/test_layerwise_upcasting_lora_3.py", line 79, in <module>
    main()
  File "/home/aryan/work/diffusers/workflows/layerwise_upcasting/test_layerwise_upcasting_lora_3.py", line 68, in main
    video = pipe(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 721, in __call__
    noise_pred = self.transformer(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py", line 498, in forward
    hidden_states, encoder_hidden_states = block(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py", line 135, in forward
    attn_hidden_states, attn_encoder_hidden_states = self.attn1(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/attention_processor.py", line 594, in forward
    return self.processor(
  File "/home/aryan/work/diffusers/src/diffusers/models/attention_processor.py", line 2835, in __call__
    query = attn.to_q(hidden_states)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 622, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/hooks/hooks.py", line 148, in new_forward
    output = function_reference.forward(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != c10::BFloat16

One way to get it to work is the workaround I mentioned in our group chat - patching the call to torch::Tensor::to with a context manager when peft layers are called into. This is what I thought you'd implement when we settled on the context manager solution, but what we have right now looks much much better and cleaner.

The reproducer passes with this branch after the following changes in Diffusers: huggingface/diffusers#10685.

output.mp4

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ben!

While I don't have an exact test case for you but I came up with the following that I think tests the correctness:

test
from peft import LoraConfig, get_peft_model
from peft.utils import infer_device
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.helpers import disable_lora_input_dtype_casting
import torch 
import torch.nn as nn

TORCH_DEVICE = infer_device()
DTYPE = torch.float32

class MLP(nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.lin0 = nn.Linear(10, 20, bias=bias)
        self.lin1 = nn.Linear(20, 2, bias=bias)
        self.sm = nn.LogSoftmax(dim=-1)

    def forward(self, X):
        X = self.lin0(X)
        X = self.lin1(X)
        X = self.sm(X)
        return X

def print_dtype_hook(module, input, output):
    assert output[0].dtype == DTYPE
    print(f"Module: {module.__class__.__name__}, Output dtype: {output[0].dtype}")

@torch.no_grad()
def cast_params_to_fp32(module, input):
    for param in module.parameters(recurse=False):
        param.data = param.data.float()
    return input

@torch.no_grad()
def cast_params_back_to_fp16(module, input, output):
    for param in module.parameters(recurse=False):
        param.data = param.data.half()
    return output


config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = MLP()

# Register hooks on the submodule that holds parameters
for module in model.modules():
    if sum(p.numel() for p in module.parameters()) > 0:
        module.register_forward_pre_hook(cast_params_to_fp32)
        module.register_forward_hook(cast_params_back_to_fp16)

model = get_peft_model(model, config0).to(TORCH_DEVICE)
inputs = torch.randn(4, 10, device=TORCH_DEVICE, dtype=DTYPE)

for module in model.modules():
    if isinstance(module, BaseTunerLayer):
        module.register_forward_hook(print_dtype_hook)

with torch.no_grad():
    with disable_lora_input_dtype_casting(model, disable=True):
        out = model(inputs)
        print(f"{out.shape=}")

Does this work?

Comment on lines 19 to 22

[[autodoc]] disable_lora_input_dtype_casting
- all
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be specific to LoRA layers or could we check against BaseTunerLayer, too, to have a broader coverage? But okay to ship this only for LoRA and then propagate to other ones as needed.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member Author

Thanks for providing the example code. With this, I could verify that the context manager works as expected.

@a-r-r-o-w: As expected, your script will raise a RuntimeError as is. When I add with disable_input_dtype_casting(pipe.transformer): before the pipe call, the error goes away. As you correctly noted, I wanted to avoid patching torch::Tensor::to, as this could have unintended side-effects that could be really tough to debug.

@sayakpaul I used your example to craft a unit test. I had to make a change, namely to cast the whole model to float16 before calling forward. However, IIUC, this is analogous with what would happen in diffusers layerwise upcasting. Could you please check that the logic is sound?

Should this be specific to LoRA layers or could we check against BaseTunerLayer, too, to have a broader coverage? But okay to ship this only for LoRA and then propagate to other ones as needed.

The thing is that each layer that wants to support this potentially needs to code changes, so I wanted to keep a small scope to get this PR ready quickly. I did check LoKr and LoHa and there is no casting there. Thus I think we have the typical diffusers cases covered. However, I did rename the function to remove the lora_ part. This way, if we ever find the need to extend the scope, we don't need to rename the function.

@pytest.fixture
def model(self, base_model):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(base_model, config).to(device=self.device, dtype=torch.float16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this casting is required, though.

We're loading the model in FP32 and keeping its params in FP16. While performing the computations, we're upcasting params to the requested dtype.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in line with how we do the tests:
https://github.com/huggingface/diffusers/blob/aad69ac2f323734a083d66fa89197bf7d88e5a57/tests/models/test_modeling_common.py#L1365

The compute_dtype is never lower than the storage dtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried what happens when I remove dtype=torch.float16. The model will be in float32, as is the input. Therefore, when PEFT casts the input dtype, it stays in float32. The pre-forward hook basically does nothing, as the model is already in float32. Therefore, there is no RuntimeError.

If my understanding is correct, the model has to start out in a dtype that is different from the one that is cast during the pre-forward hook to reflect the situation we have for layerwise upcasting. Is that correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay that makes sense.

CUDA and CPU error messages use different wordings
@BenjaminBossan
Copy link
Member Author

@a-r-r-o-w @sayakpaul Is this PR is good to proceed. If you prefer another API to make integration for diffusers easier, now is the time to let me know.

@sayakpaul
Copy link
Member

sayakpaul commented Jan 31, 2025 via email

@sayakpaul
Copy link
Member

@BenjaminBossan PR is good to proceed IMO. Let's ship.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BenjaminBossan, looks good to me for our usecase!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants