diff --git a/imaginairy/api/generate.py b/imaginairy/api/generate.py index 3806b347..e7683478 100755 --- a/imaginairy/api/generate.py +++ b/imaginairy/api/generate.py @@ -118,7 +118,7 @@ def _record_step(img, description, image_count, step_count, prompt): if videogen: try: generate_video( - input_path=filepath, + input_img=filepath, ) except FileNotFoundError as e: logger.error(str(e)) diff --git a/imaginairy/api/video_sample.py b/imaginairy/api/video_sample.py index 0d3e17a6..a612c30d 100644 --- a/imaginairy/api/video_sample.py +++ b/imaginairy/api/video_sample.py @@ -4,11 +4,8 @@ import math import os import random -import re import time -from glob import glob -from pathlib import Path -from typing import Any, Optional +from typing import Any, List, Optional import cv2 import numpy as np @@ -36,8 +33,7 @@ def generate_video( - input_path: str, # Can either be image file or folder with image files - output_folder: str | None = None, + input_images: List[LazyLoadingImage], size=(1024, 576), num_frames: int = 6, num_steps: int = 30, @@ -50,13 +46,12 @@ def generate_video( decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: Optional[str] = None, repetitions=1, - output_format="webp", ): """ Generates a video from a single image or multiple images, conditioned on the provided input_path. Args: - input_path (str): Path to an image file or a directory containing image files. + input_images (List[LazyLoadingImage]): List of LazyLoading images to be transformed into videos output_folder (str | None, optional): Directory where the generated video will be saved. Defaults to "outputs/video/" if None. num_frames (int, optional): Number of frames in the generated video. Defaults to 6. @@ -101,8 +96,7 @@ def generate_video( num_frames = default(num_frames, video_model_config.defaults.get("frames", 12)) num_steps = default(num_steps, video_model_config.defaults.get("steps", 30)) - output_folder_str = default(output_folder, "outputs/video/") - del output_folder + video_config_path = f"{PKG_ROOT}/{video_model_config.architecture.config_path}" model, safety_filter = load_model( @@ -113,42 +107,19 @@ def generate_video( weights_url=video_model_config.weights_location, ) - if input_path.startswith("http"): - all_img_paths = [input_path] - else: - path = Path(input_path) - if path.is_file(): - if any(input_path.endswith(x) for x in ["jpg", "jpeg", "png"]): - all_img_paths = [input_path] - else: - raise ValueError("Path is not valid image file.") - elif path.is_dir(): - all_img_paths = sorted( - [ - str(f) - for f in path.iterdir() - if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] - ] - ) - if len(all_img_paths) == 0: - raise ValueError("Folder does not contain any images.") - else: - msg = f"Could not find file or folder at {input_path}" - raise FileNotFoundError(msg) - expected_size = (vid_width, vid_height) + all_samples = [] for _ in range(repetitions): - for input_path in all_img_paths: + for image in input_images: start_time = time.perf_counter() _seed = default(seed, random.randint(0, 1000000)) torch.manual_seed(_seed) logger.info( - f"Generating a {num_frames} frame video from {input_path}. Device:{device} seed:{_seed}" + f"Generating a {num_frames} frame video from {image}. Device:{device} seed:{_seed}" ) - if input_path.startswith("http"): - image = LazyLoadingImage(url=input_path).as_pillow() - else: - image = LazyLoadingImage(filepath=input_path).as_pillow() + + image = image.as_pillow() + crop_coords = None if image.mode == "RGBA": image = image.convert("RGB") @@ -275,21 +246,16 @@ def denoiser(_input, sigma, c): left, upper, right, lower = crop_coords samples = samples[:, :, upper:lower, left:right] - os.makedirs(output_folder_str, exist_ok=True) - base_count = len(glob(os.path.join(output_folder_str, "*.*"))) + 1 - source_slug = make_safe_filename(input_path) - video_filename = f"{base_count:06d}_{model_name}_{_seed}_{fps_id}fps_{source_slug}.{output_format}" - video_path = os.path.join(output_folder_str, video_filename) - samples = safety_filter(samples) - # save_video(samples, video_path, output_fps) - save_video_bounce(samples, video_path, output_fps) + all_samples.append(samples) duration = time.perf_counter() - start_time logger.info( - f"Video of {num_frames} frames generated in {duration:.2f} seconds and saved to {video_path}\n" + f"Video of {num_frames} frames generated in {duration:.2f} seconds\n" ) + return all_samples, output_fps + def save_video(samples: torch.Tensor, video_filename: str, output_fps: int): """ @@ -458,18 +424,3 @@ def pillow_fit_image_within( if (w, h) != image.size: image = image.resize((w, h), resample=Image.Resampling.LANCZOS) return image - - -def make_safe_filename(input_string): - stripped_url = re.sub(r"^https?://[^/]+/", "", input_string) - - # Remove directory path if present - base_name = os.path.basename(stripped_url) - - # Remove file extension - name_without_extension = os.path.splitext(base_name)[0] - - # Keep only alphanumeric characters and dashes - safe_name = re.sub(r"[^a-zA-Z0-9\-]", "", name_without_extension) - - return safe_name diff --git a/imaginairy/cli/videogen.py b/imaginairy/cli/videogen.py index 9f3a45a1..05ec8948 100644 --- a/imaginairy/cli/videogen.py +++ b/imaginairy/cli/videogen.py @@ -85,29 +85,103 @@ def videogen_cmd( aimg videogen --start-image assets/rocket-wide.png """ + import os + from glob import glob + from imaginairy.api.video_sample import generate_video + from imaginairy.utils import default from imaginairy.utils.log_utils import configure_logging configure_logging() output_fps = output_fps or fps + + all_images = [] + try: - generate_video( - input_path=start_image, - num_frames=num_frames, - num_steps=steps, - model_name=model, - fps_id=fps, - size=size, - output_fps=output_fps, - output_format=output_format, - motion_bucket_id=motion_amount, - cond_aug=cond_aug, - seed=seed, - decoding_t=decoding_t, - output_folder=output_folder, - repetitions=repeats, - ) + all_images.extend(load_images(start_image)) except FileNotFoundError as e: logger.error(str(e)) exit(1) + + output_folder_str = default(output_folder, "outputs/video/") + + os.makedirs(output_folder_str, exist_ok=True) + + samples, output_fps = generate_video( + input_images=all_images, + num_frames=num_frames, + num_steps=steps, + model_name=model, + fps_id=fps, + size=size, + output_fps=output_fps, + motion_bucket_id=motion_amount, + cond_aug=cond_aug, + seed=seed, + decoding_t=decoding_t, + repetitions=repeats, + ) + + for sample in samples: + base_count = len(glob(os.path.join(output_folder_str, "*.*"))) + 1 + source_slug = make_safe_filename(sample) + video_filename = ( + f"{base_count:06d}_{model}_{seed}_{fps}fps_{source_slug}.{output_format}" + ) + video_path = os.path.join(output_folder_str, video_filename) + + from imaginairy.api.video_sample import save_video_bounce + + save_video_bounce(samples, video_path, output_fps) + + +def load_images(start_image): + from pathlib import Path + + from imaginairy.schema import LazyLoadingImage + + if start_image.startswith("http"): + image = LazyLoadingImage(url=start_image).as_pillow() + return [image] + else: + path = Path(start_image) + if path.is_file(): + if any(start_image.endswith(x) for x in ["jpg", "jpeg", "png"]): + return [LazyLoadingImage(filepath=start_image).as_pillow()] + else: + raise ValueError("Path is not a valid image file.") + elif path.is_dir(): + all_img_paths = sorted( + [ + str(f) + for f in path.iterdir() + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] + ] + ) + if len(all_img_paths) == 0: + raise ValueError("Folder does not contain any images.") + return [ + LazyLoadingImage(filepath=image).as_pillow() for image in all_img_paths + ] + else: + msg = f"Could not find file or folder at {start_image}" + raise FileNotFoundError(msg) + + +def make_safe_filename(input_string): + import os + import re + + stripped_url = re.sub(r"^https?://[^/]+/", "", input_string) + + # Remove directory path if present + base_name = os.path.basename(stripped_url) + + # Remove file extension + name_without_extension = os.path.splitext(base_name)[0] + + # Keep only alphanumeric characters and dashes + safe_name = re.sub(r"[^a-zA-Z0-9\-]", "", name_without_extension) + + return safe_name