-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[eval] update eval scripts with webvid for pab (#225)
- Loading branch information
Showing
8 changed files
with
439 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,3 +165,4 @@ pretrained | |
samples | ||
cache_dir | ||
test_outputs | ||
datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import argparse | ||
import os | ||
|
||
import imageio | ||
import torch | ||
import torchvision.transforms.functional as F | ||
import tqdm | ||
from calculate_lpips import calculate_lpips | ||
from calculate_psnr import calculate_psnr | ||
from calculate_ssim import calculate_ssim | ||
|
||
|
||
def load_video(video_path): | ||
""" | ||
Load a video from the given path and convert it to a PyTorch tensor. | ||
""" | ||
# Read the video using imageio | ||
reader = imageio.get_reader(video_path, "ffmpeg") | ||
|
||
# Extract frames and convert to a list of tensors | ||
frames = [] | ||
for frame in reader: | ||
# Convert the frame to a tensor and permute the dimensions to match (C, H, W) | ||
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1) | ||
frames.append(frame_tensor) | ||
|
||
# Stack the list of tensors into a single tensor with shape (T, C, H, W) | ||
video_tensor = torch.stack(frames) | ||
|
||
return video_tensor | ||
|
||
|
||
def resize_video(video, target_height, target_width): | ||
resized_frames = [] | ||
for frame in video: | ||
resized_frame = F.resize(frame, [target_height, target_width]) | ||
resized_frames.append(resized_frame) | ||
return torch.stack(resized_frames) | ||
|
||
|
||
def resize_gt_video(gt_video, gen_video): | ||
gen_video_shape = gen_video.shape | ||
T_gen, _, H_gen, W_gen = gen_video_shape | ||
T_eval, _, H_eval, W_eval = gt_video.shape | ||
|
||
if T_eval < T_gen: | ||
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).") | ||
|
||
if H_eval < H_gen or W_eval < W_gen: | ||
# Resize the video maintaining the aspect ratio | ||
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval))) | ||
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval))) | ||
gt_video = resize_video(gt_video, resize_height, resize_width) | ||
# Recalculate the dimensions | ||
T_eval, _, H_eval, W_eval = gt_video.shape | ||
|
||
# Center crop | ||
start_h = (H_eval - H_gen) // 2 | ||
start_w = (W_eval - W_gen) // 2 | ||
cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen] | ||
|
||
return cropped_video | ||
|
||
|
||
def get_video_ids(gt_video_dirs, gen_video_dirs): | ||
video_ids = [] | ||
for f in os.listdir(gt_video_dirs[0]): | ||
if f.endswith(f".mp4"): | ||
video_ids.append(f.replace(f".mp4", "")) | ||
video_ids.sort() | ||
|
||
for video_dir in gt_video_dirs + gen_video_dirs: | ||
tmp_video_ids = [] | ||
for f in os.listdir(video_dir): | ||
if f.endswith(f".mp4"): | ||
tmp_video_ids.append(f.replace(f".mp4", "")) | ||
tmp_video_ids.sort() | ||
if tmp_video_ids != video_ids: | ||
raise ValueError(f"Video IDs in {video_dir} are different.") | ||
return video_ids | ||
|
||
|
||
def get_videos(video_ids, gt_video_dirs, gen_video_dirs): | ||
gt_videos = {} | ||
generated_videos = {} | ||
|
||
for gt_video_dir in gt_video_dirs: | ||
tmp_gt_videos_tensor = [] | ||
for video_id in video_ids: | ||
gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4")) | ||
tmp_gt_videos_tensor.append(gt_video) | ||
gt_videos[gt_video_dir] = tmp_gt_videos_tensor | ||
|
||
for generated_video_dir in gen_video_dirs: | ||
tmp_generated_videos_tensor = [] | ||
for video_id in video_ids: | ||
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4")) | ||
tmp_generated_videos_tensor.append(generated_video) | ||
generated_videos[generated_video_dir] = tmp_generated_videos_tensor | ||
|
||
return gt_videos, generated_videos | ||
|
||
|
||
def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs): | ||
out_str = "" | ||
|
||
for gt_video_dir in gt_video_dirs: | ||
for generated_video_dir in gen_video_dirs: | ||
if gt_video_dir == generated_video_dir: | ||
continue | ||
lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len( | ||
lpips_results[gt_video_dir][generated_video_dir] | ||
) | ||
psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len( | ||
psnr_results[gt_video_dir][generated_video_dir] | ||
) | ||
ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len( | ||
ssim_results[gt_video_dir][generated_video_dir] | ||
) | ||
out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}" | ||
|
||
return out_str | ||
|
||
|
||
def main(args): | ||
device = "cuda" | ||
gt_video_dirs = args.gt_video_dirs | ||
gen_video_dirs = args.gen_video_dirs | ||
|
||
video_ids = get_video_ids(gt_video_dirs, gen_video_dirs) | ||
print(f"Find {len(video_ids)} videos") | ||
|
||
prompt_interval = 1 | ||
batch_size = 8 | ||
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True | ||
|
||
lpips_results = {} | ||
psnr_results = {} | ||
ssim_results = {} | ||
for gt_video_dir in gt_video_dirs: | ||
lpips_results[gt_video_dir] = {} | ||
psnr_results[gt_video_dir] = {} | ||
ssim_results[gt_video_dir] = {} | ||
for generated_video_dir in gen_video_dirs: | ||
lpips_results[gt_video_dir][generated_video_dir] = [] | ||
psnr_results[gt_video_dir][generated_video_dir] = [] | ||
ssim_results[gt_video_dir][generated_video_dir] = [] | ||
|
||
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0) | ||
|
||
for idx in tqdm.tqdm(range(total_len)): | ||
video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size] | ||
gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs) | ||
|
||
for gt_video_dir, gt_videos_tensor in gt_videos.items(): | ||
for generated_video_dir, generated_videos_tensor in generated_videos.items(): | ||
if gt_video_dir == generated_video_dir: | ||
continue | ||
|
||
if not isinstance(gt_videos_tensor, torch.Tensor): | ||
for i in range(len(gt_videos_tensor)): | ||
gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0]) | ||
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu() | ||
|
||
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu() | ||
|
||
if calculate_lpips_flag: | ||
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device) | ||
result = result["value"].values() | ||
result = float(sum(result) / len(result)) | ||
lpips_results[gt_video_dir][generated_video_dir].append(result) | ||
|
||
if calculate_psnr_flag: | ||
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor) | ||
result = result["value"].values() | ||
result = float(sum(result) / len(result)) | ||
psnr_results[gt_video_dir][generated_video_dir].append(result) | ||
|
||
if calculate_ssim_flag: | ||
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor) | ||
result = result["value"].values() | ||
result = float(sum(result) / len(result)) | ||
ssim_results[gt_video_dir][generated_video_dir].append(result) | ||
|
||
if (idx + 1) % prompt_interval == 0: | ||
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs) | ||
print(f"Processed {idx + 1} / {total_len} videos. {out_str}") | ||
|
||
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs) | ||
|
||
# save | ||
with open(f"./batch_eval.txt", "w+") as f: | ||
f.write(out_str) | ||
|
||
print(f"Processed all videos. {out_str}") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--gt_video_dirs", type=str, nargs="+") | ||
parser.add_argument("--gen_video_dirs", type=str, nargs="+") | ||
|
||
args = parser.parse_args() | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import csv | ||
import os | ||
import time | ||
|
||
import requests | ||
import tqdm | ||
|
||
|
||
def read_csv(csv_file): | ||
with open(csv_file, "r") as f: | ||
reader = csv.reader(f) | ||
data = list(reader) | ||
data = data[1:] | ||
print(f"Read {len(data)} rows from {csv_file}") | ||
return data | ||
|
||
|
||
def select_csv(data, min_text_len, min_vid_len, select_num): | ||
results = [] | ||
assert 0 <= min_vid_len <= 60 | ||
min_vid_len_str = f"PT00H00M{min_vid_len:02d}S" | ||
for d in data: | ||
# [id, link, duration, page, text] | ||
if d[2] < min_vid_len_str: | ||
continue | ||
token_num = len(d[4].split(" ")) | ||
if token_num < min_text_len: | ||
continue | ||
results.append(d) | ||
if len(results) == select_num: | ||
break | ||
return results | ||
|
||
|
||
def save_data_list(data, save_path): | ||
with open(save_path, "w") as f: | ||
writer = csv.writer(f) | ||
writer.writerow(["id", "link", "duration", "page", "text"]) | ||
for d in data: | ||
writer.writerow(d) | ||
|
||
|
||
def download_video(data, save_path): | ||
os.makedirs(save_path, exist_ok=True) | ||
for d in tqdm.tqdm(data): | ||
url = d[1] | ||
video_path = os.path.join(save_path, f"{d[0]}.mp4") | ||
while True: | ||
try: | ||
r = requests.get(url, stream=True) | ||
with open(video_path, "wb") as f: | ||
for chunk in r.iter_content(chunk_size=1024): | ||
if chunk: | ||
f.write(chunk) | ||
break | ||
except ConnectionError: | ||
time.sleep(1) | ||
print(f"Failed to download {url}, retrying...") | ||
continue | ||
time.sleep(0.1) | ||
|
||
|
||
if __name__ == "__main__": | ||
data = read_csv("./datasets/webvid.csv") | ||
selected_data = select_csv(data, 20, 5, 500) | ||
save_data_list(selected_data, "./datasets/webvid_selected.csv") | ||
download_video(selected_data, "./datasets/webvid") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from utils import generate_func, load_eval_prompts | ||
|
||
from videosys import LatteConfig, LattePABConfig, VideoSysEngine | ||
|
||
|
||
def eval_base(prompt_list): | ||
config = LatteConfig() | ||
engine = VideoSysEngine(config) | ||
generate_func(engine, prompt_list, "./samples/latte_base") | ||
|
||
|
||
def eval_pab1(prompt_list): | ||
pab_config = LattePABConfig( | ||
spatial_range=2, | ||
temporal_range=3, | ||
cross_range=6, | ||
) | ||
config = LatteConfig(enable_pab=True, pab_config=pab_config) | ||
engine = VideoSysEngine(config) | ||
generate_func(engine, prompt_list, "./samples/latte_pab1") | ||
|
||
|
||
def eval_pab2(prompt_list): | ||
pab_config = LattePABConfig( | ||
spatial_range=3, | ||
temporal_range=4, | ||
cross_range=7, | ||
) | ||
config = LatteConfig(enable_pab=True, pab_config=pab_config) | ||
engine = VideoSysEngine(config) | ||
generate_func(engine, prompt_list, "./samples/latte_pab2") | ||
|
||
|
||
def eval_pab3(prompt_list): | ||
pab_config = LattePABConfig( | ||
spatial_range=4, | ||
temporal_range=6, | ||
cross_range=9, | ||
) | ||
config = LatteConfig(enable_pab=True, pab_config=pab_config) | ||
engine = VideoSysEngine(config) | ||
generate_func(engine, prompt_list, "./samples/latte_pab3") | ||
|
||
|
||
if __name__ == "__main__": | ||
prompt_list = load_eval_prompts("./datasets/webvid_selected.csv") | ||
eval_base(prompt_list) | ||
eval_pab1(prompt_list) | ||
eval_pab2(prompt_list) | ||
eval_pab3(prompt_list) |
Oops, something went wrong.