Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CogView4 (supports different length c and uc) #10649

Merged
merged 88 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
2640bcf
init
zRzRzRzRzRzRzR Jan 14, 2025
eba11fa
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 14, 2025
6163679
encode with glm
zRzRzRzRzRzRzR Jan 14, 2025
6090ea7
draft schedule
zRzRzRzRzRzRzR Jan 15, 2025
c7d1227
feat(scheduler): Add CogView scheduler implementation
OleehyO Jan 16, 2025
e9f6626
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO Jan 16, 2025
549b357
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 16, 2025
004d002
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 16, 2025
f4457fb
feat(embeddings): add CogView 2D rotary positional embedding
OleehyO Jan 17, 2025
5f8d33b
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO Jan 17, 2025
9a93218
1
zRzRzRzRzRzRzR Jan 17, 2025
ca000dd
Update pipeline_cogview4.py
zRzRzRzRzRzRzR Jan 17, 2025
7ab4a3f
fix the timestep init and sigma
zRzRzRzRzRzRzR Jan 18, 2025
56ceaa6
update latent
zRzRzRzRzRzRzR Jan 19, 2025
a7179a2
draft patch(not work)
zRzRzRzRzRzRzR Jan 19, 2025
c9ddf50
Merge branch 'cogview4'
zRzRzRzRzRzRzR Jan 22, 2025
2f30cc1
Merge pull request #2 from zRzRzRzRzRzRzR/main
zRzRzRzRzRzRzR Jan 22, 2025
e6b8907
fix
zRzRzRzRzRzRzR Jan 22, 2025
0ab7260
[WIP][cogview4]: implement initial CogView4 pipeline
OleehyO Jan 23, 2025
f608f82
[WIP][cogview4][refactor]: Split condition/uncondition forward pass i…
OleehyO Jan 23, 2025
b86bfd4
use with -2 hidden state
zRzRzRzRzRzRzR Jan 23, 2025
c4d1e69
remove text_projector
zRzRzRzRzRzRzR Jan 23, 2025
7916140
1
zRzRzRzRzRzRzR Jan 23, 2025
f8945ce
[WIP] Add tensor-reload to align input from transformer block
OleehyO Jan 24, 2025
bf7f322
[WIP] for older glm
zRzRzRzRzRzRzR Jan 24, 2025
dd6568b
use with cogview4 transformers forward twice of u and uc
zRzRzRzRzRzRzR Jan 25, 2025
6f5407e
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 25, 2025
9e5b991
Update convert_cogview4_to_diffusers.py
zRzRzRzRzRzRzR Jan 25, 2025
36b1682
remove this
zRzRzRzRzRzRzR Jan 26, 2025
804f5cc
Merge pull request #3 from zRzRzRzRzRzRzR/main
zRzRzRzRzRzRzR Jan 28, 2025
16c2397
use main example
zRzRzRzRzRzRzR Jan 28, 2025
601696d
change back
zRzRzRzRzRzRzR Jan 28, 2025
84115dc
reset
zRzRzRzRzRzRzR Jan 28, 2025
95a103f
setback
zRzRzRzRzRzRzR Jan 28, 2025
d932f67
back
zRzRzRzRzRzRzR Jan 28, 2025
b04f15d
back 4
zRzRzRzRzRzRzR Jan 28, 2025
5d33f3f
Fix qkv conversion logic for CogView4 to Diffusers format
zRzRzRzRzRzRzR Jan 28, 2025
b889b37
back5
zRzRzRzRzRzRzR Jan 28, 2025
e239c3c
revert to sat to cogview4 version
zRzRzRzRzRzRzR Jan 28, 2025
310da29
update a new convert from megatron
zRzRzRzRzRzRzR Jan 28, 2025
3bd6d30
[WIP][cogview4]: implement CogView4 attention processor
OleehyO Jan 28, 2025
f826aec
[cogview4] implement CogView4 transformer block
OleehyO Jan 28, 2025
8d8ed8b
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 28, 2025
bf1fdc8
with new attn
zRzRzRzRzRzRzR Jan 28, 2025
6a3a07f
[bugfix] fix dimension mismatch in CogView4 attention
OleehyO Jan 28, 2025
de274f3
[cogview4][WIP]: update final normalization in CogView4 transformer
OleehyO Jan 28, 2025
e94999e
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO Jan 28, 2025
e238284
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Feb 1, 2025
a9b1e16
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Feb 5, 2025
46277b2
1
zRzRzRzRzRzRzR Feb 5, 2025
ebbaa5b
put back
zRzRzRzRzRzRzR Feb 5, 2025
f1ccdd2
Update transformer_cogview4.py
zRzRzRzRzRzRzR Feb 5, 2025
030a467
change time_shift
zRzRzRzRzRzRzR Feb 6, 2025
ad40575
Update pipeline_cogview4.py
zRzRzRzRzRzRzR Feb 6, 2025
81d39ee
change timesteps
zRzRzRzRzRzRzR Feb 6, 2025
45f9e88
fix
zRzRzRzRzRzRzR Feb 6, 2025
1dbeaa8
change text_encoder_id
zRzRzRzRzRzRzR Feb 6, 2025
f209600
[cogview4][rope] align RoPE implementation with Megatron
OleehyO Feb 6, 2025
992f5a3
[cogview4][bugfix] apply silu activation to time embeddings in CogView4
OleehyO Feb 6, 2025
03a1c3b
[cogview4][chore] clean up pipeline code
OleehyO Feb 6, 2025
dd34794
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO Feb 6, 2025
3dab073
[cogview4][scheduler] Implement CogView4 scheduler and pipeline
OleehyO Feb 6, 2025
63982d6
now It work
zRzRzRzRzRzRzR Feb 6, 2025
90a5706
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Feb 6, 2025
d4748e0
add timestep
zRzRzRzRzRzRzR Feb 7, 2025
95f851d
batch
zRzRzRzRzRzRzR Feb 7, 2025
cb56282
change convert scipt
zRzRzRzRzRzRzR Feb 7, 2025
fedf325
refactor pt. 1; make style
a-r-r-o-w Feb 10, 2025
90d29c7
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Feb 10, 2025
4c01c9d
refactor pt. 2
a-r-r-o-w Feb 12, 2025
c1b8004
refactor pt. 3
a-r-r-o-w Feb 12, 2025
9d55d0a
add tests
a-r-r-o-w Feb 12, 2025
5e6de42
make fix-copies
a-r-r-o-w Feb 12, 2025
30dd0ad
Merge branch 'main' into cogview4
a-r-r-o-w Feb 12, 2025
2046cf2
update toctree.yml
a-r-r-o-w Feb 12, 2025
39e1198
use flow match scheduler instead of custom
a-r-r-o-w Feb 13, 2025
b566a9f
Merge branch 'main' into cogview4
a-r-r-o-w Feb 13, 2025
b4c9fde
remove scheduling_cogview.py
a-r-r-o-w Feb 13, 2025
a137e17
add tiktoken to test dependencies
a-r-r-o-w Feb 13, 2025
da420fb
Update src/diffusers/models/embeddings.py
a-r-r-o-w Feb 13, 2025
4003b9c
apply suggestions from review
a-r-r-o-w Feb 13, 2025
35c0ec6
use diffusers apply_rotary_emb
a-r-r-o-w Feb 13, 2025
d328c5e
update flow match scheduler to accept timesteps
a-r-r-o-w Feb 14, 2025
d637d3a
Merge branch 'main' into cogview4
a-r-r-o-w Feb 14, 2025
4c37ef0
fix comment
a-r-r-o-w Feb 14, 2025
90c240b
apply review sugestions
a-r-r-o-w Feb 14, 2025
5c11298
Merge branch 'main' into cogview4
a-r-r-o-w Feb 14, 2025
2f12b7a
Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
a-r-r-o-w Feb 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@
title: ConsisIDTransformer3DModel
- local: api/models/cogview3plus_transformer2d
title: CogView3PlusTransformer2DModel
- local: api/models/cogview4_transformer2d
title: CogView4Transformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/flux_transformer
Expand Down Expand Up @@ -382,6 +384,8 @@
title: CogVideoX
- local: api/pipelines/cogview3
title: CogView3
- local: api/pipelines/cogview4
title: CogView4
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/consistency_models
Expand Down
30 changes: 30 additions & 0 deletions docs/source/en/api/models/cogview4_transformer2d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# CogView4Transformer2DModel

