Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Mochi T2V #9769

Merged
merged 79 commits into from
Nov 5, 2024
Merged

[core] Mochi T2V #9769

merged 79 commits into from
Nov 5, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w
Copy link
Member Author

Our implementation with full BF16 == Original implementation with full BF16. Whether you mask out tokens or don't, or whether you use SDPA or Flash attention, as long as you keep the conditions same in both implementations, the videos are the exact same and contain the artifacting.

Our implementation with autocast BF16 (on pipeline) != Original implementation with autocast (only on transformer). I suspect this might be the problem - text encoder and vae might always have to be run in FP32, which is the only way to match the implementations and make sure they are running under similar conditions - but I haven't found the time to try this yet since we're all occupied until tomorrow for you know what 👀 👀 👀

From some of my earlier experiences, to rule out anything scheduler related, I used to disable the scheduler stepping and just use the original scheduler-code. Would it be possible to try that?

I have numerically matched results so can confidently rule out anything scheduler related.

@YanzuoLu
Copy link

@a-r-r-o-w
I got it. Let me try setting text_encoder and vae to be fp32 computation.

@sayakpaul
Copy link
Member

Our implementation with full BF16 == Original implementation with full BF16. Whether you mask out tokens or don't, or whether you use SDPA or Flash attention, as long as you keep the conditions same in both implementations, the videos are the exact same and contain the artifacting.

Oh alright. Thanks for confirming. But are we using any prepcomputed stuff here?

@YanzuoLu thanks!

@YanzuoLu
Copy link

YanzuoLu commented Nov 21, 2024

It doesn't get better and the generated content is so blurry. And I think the major problem here would be the vae tiling.
It is absolutely not a lossless optimization with the same reason as I stated in this issue.
Thanks to Aryan for validating the correctness of dit model.
Since disabling the vae tiling would make 80GB single GPU inference not available, I'm going to make this VAE context parallel and see what happens in diffusers implementation.
Or if someone else have B200 GPU with more VRAM to help us 😅

000.mp4

@a-r-r-o-w
Copy link
Member Author

And I think the major problem here would be the vae tiling.

Thank you so much for continuously helping with testing things here! If you use normal generation and compare it to tiled vae output and tiled + framewise vae output, I think they are all of the same quality (maybe a few color differences because of alpha blending) but effectively the same results, no?

@a-r-r-o-w
Copy link
Member Author

I think your video is already definitely better though! In the previous generations with pure bf16 of diffusers, there is a lot of blocky patching and temporal artifacts, but very few in the video you shared. FP32 text encoder and VAE definitely seems to be helping, don't you think?

@YanzuoLu
Copy link

YanzuoLu commented Nov 22, 2024

Hi @a-r-r-o-w
I admit that this quality is better than the previous one but they are still far from the quality level of what we observe from the original codebase implementation with multi-gpu pipeline (I know the single gpu pipeline is also problematic with wrong tiling).
So I'm trying to apply context parallel to the diffusers mochi vae to see what happens first, otherwise we have no choice but to transfer to the original codebase to conduct distillation...

@sayakpaul
Copy link
Member

@YanzuoLu
Copy link

@sayakpaul
Right. FSDP doesn't conflict with other sequence or context parallel.
The reason for why VAE is not FSDP wrapped is that it's too small in sizes.
Wrapping it would cause significant communication costs.

@a-r-r-o-w
Copy link
Member Author

but they don't apply FSDP to the VAE no?

FSDP is different from context-parallel and can be applied separately: https://github.com/genmoai/mochi/blob/5ebe56a403ca2bbdc577edc3a11da1c0b7d624fa/src/genmo/mochi_preview/vae/models.py#L1013

@YanzuoLu
Copy link

YanzuoLu commented Nov 22, 2024

@a-r-r-o-w
I'm confused about the vae attention block implemented in original mochi codebase, since the shape is [BHW, T, C] but no all_to_all operation is conducted to gather sequence and scatter heads.
It seems the attention is applied within each GPU sub-temporal sequence.
Am I missing something? 😅

@a-r-r-o-w
Copy link
Member Author

Are you talking about the attention in VAE encoder? Or the attention in transformer? Since they are different implementations IIRC. I'll take a look soon after Flux release and address the questions in the other PR as well 😅

@YanzuoLu
Copy link

Yeah I mean the attention in vae.
Okay thank you so much.

@sayakpaul
Copy link
Member

sayakpaul commented Nov 22, 2024

@YanzuoLu are you running the original codebase in a single GPU? If so, we can set the option so that no tiling is enabled, no?
https://github.com/genmoai/mochi/blob/5ebe56a403ca2bbdc577edc3a11da1c0b7d624fa/src/genmo/mochi_preview/pipelines.py#L449

Edit: not possible to run without VAE tiling in a single GPU.

@YanzuoLu
Copy link

YanzuoLu commented Nov 22, 2024

@sayakpaul No I am using the diffusers implementation.
When no tiling is enabled, single GPU with 80GB VRAM in both diffusers and original codebase would be OOM IIRC.
I want to check the correctness of diffusers implementation by applying sequence & context parallel to all components, such that I can try if the quality with diffusers can be aligned with the original codebase (MultiGPUpipeline I mean, the single gpu pipeline in original codebase is also problematic as diffusers).
But the original VAE attention implementation is so confusing for me now...

@YanzuoLu
Copy link

YanzuoLu commented Nov 22, 2024

My final goal is make sure the diffusers implementation is correct (at least in multi-gpu pipeline)
This quality level should be normal like this and we indeed see the same with original codebase in multi-gpu pipeline.
I think the current diffusers with single gpu vae tiling is far from it and that's what I want to fix 😊

@sayakpaul
Copy link
Member

Thank you very much for your generous support! Happy to create a new channel with you on Slack (since we're already connected via Hyper-SD). Could you ping me there?

@Ednaordinary
Copy link

Is this closer to ideal? I'll post details after a bit more testing

mochi_rl_gnf15_yl_6_3095062583.mp4

@a-r-r-o-w
Copy link
Member Author

@Ednaordinary This does indeed look much better. Please spill your secrets! Could you maybe also share something with more dynamic motion?

@Ednaordinary
Copy link

Ednaordinary commented Nov 23, 2024

Here's something with more motion.

mochi.mp4

It's a few changes to my original configuration in #9769 (comment) (notable attributes here being fp16, float8_e4m3fn, raised cfg scale)

Compared to that, I boosted the cfg scale to 8.5, changed the vae to fp32, and use framewise decoding instead (no tiling, no slicing), and up the inference steps to 64 (49 may work well too. I have a theory that quadratic values are better for the scheduler but have yet to confirm)

There's still noticable artifacts with high motion, though that could also be due to me downcasting some layers to float8. framewise decoding also seems better for spatial understanding but worse for temporal understanding with slicing + tiling

@Ednaordinary
Copy link

Comparison with tiled + slicing (only 49 steps, though)

mochi.mp4

Noticeably, some artifacts around the legs in the beginning are gone, though the legs now glitch through each other more (could also be the 49 steps vs 64)

@Ednaordinary
Copy link

Ednaordinary commented Nov 23, 2024

Whoops, I made a mistake. My last three comments still had the vae in fp16. Give me a bit to correct that and see what's different

@Ednaordinary
Copy link

Okay, I have something new to drop.

I played around with the guidance scale across the timestep scale and found that a large part of the clarity issues can be resolved by simply increasing the guidance scale towards the end of the scale:

Low end guidance:

mochi_rl_yl_gnd_e2.mp4

High end guidance:

mochi_rl_yl_gnd_e25.mp4

Both of these are in just 25 steps, 2 minutes on a 3090 Ti

After modeling out some equations in desmos (which you can find here), I ported them as the callback for the pipeline, making the guidance scale a parabola. After playing around with values, I've found that l/guidance_scale_begin has some effect on temporal motion (which makes sense because it's denoising at a large sigma). I also found the j/guidance_scale_mid should be a lower value around 4.5, otherwise the model is extremely susceptible to cfg baking. The n/mid_point define where j is on the timestep scale, allowing a better control for where the susceptible time steps are (I found 0.3 works best, and even 0.4 lets stuff bake). The k/guidance_scale_end value is the extremely important one that is the difference between the two videos above. It defines the guidance towards the end of the denoising process, which seems to be where stuff is getting blurry.

Here's the script (still optimized for 24gb ram, including the float8 downcasting. If you just want the callback you should be able to move the callback_dynamic_cfg + values to your own script):

from diffusers import MochiPipeline, MochiTransformer3DModel, AutoencoderKLMochi
from diffusers.models.transformers.transformer_mochi import MochiTransformerBlock
import torch
from diffusers.utils import export_to_video
import math

# The below defines how many layers to quantize. The amount of layers kept is quant_div / quant_mod. Think of this as an inverted quality slider
quant_div = 1
quant_mod = 2
full_dtype = torch.float16
cast_dtype = torch.float8_e4m3fn
#torch.manual_seed(42)
torch.manual_seed(3095062583)

if quant_div > quant_mod: print("quant_div should be less than or equal to quant_mod")

