Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hunyuan Video does not support batch size > 1 #10453

Closed
Nerogar opened this issue Jan 4, 2025 · 2 comments · Fixed by #10454
Closed

Hunyuan Video does not support batch size > 1 #10453

Nerogar opened this issue Jan 4, 2025 · 2 comments · Fixed by #10454
Assignees
Labels
bug Something isn't working

Comments

@Nerogar
Copy link
Contributor

Nerogar commented Jan 4, 2025

Describe the bug

The HunyuanVideoPipeline (and I believe the model itself) does not support execution with a batch size > 1. There are some shape mismatches in the attention calculation. Trying to set the batch size to 2 will result in an error like this:

Reproduction

This example is directly taken from the model card of https://huggingface.co/hunyuanvideo-community/HunyuanVideo. The only change is the added line num_videos_per_prompt=2,

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)

# Enable memory savings
pipe.vae.enable_tiling()
pipe.enable_model_cpu_offload()

output = pipe(
    prompt="A cat walks on the grass, realistic",
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
    num_videos_per_prompt=2, # <--- This is the only line I changed
).frames[0]
export_to_video(output, "output.mp4", fps=15)

Logs

File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\pipelines\hunyuan_video\pipeline_hunyuan_video.py", line 647, in __call__
    noise_pred = self.transformer(
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\accelerate\hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_hunyuan_video.py", line 763, in forward
    hidden_states, encoder_hidden_states = block(
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_hunyuan_video.py", line 478, in forward
    attn_output, context_attn_output = self.attn(
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\models\attention_processor.py", line 588, in forward
    return self.processor(
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_hunyuan_video.py", line 117, in __call__
    hidden_states = F.scaled_dot_product_attention(
RuntimeError: The expanded size of the tensor (24) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [2, 24, 10496, 10496].  Tensor sizes: [2, 10496, 10496]


### System Info

- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Windows-10-10.0.22631-SP0
- Running on Google Colab?: No
- Python version: 3.10.8
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.26.2
- Transformers version: 4.47.0
- Accelerate version: 1.0.1
- PEFT version: not installed
- Bitsandbytes version: 0.44.1
- Safetensors version: 0.4.5
- xFormers version: 0.0.28.post3
- Accelerator: NVIDIA RTX A5000, 24564 MiB
- Using GPU in script?: NVIDIA RTX A5000
- Using distributed or parallel set-up in script?: no


### Who can help?

_No response_
@Nerogar Nerogar added the bug Something isn't working label Jan 4, 2025
@a-r-r-o-w a-r-r-o-w self-assigned this Jan 4, 2025
@a-r-r-o-w
Copy link
Member

Sorry, this was untested when the PR was merged. Will take a look

@Nerogar
Copy link
Contributor Author

Nerogar commented Jan 4, 2025

Changing the attention mask shape here like this seems to fix the issue.

before

        # 5. Attention
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

after

        # 5. Attention
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask.unsqueeze(1), dropout_p=0.0, is_causal=False
        )

Not sure if this is the best approach, or if the attention mask shape should already be changed before calling the attention processor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants