diff --git a/auto_editor/__main__.py b/auto_editor/__main__.py index b8a6ab3e2..721f3e730 100755 --- a/auto_editor/__main__.py +++ b/auto_editor/__main__.py @@ -286,7 +286,16 @@ def get_domain(url: str) -> str: def main() -> None: - subcommands = ("test", "info", "levels", "subdump", "desc", "repl", "palet") + subcommands = ( + "test", + "info", + "levels", + "subdump", + "desc", + "repl", + "palet", + "cache", + ) if len(sys.argv) > 1 and sys.argv[1] in subcommands: obj = __import__( diff --git a/auto_editor/analyze.py b/auto_editor/analyze.py index 7672d900c..edc1aaa0f 100644 --- a/auto_editor/analyze.py +++ b/auto_editor/analyze.py @@ -4,6 +4,7 @@ import re from dataclasses import dataclass from fractions import Fraction +from hashlib import sha1 from math import ceil from tempfile import gettempdir from typing import TYPE_CHECKING @@ -154,8 +155,10 @@ def iter_motion( def obj_tag(path: Path, kind: str, tb: Fraction, obj: Sequence[object]) -> str: mod_time = int(path.stat().st_mtime) - key = f"{path.name}:{mod_time:x}:{kind}:{tb}:" - return key + ",".join(f"{v}" for v in obj) + key = f"{path}:{mod_time:x}:{tb.numerator}:{tb.denominator}:" + part1 = sha1(key.encode()).hexdigest()[:16] + + return f"{part1}{kind}," + ",".join(f"{v}" for v in obj) @dataclass(slots=True) @@ -206,31 +209,31 @@ def read_cache(self, kind: str, obj: Sequence[object]) -> None | np.ndarray: if self.no_cache: return None - workfile = os.path.join(gettempdir(), f"ae-{__version__}", "cache.npz") + key = obj_tag(self.src.path, kind, self.tb, obj) + cache_file = os.path.join(gettempdir(), f"ae-{__version__}", f"{key}.npz") try: - npzfile = np.load(workfile, allow_pickle=False) + with np.load(cache_file, allow_pickle=False) as npzfile: + return npzfile["data"] except Exception as e: self.log.debug(e) return None - key = obj_tag(self.src.path, kind, self.tb, obj) - if key not in npzfile.files: - return None - - self.log.debug("Using cache") - return npzfile[key] - def cache(self, arr: np.ndarray, kind: str, obj: Sequence[object]) -> np.ndarray: if self.no_cache: return arr - workdur = os.path.join(gettempdir(), f"ae-{__version__}") - if not os.path.exists(workdur): - os.mkdir(workdur) + workdir = os.path.join(gettempdir(), f"ae-{__version__}") + if not os.path.exists(workdir): + os.mkdir(workdir) key = obj_tag(self.src.path, kind, self.tb, obj) - np.savez(os.path.join(workdur, "cache.npz"), **{key: arr}) + cache_file = os.path.join(workdir, f"{key}.npz") + + try: + np.savez(cache_file, data=arr) + except Exception as e: + self.log.warning(f"Cache write failed: {e}") return arr @@ -257,14 +260,15 @@ def audio(self, stream: int) -> NDArray[np.float32]: bar = self.bar bar.start(inaccurate_dur, "Analyzing audio volume") - result = np.zeros((inaccurate_dur), dtype=np.float32) + result: NDArray[np.float32] = np.zeros(inaccurate_dur, dtype=np.float32) index = 0 for value in iter_audio(audio, self.tb): if index > len(result) - 1: result = np.concatenate( - (result, np.zeros((len(result)), dtype=np.float32)) + (result, np.zeros(len(result), dtype=np.float32)) ) + result[index] = value bar.tick(index) index += 1 @@ -296,13 +300,13 @@ def motion(self, stream: int, blur: int, width: int) -> NDArray[np.float32]: bar = self.bar bar.start(inaccurate_dur, "Analyzing motion") - result = np.zeros((inaccurate_dur), dtype=np.float32) + result: NDArray[np.float32] = np.zeros(inaccurate_dur, dtype=np.float32) index = 0 for value in iter_motion(video, self.tb, blur, width): if index > len(result) - 1: result = np.concatenate( - (result, np.zeros((len(result)), dtype=np.float32)) + (result, np.zeros(len(result), dtype=np.float32)) ) result[index] = value bar.tick(index) diff --git a/auto_editor/subcommands/cache.py b/auto_editor/subcommands/cache.py new file mode 100644 index 000000000..b9062d14b --- /dev/null +++ b/auto_editor/subcommands/cache.py @@ -0,0 +1,69 @@ +import glob +import os +import sys +from shutil import rmtree +from tempfile import gettempdir + +import numpy as np + +from auto_editor import __version__ + + +def main(sys_args: list[str] = sys.argv[1:]) -> None: + cache_dir = os.path.join(gettempdir(), f"ae-{__version__}") + + if sys_args and sys_args[0] in ("clean", "clear"): + rmtree(cache_dir, ignore_errors=True) + return + + if not os.path.exists(cache_dir): + print("Empty cache") + return + + cache_files = glob.glob(os.path.join(cache_dir, "*.npz")) + if not cache_files: + print("Empty cache") + return + + def format_bytes(size: float) -> str: + for unit in ("B", "KiB", "MiB", "GiB", "TiB"): + if size < 1024: + return f"{size:.2f} {unit}" + size /= 1024 + return f"{size:.2f} PiB" + + GRAY = "\033[90m" + GREEN = "\033[32m" + BLUE = "\033[34m" + YELLOW = "\033[33m" + RESET = "\033[0m" + + total_size = 0 + for cache_file in cache_files: + try: + with np.load(cache_file, allow_pickle=False) as npzfile: + array = npzfile["data"] + key = os.path.basename(cache_file)[:-4] # Remove .npz extension + + hash_part = key[:16] + rest_part = key[16:] + + size = array.nbytes + total_size += size + size_str = format_bytes(size) + size_num, size_unit = size_str.rsplit(" ", 1) + + print( + f"{YELLOW}entry: {GRAY}{hash_part}{RESET}{rest_part} " + f"{YELLOW}size: {GREEN}{size_num} {BLUE}{size_unit}{RESET}" + ) + except Exception as e: + print(f"Error reading {cache_file}: {e}") + + total_str = format_bytes(total_size) + total_num, total_unit = total_str.rsplit(" ", 1) + print(f"\n{YELLOW}total cache size: {GREEN}{total_num} {BLUE}{total_unit}{RESET}") + + +if __name__ == "__main__": + main()