Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HunyuanVideoPipeline produces NaN values #10314

Closed
smedegaard opened this issue Dec 20, 2024 · 19 comments
Closed

HunyuanVideoPipeline produces NaN values #10314

smedegaard opened this issue Dec 20, 2024 · 19 comments
Labels
bug Something isn't working

Comments

@smedegaard
Copy link

Describe the bug

Running diffusers.utils.export_to_video() on the output of HunyuanVideoPipeline results in

/app/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")

After adding some checks to numpy_to_pil() in image_processor.py I have confirmed that the output contains NaN values

  File "/app/pipeline.py", line 37, in <module>
    output = pipe(
             ^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py", line 677, in __call__
    video = self.video_processor.postprocess_video(video, output_type=output_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/video_processor.py", line 103, in postprocess_video
    batch_output = self.postprocess(batch_vid, output_type)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/image_processor.py", line 823, in postprocess
    return self.numpy_to_pil(image)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/image_processor.py", line 158, in numpy_to_pil
    raise ValueError("Image array contains NaN values")
ValueError: Image array contains NaN values

Reproduction

import os
import time

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download
from torch.profiler import ProfilerActivity, profile, record_function

os.environ["TOKENIZERS_PARALLELISM"] = "false"


MODEL_ID = "tencent/HunyuanVideo"
PROMPT = "a whale shark floating through outer space"
profile_dir = os.environ.get("PROFILE_OUT_PATH", "./")
profile_file_name = os.environ.get("PROFILE_OUT_FILE_NAME", "hunyuan_profile.json")
profile_path = os.path.join(profile_dir, profile_file_name)

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    MODEL_ID, subfolder="transformer", torch_dtype=torch.float16, revision="refs/pr/18"
)
pipe = HunyuanVideoPipeline.from_pretrained(
    MODEL_ID, transformer=transformer, torch_dtype=torch.float16, revision="refs/pr/18"
)
pipe.vae.enable_tiling()
pipe.to("cuda")

print(f"\nStarting profiling of {MODEL_ID}\n")

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True
) as prof:
    with record_function("model_inference"):
        output = pipe(
            prompt=PROMPT,
            height=320,
            width=512,
            num_frames=61,
            num_inference_steps=30,
        )

# Export and print profiling results
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace(profile_path)
print(f"{profile_file_name} ready")

# export video
video = output.frames[0]

print(" ====== raw video matrix =====")
print(video)
print()

print(" ====== Exporting video =====")
export_to_video(video, "hunyuan_example.mp4", fps=15)
print()

Logs

No response

System Info

GPU: AMD MI300X

ARG BASE_IMAGE=python:3.11-slim
FROM ${BASE_IMAGE}

ENV PYTHONBUFFERED=true
ENV CUDA_VISIBLE_DEVICES=0

WORKDIR /app

