diff --git a/backend/src/api/node_context.py b/backend/src/api/node_context.py index 270cd52bb..e5d70e823 100644 --- a/backend/src/api/node_context.py +++ b/backend/src/api/node_context.py @@ -148,9 +148,11 @@ def storage_dir(self) -> Path: """ @abstractmethod - def add_cleanup(self, fn: Callable[[], None]) -> None: + def add_cleanup( + self, fn: Callable[[], None], after: Literal["node", "chain"] = "chain" + ) -> None: """ - Registers a function that will be called when the chain execution is finished. + Registers a function that will be called when the chain execution is finished (if set to chain mode) or after node execution is finished (node mode). Registering the same function (object) twice will only result in the function being called once. """ diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py index c409a03e2..43d18954e 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py @@ -220,7 +220,11 @@ def align_image_to_reference_node( alignment_passes: int, blur_strength: float, ) -> np.ndarray: - context.add_cleanup(safe_cuda_cache_empty) + exec_options = get_settings(context) + context.add_cleanup( + safe_cuda_cache_empty, + after="node" if exec_options.force_cache_wipe else "chain", + ) multiplier = precision.value / 1000 return align_images( context, diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py index c0a00b276..b7778f225 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py @@ -79,11 +79,15 @@ def guided_upscale_node( iterations: float, split_mode: SplitMode, ) -> np.ndarray: - context.add_cleanup(safe_cuda_cache_empty) + exec_options = get_settings(context) + context.add_cleanup( + safe_cuda_cache_empty, + after="node" if exec_options.force_cache_wipe else "chain", + ) return pix_transform_auto_split( source=source, guide=guide, - device=get_settings(context).device, + device=exec_options.device, params=Params(iteration=int(iterations * 1000)), split_mode=split_mode, ) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py index 8b66d6ee8..226d1ce2e 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py @@ -114,7 +114,9 @@ def inpaint_node( ), "Input image and mask must have the same resolution" exec_options = get_settings(context) - - context.add_cleanup(safe_cuda_cache_empty) + context.add_cleanup( + safe_cuda_cache_empty, + after="node" if exec_options.force_cache_wipe else "chain", + ) return inpaint(img, mask, model, exec_options) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py index e4db77f6d..b57750d83 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py @@ -265,7 +265,10 @@ def upscale_image_node( ) -> np.ndarray: exec_options = get_settings(context) - context.add_cleanup(safe_cuda_cache_empty) + context.add_cleanup( + safe_cuda_cache_empty, + after="node" if exec_options.force_cache_wipe else "chain", + ) in_nc = model.input_channels out_nc = model.output_channels @@ -296,5 +299,4 @@ def inner_upscale(img: np.ndarray) -> np.ndarray: if not use_custom_scale or scale == 1 or in_nc != out_nc: # no custom scale custom_scale = scale - return custom_scale_upscale(img, inner_upscale, scale, custom_scale, separate_alpha) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/wavelet_color_fix.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/wavelet_color_fix.py index 8aa0cd6e1..75413b82a 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/wavelet_color_fix.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/wavelet_color_fix.py @@ -79,7 +79,10 @@ def wavelet_color_fix_node( ) exec_options = get_settings(context) - context.add_cleanup(safe_cuda_cache_empty) + context.add_cleanup( + safe_cuda_cache_empty, + after="node" if exec_options.force_cache_wipe else "chain", + ) device = exec_options.device # convert to tensors diff --git a/backend/src/packages/chaiNNer_pytorch/settings.py b/backend/src/packages/chaiNNer_pytorch/settings.py index 7efa28cf9..1bbdc99f7 100644 --- a/backend/src/packages/chaiNNer_pytorch/settings.py +++ b/backend/src/packages/chaiNNer_pytorch/settings.py @@ -75,6 +75,16 @@ ) ) +if nvidia.is_available: + package.add_setting( + ToggleSetting( + label="Force CUDA Cache Wipe (not recommended)", + key="force_cache_wipe", + description="Clears PyTorch's CUDA cache after each inference. This is NOT recommended, by us or PyTorch's developers, as it basically interferes with how PyTorch is intended to work and can significantly slow down inference time. Only enable this if you're experiencing issues with VRAM allocation.", + default=False, + ) + ) + @dataclass(frozen=True) class PyTorchSettings: @@ -82,6 +92,7 @@ class PyTorchSettings: use_fp16: bool gpu_index: int budget_limit: int + force_cache_wipe: bool = False # PyTorch 2.0 does not support FP16 when using CPU def __post_init__(self): @@ -122,4 +133,5 @@ def get_settings(context: NodeContext) -> PyTorchSettings: use_fp16=settings.get_bool("use_fp16", False), gpu_index=settings.get_int("gpu_index", 0, parse_str=True), budget_limit=settings.get_int("budget_limit", 0, parse_str=True), + force_cache_wipe=settings.get_bool("force_cache_wipe", False), ) diff --git a/backend/src/process.py b/backend/src/process.py index fbaf5d4bf..d7052bb9a 100644 --- a/backend/src/process.py +++ b/backend/src/process.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Callable, Iterable, List, NewType, Sequence, Union +from typing import Callable, Iterable, List, Literal, NewType, Sequence, Union from sanic.log import logger @@ -342,7 +342,8 @@ def __init__( self.__settings = settings self._storage_dir = storage_dir - self.cleanup_fns: set[Callable[[], None]] = set() + self.chain_cleanup_fns: set[Callable[[], None]] = set() + self.node_cleanup_fns: set[Callable[[], None]] = set() @property def aborted(self) -> bool: @@ -373,8 +374,15 @@ def settings(self) -> SettingsParser: def storage_dir(self) -> Path: return self._storage_dir - def add_cleanup(self, fn: Callable[[], None]) -> None: - self.cleanup_fns.add(fn) + def add_cleanup( + self, fn: Callable[[], None], after: Literal["node", "chain"] = "chain" + ) -> None: + if after == "chain": + self.chain_cleanup_fns.add(fn) + elif after == "node": + self.node_cleanup_fns.add(fn) + else: + raise ValueError(f"Unknown cleanup type: {after}") class Executor: @@ -591,6 +599,14 @@ def get_lazy_evaluation_time(): ) await self.progress.suspend() + for fn in context.node_cleanup_fns: + try: + fn() + except Exception as e: + logger.error(f"Error running cleanup function: {e}") + finally: + context.node_cleanup_fns.remove(fn) + lazy_time_after = get_lazy_evaluation_time() execution_time -= lazy_time_after - lazy_time_before @@ -824,7 +840,7 @@ async def __process_nodes(self): # Run cleanup functions for context in self.__context_cache.values(): - for fn in context.cleanup_fns: + for fn in context.chain_cleanup_fns: try: fn() except Exception as e: