-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Mochi T2V #9769
[core] Mochi T2V #9769
Conversation
Our implementation with full BF16 == Original implementation with full BF16. Whether you mask out tokens or don't, or whether you use SDPA or Flash attention, as long as you keep the conditions same in both implementations, the videos are the exact same and contain the artifacting. Our implementation with autocast BF16 (on pipeline) != Original implementation with autocast (only on transformer). I suspect this might be the problem - text encoder and vae might always have to be run in FP32, which is the only way to match the implementations and make sure they are running under similar conditions - but I haven't found the time to try this yet since we're all occupied until tomorrow for you know what 👀 👀 👀
I have numerically matched results so can confidently rule out anything scheduler related. |
@a-r-r-o-w |
Oh alright. Thanks for confirming. But are we using any prepcomputed stuff here? @YanzuoLu thanks! |
It doesn't get better and the generated content is so blurry. And I think the major problem here would be the vae tiling. 000.mp4 |
Thank you so much for continuously helping with testing things here! If you use normal generation and compare it to tiled vae output and tiled + framewise vae output, I think they are all of the same quality (maybe a few color differences because of alpha blending) but effectively the same results, no? |
I think your video is already definitely better though! In the previous generations with pure bf16 of diffusers, there is a lot of blocky patching and temporal artifacts, but very few in the video you shared. FP32 text encoder and VAE definitely seems to be helping, don't you think? |
Hi @a-r-r-o-w |
@YanzuoLu but they don't apply FSDP to the VAE no? |
@sayakpaul |
FSDP is different from context-parallel and can be applied separately: https://github.com/genmoai/mochi/blob/5ebe56a403ca2bbdc577edc3a11da1c0b7d624fa/src/genmo/mochi_preview/vae/models.py#L1013 |
@a-r-r-o-w |
Are you talking about the attention in VAE encoder? Or the attention in transformer? Since they are different implementations IIRC. I'll take a look soon after Flux release and address the questions in the other PR as well 😅 |
Yeah I mean the attention in vae. |
@YanzuoLu are you running the original codebase in a single GPU? If so, we can set the option so that no tiling is enabled, no? Edit: not possible to run without VAE tiling in a single GPU. |
@sayakpaul No I am using the diffusers implementation. |
My final goal is make sure the diffusers implementation is correct (at least in multi-gpu pipeline) |
Thank you very much for your generous support! Happy to create a new channel with you on Slack (since we're already connected via Hyper-SD). Could you ping me there? |
Is this closer to ideal? I'll post details after a bit more testing mochi_rl_gnf15_yl_6_3095062583.mp4 |
@Ednaordinary This does indeed look much better. Please spill your secrets! Could you maybe also share something with more dynamic motion? |
Here's something with more motion. mochi.mp4It's a few changes to my original configuration in #9769 (comment) (notable attributes here being fp16, float8_e4m3fn, raised cfg scale) Compared to that, I boosted the cfg scale to 8.5, There's still noticable artifacts with high motion, though that could also be due to me downcasting some layers to float8. framewise decoding also seems better for spatial understanding but worse for temporal understanding with slicing + tiling |
Comparison with tiled + slicing (only 49 steps, though) mochi.mp4Noticeably, some artifacts around the legs in the beginning are gone, though the legs now glitch through each other more (could also be the 49 steps vs 64) |
Whoops, I made a mistake. My last three comments still had the vae in fp16. Give me a bit to correct that and see what's different |
Okay, I have something new to drop. I played around with the guidance scale across the timestep scale and found that a large part of the clarity issues can be resolved by simply increasing the guidance scale towards the end of the scale: Low end guidance: mochi_rl_yl_gnd_e2.mp4High end guidance: mochi_rl_yl_gnd_e25.mp4Both of these are in just 25 steps, 2 minutes on a 3090 Ti After modeling out some equations in desmos (which you can find here), I ported them as the callback for the pipeline, making the guidance scale a parabola. After playing around with values, I've found that l/guidance_scale_begin has some effect on temporal motion (which makes sense because it's denoising at a large sigma). I also found the j/guidance_scale_mid should be a lower value around 4.5, otherwise the model is extremely susceptible to cfg baking. The n/mid_point define where j is on the timestep scale, allowing a better control for where the susceptible time steps are (I found 0.3 works best, and even 0.4 lets stuff bake). The k/guidance_scale_end value is the extremely important one that is the difference between the two videos above. It defines the guidance towards the end of the denoising process, which seems to be where stuff is getting blurry. Here's the script (still optimized for 24gb ram, including the float8 downcasting. If you just want the callback you should be able to move the callback_dynamic_cfg + values to your own script): from diffusers import MochiPipeline, MochiTransformer3DModel, AutoencoderKLMochi
from diffusers.models.transformers.transformer_mochi import MochiTransformerBlock
import torch
from diffusers.utils import export_to_video
import math
# The below defines how many layers to quantize. The amount of layers kept is quant_div / quant_mod. Think of this as an inverted quality slider
quant_div = 1
quant_mod = 2
full_dtype = torch.float16
cast_dtype = torch.float8_e4m3fn
#torch.manual_seed(42)
torch.manual_seed(3095062583)
if quant_div > quant_mod: print("quant_div should be less than or equal to quant_mod")
# Credits to `dn6`
# Copy-pasted + slight edit from
# https://github.com/huggingface/diffusers/blob/layerwise-upcasting/src/diffusers/models/modeling_utils.py
def enable_layerwise_upcasting(model, upcast_dtype=None, original_dtype=None):
upcast_dtype = upcast_dtype or torch.float32
original_dtype = original_dtype or model.dtype
def upcast_dtype_hook_fn(module, *args, **kwargs):
module = module.to(upcast_dtype)
def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
module = module.to(original_dtype)
def fn_recursive_upcast(module):
# Upcast entire module and exist recursion
module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)
has_children = list(module.children())
if not has_children:
module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)
for child in module.children():
fn_recursive_upcast(child)
for module in model.children():
fn_recursive_upcast(module)
print("Loading transformer")
transformer = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float32)
print("Adding cast hooks to transformer")
block_idx = 0
mean_attns = []
block = 0
for idx, i in enumerate(transformer.modules()):
if isinstance(i, MochiTransformerBlock) and idx != 2084: # 2084 is the last layer, and should likely be skipped.
mean_attns.append((idx, torch.mean(torch.ravel(i.attn1.norm_q.weight)), block)) # can be changed with norm_q, values seem similar
block += 1
mean_attns.sort(key=lambda x: x[1])
attn_high = [x[0] for x in mean_attns[:int(len(mean_attns) // (quant_mod / quant_div))]]
print([x[2] for x in mean_attns[:int(len(mean_attns) // (quant_mod / quant_div))]])
transformer.to(torch.float16)
for idx, i in enumerate(transformer.modules()):
if isinstance(i, MochiTransformerBlock):
if idx in attn_high:
print("\r", idx, end='')
i.to(cast_dtype)
enable_layerwise_upcasting(i, upcast_dtype=full_dtype, original_dtype=cast_dtype)
vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32)
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=full_dtype, transformer=transformer, vae=vae)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
guidance_scale_begin = 10.0 # Large effect on temporal motion ?
guidance_scale_mid = 3.5 # Should stay at ~4.5
mid_point = 0.3 #between 0 and 1. 0.3 is a good default.
guidance_scale_end = 25.0 # Large effect on spatial clarity
steps=25
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
step_index = step_index + 1 # liar
if step_index > (steps*mid_point):
pipe._guidance_scale = (((guidance_scale_end-guidance_scale_mid) / math.pow(steps - (steps*mid_point),2)) * math.pow(step_index - (steps*mid_point),2)) + guidance_scale_mid
else:
pipe._guidance_scale = (guidance_scale_begin-guidance_scale_mid)*(math.pow(step_index-(steps*mid_point),2)/math.pow(steps*mid_point,2)) + guidance_scale_mid
print("Current guidance scale:", pipe._guidance_scale)
return callback_kwargs
frames = pipe("A woman in a white apron is diligently washing dishes in a spacious, well-lit kitchen.The gleaming stainless steel sink is filled with soapy water, and a variety of colorful plates, bowls, and utensils await cleaning. The countertops, adorned with a minimalist arrangement of olive oil, salt, and pepper, offer a backdrop of warm wood against the cool metallic surfaces. Soft, ambient lighting casts a warm glow across the scene, highlighting the woman's focused expression as she meticulously cleans eachpiece.", num_inference_steps=steps, guidance_scale=guidance_scale_begin, height=480, width=848, num_frames=19, callback_on_step_end=callback_dynamic_cfg, callback_on_step_end_tensor_inputs=["latents", "prompt_embeds", "negative_prompt_embeds"]).frames[0]
export_to_video(frames, "mochi.mp4", fps=15) (Either |
@Ednaordinary thank you! Are you running your tests from the |
@sayakpaul I have the Pyramid Attention Broadcast branch installed, though that shouldn't make a difference compared to main in this instance |
@Ednaordinary @YanzuoLu thanks for your continuous guidance here. With the attached script, I am able to do: mochi.mp4I am on the Scriptfrom diffusers import MochiPipeline, AutoencoderKLMochi
from diffusers.utils import export_to_video
from diffusers.video_processor import VideoProcessor
import torch
import gc
pipe = MochiPipeline.from_pretrained(
"genmo/mochi-1-preview", transformer=None, vae=None
).to("cuda")
prompt = "a cat walks along the sidewalk of a city. The camera follows the cat at knee level. The city has many people and cars moving around, with advertisement billboards in the background"
with torch.no_grad():
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
pipe.encode_prompt(prompt=prompt)
)
del pipe.text_encoder
del pipe
gc.collect()
pipe = MochiPipeline.from_pretrained(
"genmo/mochi-1-preview", text_encoder=None, vae=None
).to("cuda")
with torch.autocast("cuda", dtype=torch.bfloat16):
frames = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
guidance_scale=8.5,
num_inference_steps=64,
height=480,
width=848,
num_frames=85,
output_type="latent",
return_dict=False,
)[0]
print(f"{frames.shape=}, {frames.dtype=}")
del pipe.transformer
del pipe
gc.collect()
vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae").to("cuda")
vae._enable_framewise_decoding()
video_processor = VideoProcessor(vae_scale_factor=8)
has_latents_mean = hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None
has_latents_std = hasattr(vae.config, "latents_std") and vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
latents_std = (
torch.tensor(vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
frames = frames * latents_std / vae.config.scaling_factor + latents_mean
else:
frames = frames / vae.config.scaling_factor
with torch.no_grad():
video = vae.decode(frames.to(vae.dtype), return_dict=False)[0]
video = video_processor.postprocess_video(video)[0]
export_to_video(video, "mochi.mp4", fps=30) We should note that the entire generation process takes quite a bit (~7 mins) to complete so any optimization (simpler preferred) we do here should be realized. Main changes in the script:
|
Hey Y'all, I think I might have something new to drop. I realized that we might be able to split the transformer part of this model, much like Flux All I had to do is add was _no_split_modules = ["MochiTransformerBlock"] in the MochiTransformer3DModel class. With that I was able to split the transformer like this, transformer = MochiTransformer3DModel.from_pretrained(
id,
subfolder="transformer",
device_map="auto",
max_memory={0: "48GB", 1: "48GB"},
torch_dtype=torch.float16
) It seems to work on my end. If someone can verify this, that would be awesome! |
Hey @jmahajan117! Thanks for testing that. If you want to open a PR, I can help test it. We have a bunch of test to verify its effectivity:
Also, if you put it together in and end-to-end code snippet (with the pipeline call), that would be good. Here's our docs. |
* update * udpate * update transformer * make style * fix * add conversion script * update * fix * update * fix * update * fixes * make style * update * update * update * init * update * update * add * up * up * up * update * mochi transformer * remove original implementation * make style * update inits * update conversion script * docs * Update src/diffusers/pipelines/mochi/pipeline_mochi.py Co-authored-by: Dhruv Nair <[email protected]> * Update src/diffusers/pipelines/mochi/pipeline_mochi.py Co-authored-by: Dhruv Nair <[email protected]> * fix docs * pipeline fixes * make style * invert sigmas in scheduler; fix pipeline * fix pipeline num_frames * flip proj and gate in swiglu * make style * fix * make style * fix tests * latent mean and std fix * update * cherry-pick 1069d21 * remove additional sigma already handled by flow match scheduler * fix * remove hardcoded value * replace conv1x1 with linear * Update src/diffusers/pipelines/mochi/pipeline_mochi.py Co-authored-by: Dhruv Nair <[email protected]> * framewise decoding and conv_cache * make style * Apply suggestions from code review * mochi vae encoder changes * rebase correctly * Update scripts/convert_mochi_to_diffusers.py * fix tests * fixes * make style * update * make style * update * add framewise and tiled encoding * make style * make original vae implementation behaviour the default; note: framewise encoding does not work * remove framewise encoding implementation due to presence of attn layers * fight test 1 * fight test 2 --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: yiyixuxu <[email protected]>
Fixes #9744
Github: https://github.com/genmoai/models
Model: https://huggingface.co/genmo/mochi-1-preview