# Install tools
RUN apt-get update && \
    apt-get install -y --no-install-recommends \
    git \
    libgl1-mesa-glx \
    libglib2.0-0 \
    libsm6 \
    libxext6 \
    libxrender-dev \
    libfontconfig1 \
    ffmpeg \
    build-essential && \
    rm -rf /var/lib/apt/lists/*

# install ROCm pytorch and python dependencies
RUN python -m pip install --no-cache-dir \
    torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2 && \
    python -m pip install --no-cache-dir \
    accelerate transformers sentencepiece protobuf opencv-python imageio imageio-ffmpeg

# install diffusers from source to include newest pipeline classes
COPY diffusers diffusers
RUN cd diffusers && \
    python -m pip install -e .

# Copy the profiling script
ARG PIPELINE_FILE
COPY ${PIPELINE_FILE} pipeline.py

# run the script
CMD ["python", "pipeline.py"]

Who can help?

@DN6 @a-r-r-o-w

@smedegaard smedegaard added the bug Something isn't working label Dec 20, 2024
@a-r-r-o-w
Copy link
Member

Transformer needs to be in bfloat16. Could you try with that?

@smedegaard
Copy link
Author

Transformer needs to be in bfloat16. Could you try with that?

Same result @a-r-r-o-w

@hlky
Copy link
Collaborator

hlky commented Dec 20, 2024

On CUDA we've seen the same issue when not using the latest PyTorch, from torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2 it looks like you should have either 2.5.0 or 2.5.1, if it's 2.5.0 can you try 2.5.1 and if it's 2.5.1 can you try nightly?

@smedegaard
Copy link
Author

Thanks for the suggestion @hlky , I'll try some more combinations.

ROCm pytorch result
6.4 2.6.0a0+gitb7a45db ⛔ NaN values in output

@tanshuai0219
Copy link

Same, I also get nan value.

@a-r-r-o-w
Copy link
Member

@tanshuai0219 Is this on a CUDA GPU or MPS/ROCm? I'm unable to replicate when using the transformer in bfloat16 on torch >= 2.5. I can try some previous versions of pytorch to try and make it work for CUDA devices, but for other devices, I'm afraid we will need help from the community in making it work

@tanshuai0219
Copy link

@tanshuai0219 Is this on a CUDA GPU or MPS/ROCm? I'm unable to replicate when using the transformer in bfloat16 on torch >= 2.5. I can try some previous versions of pytorch to try and make it work for CUDA devices, but for other devices, I'm afraid we will need help from the community in making it work

Yes, it's on a CUDA GPU, CUDA version: 12.4
I pull the latest version of diffusers, and use
pip install -e . to install diffusers.

Then I run:

`import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
pipe.vae.enable_tiling()
pipe.to("cuda")

output = pipe(
prompt="A cat walks on the grass, realistic",
height=320,
width=512,
num_frames=61,
num_inference_steps=30,
).frames[0]

import numpy as np

print(np.array(output[0]))

export_to_video(output, "output.mp4", fps=15)`


np.array(output[0]) is all zero.
and the saved output.mp4 is all black:

output.mp4

@a-r-r-o-w
Copy link
Member

Can you share the output of diffusers-cli env? I verified once more that it works for me. I'll take a look at other torch versions soon. Here's my output:

- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
- Jax version: 0.4.31
- JaxLib version: 0.4.31
- Huggingface_hub version: 0.26.2
- Transformers version: 4.48.0.dev0
- Accelerate version: 1.1.0.dev0
- PEFT version: 0.13.3.dev0
- Bitsandbytes version: 0.43.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA DGX Display, 4096 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
output.mp4

@tanshuai0219
Copy link

Can you share the output of diffusers-cli env? I verified once more that it works for me. I'll take a look at other torch versions soon. Here's my output:

- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
- Jax version: 0.4.31
- JaxLib version: 0.4.31
- Huggingface_hub version: 0.26.2
- Transformers version: 4.48.0.dev0
- Accelerate version: 1.1.0.dev0
- PEFT version: 0.13.3.dev0
- Bitsandbytes version: 0.43.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA DGX Display, 4096 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

output.mp4

here is mine:

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux
  • Running on Google Colab?: No
  • Python version: 3.10.14
  • PyTorch version (GPU?): 2.4.0+cu124 (True)
  • Huggingface_hub version: 0.24.2
  • Transformers version: 4.46.3
  • Accelerate version: 0.33.0
  • PEFT version: 0.12.0
  • Bitsandbytes version: 0.43.2
  • Safetensors version: 0.4.3
  • xFormers version: 0.0.27
  • Accelerator: NVIDIA A100-80GB, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

@tanshuai0219
Copy link

Can you share the output of diffusers-cli env? I verified once more that it works for me. I'll take a look at other torch versions soon. Here's my output:

- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
- Jax version: 0.4.31
- JaxLib version: 0.4.31
- Huggingface_hub version: 0.26.2
- Transformers version: 4.48.0.dev0
- Accelerate version: 1.1.0.dev0
- PEFT version: 0.13.3.dev0
- Bitsandbytes version: 0.43.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA DGX Display, 4096 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

output.mp4

here is mine:

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux
  • Running on Google Colab?: No
  • Python version: 3.10.14
  • PyTorch version (GPU?): 2.4.0+cu124 (True)
  • Huggingface_hub version: 0.24.2
  • Transformers version: 4.46.3
  • Accelerate version: 0.33.0
  • PEFT version: 0.12.0
  • Bitsandbytes version: 0.43.2
  • Safetensors version: 0.4.3
  • xFormers version: 0.0.27
  • Accelerator: NVIDIA A100-80GB, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

If I upgrade the transformers from 4.46.3 to 4.48.0.dev0, I get the error like:
RuntimeError: Failed to import diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video because of the following error (look up to see its traceback): Failed to import diffusers.loaders.lora_pipeline because of the following error (look up to see its traceback): cannot import name 'shard_checkpoint' from 'transformers.modeling_utils'

@a-r-r-o-w
Copy link
Member

I would recommend trying to replicate in a clean environment if you are current in a broken state. Atleast 5 people have confirmed so far that upgrading torch to 2.5.1 does not lead to black videos any more. We are still unsure why it doesn't work on 2.4 or below.

@smedegaard
Copy link
Author

smedegaard commented Jan 7, 2025

I was not able to get a usable output with pytorch 2.5.1 either.

ROCm pytorch result
6.3.0.60300-39~22.04 2.5.1+gitabbfe77' ⛔ NaN values in output
6.4 2.6.0a0+gitb7a45db ⛔ NaN values in output

Hardware: AMD Instinct MI300X

- 🤗 Diffusers version: 0.32.0.dev0
- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.15
- PyTorch version (GPU?): 2.5.1+gitabbfe77 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.27.1
- Transformers version: 4.47.1
- Accelerate version: 1.2.1
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.5.0
- xFormers version: not installed
- Accelerator: NA
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
`pip freeze` ```bash absl-py==2.1.0 accelerate==1.2.1 aiohappyeyeballs==2.4.4 aiohttp==3.11.9 aiosignal==1.3.1 amdsmi @ file:///opt/rocm-6.3.0/share/amd_smi apex @ file:///var/lib/jenkins/apex asgiref==3.8.1 astunparse==1.6.3 async-timeout==5.0.1 attrs==24.2.0 audioread==3.0.1 autocommand==2.2.2 backports.tarfile==1.2.0 boto3==1.19.12 botocore==1.22.12 cachetools==5.5.0 certifi==2024.8.30 cffi==1.17.1 charset-normalizer==3.4.0 click==8.1.7 colorama==0.4.6 coremltools==5.0b5 cryptography==44.0.0 Cython==3.0.11 decorator==5.1.1 Deprecated==1.2.15 -e git+https://github.com/huggingface/diffusers.git@1826a1e#egg=diffusers dill==0.3.7 Django==5.1.4 exceptiongroup==1.2.2 execnet==2.1.1 expecttest==0.2.1 fbscribelogger==0.1.6 filelock==3.16.1 flatbuffers==2.0 frozenlist==1.5.0 fsspec==2024.10.0 future==1.0.0 geojson==2.5.0 ghstack==0.8.0 google-auth==2.36.0 google-auth-oauthlib==1.0.0 grpcio==1.68.1 huggingface-hub==0.27.1 hypothesis==5.35.1 idna==3.10 image==1.5.33 imageio==2.36.1 imageio-ffmpeg==0.5.1 importlib_metadata==8.0.0 importlib_resources==6.4.0 inflect==7.3.1 iniconfig==2.0.0 jaraco.collections==5.1.0 jaraco.context==5.3.0 jaraco.functools==4.0.1 jaraco.text==3.12.1 Jinja2==3.1.4 jmespath==0.10.0 joblib==1.4.2 junitparser==2.1.1 lark==0.12.0 lazy_loader==0.4 librosa==0.10.2.post1 lintrunner==0.12.5 llvmlite==0.38.1 lxml==5.0.0 Markdown==3.7 MarkupSafe==3.0.2 ml_dtypes==0.5.0 more-itertools==10.3.0 mpmath==1.3.0 msgpack==1.1.0 multidict==6.1.0 mypy==1.10.0 mypy-extensions==1.0.0 networkx==2.8.8 numba==0.55.2 numpy==1.21.2 oauthlib==3.2.2 onnx==1.16.1 onnxscript==0.1.0.dev20240817 opencv-python==4.10.0.84 opt-einsum==3.3.0 optionloop==1.0.7 optree==0.12.1 packaging==24.2 pillow==10.3.0 platformdirs==4.3.6 pluggy==1.5.0 ply==3.11 pooch==1.8.2 propcache==0.2.1 protobuf==3.20.2 psutil==6.1.0 pyasn1==0.6.1 pyasn1_modules==0.4.1 pycparser==2.22 PyGithub==2.3.0 Pygments==2.15.0 PyJWT==2.10.1 PyNaCl==1.5.0 pytest==7.3.2 pytest-cpp==2.3.0 pytest-flakefinder==1.1.0 pytest-rerunfailures==14.0 pytest-xdist==3.3.1 python-dateutil==2.9.0.post0 PyWavelets==1.4.1 PyYAML @ file:///croot/pyyaml_1728657952215/work redis==5.2.0 regex==2024.11.6 requests==2.32.3 requests-oauthlib==2.0.0 rockset==1.0.3 rsa==4.9 s3transfer==0.5.2 safetensors==0.5.0 scikit-image==0.22.0 scikit-learn==1.5.2 scipy==1.10.1 sentencepiece==0.2.0 six @ file:///tmp/build/80754af9/six_1644875935023/work sortedcontainers==2.4.0 soundfile==0.12.1 soxr==0.5.0.post1 sqlparse==0.5.2 sympy==1.13.1 tb-nightly==2.13.0a20230426 tensorboard==2.13.0 tensorboard-data-server==0.7.2 threadpoolctl==3.5.0 thriftpy2==0.5.2 tifffile==2024.9.20 tlparse==0.3.7 tokenizers==0.21.0 tomli==2.2.1 torch @ file:///var/lib/jenkins/pytorch/dist/torch-2.5.1%2Bgitabbfe77-cp310-cp310-linux_x86_64.whl#sha256=b5fecdb1e666ea7de99d5ca164c7dbe22f341f4bd07a288beeeddca65f2232be torchvision==0.20.0a0+afc54f7 tqdm==4.67.1 transformers==4.47.1 # Editable install with no version control (triton==3.1.0) -e /var/lib/jenkins/triton/python typeguard==4.3.0 typing_extensions==4.12.2 unittest-xml-reporting==3.2.0 urllib3==1.26.20 Werkzeug==3.1.3 wrapt==1.17.0 xdoctest==1.1.0 yarl==1.18.3 z3-solver==4.12.2.0 zipp==3.19.2 ```

@smedegaard
Copy link
Author

Update to Diffusers version: 0.33.0.dev0.
same result as above

@smedegaard
Copy link
Author

tested with model_id=hunyuanvideo-community/HunyuanVideo"

result ROCm pytorch Diffusers
⛔ NaN values in output 6.3.0.60300-39~22.04 2.5.1+gitabbfe77' 0.33.0.dev0

@hlky
Copy link
Collaborator

hlky commented Jan 7, 2025

@smedegaard Could you test with these changes?

diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index 6cb97af9..84610471 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -713,15 +713,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
         condition_sequence_length = encoder_hidden_states.shape[1]
         sequence_length = latent_sequence_length + condition_sequence_length
         attention_mask = torch.zeros(
-            batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
-        )  # [B, N, N]
+            batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
+        )  # [B, N]
 
         effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int)  # [B,]
         effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
 
         for i in range(batch_size):
-            attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
-        attention_mask = attention_mask.unsqueeze(1)  # [B, 1, N, N], for broadcasting across attention heads
+            attention_mask[i, : effective_sequence_length[i]] = True
+        attention_mask = attention_mask.unsqueeze(1)  # [B, 1, N], for broadcasting across attention heads
 
         # 4. Transformer blocks
         if torch.is_grad_enabled() and self.gradient_checkpointing:

I was able to generate successfully on CUDA with PyTorch 2.4.1 which is also known to produce NaN.

output.mp4

cc @a-r-r-o-w

There's also a small performance gain

Code
import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
  model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to("cuda")
pipe.vae.enable_tiling()

output = pipe(
  prompt="A cat walks on the grass, realistic",
  height=320,
  width=512,
  num_frames=61,
  num_inference_steps=30,
).frames[0]
export_to_video(output, "output.mp4", fps=15)
2.4.1 with fix 2.5.1 2.5.1 with fix
30/30 [01:56<00:00, 3.88s/it] 30/30 [02:04<00:00, 4.16s/it] [01:56<00:00, 3.89s/it]

@smedegaard
Copy link
Author

@hlky Thanks for the tip. I'm afraid it didn't fix the problem for me.

I added your suggested changes to diffusers/models/transformers/transformer_hunyuan_video.py but I still get NaN in the output.

@smedegaard
Copy link
Author

smedegaard commented Jan 7, 2025

For clarity, here's my changes to numpy_to_pil() in src/diffusers/image_processor.py. This is where I detect the NaN values in the output.

       @staticmethod
      def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
          """
          Convert numpy image array(s) to PIL images with validation.
          
          Args:
              images (np.ndarray): Image array in range [0, 1] with shape (N, H, W, C) or (H, W, C)
              
          Returns:
              List[PIL.Image.Image]: List of PIL images
              
          Raises:
              ValueError: If images contain invalid values
              TypeError: If input is not a numpy array or has invalid shape/type
          """
          if not isinstance(images, np.ndarray):
              raise TypeError(f"Expected numpy array, got {type(images)}")
  
          # Handle single image case
          if images.ndim == 3:
              images = images[None, ...]
          elif images.ndim != 4:
              raise ValueError(f"Expected 3D or 4D array, got {images.ndim}D")
  
          # Check for NaN/inf before any operations
          if np.any(np.isnan(images)):
              raise ValueError("Image array contains NaN values")
          if np.any(np.isinf(images)):
              raise ValueError("Image array contains infinite values")
  
          # Check value range
          min_val = np.min(images)
          max_val = np.max(images)
          if min_val < 0 or max_val > 1:
              raise ValueError(
                  f"Image values must be in range [0, 1], got range [{min_val}, {max_val}]"
              )
  
          try:
              # Convert to uint8
              images_uint8 = (images * 255).round().astype("uint8")
              
              # Verify the conversion worked correctly
              if np.any(np.isnan(images_uint8)):
                  raise ValueError("Conversion to uint8 produced NaN values")
                  
          except Exception as e:
              raise ValueError(f"Failed to convert to uint8: {str(e)}")
  
          try:
              # Convert to PIL images
              if images.shape[-1] == 1:
                  pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_uint8]
              else:
                  pil_images = [Image.fromarray(image) for image in images_uint8]
                  
              return pil_images
              
          except Exception as e:
              raise ValueError(f"Failed to create PIL images: {str(e)}")

@hlky
Copy link
Collaborator

hlky commented Jan 7, 2025

Could you double check with the PR #10482? I was able to generate the following on AMD Instinct MI300X using the PR branch.

output.10.mp4
output.9.mp4

@smedegaard
Copy link
Author

Thanks @hlky and @a-r-r-o-w , we have confirmed on our side that it produces video images after the recent patch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants