Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: CogVideox train dataset _preprocess_data crop video #9574

Merged
merged 9 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/cogvideo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen

> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.

> [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
Expand Down
91 changes: 77 additions & 14 deletions examples/cogvideo/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torchvision.transforms as TT
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer

import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
Expand Down Expand Up @@ -214,6 +218,12 @@ def get_args():
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument(
"--video_reshape_mode",
type=str,
default="center",
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
Expand Down Expand Up @@ -413,6 +423,7 @@ def __init__(
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
Expand All @@ -429,6 +440,7 @@ def __init__(
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
Expand Down Expand Up @@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):

return instance_prompts, instance_videos

def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)

h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)

delta_h = h - image_size[0]
delta_w = w - image_size[1]

if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr

def _preprocess_data(self):
try:
import decord
Expand All @@ -542,15 +586,14 @@ def _preprocess_data(self):

decord.bridge.set_bridge("torch")

videos = []
train_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
]
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []

for filename in self.instance_video_paths:
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader)

start_frame = min(self.skip_frames_start, video_num_frames)
Expand All @@ -576,10 +619,16 @@ def _preprocess_data(self):
assert (selected_num_frames - 1) % 4 == 0

# Training transforms
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
frames = (frames - 127.5) / 127.5
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
progress_dataset_bar.set_description(
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
)
frames = self._resize_for_rectangle_crop(frames)
videos.append(frames.contiguous()) # [F, C, H, W]
progress_dataset_bar.update(1)

progress_dataset_bar.close()
return videos


Expand Down Expand Up @@ -694,8 +743,13 @@ def log_validation(

videos = []
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])

image_np = VaeImageProcessor.pt_to_numpy(pt_images)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)

videos.append(image_pil)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down Expand Up @@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir):
video_column=args.video_column,
height=args.height,
width=args.width,
video_reshape_mode=args.video_reshape_mode,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
Expand All @@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir):
id_token=args.id_token,
)

def encode_video(video):
def encode_video(video, bar):
bar.update(1)
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist

train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
progress_encode_bar.close()

def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
Expand Down
Loading