Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ltxv: add noise to guidance image to ensure generated motion. #5937

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions comfy/ldm/lightricks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def __init__(self,
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
Expand Down Expand Up @@ -417,6 +418,7 @@ def __init__(self,

def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
image_noise_scale = transformer_options.get("image_noise_scale", 0.15)

indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
Expand All @@ -435,6 +437,17 @@ def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_l
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if image_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())

noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
guiding_noise = image_noise_scale * (input_ts ** 2) * torch.randn(size=noise_shape, device=x.device, generator=self.generator)

x[:, :, 0] += guiding_noise[:, :, 0]


orig_shape = list(x.shape)

Expand Down
6 changes: 5 additions & 1 deletion comfy_extras/nodes_lt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import nodes
import node_helpers
import torch
Expand Down Expand Up @@ -77,6 +78,7 @@ def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 100, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
},
"optional": {"latent": ("LATENT",), }
}
Expand All @@ -86,7 +88,7 @@ def INPUT_TYPES(s):

CATEGORY = "advanced/model"

def patch(self, model, max_shift, base_shift, latent=None):
def patch(self, model, max_shift, base_shift, image_noise_scale, latent=None):
m = model.clone()

if latent is None:
Expand All @@ -109,6 +111,8 @@ class ModelSamplingAdvanced(sampling_base, sampling_type):
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
m.model_options.setdefault("transformer_options", {})["image_noise_scale"] = image_noise_scale

return (m, )


Expand Down