diff --git a/examples/cogvideo/README.md b/examples/cogvideo/README.md index 398ae9543150..8e57c2b19188 100644 --- a/examples/cogvideo/README.md +++ b/examples/cogvideo/README.md @@ -180,6 +180,7 @@ Note that setting the `` is not necessary. From some limited experimen > [!TIP] > You can pass `--use_8bit_adam` to reduce the memory requirements of training. +> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples. > [!IMPORTANT] > The following settings have been tested at the time of adding CogVideoX LoRA training support: diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 6787c37f93a8..2fc05bf692bb 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -21,7 +21,9 @@ from pathlib import Path from typing import List, Optional, Tuple, Union +import numpy as np import torch +import torchvision.transforms as TT import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -29,12 +31,14 @@ from huggingface_hub import create_repo, upload_folder from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from torch.utils.data import DataLoader, Dataset -from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize from tqdm.auto import tqdm from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer import diffusers from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel +from diffusers.image_processor import VaeImageProcessor from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid @@ -214,6 +218,12 @@ def get_args(): default=720, help="All input videos are resized to this width.", ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default="center", + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") parser.add_argument( "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." @@ -413,6 +423,7 @@ def __init__( video_column: str = "video", height: int = 480, width: int = 720, + video_reshape_mode: str = "center", fps: int = 8, max_num_frames: int = 49, skip_frames_start: int = 0, @@ -429,6 +440,7 @@ def __init__( self.video_column = video_column self.height = height self.width = width + self.video_reshape_mode = video_reshape_mode self.fps = fps self.max_num_frames = max_num_frames self.skip_frames_start = skip_frames_start @@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self): return instance_prompts, instance_videos + def _resize_for_rectangle_crop(self, arr): + image_size = self.height, self.width + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + def _preprocess_data(self): try: import decord @@ -542,15 +586,14 @@ def _preprocess_data(self): decord.bridge.set_bridge("torch") - videos = [] - train_transforms = transforms.Compose( - [ - transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), - ] + progress_dataset_bar = tqdm( + range(0, len(self.instance_video_paths)), + desc="Loading progress resize and crop videos", ) + videos = [] for filename in self.instance_video_paths: - video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) + video_reader = decord.VideoReader(uri=filename.as_posix()) video_num_frames = len(video_reader) start_frame = min(self.skip_frames_start, video_num_frames) @@ -576,10 +619,16 @@ def _preprocess_data(self): assert (selected_num_frames - 1) % 4 == 0 # Training transforms - frames = frames.float() - frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) - videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] + frames = (frames - 127.5) / 127.5 + frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] + progress_dataset_bar.set_description( + f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}" + ) + frames = self._resize_for_rectangle_crop(frames) + videos.append(frames.contiguous()) # [F, C, H, W] + progress_dataset_bar.update(1) + progress_dataset_bar.close() return videos @@ -694,8 +743,13 @@ def log_validation( videos = [] for _ in range(args.num_validation_videos): - video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] - videos.append(video) + pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0] + pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])]) + + image_np = VaeImageProcessor.pt_to_numpy(pt_images) + image_pil = VaeImageProcessor.numpy_to_pil(image_np) + + videos.append(image_pil) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir): video_column=args.video_column, height=args.height, width=args.width, + video_reshape_mode=args.video_reshape_mode, fps=args.fps, max_num_frames=args.max_num_frames, skip_frames_start=args.skip_frames_start, @@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir): id_token=args.id_token, ) - def encode_video(video): + def encode_video(video, bar): + bar.update(1) video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] latent_dist = vae.encode(video).latent_dist return latent_dist - train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] + progress_encode_bar = tqdm( + range(0, len(train_dataset.instance_videos)), + desc="Loading Encode videos", + ) + train_dataset.instance_videos = [ + encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos + ] + progress_encode_bar.close() def collate_fn(examples): videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]