Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any plan support cogvideo v1.5 lora and full parameter finetune? #85

Open
trouble-maker007 opened this issue Nov 12, 2024 · 26 comments
Open

Comments

@trouble-maker007
Copy link

No description provided.

@a-r-r-o-w
Copy link
Owner

Hi, yes we are going to be running some experiments now that diffusers adaptation is complete. I think it should already work out-of-the-box and can be finetuned just by changing model_id parameter in the script.

We're waiting for the authors to validate implementation and open source the diffusers format weights at the moment

@trouble-maker007
Copy link
Author

trouble-maker007 commented Nov 12, 2024

@a-r-r-o-w I also believe that it need the diffusion adaptation is complete. But I found that there is some change in the cogvideo v1.5:

  1. self.proj change to use nn.Linear:https://github.com/THUDM/CogVideo/blob/main/sat/dit_video_concat.py#L25
  2. What is the role of OFS embedding?
    https://github.com/THUDM/CogVideo/blob/main/sat/dit_video_concat.py#L721
  3. According to what you said, if only minor changes are needed, does it mean that multi-resolution still relies on different resolutions for bucket training?

following the introduction in: https://hub.baai.ac.cn/view/40956

Finally, to improve training efficiency, we have constructed an efficient training framework for diffusion models. Through various parallel computations and extreme time optimizations, we are able to quickly train longer video sequences. Drawing on the approach of NaViT, our model can simultaneously train videos of various resolutions and durations without the need for video cropping, thus avoiding biases that may arise from various cropping methods. At the same time, the model also has the capability to generate videos of arbitrary resolutions.

It means that use the method like Navit?
image

@sayakpaul
Copy link
Collaborator

I think these all correspond to the pre-training strategy and deviate a bit from the original question. I will let @zRzRzRzRzRzRzR answer them.

@Cubey42
Copy link

Cubey42 commented Nov 17, 2024

seems just changing the model_id isn't enough, I got this error when attempting. updated to diffusers 0.32 just incase and tried making sure it was set to bfloat16.

