Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: videogen improvements #447

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 55 additions & 28 deletions imaginairy/api/video_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
instantiate_from_config,
platform_appropriate_autocast,
)
from imaginairy.utils.animations import make_bounce_animation
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.named_resolutions import normalize_image_size
from imaginairy.utils.paths import PKG_ROOT

logger = logging.getLogger(__name__)
Expand All @@ -35,6 +37,7 @@
def generate_video(
input_path: str, # Can either be image file or folder with image files
output_folder: str | None = None,
size=(1024, 576),
num_frames: int = 6,
num_steps: int = 30,
model_name: str = "svd-xt",
Expand All @@ -46,6 +49,7 @@ def generate_video(
decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: Optional[str] = None,
repetitions=1,
output_format="webp",
):
"""
Generates a video from a single image or multiple images, conditioned on the provided input_path.
Expand All @@ -71,7 +75,7 @@ def generate_video(
None: The function saves the generated video(s) to the specified output folder.
"""
device = default(device, get_device)

vid_width, vid_height = normalize_image_size(size)
if device == "mps":
msg = "Apple Silicon MPS (M1, M2, etc) is not currently supported for video generation. Switching to cpu generation."
logger.warning(msg)
Expand All @@ -88,7 +92,6 @@ def generate_video(
logger.warning(msg)

start_time = time.perf_counter()
seed = default(seed, random.randint(0, 1000000))
output_fps = default(output_fps, fps_id)

video_model_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(model_name, None)
Expand All @@ -102,17 +105,13 @@ def generate_video(
del output_folder
video_config_path = f"{PKG_ROOT}/{video_model_config.architecture.config_path}"

logger.info(
f"Generating a {num_frames} frame video from {input_path}. Device:{device} seed:{seed}"
)
model, safety_filter = load_model(
config=video_config_path,
device="cpu",
num_frames=num_frames,
num_steps=num_steps,
weights_url=video_model_config.weights_location,
)
torch.manual_seed(seed)

if input_path.startswith("http"):
all_img_paths = [input_path]
Expand All @@ -137,9 +136,14 @@ def generate_video(
msg = f"Could not find file or folder at {input_path}"
raise FileNotFoundError(msg)

expected_size = (1024, 576)
expected_size = (vid_width, vid_height)
for _ in range(repetitions):
for input_path in all_img_paths:
_seed = default(seed, random.randint(0, 1000000))
torch.manual_seed(_seed)
logger.info(
f"Generating a {num_frames} frame video from {input_path}. Device:{device} seed:{_seed}"
)
if input_path.startswith("http"):
image = LazyLoadingImage(url=input_path).as_pillow()
else:
Expand Down Expand Up @@ -207,7 +211,6 @@ def generate_video(
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames_without_noise"] = image
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
value_dict["cond_aug"] = cond_aug

with torch.no_grad(), platform_appropriate_autocast():
reload_model(model.conditioner, device=device)
Expand Down Expand Up @@ -272,37 +275,61 @@ def denoiser(_input, sigma, c):
samples = samples[:, :, upper:lower, left:right]

os.makedirs(output_folder_str, exist_ok=True)
base_count = len(glob(os.path.join(output_folder_str, "*.mp4"))) + 1
base_count = len(glob(os.path.join(output_folder_str, "*.*"))) + 1
source_slug = make_safe_filename(input_path)
video_filename = f"{base_count:06d}_{model_name}_{seed}_{fps_id}fps_{source_slug}.mp4"
video_filename = f"{base_count:06d}_{model_name}_{_seed}_{fps_id}fps_{source_slug}.{output_format}"
video_path = os.path.join(output_folder_str, video_filename)
writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"MP4V"), # type: ignore
output_fps,
(samples.shape[-1], samples.shape[-2]),
)

samples = safety_filter(samples)
vid = (
(rearrange(samples, "t c h w -> t h w c") * 255)
.cpu()
.numpy()
.astype(np.uint8)
)
for frame in vid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()
video_path_h264 = video_path[:-4] + "_h264.mp4"
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
# save_video(samples, video_path, output_fps)
save_video_bounce(samples, video_path, output_fps)

duration = time.perf_counter() - start_time
logger.info(
f"Video of {num_frames} frames generated in {duration:.2f} seconds and saved to {video_path}\n"
)


def save_video(samples: torch.Tensor, video_filename: str, output_fps: int):
"""
Saves a video from given tensor samples.

Args:
samples (torch.Tensor): Tensor containing video frame data.
video_filename (str): The full path and filename where the video will be saved.
output_fps (int): Frames per second for the output video.
safety_filter (Callable[[torch.Tensor], torch.Tensor]): A function to apply a safety filter to the samples.

Returns:
str: The path to the saved video.
"""
vid = (torch.permute(samples, (0, 2, 3, 1)) * 255).cpu().numpy().astype(np.uint8)
writer = cv2.VideoWriter(
video_filename,
cv2.VideoWriter_fourcc(*"MP4V"), # type: ignore
output_fps,
(samples.shape[-1], samples.shape[-2]),
)
for frame in vid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()
video_path_h264 = video_filename[:-4] + "_h264.mp4"
os.system(f"ffmpeg -i {video_filename} -c:v libx264 {video_path_h264}")


def save_video_bounce(samples: torch.Tensor, video_filename: str, output_fps: int):
frames_np = (
(torch.permute(samples, (0, 2, 3, 1)) * 255).cpu().numpy().astype(np.uint8)
)

make_bounce_animation(
imgs=[Image.fromarray(frame) for frame in frames_np],
outpath=video_filename,
end_pause_duration_ms=750,
)


def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders})

Expand Down
17 changes: 17 additions & 0 deletions imaginairy/cli/videogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,20 @@
@click.option(
"--fps", default=6, type=int, help="FPS for the AI to target when generating video"
)
@click.option(
"--size",
default="1024,576",
show_default=True,
type=str,
help="Video dimensions. Can be a named size, single integer, or WIDTHxHEIGHT pair. Should be multiple of 8. Examples: SVD, 512x512, 4k, UHD, 8k, 512, 1080p",
)
@click.option("--output-fps", default=None, type=int, help="FPS for the output video")
@click.option(
"--output-format",
default="webp",
help="Output video format",
type=click.Choice(["webp", "mp4", "gif"]),
)
@click.option(
"--motion-amount",
default=127,
Expand Down Expand Up @@ -54,7 +67,9 @@ def videogen_cmd(
steps,
model,
fps,
size,
output_fps,
output_format,
motion_amount,
repeats,
cond_aug,
Expand Down Expand Up @@ -83,7 +98,9 @@ def videogen_cmd(
num_steps=steps,
model_name=model,
fps_id=fps,
size=size,
output_fps=output_fps,
output_format=output_format,
motion_bucket_id=motion_amount,
cond_aug=cond_aug,
seed=seed,
Expand Down
61 changes: 40 additions & 21 deletions imaginairy/utils/animations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Functions for creating animations from images."""
import os.path
from typing import TYPE_CHECKING, List

import cv2
import torch
Expand All @@ -12,18 +13,24 @@
pillow_img_to_opencv_img,
)

if TYPE_CHECKING:
from PIL import Image

from imaginairy.utils.img_utils import LazyLoadingImage


def make_bounce_animation(
imgs,
outpath,
imgs: "List[Image.Image | LazyLoadingImage | torch.Tensor]",
outpath: str,
transition_duration_ms=500,
start_pause_duration_ms=1000,
end_pause_duration_ms=2000,
max_fps=20,
):
first_img = imgs[0]
last_img = imgs[-1]
middle_imgs = imgs[1:-1]
max_fps = 20
last_img = imgs[-1]

max_frames = int(round(transition_duration_ms / 1000 * max_fps))
min_duration = int(1000 / 20)
if middle_imgs:
Expand All @@ -37,20 +44,8 @@ def make_bounce_animation(
frames = [first_img, *middle_imgs, last_img, *list(reversed(middle_imgs))]

# convert from latents
converted_frames = []

for frame in frames:
if isinstance(frame, torch.Tensor):
frame = model_latents_to_pillow_imgs(frame)[0]
converted_frames.append(frame)
frames = converted_frames
max_size = max([frame.size for frame in frames])
converted_frames = []
for frame in frames:
if frame.size != max_size:
frame = frame.resize(max_size)
converted_frames.append(frame)
frames = converted_frames
converted_frames = _ensure_pillow_images(frames)
converted_frames = _ensure_images_same_size(converted_frames)

durations = (
[start_pause_duration_ms]
Expand All @@ -59,7 +54,29 @@ def make_bounce_animation(
+ [progress_duration] * len(middle_imgs)
)

make_animation(imgs=frames, outpath=outpath, frame_duration_ms=durations)
make_animation(imgs=converted_frames, outpath=outpath, frame_duration_ms=durations)


def _ensure_pillow_images(
imgs: "List[Image.Image | LazyLoadingImage | torch.Tensor]",
) -> "List[Image.Image]":
converted_frames: "List[Image.Image]" = []
for frame in imgs:
if isinstance(frame, torch.Tensor):
converted_frames.append(model_latents_to_pillow_imgs(frame)[0])
else:
converted_frames.append(frame) # type: ignore
return converted_frames


def _ensure_images_same_size(imgs: "List[Image.Image]") -> "List[Image.Image]":
max_size = max([frame.size for frame in imgs])
converted_frames = []
for frame in imgs:
if frame.size != max_size:
frame = frame.resize(max_size)
converted_frames.append(frame)
return converted_frames


def make_slideshow_animation(
Expand All @@ -79,7 +96,9 @@ def make_slideshow_animation(
make_animation(imgs=converted_frames, outpath=outpath, frame_duration_ms=durations)


def make_animation(imgs, outpath, frame_duration_ms=100, captions=None):
def make_animation(
imgs, outpath, frame_duration_ms: int | List[int] = 100, captions=None
):
imgs = imgpaths_to_imgs(imgs)
ext = os.path.splitext(outpath)[1].lower().strip(".")

Expand All @@ -89,7 +108,7 @@ def make_animation(imgs, outpath, frame_duration_ms=100, captions=None):
for img, caption in zip(imgs, captions):
add_caption_to_image(img, caption)

if ext == "gif":
if ext == "gif" or ext == "webp":
make_gif_animation(
imgs=imgs, outpath=outpath, frame_duration_ms=frame_duration_ms
)
Expand Down