diff --git a/.gitignore b/.gitignore index 3476c46f..683c65ee 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ pretrained samples cache_dir test_outputs +datasets diff --git a/eval/pab/common_metrics/batch_eval.py b/eval/pab/common_metrics/batch_eval.py new file mode 100644 index 00000000..16962350 --- /dev/null +++ b/eval/pab/common_metrics/batch_eval.py @@ -0,0 +1,205 @@ +import argparse +import os + +import imageio +import torch +import torchvision.transforms.functional as F +import tqdm +from calculate_lpips import calculate_lpips +from calculate_psnr import calculate_psnr +from calculate_ssim import calculate_ssim + + +def load_video(video_path): + """ + Load a video from the given path and convert it to a PyTorch tensor. + """ + # Read the video using imageio + reader = imageio.get_reader(video_path, "ffmpeg") + + # Extract frames and convert to a list of tensors + frames = [] + for frame in reader: + # Convert the frame to a tensor and permute the dimensions to match (C, H, W) + frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1) + frames.append(frame_tensor) + + # Stack the list of tensors into a single tensor with shape (T, C, H, W) + video_tensor = torch.stack(frames) + + return video_tensor + + +def resize_video(video, target_height, target_width): + resized_frames = [] + for frame in video: + resized_frame = F.resize(frame, [target_height, target_width]) + resized_frames.append(resized_frame) + return torch.stack(resized_frames) + + +def resize_gt_video(gt_video, gen_video): + gen_video_shape = gen_video.shape + T_gen, _, H_gen, W_gen = gen_video_shape + T_eval, _, H_eval, W_eval = gt_video.shape + + if T_eval < T_gen: + raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).") + + if H_eval < H_gen or W_eval < W_gen: + # Resize the video maintaining the aspect ratio + resize_height = max(H_gen, int(H_gen * (H_eval / W_eval))) + resize_width = max(W_gen, int(W_gen * (W_eval / H_eval))) + gt_video = resize_video(gt_video, resize_height, resize_width) + # Recalculate the dimensions + T_eval, _, H_eval, W_eval = gt_video.shape + + # Center crop + start_h = (H_eval - H_gen) // 2 + start_w = (W_eval - W_gen) // 2 + cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen] + + return cropped_video + + +def get_video_ids(gt_video_dirs, gen_video_dirs): + video_ids = [] + for f in os.listdir(gt_video_dirs[0]): + if f.endswith(f".mp4"): + video_ids.append(f.replace(f".mp4", "")) + video_ids.sort() + + for video_dir in gt_video_dirs + gen_video_dirs: + tmp_video_ids = [] + for f in os.listdir(video_dir): + if f.endswith(f".mp4"): + tmp_video_ids.append(f.replace(f".mp4", "")) + tmp_video_ids.sort() + if tmp_video_ids != video_ids: + raise ValueError(f"Video IDs in {video_dir} are different.") + return video_ids + + +def get_videos(video_ids, gt_video_dirs, gen_video_dirs): + gt_videos = {} + generated_videos = {} + + for gt_video_dir in gt_video_dirs: + tmp_gt_videos_tensor = [] + for video_id in video_ids: + gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4")) + tmp_gt_videos_tensor.append(gt_video) + gt_videos[gt_video_dir] = tmp_gt_videos_tensor + + for generated_video_dir in gen_video_dirs: + tmp_generated_videos_tensor = [] + for video_id in video_ids: + generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4")) + tmp_generated_videos_tensor.append(generated_video) + generated_videos[generated_video_dir] = tmp_generated_videos_tensor + + return gt_videos, generated_videos + + +def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs): + out_str = "" + + for gt_video_dir in gt_video_dirs: + for generated_video_dir in gen_video_dirs: + if gt_video_dir == generated_video_dir: + continue + lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len( + lpips_results[gt_video_dir][generated_video_dir] + ) + psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len( + psnr_results[gt_video_dir][generated_video_dir] + ) + ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len( + ssim_results[gt_video_dir][generated_video_dir] + ) + out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}" + + return out_str + + +def main(args): + device = "cuda" + gt_video_dirs = args.gt_video_dirs + gen_video_dirs = args.gen_video_dirs + + video_ids = get_video_ids(gt_video_dirs, gen_video_dirs) + print(f"Find {len(video_ids)} videos") + + prompt_interval = 1 + batch_size = 8 + calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True + + lpips_results = {} + psnr_results = {} + ssim_results = {} + for gt_video_dir in gt_video_dirs: + lpips_results[gt_video_dir] = {} + psnr_results[gt_video_dir] = {} + ssim_results[gt_video_dir] = {} + for generated_video_dir in gen_video_dirs: + lpips_results[gt_video_dir][generated_video_dir] = [] + psnr_results[gt_video_dir][generated_video_dir] = [] + ssim_results[gt_video_dir][generated_video_dir] = [] + + total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0) + + for idx in tqdm.tqdm(range(total_len)): + video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size] + gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs) + + for gt_video_dir, gt_videos_tensor in gt_videos.items(): + for generated_video_dir, generated_videos_tensor in generated_videos.items(): + if gt_video_dir == generated_video_dir: + continue + + if not isinstance(gt_videos_tensor, torch.Tensor): + for i in range(len(gt_videos_tensor)): + gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0]) + gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu() + + generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu() + + if calculate_lpips_flag: + result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device) + result = result["value"].values() + result = float(sum(result) / len(result)) + lpips_results[gt_video_dir][generated_video_dir].append(result) + + if calculate_psnr_flag: + result = calculate_psnr(gt_videos_tensor, generated_videos_tensor) + result = result["value"].values() + result = float(sum(result) / len(result)) + psnr_results[gt_video_dir][generated_video_dir].append(result) + + if calculate_ssim_flag: + result = calculate_ssim(gt_videos_tensor, generated_videos_tensor) + result = result["value"].values() + result = float(sum(result) / len(result)) + ssim_results[gt_video_dir][generated_video_dir].append(result) + + if (idx + 1) % prompt_interval == 0: + out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs) + print(f"Processed {idx + 1} / {total_len} videos. {out_str}") + + out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs) + + # save + with open(f"./batch_eval.txt", "w+") as f: + f.write(out_str) + + print(f"Processed all videos. {out_str}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--gt_video_dirs", type=str, nargs="+") + parser.add_argument("--gen_video_dirs", type=str, nargs="+") + + args = parser.parse_args() + + main(args) diff --git a/eval/pab/experiments/opensora_plan.py b/eval/pab/experiments/opensora_plan.py index 3a5716fa..c7ed44d6 100644 --- a/eval/pab/experiments/opensora_plan.py +++ b/eval/pab/experiments/opensora_plan.py @@ -4,7 +4,7 @@ def eval_base(prompt_list): - config = OpenSoraPlanConfig() + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512") engine = VideoSysEngine(config) generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5) @@ -15,7 +15,7 @@ def eval_pab1(prompt_list): temporal_gap=4, cross_gap=6, ) - config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config) + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config) engine = VideoSysEngine(config) generate_func(engine, prompt_list, "./samples/opensoraplan_pab1", loop=5) @@ -26,7 +26,7 @@ def eval_pab2(prompt_list): temporal_gap=5, cross_gap=7, ) - config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config) + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config) engine = VideoSysEngine(config) generate_func(engine, prompt_list, "./samples/opensoraplan_pab2", loop=5) @@ -37,7 +37,7 @@ def eval_pab3(prompt_list): temporal_gap=7, cross_gap=9, ) - config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config) + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config) engine = VideoSysEngine(config) generate_func(engine, prompt_list, "./samples/opensoraplan_pab3", loop=5) diff --git a/eval/pab/webvid/download.py b/eval/pab/webvid/download.py new file mode 100644 index 00000000..3b698eb1 --- /dev/null +++ b/eval/pab/webvid/download.py @@ -0,0 +1,67 @@ +import csv +import os +import time + +import requests +import tqdm + + +def read_csv(csv_file): + with open(csv_file, "r") as f: + reader = csv.reader(f) + data = list(reader) + data = data[1:] + print(f"Read {len(data)} rows from {csv_file}") + return data + + +def select_csv(data, min_text_len, min_vid_len, select_num): + results = [] + assert 0 <= min_vid_len <= 60 + min_vid_len_str = f"PT00H00M{min_vid_len:02d}S" + for d in data: + # [id, link, duration, page, text] + if d[2] < min_vid_len_str: + continue + token_num = len(d[4].split(" ")) + if token_num < min_text_len: + continue + results.append(d) + if len(results) == select_num: + break + return results + + +def save_data_list(data, save_path): + with open(save_path, "w") as f: + writer = csv.writer(f) + writer.writerow(["id", "link", "duration", "page", "text"]) + for d in data: + writer.writerow(d) + + +def download_video(data, save_path): + os.makedirs(save_path, exist_ok=True) + for d in tqdm.tqdm(data): + url = d[1] + video_path = os.path.join(save_path, f"{d[0]}.mp4") + while True: + try: + r = requests.get(url, stream=True) + with open(video_path, "wb") as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + break + except ConnectionError: + time.sleep(1) + print(f"Failed to download {url}, retrying...") + continue + time.sleep(0.1) + + +if __name__ == "__main__": + data = read_csv("./datasets/webvid.csv") + selected_data = select_csv(data, 20, 5, 500) + save_data_list(selected_data, "./datasets/webvid_selected.csv") + download_video(selected_data, "./datasets/webvid") diff --git a/eval/pab/webvid/latte.py b/eval/pab/webvid/latte.py new file mode 100644 index 00000000..e19298f2 --- /dev/null +++ b/eval/pab/webvid/latte.py @@ -0,0 +1,50 @@ +from utils import generate_func, load_eval_prompts + +from videosys import LatteConfig, LattePABConfig, VideoSysEngine + + +def eval_base(prompt_list): + config = LatteConfig() + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/latte_base") + + +def eval_pab1(prompt_list): + pab_config = LattePABConfig( + spatial_range=2, + temporal_range=3, + cross_range=6, + ) + config = LatteConfig(enable_pab=True, pab_config=pab_config) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/latte_pab1") + + +def eval_pab2(prompt_list): + pab_config = LattePABConfig( + spatial_range=3, + temporal_range=4, + cross_range=7, + ) + config = LatteConfig(enable_pab=True, pab_config=pab_config) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/latte_pab2") + + +def eval_pab3(prompt_list): + pab_config = LattePABConfig( + spatial_range=4, + temporal_range=6, + cross_range=9, + ) + config = LatteConfig(enable_pab=True, pab_config=pab_config) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/latte_pab3") + + +if __name__ == "__main__": + prompt_list = load_eval_prompts("./datasets/webvid_selected.csv") + eval_base(prompt_list) + eval_pab1(prompt_list) + eval_pab2(prompt_list) + eval_pab3(prompt_list) diff --git a/eval/pab/webvid/open_sora.py b/eval/pab/webvid/open_sora.py new file mode 100644 index 00000000..90527259 --- /dev/null +++ b/eval/pab/webvid/open_sora.py @@ -0,0 +1,37 @@ +from utils import generate_func, load_eval_prompts + +from videosys import OpenSoraConfig, OpenSoraPABConfig, VideoSysEngine + + +def eval_base(prompt_list): + config = OpenSoraConfig() + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensora_base") + + +def eval_pab1(prompt_list): + config = OpenSoraConfig(enable_pab=True) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensora_pab1") + + +def eval_pab2(prompt_list): + pab_config = OpenSoraPABConfig(spatial_range=3, temporal_range=5, cross_range=7) + config = OpenSoraConfig(enable_pab=True, pab_config=pab_config) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensora_pab2") + + +def eval_pab3(prompt_list): + pab_config = OpenSoraPABConfig(spatial_range=5, temporal_range=7, cross_range=9) + config = OpenSoraConfig(enable_pab=True, pab_config=pab_config) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensora_pab3") + + +if __name__ == "__main__": + prompt_list = load_eval_prompts("./datasets/webvid_selected.csv") + eval_base(prompt_list) + eval_pab1(prompt_list) + eval_pab2(prompt_list) + eval_pab3(prompt_list) diff --git a/eval/pab/webvid/opensora_plan.py b/eval/pab/webvid/opensora_plan.py new file mode 100644 index 00000000..a505f8c1 --- /dev/null +++ b/eval/pab/webvid/opensora_plan.py @@ -0,0 +1,50 @@ +from utils import generate_func, load_eval_prompts + +from videosys import OpenSoraPlanConfig, OpenSoraPlanV110PABConfig, VideoSysEngine + + +def eval_base(prompt_list): + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512") + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensoraplan_base") + + +def eval_pab1(prompt_list): + pab_config = OpenSoraPlanV110PABConfig( + spatial_range=2, + temporal_range=4, + cross_range=6, + ) + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensoraplan_pab1") + + +def eval_pab2(prompt_list): + pab_config = OpenSoraPlanV110PABConfig( + spatial_range=3, + temporal_range=5, + cross_range=7, + ) + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensoraplan_pab2") + + +def eval_pab3(prompt_list): + pab_config = OpenSoraPlanV110PABConfig( + spatial_range=5, + temporal_range=7, + cross_range=9, + ) + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensoraplan_pab3") + + +if __name__ == "__main__": + prompt_list = load_eval_prompts("./datasets/webvid_selected.csv") + eval_base(prompt_list) + eval_pab1(prompt_list) + eval_pab2(prompt_list) + eval_pab3(prompt_list) diff --git a/eval/pab/webvid/utils.py b/eval/pab/webvid/utils.py new file mode 100644 index 00000000..1968fd5d --- /dev/null +++ b/eval/pab/webvid/utils.py @@ -0,0 +1,25 @@ +import csv +import os + +import tqdm + + +def load_eval_prompts(csv_file_path): + prompts_dict = {} + # Read the CSV file + with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile: + csvreader = csv.DictReader(csvfile) + for row in csvreader: + prompts_dict[row["id"]] = row["text"] + return prompts_dict + + +def generate_func(pipeline, prompt_list, output_dir, kwargs: dict = {}): + kwargs["verbose"] = False + kwargs["seed"] = 0 + for idx, prompt in tqdm.tqdm(list(prompt_list.items())): + if os.path.exists(os.path.join(output_dir, f"{idx}.mp4")): + print(f"Skip {idx} because it already exists") + continue + video = pipeline.generate(prompt, **kwargs).video[0] + pipeline.save_video(video, os.path.join(output_dir, f"{idx}.mp4"))