Skip to content

Commit

Permalink
Fix some tiled VAE decoding issues with LTX-Video.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 22, 2024
1 parent e5c3f4b commit 6e8cdcd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
12 changes: 10 additions & 2 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = 8
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
Expand Down Expand Up @@ -370,7 +370,9 @@ def decode(self, samples_in):
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in)
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))

pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
Expand Down Expand Up @@ -434,6 +436,12 @@ def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
def get_sd(self):
return self.first_stage_model.state_dict()

def spacial_compression_decode(self):
try:
return self.upscale_ratio[-1]
except:
return self.upscale_ratio

class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
Expand Down
3 changes: 2 additions & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def INPUT_TYPES(s):
def decode(self, vae, samples, tile_size, overlap=64):
if tile_size < overlap * 4:
overlap = tile_size // 4
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
Expand Down

0 comments on commit 6e8cdcd

Please sign in to comment.