# Credits to `dn6`
# Copy-pasted + slight edit from 
# https://github.com/huggingface/diffusers/blob/layerwise-upcasting/src/diffusers/models/modeling_utils.py
def enable_layerwise_upcasting(model, upcast_dtype=None, original_dtype=None):
    upcast_dtype = upcast_dtype or torch.float32
    original_dtype = original_dtype or model.dtype

    def upcast_dtype_hook_fn(module, *args, **kwargs):
        module = module.to(upcast_dtype)

    def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
        module = module.to(original_dtype)

    def fn_recursive_upcast(module):
        # Upcast entire module and exist recursion
        module.register_forward_pre_hook(upcast_dtype_hook_fn)
        module.register_forward_hook(cast_to_original_dtype_hook_fn)

        has_children = list(module.children())
        if not has_children:
            module.register_forward_pre_hook(upcast_dtype_hook_fn)
            module.register_forward_hook(cast_to_original_dtype_hook_fn)

        for child in module.children():
            fn_recursive_upcast(child)

    for module in model.children():
        fn_recursive_upcast(module)

print("Loading transformer")
transformer = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float32)
print("Adding cast hooks to transformer")
block_idx = 0
mean_attns = []
block = 0
for idx, i in enumerate(transformer.modules()):
    if isinstance(i, MochiTransformerBlock) and idx != 2084: # 2084 is the last layer, and should likely be skipped.
        mean_attns.append((idx, torch.mean(torch.ravel(i.attn1.norm_q.weight)), block)) # can be changed with norm_q, values seem similar
        block += 1
