Skip to content

Commit

Permalink
[eval] update eval scripts with webvid for pab (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
oahzxl authored Oct 3, 2024
1 parent 49f77c8 commit 33a7782
Show file tree
Hide file tree
Showing 8 changed files with 439 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ pretrained
samples
cache_dir
test_outputs
datasets
205 changes: 205 additions & 0 deletions eval/pab/common_metrics/batch_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import argparse
import os

import imageio
import torch
import torchvision.transforms.functional as F
import tqdm
from calculate_lpips import calculate_lpips
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim


def load_video(video_path):
"""
Load a video from the given path and convert it to a PyTorch tensor.
"""
# Read the video using imageio
reader = imageio.get_reader(video_path, "ffmpeg")

# Extract frames and convert to a list of tensors
frames = []
for frame in reader:
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
frames.append(frame_tensor)

# Stack the list of tensors into a single tensor with shape (T, C, H, W)
video_tensor = torch.stack(frames)

return video_tensor


def resize_video(video, target_height, target_width):
resized_frames = []
for frame in video:
resized_frame = F.resize(frame, [target_height, target_width])
resized_frames.append(resized_frame)
return torch.stack(resized_frames)


def resize_gt_video(gt_video, gen_video):
gen_video_shape = gen_video.shape
T_gen, _, H_gen, W_gen = gen_video_shape
T_eval, _, H_eval, W_eval = gt_video.shape

if T_eval < T_gen:
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")

if H_eval < H_gen or W_eval < W_gen:
# Resize the video maintaining the aspect ratio
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
gt_video = resize_video(gt_video, resize_height, resize_width)
# Recalculate the dimensions
T_eval, _, H_eval, W_eval = gt_video.shape

# Center crop
start_h = (H_eval - H_gen) // 2
start_w = (W_eval - W_gen) // 2
cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]

return cropped_video


def get_video_ids(gt_video_dirs, gen_video_dirs):
video_ids = []
for f in os.listdir(gt_video_dirs[0]):
if f.endswith(f".mp4"):
video_ids.append(f.replace(f".mp4", ""))
video_ids.sort()

for video_dir in gt_video_dirs + gen_video_dirs:
tmp_video_ids = []
for f in os.listdir(video_dir):
if f.endswith(f".mp4"):
tmp_video_ids.append(f.replace(f".mp4", ""))
tmp_video_ids.sort()
if tmp_video_ids != video_ids:
raise ValueError(f"Video IDs in {video_dir} are different.")
return video_ids


def get_videos(video_ids, gt_video_dirs, gen_video_dirs):
gt_videos = {}
generated_videos = {}

for gt_video_dir in gt_video_dirs:
tmp_gt_videos_tensor = []
for video_id in video_ids:
gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4"))
tmp_gt_videos_tensor.append(gt_video)
gt_videos[gt_video_dir] = tmp_gt_videos_tensor

for generated_video_dir in gen_video_dirs:
tmp_generated_videos_tensor = []
for video_id in video_ids:
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4"))
tmp_generated_videos_tensor.append(generated_video)
generated_videos[generated_video_dir] = tmp_generated_videos_tensor

return gt_videos, generated_videos


def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs):
out_str = ""

for gt_video_dir in gt_video_dirs:
for generated_video_dir in gen_video_dirs:
if gt_video_dir == generated_video_dir:
continue
lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len(
lpips_results[gt_video_dir][generated_video_dir]
)
psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len(
psnr_results[gt_video_dir][generated_video_dir]
)
ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len(
ssim_results[gt_video_dir][generated_video_dir]
)
out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}"

return out_str


def main(args):
device = "cuda"
gt_video_dirs = args.gt_video_dirs
gen_video_dirs = args.gen_video_dirs

video_ids = get_video_ids(gt_video_dirs, gen_video_dirs)
print(f"Find {len(video_ids)} videos")

prompt_interval = 1
batch_size = 8
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True

lpips_results = {}
psnr_results = {}
ssim_results = {}
for gt_video_dir in gt_video_dirs:
lpips_results[gt_video_dir] = {}
psnr_results[gt_video_dir] = {}
ssim_results[gt_video_dir] = {}
for generated_video_dir in gen_video_dirs:
lpips_results[gt_video_dir][generated_video_dir] = []
psnr_results[gt_video_dir][generated_video_dir] = []
ssim_results[gt_video_dir][generated_video_dir] = []

total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)

for idx in tqdm.tqdm(range(total_len)):
video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size]
gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs)

for gt_video_dir, gt_videos_tensor in gt_videos.items():
for generated_video_dir, generated_videos_tensor in generated_videos.items():
if gt_video_dir == generated_video_dir:
continue