A Diffusion Transformer model for 2D data from [CogView4]()

The model can be loaded with the following code snippet.

```python
from diffusers import CogView4Transformer2DModel

transformer = CogView4Transformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```

## CogView4Transformer2DModel

[[autodoc]] CogView4Transformer2DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
34 changes: 34 additions & 0 deletions docs/source/en/api/pipelines/cogview4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->

# CogView4

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).

## CogView4Pipeline

[[autodoc]] CogView4Pipeline
- all
- __call__

## CogView4PipelineOutput

[[autodoc]] pipelines.cogview4.pipeline_output.CogView4PipelineOutput
243 changes: 243 additions & 0 deletions scripts/convert_cogview4_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
"""
Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format.
(deprecated Since 2025-02-07 and will remove it in later CogView4 version)
This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
with the Diffusers library.
Example usage:
python scripts/convert_cogview4_to_diffusers.py \
--transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
--output_path "THUDM/CogView4-6B" \
--dtype "bf16"
Arguments:
--transformer_checkpoint_path: Path to Transformer state dict.
--vae_checkpoint_path: Path to VAE state dict.
--output_path: The path to save the converted model.
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
Default is "bf16" because CogView4 uses bfloat16 for Training.
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
"""

import argparse
from contextlib import nullcontext

import torch
from accelerate import init_empty_weights
from transformers import GlmForCausalLM, PreTrainedTokenizerFast

from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available


CTX = init_empty_weights if is_accelerate_available() else nullcontext

parser = argparse.ArgumentParser()
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default="bf16")

args = parser.parse_args()


# this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight


def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
original_state_dict = torch.load(ckpt_path, map_location="cpu")
original_state_dict = original_state_dict["module"]
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}

new_state_dict = {}

# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")

# Convert time_condition_embed
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_embed.0.bias"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_embed.2.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
"label_emb.0.0.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
"label_emb.0.0.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
"label_emb.0.2.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
"label_emb.0.2.bias"
)

# Convert transformer blocks, for cogview4 is 28 blocks
for i in range(28):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"transformer.layers.{i}."
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")

qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)

new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias

new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
old_prefix + "attention.dense.bias"
)

new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.bias"
)
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_4h_to_h.weight"
)
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")

# Convert final norm and projection
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
)
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")

return new_state_dict


def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)


def main(args):
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")

transformer = None
vae = None

if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=16,
num_layers=28,
attention_head_dim=128,
num_attention_heads=32,
out_channels=16,
text_embed_dim=4096,
time_embed_dim=512,
condition_dim=256,
pos_embed_max_size=128,
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
# Original checkpoint data type will be preserved
transformer = transformer.to(dtype=dtype)

if args.vae_checkpoint_path is not None:
vae_config = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ("DownEncoderBlock2D",) * 4,
"up_block_types": ("UpDecoderBlock2D",) * 4,
"block_out_channels": (128, 512, 1024, 1024),
"layers_per_block": 3,
"act_fn": "silu",
"latent_channels": 16,
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
"mid_block_add_attention": False,
}
converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if dtype is not None:
vae = vae.to(dtype=dtype)

text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmForCausalLM.from_pretrained(
text_encoder_id,
cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
)

for param in text_encoder.parameters():
param.data = param.data.contiguous()

scheduler = FlowMatchEulerDiscreteScheduler(
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
)
Comment on lines +225 to +227
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zRzRzRzRzRzRzR Same comment as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the current algorithm comparison and the images produced in practice, this change seems to be functioning properly.


pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)

# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
# save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)


if __name__ == "__main__":
main(args)
Loading
Loading