mean_attns.sort(key=lambda x: x[1])
attn_high = [x[0] for x in mean_attns[:int(len(mean_attns) // (quant_mod / quant_div))]]
print([x[2] for x in mean_attns[:int(len(mean_attns) // (quant_mod / quant_div))]])
transformer.to(torch.float16)
for idx, i in enumerate(transformer.modules()):
    if isinstance(i, MochiTransformerBlock):
        if idx in attn_high:
            print("\r", idx, end='')
            i.to(cast_dtype)
            enable_layerwise_upcasting(i, upcast_dtype=full_dtype, original_dtype=cast_dtype)

vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32)
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=full_dtype, transformer=transformer, vae=vae)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()

guidance_scale_begin = 10.0 # Large effect on temporal motion ?
guidance_scale_mid = 3.5 # Should stay at ~4.5
mid_point = 0.3 #between 0 and 1. 0.3 is a good default.
guidance_scale_end = 25.0 # Large effect on spatial clarity
steps=25

def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
    step_index = step_index + 1 # liar
    if step_index > (steps*mid_point):
        pipe._guidance_scale = (((guidance_scale_end-guidance_scale_mid) / math.pow(steps - (steps*mid_point),2)) * math.pow(step_index - (steps*mid_point),2)) + guidance_scale_mid
    else:
        pipe._guidance_scale = (guidance_scale_begin-guidance_scale_mid)*(math.pow(step_index-(steps*mid_point),2)/math.pow(steps*mid_point,2)) + guidance_scale_mid
    print("Current guidance scale:", pipe._guidance_scale)
    return callback_kwargs
frames = pipe("A woman in a white apron is diligently washing dishes in a spacious, well-lit kitchen.The gleaming stainless steel sink is filled with soapy water, and a variety of colorful plates, bowls, and utensils await cleaning. The countertops, adorned with a minimalist arrangement of olive oil, salt, and pepper, offer a backdrop of warm wood against the cool metallic surfaces. Soft, ambient lighting casts a warm glow across the scene, highlighting the woman's focused expression as she meticulously cleans eachpiece.", num_inference_steps=steps, guidance_scale=guidance_scale_begin, height=480, width=848, num_frames=19, callback_on_step_end=callback_dynamic_cfg, callback_on_step_end_tensor_inputs=["latents", "prompt_embeds", "negative_prompt_embeds"]).frames[0]
export_to_video(frames, "mochi.mp4", fps=15)

(Either video = self.vae.decode(latents, return_dict=False)[0] has to be changed to video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] in pipeline_mochi.py or switch to fp16 vae in the above script)

@sayakpaul
Copy link
Member

@Ednaordinary thank you! Are you running your tests from the main branch of diffusers?

@Ednaordinary
Copy link

@sayakpaul I have the Pyramid Attention Broadcast branch installed, though that shouldn't make a difference compared to main in this instance

@sayakpaul
Copy link
Member

@Ednaordinary @YanzuoLu thanks for your continuous guidance here.

With the attached script, I am able to do:

mochi.mp4

I am on the main branch of diffusers

Script
from diffusers import MochiPipeline, AutoencoderKLMochi
from diffusers.utils import export_to_video
from diffusers.video_processor import VideoProcessor
import torch
import gc

pipe = MochiPipeline.from_pretrained(
    "genmo/mochi-1-preview", transformer=None, vae=None
).to("cuda")
prompt =  "a cat walks along the sidewalk of a city. The camera follows the cat at knee level. The city has many people and cars moving around, with advertisement billboards in the background"

with torch.no_grad():
    prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
        pipe.encode_prompt(prompt=prompt)
    )

del pipe.text_encoder 
del pipe 
gc.collect()

pipe = MochiPipeline.from_pretrained(
    "genmo/mochi-1-preview", text_encoder=None, vae=None
).to("cuda")

with torch.autocast("cuda", dtype=torch.bfloat16):
    frames = pipe(
        prompt_embeds=prompt_embeds,
        prompt_attention_mask=prompt_attention_mask,
        negative_prompt_embeds=negative_prompt_embeds,
        negative_prompt_attention_mask=negative_prompt_attention_mask,
        guidance_scale=8.5,
        num_inference_steps=64,
        height=480,
        width=848,
        num_frames=85,
        output_type="latent",
        return_dict=False,
    )[0]

print(f"{frames.shape=}, {frames.dtype=}")

del pipe.transformer
del pipe
gc.collect()

vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae").to("cuda")
vae._enable_framewise_decoding()
video_processor = VideoProcessor(vae_scale_factor=8)

has_latents_mean = hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None
has_latents_std = hasattr(vae.config, "latents_std") and vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
    latents_mean = (
        torch.tensor(vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
    )
    latents_std = (
        torch.tensor(vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
    )
    frames = frames * latents_std / vae.config.scaling_factor + latents_mean
else:
    frames = frames / vae.config.scaling_factor

with torch.no_grad():
    video = vae.decode(frames.to(vae.dtype), return_dict=False)[0]

video = video_processor.postprocess_video(video)[0]
export_to_video(video, "mochi.mp4", fps=30)

We should note that the entire generation process takes quite a bit (~7 mins) to complete so any optimization (simpler preferred) we do here should be realized.

Main changes in the script:

  • Keeping the inputs and outputs to the text encoder and the VAE in FP32 while keeping them in FP32 as well (which is obvious).
  • Keeping the transformer in FP32 and performing computation with bf16 autocast
  • No tiling and slicing in the VAE and enabling framewise decoding
  • High guidance scale, high inference steps

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Dec 4, 2024
@jmahajan117
Copy link

Hey Y'all,

I think I might have something new to drop.

I realized that we might be able to split the transformer part of this model, much like Flux

All I had to do is add was

_no_split_modules = ["MochiTransformerBlock"]

in the MochiTransformer3DModel class.

With that I was able to split the transformer like this,

transformer = MochiTransformer3DModel.from_pretrained(
        id, 
        subfolder="transformer",
        device_map="auto",
        max_memory={0: "48GB", 1: "48GB"},
        torch_dtype=torch.float16
    )

It seems to work on my end. If someone can verify this, that would be awesome!

@sayakpaul
Copy link
Member

Hey @jmahajan117! Thanks for testing that. If you want to open a PR, I can help test it.

We have a bunch of test to verify its effectivity:

Also, if you put it together in and end-to-end code snippet (with the pipeline call), that would be good. Here's our docs.

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* update

* udpate

* update transformer

* make style

* fix

* add conversion script

* update

* fix

* update

* fix

* update

* fixes

* make style

* update

* update

* update

* init

* update

* update

* add

* up

* up

* up

* update

* mochi transformer

* remove original implementation

* make style

* update inits

* update conversion script

* docs

* Update src/diffusers/pipelines/mochi/pipeline_mochi.py

Co-authored-by: Dhruv Nair <[email protected]>

* Update src/diffusers/pipelines/mochi/pipeline_mochi.py

Co-authored-by: Dhruv Nair <[email protected]>

* fix docs

* pipeline fixes

* make style

* invert sigmas in scheduler; fix pipeline

* fix pipeline num_frames

* flip proj and gate in swiglu

* make style

* fix

* make style

* fix tests

* latent mean and std fix

* update

* cherry-pick 1069d21

* remove additional sigma already handled by flow match scheduler

* fix

* remove hardcoded value

* replace conv1x1 with linear

* Update src/diffusers/pipelines/mochi/pipeline_mochi.py

Co-authored-by: Dhruv Nair <[email protected]>

* framewise decoding and conv_cache

* make style

* Apply suggestions from code review

* mochi vae encoder changes

* rebase correctly

* Update scripts/convert_mochi_to_diffusers.py

* fix tests

* fixes

* make style

* update

* make style

* update

* add framewise and tiled encoding

* make style

* make original vae implementation behaviour the default; note: framewise encoding does not work

* remove framewise encoding implementation due to presence of attn layers

* fight test 1

* fight test 2

---------

Co-authored-by: Dhruv Nair <[email protected]>
Co-authored-by: yiyixuxu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Mochi Video Model