Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cogvideox-5B Model adapter change #9203

Merged
merged 24 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m

This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).

There are two models available that can be used with the CogVideoX pipeline:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)

## Inference

Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
Expand Down Expand Up @@ -68,7 +72,7 @@ With torch.compile(): Average inference time: 76.27 seconds.

### Memory optimization

CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.

- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`
Expand Down
64 changes: 55 additions & 9 deletions scripts/convert_cogvideox_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
"freqs_sin": remove_keys_inplace,
"freqs_cos": remove_keys_inplace,
"position_embedding": remove_keys_inplace,
}

VAE_KEYS_RENAME_DICT = {
Expand Down Expand Up @@ -123,11 +126,21 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
state_dict[new_key] = state_dict.pop(old_key)


def convert_transformer(ckpt_path: str):
def convert_transformer(
ckpt_path: str,
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model."

original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel()
transformer = CogVideoXTransformer3DModel(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
).to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
Expand All @@ -145,9 +158,9 @@ def convert_transformer(ckpt_path: str):
return transformer


def convert_vae(ckpt_path: str):
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX()
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[:]
Expand All @@ -172,13 +185,26 @@ def get_args():
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
parser.add_argument(
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
)
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
return parser.parse_args()


Expand All @@ -188,18 +214,33 @@ def get_args():
transformer = None
vae = None

if args.fp16 and args.bf16:
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")

dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32

if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
dtype,
)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)

text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)

# Apparently, the conversion does not work any more without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()

scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 3.0,
"snr_shift_scale": args.snr_shift_scale,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
Expand All @@ -208,7 +249,7 @@ def get_args():
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "linspace",
"timestep_spacing": "trailing",
}
)

Expand All @@ -218,5 +259,10 @@ def get_args():

if args.fp16:
pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)

# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
142 changes: 142 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,148 @@ def __call__(
return hidden_states


class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb

query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states


class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb

query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states


class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
Expand Down
84 changes: 84 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,90 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
return embeds


def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.

Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.

Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)

# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3

# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's refactor with get_1d_rotary_pos_embeds (

def get_1d_rotary_pos_embed(
) (in a follow-up PR )

grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)

# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)

# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)

freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)

t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)

# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()

if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis


def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
Expand Down
Loading
Loading