Skip to content

Commit

Permalink
feature: makes generating videos more programmatic.
Browse files Browse the repository at this point in the history
the generate_video function previously involved logic for saving the new file and wouldn't return anything. now it will return a list of the generated samples.
  • Loading branch information
jaydrennan committed Mar 2, 2024
1 parent e6a1c98 commit 030d126
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 80 deletions.
2 changes: 1 addition & 1 deletion imaginairy/api/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _record_step(img, description, image_count, step_count, prompt):
if videogen:
try:
generate_video(
input_path=filepath,
input_img=filepath,
)
except FileNotFoundError as e:
logger.error(str(e))
Expand Down
77 changes: 14 additions & 63 deletions imaginairy/api/video_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
import math
import os
import random
import re
import time
from glob import glob
from pathlib import Path
from typing import Any, Optional
from typing import Any, List, Optional

import cv2
import numpy as np
Expand Down Expand Up @@ -36,8 +33,7 @@


def generate_video(
input_path: str, # Can either be image file or folder with image files
output_folder: str | None = None,
input_images: List[LazyLoadingImage],
size=(1024, 576),
num_frames: int = 6,
num_steps: int = 30,
Expand All @@ -50,13 +46,12 @@ def generate_video(
decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: Optional[str] = None,
repetitions=1,
output_format="webp",
):
"""
Generates a video from a single image or multiple images, conditioned on the provided input_path.
Args:
input_path (str): Path to an image file or a directory containing image files.
input_images (List[LazyLoadingImage]): List of LazyLoading images to be transformed into videos
output_folder (str | None, optional): Directory where the generated video will be saved.
Defaults to "outputs/video/" if None.
num_frames (int, optional): Number of frames in the generated video. Defaults to 6.
Expand Down Expand Up @@ -101,8 +96,7 @@ def generate_video(

num_frames = default(num_frames, video_model_config.defaults.get("frames", 12))
num_steps = default(num_steps, video_model_config.defaults.get("steps", 30))
output_folder_str = default(output_folder, "outputs/video/")
del output_folder

video_config_path = f"{PKG_ROOT}/{video_model_config.architecture.config_path}"

model, safety_filter = load_model(
Expand All @@ -113,42 +107,19 @@ def generate_video(
weights_url=video_model_config.weights_location,
)

if input_path.startswith("http"):
all_img_paths = [input_path]
else:
path = Path(input_path)
if path.is_file():
if any(input_path.endswith(x) for x in ["jpg", "jpeg", "png"]):
all_img_paths = [input_path]
else:
raise ValueError("Path is not valid image file.")
elif path.is_dir():
all_img_paths = sorted(
[
str(f)
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)
if len(all_img_paths) == 0:
raise ValueError("Folder does not contain any images.")
else:
msg = f"Could not find file or folder at {input_path}"
raise FileNotFoundError(msg)

expected_size = (vid_width, vid_height)
all_samples = []
for _ in range(repetitions):
for input_path in all_img_paths:
for image in input_images:
start_time = time.perf_counter()
_seed = default(seed, random.randint(0, 1000000))
torch.manual_seed(_seed)
logger.info(
f"Generating a {num_frames} frame video from {input_path}. Device:{device} seed:{_seed}"
f"Generating a {num_frames} frame video from {image}. Device:{device} seed:{_seed}"
)
if input_path.startswith("http"):
image = LazyLoadingImage(url=input_path).as_pillow()
else:
image = LazyLoadingImage(filepath=input_path).as_pillow()

image = image.as_pillow()

crop_coords = None
if image.mode == "RGBA":
image = image.convert("RGB")
Expand Down Expand Up @@ -275,21 +246,16 @@ def denoiser(_input, sigma, c):
left, upper, right, lower = crop_coords
samples = samples[:, :, upper:lower, left:right]

os.makedirs(output_folder_str, exist_ok=True)
base_count = len(glob(os.path.join(output_folder_str, "*.*"))) + 1
source_slug = make_safe_filename(input_path)
video_filename = f"{base_count:06d}_{model_name}_{_seed}_{fps_id}fps_{source_slug}.{output_format}"
video_path = os.path.join(output_folder_str, video_filename)

samples = safety_filter(samples)
# save_video(samples, video_path, output_fps)
save_video_bounce(samples, video_path, output_fps)
all_samples.append(samples)

duration = time.perf_counter() - start_time
logger.info(
f"Video of {num_frames} frames generated in {duration:.2f} seconds and saved to {video_path}\n"
f"Video of {num_frames} frames generated in {duration:.2f} seconds\n"
)

return all_samples, output_fps


def save_video(samples: torch.Tensor, video_filename: str, output_fps: int):
"""
Expand Down Expand Up @@ -458,18 +424,3 @@ def pillow_fit_image_within(
if (w, h) != image.size:
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
return image


def make_safe_filename(input_string):
stripped_url = re.sub(r"^https?://[^/]+/", "", input_string)

# Remove directory path if present
base_name = os.path.basename(stripped_url)

# Remove file extension
name_without_extension = os.path.splitext(base_name)[0]

# Keep only alphanumeric characters and dashes
safe_name = re.sub(r"[^a-zA-Z0-9\-]", "", name_without_extension)

return safe_name
106 changes: 90 additions & 16 deletions imaginairy/cli/videogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,103 @@ def videogen_cmd(
aimg videogen --start-image assets/rocket-wide.png
"""
import os
from glob import glob

from imaginairy.api.video_sample import generate_video
from imaginairy.utils import default
from imaginairy.utils.log_utils import configure_logging

configure_logging()

output_fps = output_fps or fps

all_images = []

try:
generate_video(
input_path=start_image,
num_frames=num_frames,
num_steps=steps,
model_name=model,
fps_id=fps,
size=size,
output_fps=output_fps,
output_format=output_format,
motion_bucket_id=motion_amount,
cond_aug=cond_aug,
seed=seed,
decoding_t=decoding_t,
output_folder=output_folder,
repetitions=repeats,
)
all_images.extend(load_images(start_image))
except FileNotFoundError as e:
logger.error(str(e))
exit(1)

output_folder_str = default(output_folder, "outputs/video/")

os.makedirs(output_folder_str, exist_ok=True)

samples, output_fps = generate_video(
input_images=all_images,
num_frames=num_frames,
num_steps=steps,
model_name=model,
fps_id=fps,
size=size,
output_fps=output_fps,
motion_bucket_id=motion_amount,
cond_aug=cond_aug,
seed=seed,
decoding_t=decoding_t,
repetitions=repeats,
)

for sample in samples:
base_count = len(glob(os.path.join(output_folder_str, "*.*"))) + 1
source_slug = make_safe_filename(sample)
video_filename = (
f"{base_count:06d}_{model}_{seed}_{fps}fps_{source_slug}.{output_format}"
)
video_path = os.path.join(output_folder_str, video_filename)

from imaginairy.api.video_sample import save_video_bounce

save_video_bounce(samples, video_path, output_fps)


def load_images(start_image):
from pathlib import Path

from imaginairy.schema import LazyLoadingImage

if start_image.startswith("http"):
image = LazyLoadingImage(url=start_image).as_pillow()
return [image]
else:
path = Path(start_image)
if path.is_file():
if any(start_image.endswith(x) for x in ["jpg", "jpeg", "png"]):
return [LazyLoadingImage(filepath=start_image).as_pillow()]
else:
raise ValueError("Path is not a valid image file.")
elif path.is_dir():
all_img_paths = sorted(
[
str(f)
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)
if len(all_img_paths) == 0:
raise ValueError("Folder does not contain any images.")
return [
LazyLoadingImage(filepath=image).as_pillow() for image in all_img_paths
]
else:
msg = f"Could not find file or folder at {start_image}"
raise FileNotFoundError(msg)


def make_safe_filename(input_string):
import os
import re

stripped_url = re.sub(r"^https?://[^/]+/", "", input_string)

# Remove directory path if present
base_name = os.path.basename(stripped_url)

# Remove file extension
name_without_extension = os.path.splitext(base_name)[0]

# Keep only alphanumeric characters and dashes
safe_name = re.sub(r"[^a-zA-Z0-9\-]", "", name_without_extension)

return safe_name

0 comments on commit 030d126

Please sign in to comment.