[2024-11-17 04:40:14,173] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-11-17 04:40:14,890] [INFO] [comm.py:652:init_distributed] cdb=None
[2024-11-17 04:40:14,890] [INFO] [comm.py:683:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Downloading shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 5242.88it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  8.31it/s]
Fetching 3 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 17722.41it/s]
The config attributes {'patch_bias': False, 'patch_size_t': 2} were passed to CogVideoXTransformer3DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/cubeyai/cogvideox-factory/training/cogvideox_text_to_video_lora.py", line 932, in <module>
[rank0]:     main(args)
[rank0]:   File "/home/cubeyai/cogvideox-factory/training/cogvideox_text_to_video_lora.py", line 276, in main
[rank0]:     transformer = CogVideoXTransformer3DModel.from_pretrained(
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/diffusers/src/diffusers/models/modeling_utils.py", line 886, in from_pretrained
[rank0]:     accelerate.load_checkpoint_and_dispatch(
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/accelerate/big_modeling.py", line 613, in load_checkpoint_and_dispatch
[rank0]:     load_checkpoint_in_model(
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/accelerate/utils/modeling.py", line 1806, in load_checkpoint_in_model
[rank0]:     set_module_tensor_to_device(
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/accelerate/utils/modeling.py", line 286, in set_module_tensor_to_device
[rank0]:     raise ValueError(
[rank0]: ValueError: Trying to set a tensor of shape torch.Size([3072, 128]) in "weight" (which has shape torch.Size([3072, 16, 2, 2])), this looks incorrect.
[rank0]:I1117 04:40:16.571000 139790743588992 torch/_dynamo/utils.py:335] TorchDynamo compilation metrics:
[rank0]:I1117 04:40:16.571000 139790743588992 torch/_dynamo/utils.py:335] Function, Runtimes (s)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats evaluate_expr: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _find: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats simplify: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats replace: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats get_axioms: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats safe_expand: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
[rank0]:V1117 04:40:16.571000 139790743588992 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats uninteresting_files: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
E1117 04:40:17.345000 140637060907136 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 3775) of binary: /home/cubeyai/cogvideox-factory/vvv/bin/python3
Traceback (most recent call last):
  File "/home/cubeyai/cogvideox-factory/vvv/bin/accelerate", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/accelerate/commands/launch.py", line 1153, in launch_command
    deepspeed_launcher(args)
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/accelerate/commands/launch.py", line 846, in deepspeed_launcher
    distrib_run.run(args)
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
training/cogvideox_text_to_video_lora.py FAILED ```

@TrickyBarrel
Copy link

@Cubey42 same error. Let me know if you found a fix.

@Cubey42
Copy link

Cubey42 commented Nov 18, 2024

@Cubey42 same error. Let me know if you found a fix.
I tried to add the patch_size_t and patch_bias to the config but was met with a different error, which is probably out of my understanding but sounds like the patch embed needs to be different but im unsure how

@trouble-maker007
Copy link
Author

@Cubey42 @TrickyBarrel The diffusion adaptation is complete?

@Cubey42
Copy link

Cubey42 commented Nov 18, 2024

1.5-5B and 1.5-5B-I2V are on hugging face with diffusers updated to 0.32
https://huggingface.co/THUDM/CogVideoX1.5-5B
has only a couple changes to the config.json
"patch_bias": false,
"patch_size_t": 2,
"use_learned_positional_embeddings": false,

@Cubey42
Copy link

Cubey42 commented Nov 18, 2024

got past the tensor size but stuck again here

***** Running training *****
  Num trainable parameters = 264241152
  Num examples = 1
  Num epochs = 3000
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient accumulation steps = 1
  Total optimization steps = 3000
Steps:   0%|                                                                       | 0/3000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/cubeyai/cogvideox-factory/training/cogvideox_text_to_video_lora.py", line 932, in <module>
    main(args)
  File "/home/cubeyai/cogvideox-factory/training/cogvideox_text_to_video_lora.py", line 700, in main
    model_output = transformer(
                   ^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1899, in forward
    loss = self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py", line 476, in forward
    hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 431, in forward
    image_embeds = image_embeds.reshape(
                   ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[1, 10, 2, 48, 2, 85, 2, 16]' is invalid for input of size 5483520

EDIT: so I don't overspam, I've found the route to be --load_tensors but without it, I can't actually train

@sayakpaul
Copy link
Collaborator

This could be because we might have to do the data reshaping accordingly.

What I would recommend doing is:

  • Going to the dataset.py and prepare_dataset.py files and seeing if the shapes are expected. For example, if the shapes of these variables are expected:
    image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index])
    .
  • Then there might be other nuances regarding scaling of the latents from the VAE.

It would take a bit of time for us to do it ourselves until the CogVideoX 1.5 PR is actually merged into diffusers.

I can understand if this is frustrating but we cannot do much as of now.

@Cubey42
Copy link

Cubey42 commented Nov 18, 2024

I'll have to examine this more but thanks for the response, great work so far!

@ukhalidAI
Copy link

anyone experienced the following error?
Traceback (most recent call last): File "/mnt/round-cake/home/umar/cogvideox-factory/training/cogvideox_image_to_video_lora.py", line 1004, in <module> main(args) File "/mnt/round-cake/home/umar/cogvideox-factory/training/cogvideox_image_to_video_lora.py", line 803, in main model_output = transformer( ^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/accelerate/utils/operations.py", line 823, in forward return model_forward(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/accelerate/utils/operations.py", line 811, in __call__ return convert_to_fp32(self.model_forward(*args, **kwargs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py", line 470, in forward ofs_emb = self.ofs_proj(ofs) ^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 928, in forward t_emb = get_timestep_embedding( ^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/round-cake/home/umar/miniconda3/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 54, in get_timestep_embedding assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" ^^^^^^^^^^^^^^^ AttributeError: 'NoneType' object has no attribute 'shape'

@lyggyhmm
Copy link

continuous attention!

@sayakpaul
Copy link
Collaborator

#92

@Cubey42
Copy link

Cubey42 commented Nov 23, 2024

#92

I've been trying to get this pr to work, but it seems to my cuda not findable. everything worked before but now I get: I've rebuilt the venv and rechecked cuda to make sure

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/cubeyai/cogvideox-factory/training/cogvideox_text_to_video_lora.py", line 947, in <module>
[rank0]:     main(args)
[rank0]:   File "/home/cubeyai/cogvideox-factory/training/cogvideox_text_to_video_lora.py", line 578, in main
[rank0]:     reset_memory(accelerator.device)
[rank0]:   File "/home/cubeyai/cogvideox-factory/training/utils.py", line 242, in reset_memory
[rank0]:     torch.cuda.reset_peak_memory_stats(device)
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/cuda/memory.py", line 321, in reset_peak_memory_stats
[rank0]:     device = _get_device_index(device, optional=True)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/cuda/_utils.py", line 34, in _get_device_index
[rank0]:     raise ValueError(f"Expected a cuda device, but got: {device}")
[rank0]: ValueError: Expected a cuda device, but got: cpu:0

also:
 CUDA Device Query (Runtime API) version (CUDART static linking)

Detected 1 CUDA Capable device(s)

Device 0: "NVIDIA GeForce RTX 4090"
  CUDA Driver Version / Runtime Version          12.7 / 12.6
  CUDA Capability Major/Minor version number:    8.9
  Total amount of global memory:                 24564 MBytes (25756696576 bytes)
  (128) Multiprocessors, (128) CUDA Cores/MP:    16384 CUDA Cores
  GPU Max Clock rate:                            2565 MHz (2.57 GHz)
  Memory Clock rate:                             10501 Mhz
  Memory Bus Width:                              384-bit
  L2 Cache Size:                                 75497472 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
  Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total shared memory per multiprocessor:        102400 bytes
  Total number of registers available per block: 65536
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  1536
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 5 copy engine(s)
  Run time limit on kernels:                     Yes
  Integrated GPU sharing Host Memory:            No
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Disabled
  Device supports Unified Addressing (UVA):      Yes
  Device supports Managed Memory:                Yes
  Device supports Compute Preemption:            Yes
  Supports Cooperative Kernel Launch:            Yes
  Supports MultiDevice Co-op Kernel Launch:      No
  Device PCI Domain ID / Bus ID / location ID:   0 / 1 / 0
  Compute Mode:
     < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 12.7, CUDA Runtime Version = 12.6, NumDevs = 1
Result = PASS

@a-r-r-o-w
Copy link
Owner

That is super weird as I'm able to run it. Are you launching with the training script with accelerate and not python3? Could you also share your environment so I can try to replicate? I don't see anything in the PR that looks like the tensor device is being affected.

@Cubey42
Copy link

Cubey42 commented Nov 23, 2024

That is super weird as I'm able to run it. Are you launching with the training script with accelerate and not python3? Could you also share your environment so I can try to replicate? I don't see anything in the PR that looks like the tensor device is being affected.

I found the GPU ID was different between them and adjusting this value resolved the issue, I'm still getting RuntimeError: shape '[1, 6, 2, 30, 2, 45, 2, 16]' is invalid for input of size 1123200 however.
venv:

accelerate==1.1.1
annotated-types==0.7.0
bitsandbytes==0.44.1
certifi==2024.8.30
charset-normalizer==3.4.0
click==8.1.7
decord==0.6.0
deepspeed==0.15.4
diffusers @ git+https://github.com/huggingface/diffusers@b5fd6f13f5434d69d919cc8cedf0b11db664cf06
docker-pycreds==0.4.0
filelock==3.16.1
fsspec==2024.10.0
gitdb==4.0.11
GitPython==3.1.43
hf_transfer==0.1.8
hjson==3.1.0
huggingface-hub==0.26.2
idna==3.10
imageio==2.36.0
imageio-ffmpeg==0.5.1
importlib_metadata==8.5.0
Jinja2==3.1.4
MarkupSafe==3.0.2
mpmath==1.3.0
msgpack==1.1.0
networkx==3.4.2
ninja==1.11.1.1
numpy==2.1.3
nvidia-cublas-cu12==12.4.2.65
nvidia-cuda-cupti-cu12==12.4.99
nvidia-cuda-nvrtc-cu12==12.4.99
nvidia-cuda-runtime-cu12==12.4.99
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.0.44
nvidia-curand-cu12==10.3.5.119
nvidia-cusolver-cu12==11.6.0.99
nvidia-cusparse-cu12==12.3.0.142
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.4.99
opencv-python==4.10.0.84
packaging==24.2
pandas==2.2.3
peft==0.13.2
pillow==11.0.0
platformdirs==4.3.6
protobuf==5.28.3
psutil==6.1.0
py-cpuinfo==9.0.0
pydantic==2.10.1
pydantic_core==2.27.1
python-dateutil==2.9.0.post0
pytz==2024.2
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
safetensors==0.4.5
sentencepiece==0.2.0
sentry-sdk==2.19.0
setproctitle==1.3.4
setuptools==75.6.0
six==1.16.0
smmap==5.0.1
sympy==1.13.1
tokenizers==0.20.3
torch==2.4.1+cu124
torchao==0.6.1
torchaudio==2.4.1+cu124
torchvision==0.19.1
tqdm==4.67.0
transformers==4.46.3
triton==3.0.0
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
wandb==0.18.7
zipp==3.21.0

if I do I2V instead, I get the

 [rank0]: Traceback (most recent call last):
[rank0]:   File "/home/cubeyai/cogvideox-factory/training/cogvideox_image_to_video_lora.py", line 1005, in <module>
[rank0]:     main(args)
[rank0]:   File "/home/cubeyai/cogvideox-factory/training/cogvideox_image_to_video_lora.py", line 804, in main
[rank0]:     model_output = transformer(
[rank0]:                    ^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1899, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py", line 470, in forward
[rank0]:     ofs_emb = self.ofs_proj(ofs)
[rank0]:               ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 928, in forward
[rank0]:     t_emb = get_timestep_embedding(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/cubeyai/cogvideox-factory/vvv/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 54, in get_timestep_embedding
[rank0]:     assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
[rank0]:                ^^^^^^^^^^^^^^^
[rank0]: AttributeError: 'NoneType' object has no attribute 'shape'

@Cubey42
Copy link

Cubey42 commented Nov 24, 2024

Okay with the help of a fellow discord user, it seems I needed to do 56 frames instead of 49 line things up.

@a-r-r-o-w
Copy link
Owner

Sorry for the late response, I got busy with something. Yes Cog 1.5 required number of latent frames to be divisible by 2. So, for 49 sample frames, you get 13 latent frames (calculated as (x - 1) / 4 + 1). For 53, it is 14 latent frames which would work.

Btw, would love to join any servers you're discussing on to gather feedback! We have many improvements, with further memory savings, planned in the coming weeks as we try to find time outside of normal work.

@sayakpaul
Copy link
Collaborator

#92 added support.

Feel free to share your results, too. Do you think this issue can be closed now?

@Cubey42
Copy link

Cubey42 commented Nov 24, 2024

Sorry for the late response, I got busy with something. Yes Cog 1.5 required number of latent frames to be divisible by 2. So, for 49 sample frames, you get 13 latent frames (calculated as (x - 1) / 4 + 1). For 53, it is 14 latent frames which would work.

Btw, would love to join any servers you're discussing on to gather feedback! We have many improvements, with further memory savings, planned in the coming weeks as we try to find time outside of normal work.

No problem, this is the server (Banodoco) here, we have a wide variety of users and its mainly about animation/video AI. https://discord.gg/qjSMmUKG we have channels for cogvideox along with many other models. More than welcome to join us!

@Yukinoshita-Yukinoe
Copy link

Sorry for the late response, I got busy with something. Yes Cog 1.5 required number of latent frames to be divisible by 2. So, for 49 sample frames, you get 13 latent frames (calculated as (x - 1) / 4 + 1). For 53, it is 14 latent frames which would work.
Btw, would love to join any servers you're discussing on to gather feedback! We have many improvements, with further memory savings, planned in the coming weeks as we try to find time outside of normal work.

No problem, this is the server (Banodoco) here, we have a wide variety of users and its mainly about animation/video AI. https://discord.gg/qjSMmUKG we have channels for cogvideox along with many other models. More than welcome to join us!

The invite link seems to be invalid.

@Cubey42
Copy link

Cubey42 commented Nov 24, 2024

Sorry for the late response, I got busy with something. Yes Cog 1.5 required number of latent frames to be divisible by 2. So, for 49 sample frames, you get 13 latent frames (calculated as (x - 1) / 4 + 1). For 53, it is 14 latent frames which would work.
Btw, would love to join any servers you're discussing on to gather feedback! We have many improvements, with further memory savings, planned in the coming weeks as we try to find time outside of normal work.

No problem, this is the server (Banodoco) here, we have a wide variety of users and its mainly about animation/video AI. https://discord.gg/qjSMmUKG we have channels for cogvideox along with many other models. More than welcome to join us!

The invite link seems to be invalid.

https://discord.gg/P9EbDYFT try this, if not, search Banodoco on google or try https://discord.gg/eKQm3uHKx2

@lyggyhmm
Copy link

Okay with the help of a fellow discord user, it seems I needed to do 56 frames instead of 49 line things up.

Hi, have you tried training with 56 frames in I2V, I used diffusers0.32 but Timesteps should be a 1d-array

@ukhalidAI
Copy link

ukhalidAI commented Nov 25, 2024

Okay with the help of a fellow discord user, it seems I needed to do 56 frames instead of 49 line things up.

I still have the same error even after setting max_num_frames= 56 for I2V. Any suggestions to solve it?

@lijain
Copy link

lijain commented Nov 26, 2024

企业微信截图_17325921895257
Modifying the max num frames above does not solve the problem,i find ofs is no,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants