-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogView4 (supports different length c and uc) #10649
Merged
Merged
Changes from 79 commits
Commits
Show all changes
88 commits
Select commit
Hold shift + click to select a range
2640bcf
init
zRzRzRzRzRzRzR eba11fa
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 6163679
encode with glm
zRzRzRzRzRzRzR 6090ea7
draft schedule
zRzRzRzRzRzRzR c7d1227
feat(scheduler): Add CogView scheduler implementation
OleehyO e9f6626
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO 549b357
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 004d002
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR f4457fb
feat(embeddings): add CogView 2D rotary positional embedding
OleehyO 5f8d33b
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO 9a93218
1
zRzRzRzRzRzRzR ca000dd
Update pipeline_cogview4.py
zRzRzRzRzRzRzR 7ab4a3f
fix the timestep init and sigma
zRzRzRzRzRzRzR 56ceaa6
update latent
zRzRzRzRzRzRzR a7179a2
draft patch(not work)
zRzRzRzRzRzRzR c9ddf50
Merge branch 'cogview4'
zRzRzRzRzRzRzR 2f30cc1
Merge pull request #2 from zRzRzRzRzRzRzR/main
zRzRzRzRzRzRzR e6b8907
fix
zRzRzRzRzRzRzR 0ab7260
[WIP][cogview4]: implement initial CogView4 pipeline
OleehyO f608f82
[WIP][cogview4][refactor]: Split condition/uncondition forward pass i…
OleehyO b86bfd4
use with -2 hidden state
zRzRzRzRzRzRzR c4d1e69
remove text_projector
zRzRzRzRzRzRzR 7916140
1
zRzRzRzRzRzRzR f8945ce
[WIP] Add tensor-reload to align input from transformer block
OleehyO bf7f322
[WIP] for older glm
zRzRzRzRzRzRzR dd6568b
use with cogview4 transformers forward twice of u and uc
zRzRzRzRzRzRzR 6f5407e
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 9e5b991
Update convert_cogview4_to_diffusers.py
zRzRzRzRzRzRzR 36b1682
remove this
zRzRzRzRzRzRzR 804f5cc
Merge pull request #3 from zRzRzRzRzRzRzR/main
zRzRzRzRzRzRzR 16c2397
use main example
zRzRzRzRzRzRzR 601696d
change back
zRzRzRzRzRzRzR 84115dc
reset
zRzRzRzRzRzRzR 95a103f
setback
zRzRzRzRzRzRzR d932f67
back
zRzRzRzRzRzRzR b04f15d
back 4
zRzRzRzRzRzRzR 5d33f3f
Fix qkv conversion logic for CogView4 to Diffusers format
zRzRzRzRzRzRzR b889b37
back5
zRzRzRzRzRzRzR e239c3c
revert to sat to cogview4 version
zRzRzRzRzRzRzR 310da29
update a new convert from megatron
zRzRzRzRzRzRzR 3bd6d30
[WIP][cogview4]: implement CogView4 attention processor
OleehyO f826aec
[cogview4] implement CogView4 transformer block
OleehyO 8d8ed8b
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR bf1fdc8
with new attn
zRzRzRzRzRzRzR 6a3a07f
[bugfix] fix dimension mismatch in CogView4 attention
OleehyO de274f3
[cogview4][WIP]: update final normalization in CogView4 transformer
OleehyO e94999e
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO e238284
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR a9b1e16
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 46277b2
1
zRzRzRzRzRzRzR ebbaa5b
put back
zRzRzRzRzRzRzR f1ccdd2
Update transformer_cogview4.py
zRzRzRzRzRzRzR 030a467
change time_shift
zRzRzRzRzRzRzR ad40575
Update pipeline_cogview4.py
zRzRzRzRzRzRzR 81d39ee
change timesteps
zRzRzRzRzRzRzR 45f9e88
fix
zRzRzRzRzRzRzR 1dbeaa8
change text_encoder_id
zRzRzRzRzRzRzR f209600
[cogview4][rope] align RoPE implementation with Megatron
OleehyO 992f5a3
[cogview4][bugfix] apply silu activation to time embeddings in CogView4
OleehyO 03a1c3b
[cogview4][chore] clean up pipeline code
OleehyO dd34794
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO 3dab073
[cogview4][scheduler] Implement CogView4 scheduler and pipeline
OleehyO 63982d6
now It work
zRzRzRzRzRzRzR 90a5706
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR d4748e0
add timestep
zRzRzRzRzRzRzR 95f851d
batch
zRzRzRzRzRzRzR cb56282
change convert scipt
zRzRzRzRzRzRzR fedf325
refactor pt. 1; make style
a-r-r-o-w 90d29c7
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR 4c01c9d
refactor pt. 2
a-r-r-o-w c1b8004
refactor pt. 3
a-r-r-o-w 9d55d0a
add tests
a-r-r-o-w 5e6de42
make fix-copies
a-r-r-o-w 30dd0ad
Merge branch 'main' into cogview4
a-r-r-o-w 2046cf2
update toctree.yml
a-r-r-o-w 39e1198
use flow match scheduler instead of custom
a-r-r-o-w b566a9f
Merge branch 'main' into cogview4
a-r-r-o-w b4c9fde
remove scheduling_cogview.py
a-r-r-o-w a137e17
add tiktoken to test dependencies
a-r-r-o-w da420fb
Update src/diffusers/models/embeddings.py
a-r-r-o-w 4003b9c
apply suggestions from review
a-r-r-o-w 35c0ec6
use diffusers apply_rotary_emb
a-r-r-o-w d328c5e
update flow match scheduler to accept timesteps
a-r-r-o-w d637d3a
Merge branch 'main' into cogview4
a-r-r-o-w 4c37ef0
fix comment
a-r-r-o-w 90c240b
apply review sugestions
a-r-r-o-w 5c11298
Merge branch 'main' into cogview4
a-r-r-o-w 2f12b7a
Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
a-r-r-o-w File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# CogView4Transformer2DModel | ||
|
||
A Diffusion Transformer model for 2D data from [CogView4]() | ||
|
||
The model can be loaded with the following code snippet. | ||
|
||
```python | ||
from diffusers import CogView4Transformer2DModel | ||
|
||
transformer = CogView4Transformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") | ||
``` | ||
|
||
## CogView4Transformer2DModel | ||
|
||
[[autodoc]] CogView4Transformer2DModel | ||
|
||
## Transformer2DModelOutput | ||
|
||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
--> | ||
|
||
# CogView4 | ||
|
||
<Tip> | ||
|
||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. | ||
|
||
</Tip> | ||
|
||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). | ||
|
||
## CogView4Pipeline | ||
|
||
[[autodoc]] CogView4Pipeline | ||
- all | ||
- __call__ | ||
|
||
## CogView4PipelineOutput | ||
|
||
[[autodoc]] pipelines.cogview4.pipeline_output.CogView4PipelineOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
""" | ||
Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format. | ||
(deprecated Since 2025-02-07 and will remove it in later CogView4 version) | ||
This script converts a CogView4 checkpoint to the Diffusers format, which can then be used | ||
with the Diffusers library. | ||
Example usage: | ||
python scripts/convert_cogview4_to_diffusers.py \ | ||
--transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \ | ||
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \ | ||
--output_path "THUDM/CogView4-6B" \ | ||
--dtype "bf16" | ||
Arguments: | ||
--transformer_checkpoint_path: Path to Transformer state dict. | ||
--vae_checkpoint_path: Path to VAE state dict. | ||
--output_path: The path to save the converted model. | ||
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`. | ||
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used | ||
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered. | ||
Default is "bf16" because CogView4 uses bfloat16 for Training. | ||
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path. | ||
""" | ||
|
||
import argparse | ||
from contextlib import nullcontext | ||
|
||
import torch | ||
from accelerate import init_empty_weights | ||
from transformers import GlmForCausalLM, PreTrainedTokenizerFast | ||
|
||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler | ||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint | ||
from diffusers.utils.import_utils import is_accelerate_available | ||
|
||
|
||
CTX = init_empty_weights if is_accelerate_available() else nullcontext | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--transformer_checkpoint_path", default=None, type=str) | ||
parser.add_argument("--vae_checkpoint_path", default=None, type=str) | ||
parser.add_argument("--output_path", required=True, type=str) | ||
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") | ||
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") | ||
parser.add_argument("--dtype", type=str, default="bf16") | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
# this is specific to `AdaLayerNormContinuous`: | ||
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale | ||
def swap_scale_shift(weight, dim): | ||
shift, scale = weight.chunk(2, dim=0) | ||
new_weight = torch.cat([scale, shift], dim=0) | ||
return new_weight | ||
|
||
|
||
def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path): | ||
original_state_dict = torch.load(ckpt_path, map_location="cpu") | ||
original_state_dict = original_state_dict["module"] | ||
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()} | ||
|
||
new_state_dict = {} | ||
|
||
# Convert patch_embed | ||
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight") | ||
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias") | ||
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight") | ||
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias") | ||
|
||
# Convert time_condition_embed | ||
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( | ||
"time_embed.0.weight" | ||
) | ||
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( | ||
"time_embed.0.bias" | ||
) | ||
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( | ||
"time_embed.2.weight" | ||
) | ||
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( | ||
"time_embed.2.bias" | ||
) | ||
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop( | ||
"label_emb.0.0.weight" | ||
) | ||
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop( | ||
"label_emb.0.0.bias" | ||
) | ||
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop( | ||
"label_emb.0.2.weight" | ||
) | ||
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop( | ||
"label_emb.0.2.bias" | ||
) | ||
|
||
# Convert transformer blocks, for cogview4 is 28 blocks | ||
for i in range(28): | ||
block_prefix = f"transformer_blocks.{i}." | ||
old_prefix = f"transformer.layers.{i}." | ||
adaln_prefix = f"mixins.adaln.adaln_modules.{i}." | ||
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight") | ||
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias") | ||
|
||
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight") | ||
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias") | ||
q, k, v = qkv_weight.chunk(3, dim=0) | ||
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) | ||
|
||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q | ||
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias | ||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k | ||
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias | ||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v | ||
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias | ||
|
||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( | ||
old_prefix + "attention.dense.weight" | ||
) | ||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop( | ||
old_prefix + "attention.dense.bias" | ||
) | ||
|
||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop( | ||
old_prefix + "mlp.dense_h_to_4h.weight" | ||
) | ||
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop( | ||
old_prefix + "mlp.dense_h_to_4h.bias" | ||
) | ||
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop( | ||
old_prefix + "mlp.dense_4h_to_h.weight" | ||
) | ||
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias") | ||
|
||
# Convert final norm and projection | ||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift( | ||
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0 | ||
) | ||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift( | ||
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0 | ||
) | ||
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight") | ||
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias") | ||
|
||
return new_state_dict | ||
|
||
|
||
def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config): | ||
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] | ||
return convert_ldm_vae_checkpoint(original_state_dict, vae_config) | ||
|
||
|
||
def main(args): | ||
if args.dtype == "fp16": | ||
dtype = torch.float16 | ||
elif args.dtype == "bf16": | ||
dtype = torch.bfloat16 | ||
elif args.dtype == "fp32": | ||
dtype = torch.float32 | ||
else: | ||
raise ValueError(f"Unsupported dtype: {args.dtype}") | ||
|
||
transformer = None | ||
vae = None | ||
|
||
if args.transformer_checkpoint_path is not None: | ||
converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers( | ||
args.transformer_checkpoint_path | ||
) | ||
transformer = CogView4Transformer2DModel( | ||
patch_size=2, | ||
in_channels=16, | ||
num_layers=28, | ||
attention_head_dim=128, | ||
num_attention_heads=32, | ||
out_channels=16, | ||
text_embed_dim=4096, | ||
time_embed_dim=512, | ||
condition_dim=256, | ||
pos_embed_max_size=128, | ||
) | ||
transformer.load_state_dict(converted_transformer_state_dict, strict=True) | ||
if dtype is not None: | ||
# Original checkpoint data type will be preserved | ||
transformer = transformer.to(dtype=dtype) | ||
|
||
if args.vae_checkpoint_path is not None: | ||
vae_config = { | ||
"in_channels": 3, | ||
"out_channels": 3, | ||
"down_block_types": ("DownEncoderBlock2D",) * 4, | ||
"up_block_types": ("UpDecoderBlock2D",) * 4, | ||
"block_out_channels": (128, 512, 1024, 1024), | ||
"layers_per_block": 3, | ||
"act_fn": "silu", | ||
"latent_channels": 16, | ||
"norm_num_groups": 32, | ||
"sample_size": 1024, | ||
"scaling_factor": 1.0, | ||
"force_upcast": True, | ||
"use_quant_conv": False, | ||
"use_post_quant_conv": False, | ||
"mid_block_add_attention": False, | ||
} | ||
converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config) | ||
vae = AutoencoderKL(**vae_config) | ||
vae.load_state_dict(converted_vae_state_dict, strict=True) | ||
if dtype is not None: | ||
vae = vae.to(dtype=dtype) | ||
|
||
text_encoder_id = "THUDM/glm-4-9b-hf" | ||
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id) | ||
text_encoder = GlmForCausalLM.from_pretrained( | ||
text_encoder_id, | ||
cache_dir=args.text_encoder_cache_dir, | ||
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32, | ||
) | ||
|
||
for param in text_encoder.parameters(): | ||
param.data = param.data.contiguous() | ||
|
||
scheduler = FlowMatchEulerDiscreteScheduler( | ||
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear" | ||
) | ||
|
||
pipe = CogView4Pipeline( | ||
tokenizer=tokenizer, | ||
text_encoder=text_encoder, | ||
vae=vae, | ||
transformer=transformer, | ||
scheduler=scheduler, | ||
) | ||
|
||
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can | ||
# save some memory used for model loading. | ||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(args) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zRzRzRzRzRzRzR Same comment as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the current algorithm comparison and the images produced in practice, this change seems to be functioning properly.