if not isinstance(gt_videos_tensor, torch.Tensor):
for i in range(len(gt_videos_tensor)):
gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0])
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()

generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()

if calculate_lpips_flag:
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
result = result["value"].values()
result = float(sum(result) / len(result))
lpips_results[gt_video_dir][generated_video_dir].append(result)

if calculate_psnr_flag:
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
result = result["value"].values()
result = float(sum(result) / len(result))
psnr_results[gt_video_dir][generated_video_dir].append(result)

if calculate_ssim_flag:
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
result = result["value"].values()
result = float(sum(result) / len(result))
ssim_results[gt_video_dir][generated_video_dir].append(result)

if (idx + 1) % prompt_interval == 0:
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
print(f"Processed {idx + 1} / {total_len} videos. {out_str}")

out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)

# save
with open(f"./batch_eval.txt", "w+") as f:
f.write(out_str)

print(f"Processed all videos. {out_str}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gt_video_dirs", type=str, nargs="+")
parser.add_argument("--gen_video_dirs", type=str, nargs="+")

args = parser.parse_args()

main(args)
8 changes: 4 additions & 4 deletions eval/pab/experiments/opensora_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def eval_base(prompt_list):
config = OpenSoraPlanConfig()
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)

Expand All @@ -15,7 +15,7 @@ def eval_pab1(prompt_list):
temporal_gap=4,
cross_gap=6,
)
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_pab1", loop=5)

Expand All @@ -26,7 +26,7 @@ def eval_pab2(prompt_list):
temporal_gap=5,
cross_gap=7,
)
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_pab2", loop=5)

Expand All @@ -37,7 +37,7 @@ def eval_pab3(prompt_list):
temporal_gap=7,
cross_gap=9,
)
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_pab3", loop=5)

Expand Down
67 changes: 67 additions & 0 deletions eval/pab/webvid/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import csv
import os
import time

import requests
import tqdm


def read_csv(csv_file):
with open(csv_file, "r") as f:
reader = csv.reader(f)
data = list(reader)
data = data[1:]
print(f"Read {len(data)} rows from {csv_file}")
return data


def select_csv(data, min_text_len, min_vid_len, select_num):
results = []
assert 0 <= min_vid_len <= 60
min_vid_len_str = f"PT00H00M{min_vid_len:02d}S"
for d in data:
# [id, link, duration, page, text]
if d[2] < min_vid_len_str:
continue
token_num = len(d[4].split(" "))
if token_num < min_text_len:
continue
results.append(d)
if len(results) == select_num:
break
return results


def save_data_list(data, save_path):
with open(save_path, "w") as f:
writer = csv.writer(f)
writer.writerow(["id", "link", "duration", "page", "text"])
for d in data:
writer.writerow(d)


def download_video(data, save_path):
os.makedirs(save_path, exist_ok=True)
for d in tqdm.tqdm(data):
url = d[1]
video_path = os.path.join(save_path, f"{d[0]}.mp4")
while True:
try:
r = requests.get(url, stream=True)
with open(video_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
break
except ConnectionError:
time.sleep(1)
print(f"Failed to download {url}, retrying...")
continue
time.sleep(0.1)


if __name__ == "__main__":
data = read_csv("./datasets/webvid.csv")
selected_data = select_csv(data, 20, 5, 500)
save_data_list(selected_data, "./datasets/webvid_selected.csv")
download_video(selected_data, "./datasets/webvid")
50 changes: 50 additions & 0 deletions eval/pab/webvid/latte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from utils import generate_func, load_eval_prompts

from videosys import LatteConfig, LattePABConfig, VideoSysEngine


def eval_base(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_base")


def eval_pab1(prompt_list):
pab_config = LattePABConfig(
spatial_range=2,
temporal_range=3,
cross_range=6,
)
config = LatteConfig(enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_pab1")


def eval_pab2(prompt_list):
pab_config = LattePABConfig(
spatial_range=3,
temporal_range=4,
cross_range=7,
)
config = LatteConfig(enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_pab2")


def eval_pab3(prompt_list):
pab_config = LattePABConfig(
spatial_range=4,
temporal_range=6,
cross_range=9,
)
config = LatteConfig(enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_pab3")


if __name__ == "__main__":
prompt_list = load_eval_prompts("./datasets/webvid_selected.csv")
eval_base(prompt_list)
eval_pab1(prompt_list)
eval_pab2(prompt_list)
eval_pab3(prompt_list)
Loading

0 comments on commit 33a7782

Please sign in to comment.