From 2640bcf60b17a54d20e68430cdf9cccb9c0a648a Mon Sep 17 00:00:00 2001
From: zRzRzRzRzRzRzR <2448370773@qq.com>
Date: Tue, 14 Jan 2025 20:22:06 +0800
Subject: [PATCH 01/68] init
---
docs/source/en/_toctree.yml | 2 +
docs/source/en/api/pipelines/cogview4.md | 34 +
.../train_dreambooth_lora_flux_advanced.py | 8 +-
.../train_dreambooth_lora_sd15_advanced.py | 16 +-
.../train_dreambooth_lora_sdxl_advanced.py | 16 +-
examples/amused/train_amused.py | 2 +-
.../train_cogvideox_image_to_video_lora.py | 2 +-
examples/cogvideo/train_cogvideox_lora.py | 2 +-
.../community/adaptive_mask_inpainting.py | 2 +-
examples/community/hd_painter.py | 2 +-
examples/community/img2img_inpainting.py | 2 +-
examples/community/llm_grounded_diffusion.py | 4 +-
examples/community/lpw_stable_diffusion_xl.py | 2 +-
.../pipeline_flux_differential_img2img.py | 4 +-
examples/community/pipeline_prompt2prompt.py | 12 +-
.../community/pipeline_sdxl_style_aligned.py | 2 +-
...pipeline_stable_diffusion_upscale_ldm3d.py | 2 +-
...diffusion_xl_controlnet_adapter_inpaint.py | 2 +-
examples/community/scheduling_ufogen.py | 3 +-
.../train_lcm_distill_lora_sd_wds.py | 2 +-
.../train_lcm_distill_lora_sdxl.py | 2 +-
.../train_lcm_distill_lora_sdxl_wds.py | 2 +-
examples/custom_diffusion/retrieve.py | 8 +-
.../train_custom_diffusion.py | 24 +-
examples/dreambooth/train_dreambooth.py | 2 +-
examples/dreambooth/train_dreambooth_lora.py | 2 +-
.../dreambooth/train_dreambooth_lora_flux.py | 2 +-
.../dreambooth/train_dreambooth_lora_sana.py | 2 +-
.../dreambooth/train_dreambooth_lora_sd3.py | 2 +-
.../dreambooth/train_dreambooth_lora_sdxl.py | 4 +-
.../flux-control/train_control_lora_flux.py | 8 +-
.../colossalai/train_dreambooth_colossalai.py | 2 +-
.../controlnet/train_controlnet_webdataset.py | 7 +-
.../diffusion_dpo/train_diffusion_dpo.py | 2 +-
.../diffusion_dpo/train_diffusion_dpo_sdxl.py | 2 +-
.../train_diffusion_orpo_sdxl_lora.py | 4 +-
.../train_diffusion_orpo_sdxl_lora_wds.py | 4 +-
.../train_dreambooth_lora_flux_miniature.py | 2 +-
examples/research_projects/gligen/demo.ipynb | 20 +-
.../train_instruct_pix2pix_lora.py | 4 +-
.../train_multi_subject_dreambooth.py | 12 +-
.../textual_inversion.py | 6 +-
.../textual_inversion/textual_inversion.py | 6 +-
.../pipeline_prompt_diffusion.py | 3 +-
.../pytorch_xla/train_text_to_image_xla.py | 4 +-
.../dreambooth/train_dreambooth.py | 2 +-
.../dreambooth/train_dreambooth_lora.py | 2 +-
.../dreambooth/train_dreambooth_lora_sdxl.py | 4 +-
.../train_text_to_image_lora_sdxl.py | 2 +-
.../train_dreambooth_lora_sd3_miniature.py | 2 +-
.../train_text_to_image_lora_sdxl.py | 2 +-
.../textual_inversion/textual_inversion.py | 6 +-
.../textual_inversion_sdxl.py | 12 +-
examples/vqgan/test_vqgan.py | 6 +-
examples/vqgan/train_vqgan.py | 12 +-
scripts/convert_amused.py | 2 +-
scripts/convert_cogview4_to_diffusers.py | 249 +++++++
scripts/convert_consistency_to_diffusers.py | 4 +-
.../convert_dance_diffusion_to_diffusers.py | 12 +-
scripts/convert_diffusers_to_original_sdxl.py | 18 +-
..._diffusers_to_original_stable_diffusion.py | 20 +-
...vert_hunyuandit_controlnet_to_diffusers.py | 6 +-
scripts/convert_hunyuandit_to_diffusers.py | 9 +-
scripts/convert_k_upscaler_to_diffusers.py | 10 +-
scripts/convert_mochi_to_diffusers.py | 118 +--
...convert_original_audioldm2_to_diffusers.py | 2 +-
.../convert_original_audioldm_to_diffusers.py | 2 +-
.../convert_original_musicldm_to_diffusers.py | 2 +-
scripts/convert_stable_audio.py | 18 +-
scripts/convert_svd_to_diffusers.py | 12 +-
scripts/convert_vq_diffusion_to_diffusers.py | 24 +-
src/diffusers/__init__.py | 2 +
src/diffusers/loaders/ip_adapter.py | 6 +-
.../loaders/lora_conversion_utils.py | 66 +-
src/diffusers/models/model_loading_utils.py | 2 +-
.../models/transformers/transformer_2d.py | 6 +-
src/diffusers/pipelines/__init__.py | 3 +
.../pipelines/audioldm2/pipeline_audioldm2.py | 2 +-
src/diffusers/pipelines/cogview4/__init__.py | 47 ++
.../pipelines/cogview4/pipeline_cogview4.py | 675 ++++++++++++++++++
.../pipelines/cogview4/pipeline_output.py | 21 +
.../controlnet/pipeline_controlnet_inpaint.py | 4 +-
.../pipeline_controlnet_inpaint_sd_xl.py | 6 +-
...pipeline_controlnet_union_inpaint_sd_xl.py | 4 +-
.../pipeline_flux_controlnet_inpainting.py | 4 +-
.../pipelines/flux/pipeline_flux_inpaint.py | 4 +-
src/diffusers/pipelines/free_noise_utils.py | 6 +-
.../kandinsky/pipeline_kandinsky_combined.py | 2 +-
.../kandinsky/pipeline_kandinsky_inpaint.py | 2 +-
.../pag/pipeline_pag_controlnet_sd_inpaint.py | 6 +-
.../pipelines/pag/pipeline_pag_sd_inpaint.py | 6 +-
.../pag/pipeline_pag_sd_xl_inpaint.py | 6 +-
.../pipeline_paint_by_example.py | 2 +-
.../pipelines/pipeline_loading_utils.py | 4 +-
src/diffusers/pipelines/pipeline_utils.py | 4 +-
src/diffusers/pipelines/shap_e/renderer.py | 12 +-
.../stable_audio/pipeline_stable_audio.py | 2 +-
.../pipeline_flax_stable_diffusion_inpaint.py | 2 +-
.../pipeline_onnx_stable_diffusion_inpaint.py | 2 +-
.../pipeline_stable_diffusion_inpaint.py | 6 +-
...eline_stable_diffusion_instruct_pix2pix.py | 2 +-
...ipeline_stable_diffusion_latent_upscale.py | 2 +-
.../pipeline_stable_diffusion_upscale.py | 2 +-
.../pipeline_stable_diffusion_3_inpaint.py | 2 +-
.../pipeline_stable_diffusion_xl_inpaint.py | 6 +-
src/diffusers/quantizers/base.py | 12 +-
.../scheduling_consistency_models.py | 3 +-
src/diffusers/schedulers/scheduling_ddpm.py | 3 +-
.../schedulers/scheduling_ddpm_parallel.py | 3 +-
src/diffusers/schedulers/scheduling_lcm.py | 3 +-
src/diffusers/schedulers/scheduling_tcd.py | 3 +-
src/diffusers/training_utils.py | 4 +-
src/diffusers/utils/deprecation_utils.py | 2 +-
.../dummy_torch_and_transformers_objects.py | 15 +
src/diffusers/utils/import_utils.py | 2 +-
src/diffusers/utils/logging.py | 3 +-
src/diffusers/utils/state_dict_utils.py | 2 +-
src/diffusers/utils/testing_utils.py | 4 +-
tests/models/test_modeling_common.py | 12 +-
.../test_models_transformer_sd3.py | 12 +-
.../unets/test_models_unet_2d_condition.py | 36 +-
tests/others/test_image_processor.py | 30 +-
tests/pipelines/amused/test_amused.py | 3 +-
tests/pipelines/amused/test_amused_img2img.py | 3 +-
tests/pipelines/amused/test_amused_inpaint.py | 3 +-
.../aura_flow/test_pipeline_aura_flow.py | 24 +-
.../blipdiffusion/test_blipdiffusion.py | 6 +-
tests/pipelines/cogvideo/test_cogvideox.py | 24 +-
.../cogvideo/test_cogvideox_fun_control.py | 24 +-
.../cogvideo/test_cogvideox_image2video.py | 24 +-
.../cogvideo/test_cogvideox_video2video.py | 24 +-
.../test_controlnet_blip_diffusion.py | 6 +-
.../controlnet_flux/test_controlnet_flux.py | 6 +-
.../test_controlnet_flux_img2img.py | 24 +-
.../test_controlnet_hunyuandit.py | 6 +-
.../test_controlnet_inpaint_sd3.py | 6 +-
.../controlnet_sd3/test_controlnet_sd3.py | 6 +-
tests/pipelines/dit/test_dit.py | 3 +-
tests/pipelines/flux/test_pipeline_flux.py | 24 +-
.../flux/test_pipeline_flux_control.py | 24 +-
.../test_pipeline_flux_control_inpaint.py | 24 +-
.../pipelines/hunyuan_dit/test_hunyuan_dit.py | 24 +-
tests/pipelines/kandinsky/test_kandinsky.py | 12 +-
.../kandinsky/test_kandinsky_combined.py | 36 +-
.../kandinsky/test_kandinsky_img2img.py | 16 +-
.../kandinsky/test_kandinsky_inpaint.py | 14 +-
.../pipelines/kandinsky2_2/test_kandinsky.py | 12 +-
.../kandinsky2_2/test_kandinsky_combined.py | 36 +-
.../kandinsky2_2/test_kandinsky_controlnet.py | 12 +-
.../test_kandinsky_controlnet_img2img.py | 14 +-
.../kandinsky2_2/test_kandinsky_img2img.py | 14 +-
.../kandinsky2_2/test_kandinsky_inpaint.py | 14 +-
tests/pipelines/kandinsky3/test_kandinsky3.py | 6 +-
.../kandinsky3/test_kandinsky3_img2img.py | 6 +-
tests/pipelines/pag/test_pag_animatediff.py | 6 +-
tests/pipelines/pag/test_pag_controlnet_sd.py | 6 +-
.../pag/test_pag_controlnet_sd_inpaint.py | 6 +-
.../pipelines/pag/test_pag_controlnet_sdxl.py | 6 +-
.../pag/test_pag_controlnet_sdxl_img2img.py | 6 +-
tests/pipelines/pag/test_pag_hunyuan_dit.py | 24 +-
tests/pipelines/pag/test_pag_kolors.py | 6 +-
tests/pipelines/pag/test_pag_pixart_sigma.py | 6 +-
tests/pipelines/pag/test_pag_sana.py | 6 +-
tests/pipelines/pag/test_pag_sd.py | 18 +-
tests/pipelines/pag/test_pag_sd3.py | 30 +-
tests/pipelines/pag/test_pag_sd3_img2img.py | 18 +-
tests/pipelines/pag/test_pag_sd_img2img.py | 18 +-
tests/pipelines/pag/test_pag_sd_inpaint.py | 12 +-
tests/pipelines/pag/test_pag_sdxl.py | 18 +-
tests/pipelines/pag/test_pag_sdxl_img2img.py | 18 +-
tests/pipelines/pag/test_pag_sdxl_inpaint.py | 18 +-
tests/pipelines/pixart_sigma/test_pixart.py | 24 +-
tests/pipelines/shap_e/test_shap_e_img2img.py | 2 +-
.../test_stable_cascade_combined.py | 12 +-
.../stable_diffusion/test_stable_diffusion.py | 48 +-
.../test_pipeline_stable_diffusion_3.py | 24 +-
.../test_stable_diffusion_xl.py | 30 +-
.../test_stable_diffusion_xl_inpaint.py | 12 +-
tests/pipelines/test_pipelines.py | 24 +-
tests/pipelines/test_pipelines_common.py | 48 +-
.../wuerstchen/test_wuerstchen_combined.py | 12 +-
tests/schedulers/test_scheduler_dpm_multi.py | 6 +-
tests/schedulers/test_scheduler_dpm_single.py | 6 +-
.../test_scheduler_edm_dpmsolver_multistep.py | 6 +-
tests/schedulers/test_scheduler_euler.py | 12 +-
tests/schedulers/test_scheduler_heun.py | 6 +-
.../single_file/single_file_testing_utils.py | 24 +-
.../test_model_autoencoder_dc_single_file.py | 18 +-
.../test_model_controlnet_single_file.py | 6 +-
...test_model_flux_transformer_single_file.py | 6 +-
.../test_model_motion_adapter_single_file.py | 24 +-
.../test_model_sd_cascade_unet_single_file.py | 24 +-
.../single_file/test_model_vae_single_file.py | 6 +-
utils/log_reports.py | 2 +-
utils/update_metadata.py | 3 +-
195 files changed, 2001 insertions(+), 978 deletions(-)
create mode 100644 docs/source/en/api/pipelines/cogview4.md
create mode 100644 scripts/convert_cogview4_to_diffusers.py
create mode 100644 src/diffusers/pipelines/cogview4/__init__.py
create mode 100644 src/diffusers/pipelines/cogview4/pipeline_cogview4.py
create mode 100644 src/diffusers/pipelines/cogview4/pipeline_output.py
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index a2b411c8fcb0..71e7a300f4cf 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -370,6 +370,8 @@
title: CogVideoX
- local: api/pipelines/cogview3
title: CogView3
+ - local: api/pipelines/cogview4
+ title: CogView4
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
diff --git a/docs/source/en/api/pipelines/cogview4.md b/docs/source/en/api/pipelines/cogview4.md
new file mode 100644
index 000000000000..cc17c3c905fb
--- /dev/null
+++ b/docs/source/en/api/pipelines/cogview4.md
@@ -0,0 +1,34 @@
+
+
+# CogView4
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
+
+## CogView4Pipeline
+
+[[autodoc]] CogView4Pipeline
+ - all
+ - __call__
+
+## CogView4PipelineOutput
+
+[[autodoc]] pipelines.cogview4.pipeline_output.CogView4PipelineOutput
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
index 0fcbe2000ce7..96d56138bb5a 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
@@ -818,9 +818,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(
- isinstance(tok, str) for tok in inserting_toks
- ), "All elements in inserting_toks should be strings."
+ assert all(isinstance(tok, str) for tok in inserting_toks), (
+ "All elements in inserting_toks should be strings."
+ )
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -1683,7 +1683,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 923683ae7c38..22472298d7ac 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -200,7 +200,7 @@ def save_model_card(
"diffusers",
"diffusers-training",
lora,
- "template:sd-lora" "stable-diffusion",
+ "template:sd-lorastable-diffusion",
"stable-diffusion-diffusers",
]
model_card = populate_model_card(model_card, tags=tags)
@@ -724,9 +724,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(
- isinstance(tok, str) for tok in inserting_toks
- ), "All elements in inserting_toks should be strings."
+ assert all(isinstance(tok, str) for tok in inserting_toks), (
+ "All elements in inserting_toks should be strings."
+ )
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -746,9 +746,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
.to(dtype=self.dtype)
* std_token_embedding
)
- self.embeddings_settings[
- f"original_embeddings_{idx}"
- ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ self.embeddings_settings[f"original_embeddings_{idx}"] = (
+ text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ )
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1322,7 +1322,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index 07119618543d..c534e9049ec4 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -116,7 +116,7 @@ def save_model_card(
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
- - text: '{validation_prompt if validation_prompt else ' ' }'
+ - text: '{validation_prompt if validation_prompt else " "}'
output:
url:
"image_{i}.png"
@@ -891,9 +891,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(
- isinstance(tok, str) for tok in inserting_toks
- ), "All elements in inserting_toks should be strings."
+ assert all(isinstance(tok, str) for tok in inserting_toks), (
+ "All elements in inserting_toks should be strings."
+ )
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -913,9 +913,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
.to(dtype=self.dtype)
* std_token_embedding
)
- self.embeddings_settings[
- f"original_embeddings_{idx}"
- ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ self.embeddings_settings[f"original_embeddings_{idx}"] = (
+ text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ )
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1648,7 +1648,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py
index ede51775dd8f..3b4cabf075b0 100644
--- a/examples/amused/train_amused.py
+++ b/examples/amused/train_amused.py
@@ -720,7 +720,7 @@ def load_model_hook(models, input_dir):
# Train!
logger.info("***** Running training *****")
logger.info(f" Num training steps = {args.max_train_steps}")
- logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
index aaee133680ea..86f2965636f3 100644
--- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py
+++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
@@ -1138,7 +1138,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
index 01ea59c593a9..59e42fcb80d7 100644
--- a/examples/cogvideo/train_cogvideox_lora.py
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -1159,7 +1159,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py
index df736956485b..81f9527b4703 100644
--- a/examples/community/adaptive_mask_inpainting.py
+++ b/examples/community/adaptive_mask_inpainting.py
@@ -1103,7 +1103,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `default_mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/hd_painter.py b/examples/community/hd_painter.py
index 91ebe076104a..9d7b95b62c6e 100644
--- a/examples/community/hd_painter.py
+++ b/examples/community/hd_painter.py
@@ -686,7 +686,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py
index 292c9aa2bc47..001e4cc5b2cf 100644
--- a/examples/community/img2img_inpainting.py
+++ b/examples/community/img2img_inpainting.py
@@ -362,7 +362,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py
index 129793dae6b0..814694f1e366 100644
--- a/examples/community/llm_grounded_diffusion.py
+++ b/examples/community/llm_grounded_diffusion.py
@@ -1120,7 +1120,7 @@ def latent_lmd_guidance(
if verbose:
logger.info(
- f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
+ f"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
)
try:
@@ -1184,7 +1184,7 @@ def latent_lmd_guidance(
if verbose:
logger.info(
- f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
+ f"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
)
finally:
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index 4bcef10f97c2..af1082e8410b 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -1773,7 +1773,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py
index a66e2b1c7c8a..33eaa9de04cd 100644
--- a/examples/community/pipeline_flux_differential_img2img.py
+++ b/examples/community/pipeline_flux_differential_img2img.py
@@ -488,7 +488,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -496,7 +496,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index 736f00799eae..b9985542ccf7 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -907,12 +907,12 @@ def create_controller(
# reweight
if edit_type == "reweight":
- assert (
- equalizer_words is not None and equalizer_strengths is not None
- ), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
- assert len(equalizer_words) == len(
- equalizer_strengths
- ), "equalizer_words and equalizer_strengths must be of same length."
+ assert equalizer_words is not None and equalizer_strengths is not None, (
+ "To use reweight edit, please specify equalizer_words and equalizer_strengths."
+ )
+ assert len(equalizer_words) == len(equalizer_strengths), (
+ "equalizer_words and equalizer_strengths must be of same length."
+ )
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
return AttentionReweight(
prompts,
diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py
index 9377caf7ba2e..6aebb6c18df7 100644
--- a/examples/community/pipeline_sdxl_style_aligned.py
+++ b/examples/community/pipeline_sdxl_style_aligned.py
@@ -1738,7 +1738,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
index 8a709ab46757..6c63f53e815c 100644
--- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
+++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
@@ -689,7 +689,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index 8480117866cc..6a0ed3523dab 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -1578,7 +1578,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/scheduling_ufogen.py b/examples/community/scheduling_ufogen.py
index 4b1b92ff183a..0b832394cf97 100644
--- a/examples/community/scheduling_ufogen.py
+++ b/examples/community/scheduling_ufogen.py
@@ -288,8 +288,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index db4177999e55..247f2863423f 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
if "lora_down" in kohya_key:
- alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
index 38fe94ed3fe5..61d883fdfb78 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -901,7 +901,7 @@ def load_model_hook(models, input_dir):
unet_ = accelerator.unwrap_model(unet)
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
unet_state_dict = {
- f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")
}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index fe36e9d3abcd..cc6f9d389db3 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
if "lora_down" in kohya_key:
- alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict
diff --git a/examples/custom_diffusion/retrieve.py b/examples/custom_diffusion/retrieve.py
index a28fe344d93b..27f4b4e0dc60 100644
--- a/examples/custom_diffusion/retrieve.py
+++ b/examples/custom_diffusion/retrieve.py
@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
total = 0
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
- with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
- f"{class_data_dir}/images.txt", "w"
- ) as f3:
+ with (
+ open(f"{class_data_dir}/caption.txt", "w") as f1,
+ open(f"{class_data_dir}/urls.txt", "w") as f2,
+ open(f"{class_data_dir}/images.txt", "w") as f3,
+ ):
while total < num_class_images:
images = class_images[count]
count += 1
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index dc21746cb159..140e64a0e075 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -731,18 +731,18 @@ def main(args):
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True)
if args.real_prior:
- assert (
- class_images_dir / "images"
- ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
- assert (
- len(list((class_images_dir / "images").iterdir())) == args.num_class_images
- ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
- assert (
- class_images_dir / "caption.txt"
- ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
- assert (
- class_images_dir / "images.txt"
- ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (class_images_dir / "images").exists(), (
+ f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
+ )
+ assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
+ f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
+ )
+ assert (class_images_dir / "caption.txt").exists(), (
+ f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
+ )
+ assert (class_images_dir / "images.txt").exists(), (
+ f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
+ )
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
args.concepts_list[i] = concept
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index ac21373e478f..8f5509039003 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -1014,7 +1014,7 @@ def load_model_hook(models, input_dir):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
- f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index e81fbe80576d..b46eb2cb4bcf 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -979,7 +979,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index 7b7ae4f46588..99c90c83735c 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -1275,7 +1275,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
index 7bec9c799cae..16b76313e5cf 100644
--- a/examples/dreambooth/train_dreambooth_lora_sana.py
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -1048,7 +1048,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 097eaed8b504..438175c156c8 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -1355,7 +1355,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 15ba7bb14fb2..2bf67dad14d7 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
-# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
+# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
@@ -1271,7 +1271,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index 44c684395849..0d47e62eedea 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
torch_dtype=weight_dtype,
)
pipeline.load_lora_weights(args.output_dir)
- assert (
- pipeline.transformer.config.in_channels == initial_channels * 2
- ), f"{pipeline.transformer.config.in_channels=}"
+ assert pipeline.transformer.config.in_channels == initial_channels * 2, (
+ f"{pipeline.transformer.config.in_channels=}"
+ )
pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
@@ -954,7 +954,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
transformer_lora_state_dict = {
- f'{k.replace("transformer.", "")}': v
+ f"{k.replace('transformer.', '')}": v
for k, v in lora_state_dict.items()
if k.startswith("transformer.") and "lora" in k
}
diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py
index 10c8e095a696..4e541b8d3a02 100644
--- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py
+++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py
@@ -619,7 +619,7 @@ def collate_fn(examples):
optimizer.step()
lr_scheduler.step()
- logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
+ logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0])
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py
index 88a5d93d8edf..e829da848f9b 100644
--- a/examples/research_projects/controlnet/train_controlnet_webdataset.py
+++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -805,21 +805,20 @@ def parse_args(input_args=None):
"--control_type",
type=str,
default="canny",
- help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."),
+ help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
)
parser.add_argument(
"--transformer_layers_per_block",
type=str,
default=None,
- help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."),
+ help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
)
parser.add_argument(
"--old_style_controlnet",
action="store_true",
default=False,
help=(
- "Use the old style controlnet, which is a single transformer layer with"
- " a single head. Defaults to False."
+ "Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
),
)
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
index ab88d4967766..0b9c248ed004 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
index 0297a06f5b2c..f0afa12e9ceb 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
index cdc096190f08..7ef2667b3339 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -683,7 +683,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
index cd1ef265d23e..c960860c8dcf 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -790,7 +790,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
index f3b4602c7fcf..7d0d2ccc4b89 100644
--- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
+++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
@@ -783,7 +783,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb
index 571f1a0323a2..b467ba3a87bc 100644
--- a/examples/research_projects/gligen/demo.ipynb
+++ b/examples/research_projects/gligen/demo.ipynb
@@ -48,16 +48,12 @@
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
- "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
+ "pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
- "text_encoder = CLIPTextModel.from_pretrained(\n",
- " pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
- ")\n",
- "vae = AutoencoderKL.from_pretrained(\n",
- " pretrained_model_name_or_path, subfolder=\"vae\"\n",
- ")\n",
+ "text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
+ "vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n",
@@ -71,9 +67,7 @@
"metadata": {},
"outputs": [],
"source": [
- "unet = UNet2DConditionModel.from_pretrained(\n",
- " '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
- ")"
+ "unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
]
},
{
@@ -117,8 +111,8 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n",
- "prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
- "gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
+ "prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
+ "gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"\n",
@@ -166,7 +160,7 @@
"metadata": {},
"outputs": [],
"source": [
- "diffusers.utils.make_image_grid(images, 4, len(images)//4)"
+ "diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
]
},
{
diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
index fcb927c680a0..197d0f84ee04 100644
--- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
+++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -15,8 +15,8 @@
# limitations under the License.
"""
- Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
- Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
+Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
+Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
import argparse
diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
index 0f507b26d6a8..57c555e43fd8 100644
--- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
+++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",")
- assert all(
- x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
- ), "Instance data dir and prompt inputs are not of the same length."
+ assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
+ "Instance data dir and prompt inputs are not of the same length."
+ )
if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts
- assert num_of_validation_prompts == len(
- negative_validation_prompts
- ), "The length of negative prompts for validation is greater than the number of validation prompts."
+ assert num_of_validation_prompts == len(negative_validation_prompts), (
+ "The length of negative prompts for validation is greater than the number of validation prompts."
+ )
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
index 57ad77477b0d..7aad64ecb1dd 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -830,9 +830,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
index e10564fa59ef..5f0710e85319 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
+++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -886,9 +886,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index 19c1f30d82da..51668a61cdc2 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -663,8 +663,7 @@ def check_inputs(
self.check_image(image, prompt, prompt_embeds)
else:
raise ValueError(
- f"You have passed a list of images of length {len(image_pair)}."
- f"Make sure the list size equals to two."
+ f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
)
# Check `controlnet_conditioning_scale`
diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py
index 9719585d3dfb..6ae1a9a6c611 100644
--- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py
+++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py
@@ -173,7 +173,7 @@ def print_loss_closure(step, loss):
if not dataloader_exception:
xm.wait_device_ops()
total_time = time.time() - last_time
- print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
+ print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
else:
print("dataloader exception happen, skip result")
return
@@ -622,7 +622,7 @@ def collate_fn(examples):
num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal():
print("***** Running training *****")
- print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
+ print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
)
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
index 5f7ca2262dcc..926b52e879db 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -1057,7 +1057,7 @@ def load_model_hook(models, input_dir):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
- f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
index 663dbbf99473..d5d773a48b2a 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -1021,7 +1021,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
index 2a9801038999..a28bc3ee7c6c 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
-# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
+# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
@@ -1336,7 +1336,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
index bab86bf21a76..880021e04a24 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
@@ -750,7 +750,7 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
index e883d8ef95a7..5ad6aa29f6f3 100644
--- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
+++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
@@ -765,7 +765,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index d7b52307f048..9bcef187cc83 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -767,7 +767,7 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 4a28ff3ed228..c420d82baa17 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -910,9 +910,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index 5f38390c3193..657fc40eec23 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -965,12 +965,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
- accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
- index_no_updates_2
- ] = orig_embeds_params_2[index_no_updates_2]
+ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
+ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
+ orig_embeds_params_2[index_no_updates_2]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py
index 664a7f7365b0..6fb0179140c4 100644
--- a/examples/vqgan/test_vqgan.py
+++ b/examples/vqgan/test_vqgan.py
@@ -177,7 +177,7 @@ def test_vqmodel_checkpointing(self):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir}
--seed=0
""".split()
@@ -262,7 +262,7 @@ def test_vqmodel_checkpointing_use_ema(self):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir}
--use_ema
--seed=0
@@ -377,7 +377,7 @@ def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoi
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
--checkpointing_steps=2
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--checkpoints_total_limit=2
--seed=0
""".split()
diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py
index 992722fa7a78..33d234da52d7 100644
--- a/examples/vqgan/train_vqgan.py
+++ b/examples/vqgan/train_vqgan.py
@@ -653,15 +653,15 @@ def main():
try:
# Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1]
- assert isinstance(
- timm_centercrop_transform, transforms.CenterCrop
- ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
+ f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ )
timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization
timm_model_normalization = timm_transform.transforms[-1]
- assert isinstance(
- timm_model_normalization, transforms.Normalize
- ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ assert isinstance(timm_model_normalization, transforms.Normalize), (
+ f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ )
except AssertionError as e:
raise NotImplementedError(e)
# Enable flash attention if asked
diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py
index 21be29dfdb99..ddd1bf508b6d 100644
--- a/scripts/convert_amused.py
+++ b/scripts/convert_amused.py
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
# assert (old_output == new_output).all()
print("skipping full vae equivalence check")
- print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
+ print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
return new_vae
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
new file mode 100644
index 000000000000..e99562898b52
--- /dev/null
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -0,0 +1,249 @@
+"""
+Convert a CogView4 checkpoint to the Diffusers format.
+
+This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
+with the Diffusers library.
+
+Example usage:
+ python scripts/convert_cogview4_to_diffusers.py \
+ --transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
+ --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
+ --output_path "/raid/yiyi/CogBiew4-6B" \
+ --dtype "bf16"
+
+Arguments:
+ --transformer_checkpoint_path: Path to Transformer state dict.
+ --vae_checkpoint_path: Path to VAE state dict.
+ --output_path: The path to save the converted model.
+ --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
+ --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
+
+ Default is "bf16" because CogView4 uses bfloat16 for Training.
+
+Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
+"""
+
+import argparse
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+from transformers import PreTrainedTokenizerFast, GlmForCausalLM
+
+from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView4Pipeline, CogView3PlusTransformer2DModel
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
+parser.add_argument("--vae_checkpoint_path", default=None, type=str)
+parser.add_argument("--output_path", required=True, type=str)
+parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
+parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
+parser.add_argument("--dtype", type=str, default="bf16")
+
+args = parser.parse_args()
+
+
+# this is specific to `AdaLayerNormContinuous`:
+# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
+def swap_scale_shift(weight, dim):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")
+ original_state_dict = original_state_dict["module"]
+ original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
+
+ new_state_dict = {}
+
+ # Convert patch_embed
+ new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
+ new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
+ new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
+ new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
+
+ # Convert time_condition_embed
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_embed.0.weight"
+ )
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "time_embed.0.bias"
+ )
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_embed.2.weight"
+ )
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "time_embed.2.bias"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
+ "label_emb.0.0.weight"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
+ "label_emb.0.0.bias"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
+ "label_emb.0.2.weight"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
+ "label_emb.0.2.bias"
+ )
+
+ # Convert transformer blocks, for cogview4 is 28 blocks
+ for i in range(28):
+ block_prefix = f"transformer_blocks.{i}."
+ old_prefix = f"transformer.layers.{i}."
+ adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
+ new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
+ new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
+
+ qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
+ qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
+
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
+
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
+ old_prefix + "attention.dense.weight"
+ )
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
+ old_prefix + "attention.dense.bias"
+ )
+
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_h_to_4h.weight"
+ )
+ new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_h_to_4h.bias"
+ )
+ new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_4h_to_h.weight"
+ )
+ new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
+
+ # Convert final norm and projection
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
+ )
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
+ )
+ new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
+ new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
+
+ return new_state_dict
+
+
+def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+ return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
+
+
+def main(args):
+ if args.dtype == "fp16":
+ dtype = torch.float16
+ elif args.dtype == "bf16":
+ dtype = torch.bfloat16
+ elif args.dtype == "fp32":
+ dtype = torch.float32
+ else:
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
+
+ transformer = None
+ vae = None
+
+ if args.transformer_checkpoint_path is not None:
+ converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
+ args.transformer_checkpoint_path
+ )
+ transformer = CogView3PlusTransformer2DModel(
+ patch_size = 2,
+ in_channels = 16,
+ num_layers = 28,
+ attention_head_dim= 128,
+ num_attention_heads = 32,
+ out_channels = 16,
+ text_embed_dim= 4096,
+ time_embed_dim = 512,
+ condition_dim= 256,
+ pos_embed_max_size = 128,
+ )
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+ if dtype is not None:
+ # Original checkpoint data type will be preserved
+ transformer = transformer.to(dtype=dtype)
+
+ if args.vae_checkpoint_path is not None:
+ vae_config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ("DownEncoderBlock2D",) * 4,
+ "up_block_types": ("UpDecoderBlock2D",) * 4,
+ "block_out_channels": (128, 512, 1024, 1024),
+ "layers_per_block": 3,
+ "act_fn": "silu",
+ "latent_channels": 16,
+ "norm_num_groups": 32,
+ "sample_size": 1024,
+ "scaling_factor": 1.0,
+ "force_upcast": True,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "mid_block_add_attention": False,
+ }
+ converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ if dtype is not None:
+ vae = vae.to(dtype=dtype)
+
+ text_encoder_id = 'THUDM/glm-4-9b-hf'
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
+ text_encoder = GlmForCausalLM.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir, torch_dtype=torch.bfloat16 if dtype=="bf16" else torch.float32)
+ # Apparently, the conversion does not work anymore without this :shrug:
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ scheduler = CogVideoXDDIMScheduler.from_config(
+ {
+ "snr_shift_scale": 4.0,
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": False,
+ "num_train_timesteps": 1000,
+ "prediction_type": "v_prediction",
+ "rescale_betas_zero_snr": True,
+ "set_alpha_to_one": True,
+ "timestep_spacing": "trailing",
+ }
+ )
+
+ pipe = CogView4Pipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ # This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
+ # save some memory used for model loading.
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py
index 0f8b4ddca8ef..2b918280ca05 100644
--- a/scripts/convert_consistency_to_diffusers.py
+++ b/scripts/convert_consistency_to_diffusers.py
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
- old_prefix = f"output_blocks.{current_layer-1}.1"
+ old_prefix = f"output_blocks.{current_layer - 1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1):
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
- old_prefix = f"output_blocks.{current_layer-1}.2"
+ old_prefix = f"output_blocks.{current_layer - 1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
diff --git a/scripts/convert_dance_diffusion_to_diffusers.py b/scripts/convert_dance_diffusion_to_diffusers.py
index ce69bfe2bfc8..3d64a77fae7d 100755
--- a/scripts/convert_dance_diffusion_to_diffusers.py
+++ b/scripts/convert_dance_diffusion_to_diffusers.py
@@ -260,9 +260,9 @@ def main(args):
model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path):
- assert (
- model_name == args.model_path
- ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ assert model_name == args.model_path, (
+ f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ )
args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"]
@@ -289,9 +289,9 @@ def main(args):
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items():
- assert (
- diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
- ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
+ f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ )
if key == "time_proj.weight":
value = value.squeeze()
diff --git a/scripts/convert_diffusers_to_original_sdxl.py b/scripts/convert_diffusers_to_original_sdxl.py
index 648d0376f72e..1aa792b3f06a 100644
--- a/scripts/convert_diffusers_to_original_sdxl.py
+++ b/scripts/convert_diffusers_to_original_sdxl.py
@@ -52,18 +52,18 @@
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i > 0:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(4):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i < 2:
@@ -75,12 +75,12 @@
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
@@ -89,7 +89,7 @@
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -137,20 +137,20 @@ def convert_unet_state_dict(unet_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"up.{3-i}.upsample."
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
- sd_mid_res_prefix = f"mid.block_{i+1}."
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py
index d1b7df070c43..049dda7d42a7 100644
--- a/scripts/convert_diffusers_to_original_stable_diffusion.py
+++ b/scripts/convert_diffusers_to_original_stable_diffusion.py
@@ -47,36 +47,36 @@
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
@@ -85,7 +85,7 @@
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -133,20 +133,20 @@ def convert_unet_state_dict(unet_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"up.{3-i}.upsample."
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
- sd_mid_res_prefix = f"mid.block_{i+1}."
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
diff --git a/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
index 1c8383690890..5cef46c98983 100644
--- a/scripts/convert_hunyuandit_controlnet_to_diffusers.py
+++ b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
@@ -21,9 +21,9 @@ def main(args):
model_config = HunyuanDiT2DControlNetModel.load_config(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
)
- model_config[
- "use_style_cond_and_image_meta_size"
- ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ model_config["use_style_cond_and_image_meta_size"] = (
+ args.use_style_cond_and_image_meta_size
+ ) ### version <= v1.1: True; version >= v1.2: False
print(model_config)
for key in state_dict:
diff --git a/scripts/convert_hunyuandit_to_diffusers.py b/scripts/convert_hunyuandit_to_diffusers.py
index da3af8333ee3..65fcccb22a1a 100644
--- a/scripts/convert_hunyuandit_to_diffusers.py
+++ b/scripts/convert_hunyuandit_to_diffusers.py
@@ -13,15 +13,14 @@ def main(args):
state_dict = state_dict[args.load_key]
except KeyError:
raise KeyError(
- f"{args.load_key} not found in the checkpoint."
- f"Please load from the following keys:{state_dict.keys()}"
+ f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
)
device = "cuda"
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
- model_config[
- "use_style_cond_and_image_meta_size"
- ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ model_config["use_style_cond_and_image_meta_size"] = (
+ args.use_style_cond_and_image_meta_size
+ ) ### version <= v1.1: True; version >= v1.2: False
# input_size -> sample_size, text_dim -> cross_attention_dim
for key in state_dict:
diff --git a/scripts/convert_k_upscaler_to_diffusers.py b/scripts/convert_k_upscaler_to_diffusers.py
index 62abedd73785..cff845ef8099 100644
--- a/scripts/convert_k_upscaler_to_diffusers.py
+++ b/scripts/convert_k_upscaler_to_diffusers.py
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}"
- cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_index = 1 if not attention.add_self_attention else 2
idx = (
n * attention_idx + cross_attention_index
if block_type == "up"
else n * attention_idx + cross_attention_index + 1
)
- cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_prefix = f"{block_prefix}.{idx}"
diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint(
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
block_out_channels = original_config["channels"]
- assert (
- len(set(original_config["depths"])) == 1
- ), "UNet2DConditionModel currently do not support blocks with different number of layers"
+ assert len(set(original_config["depths"])) == 1, (
+ "UNet2DConditionModel currently do not support blocks with different number of layers"
+ )
layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"]
diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py
index 9727deeb6b0c..64e4f69eac17 100644
--- a/scripts/convert_mochi_to_diffusers.py
+++ b/scripts/convert_mochi_to_diffusers.py
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.0.weight"
+ f"blocks.0.{i + 1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.0.bias"
+ f"blocks.0.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.2.weight"
+ f"blocks.0.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.2.bias"
+ f"blocks.0.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.3.weight"
+ f"blocks.0.{i + 1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.3.bias"
+ f"blocks.0.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.5.weight"
+ f"blocks.0.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.5.bias"
+ f"blocks.0.{i + 1}.stack.5.bias"
)
# Convert up_blocks (MochiUpBlock3D)
@@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
for block in range(3):
for i in range(down_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.0.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.0.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.2.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.2.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.3.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.3.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.5.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.5.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.proj.weight"
+ f"blocks.{block + 1}.proj.weight"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
+ f"blocks.{block + 1}.proj.bias"
)
- new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
@@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.0.weight"
+ f"layers.{i + 1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.0.bias"
+ f"layers.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.2.weight"
+ f"layers.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.2.bias"
+ f"layers.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.3.weight"
+ f"layers.{i + 1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.3.bias"
+ f"layers.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.5.weight"
+ f"layers.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.5.bias"
+ f"layers.{i + 1}.stack.5.bias"
)
# Convert down_blocks (MochiDownBlock3D)
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
for block in range(3):
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.0.weight"
+ f"layers.{block + 4}.layers.0.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.0.bias"
+ f"layers.{block + 4}.layers.0.bias"
)
for i in range(down_block_layers[block]):
# Convert resnets
- new_state_dict[
- f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
- ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
+ encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
+ )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.0.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.2.weight"
+ f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.2.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
+ encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
)
- new_state_dict[
- f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
- ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.3.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.5.weight"
+ f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.5.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
)
# Convert attentions
- qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
+ qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
)
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
# Convert resnets
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.0.weight"
+ f"layers.{i + 7}.stack.0.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.0.bias"
+ f"layers.{i + 7}.stack.0.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.2.weight"
+ f"layers.{i + 7}.stack.2.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.2.bias"
+ f"layers.{i + 7}.stack.2.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.3.weight"
+ f"layers.{i + 7}.stack.3.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.3.bias"
+ f"layers.{i + 7}.stack.3.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.5.weight"
+ f"layers.{i + 7}.stack.5.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.5.bias"
+ f"layers.{i + 7}.stack.5.bias"
)
# Convert attentions
- qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
+ qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.attn.out.weight"
+ f"layers.{i + 7}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.attn.out.bias"
+ f"layers.{i + 7}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.norm.weight"
+ f"layers.{i + 7}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.norm.bias"
+ f"layers.{i + 7}.attn_block.norm.bias"
)
# Convert output layers
diff --git a/scripts/convert_original_audioldm2_to_diffusers.py b/scripts/convert_original_audioldm2_to_diffusers.py
index ea9c02d53815..c1534fbba643 100644
--- a/scripts/convert_original_audioldm2_to_diffusers.py
+++ b/scripts/convert_original_audioldm2_to_diffusers.py
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py
index 797d19826091..c67024da0b73 100644
--- a/scripts/convert_original_audioldm_to_diffusers.py
+++ b/scripts/convert_original_audioldm_to_diffusers.py
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_original_musicldm_to_diffusers.py b/scripts/convert_original_musicldm_to_diffusers.py
index 6db9dbdfdb74..3fbce3a7c84f 100644
--- a/scripts/convert_original_musicldm_to_diffusers.py
+++ b/scripts/convert_original_musicldm_to_diffusers.py
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py
index a0f9d0f87d90..b33c8b0608e7 100644
--- a/scripts/convert_stable_audio.py
+++ b/scripts/convert_stable_audio.py
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
# get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
- new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
+ new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
if "encoder" in new_key:
for i in range(3):
- new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
- new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
- new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
else:
for i in range(2, 5):
- new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
- new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
- new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1:
- new_key = new_key.replace(f"block.{idx-1}", "snake1")
+ new_key = new_key.replace(f"block.{idx - 1}", "snake1")
elif idx == num_autoencoder_layers + 2:
- new_key = new_key.replace(f"block.{idx-1}", "conv2")
+ new_key = new_key.replace(f"block.{idx - 1}", "conv2")
else:
new_key = new_key
diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py
index 3243ce294b26..e46410ccb3bd 100644
--- a/scripts/convert_svd_to_diffusers.py
+++ b/scripts/convert_svd_to_diffusers.py
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
- new_checkpoint[
- f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
- ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
+ new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
+ unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
+ )
if len(attentions):
paths = renew_attention_paths(attentions)
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
- new_checkpoint[
- f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
- ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
+ new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
+ unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
+ )
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py
index 7da6b4094986..fe62d18faff0 100644
--- a/scripts/convert_vq_diffusion_to_diffusers.py
+++ b/scripts/convert_vq_diffusion_to_diffusers.py
@@ -51,9 +51,9 @@
def vqvae_model_from_original_config(original_config):
- assert (
- original_config["target"] in PORTED_VQVAES
- ), f"{original_config['target']} has not yet been ported to diffusers."
+ assert original_config["target"] in PORTED_VQVAES, (
+ f"{original_config['target']} has not yet been ported to diffusers."
+ )
original_config = original_config["params"]
@@ -464,15 +464,15 @@ def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_p
def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config
):
- assert (
- original_diffusion_config["target"] in PORTED_DIFFUSIONS
- ), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
- assert (
- original_transformer_config["target"] in PORTED_TRANSFORMERS
- ), f"{original_transformer_config['target']} has not yet been ported to diffusers."
- assert (
- original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
- ), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
+ assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
+ f"{original_diffusion_config['target']} has not yet been ported to diffusers."
+ )
+ assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
+ f"{original_transformer_config['target']} has not yet been ported to diffusers."
+ )
+ assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
+ f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
+ )
original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"]
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 5e9ab2a117d1..1b19f9161ca1 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -275,6 +275,7 @@
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
+ "CogView4Pipeline",
"CycleDiffusionPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
@@ -764,6 +765,7 @@
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
+ CogView4Pipeline,
CycleDiffusionPipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index 7b691d1fe16e..0870f059e8f0 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -292,8 +292,7 @@ def set_ip_adapter_scale(self, scale):
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
- f"Cannot assign {len(scale_configs)} scale_configs to "
- f"{len(attn_processor.scale)} IP-Adapter."
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
@@ -592,8 +591,7 @@ def LinearStrengthModel(start, finish, size):
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
- f"Cannot assign {len(scale_configs)} scale_configs to "
- f"{len(attn_processor.scale)} IP-Adapter."
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index e064aeba43b6..fecf5170a489 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -177,9 +177,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Store DoRA scale if present.
if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
- unet_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
# Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -199,13 +199,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
- te_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
elif lora_name.startswith("lora_te2_"):
- te2_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
# Store alpha if present.
if lora_name_alpha in state_dict:
@@ -684,21 +684,21 @@ def swap_scale_shift(weight):
for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
+ )
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
+ )
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
+ )
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
+ )
## time_text_embed.text_embedder <- vector_in
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
@@ -720,21 +720,21 @@ def swap_scale_shift(weight):
# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
+ )
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
+ )
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
+ )
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
+ )
# context_embedder
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index a3d006f18994..b2b0fe7b405f 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -173,7 +173,7 @@ def load_state_dict(
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
)
diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py
index e208a1c10ed4..ef96f91afb36 100644
--- a/src/diffusers/models/transformers/transformer_2d.py
+++ b/src/diffusers/models/transformers/transformer_2d.py
@@ -210,9 +210,9 @@ def _init_continuous_input(self, norm_type):
def _init_vectorized_inputs(self, norm_type):
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
- assert (
- self.config.num_vector_embeds is not None
- ), "Transformer2DModel over discrete input must provide num_embed"
+ assert self.config.num_vector_embeds is not None, (
+ "Transformer2DModel over discrete input must provide num_embed"
+ )
self.height = self.config.sample_size
self.width = self.config.sample_size
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index ce291e5ceb45..a13714481dc1 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -154,6 +154,7 @@
"CogVideoXFunControlPipeline",
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
+ _import_structure["cogview4"] = ["CogView4Pipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -496,6 +497,8 @@
CogVideoXVideoToVideoPipeline,
)
from .cogview3 import CogView3PlusPipeline
+ from .cogview4 import CogView4Pipeline
+
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index b8b5d07af529..e36e36304bd8 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -788,7 +788,7 @@ def check_inputs(
if transcription is None:
if self.text_encoder_2.config.model_type == "vits":
- raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
+ raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
elif transcription is not None and (
not isinstance(transcription, str) and not isinstance(transcription, list)
):
diff --git a/src/diffusers/pipelines/cogview4/__init__.py b/src/diffusers/pipelines/cogview4/__init__.py
new file mode 100644
index 000000000000..5a535b3feb4b
--- /dev/null
+++ b/src/diffusers/pipelines/cogview4/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["CogView4PlusPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_cogview4 import CogView4Pipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
new file mode 100644
index 000000000000..48d411349703
--- /dev/null
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -0,0 +1,675 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import GlmModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from .pipeline_output import CogView4PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogView4Pipeline
+
+ >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CogView4Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using CogView4.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. CogView4 uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CogView4Transformer2DModel`]):
+ A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: GlmModel,
+ text_encoder: GlmModel,
+ vae: AutoencoderKL,
+ transformer: CogView3PlusTransformer2DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _get_glm_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 1024,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 224,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_sequence_length (`int`, defaults to `224`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt is None:
+ negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 5.0,
+ num_images_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ output_type: str = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ CogView4Pipeline: int = 224,
+ ) -> Union[CogView4PipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. If not provided, it is set to 1024.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. If not provided it is set to 1024.
+ num_inference_steps (`int`, *optional*, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to `1`):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `224`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogview3.pipeline_CogView4.CogView3PipelineOutput`] or `tuple`:
+ [`~pipelines.cogview3.pipeline_CogView4.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ self.do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare additional timestep conditions
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
+
+ if self.do_classifier_free_guidance:
+ original_size = torch.cat([original_size, original_size])
+ target_size = torch.cat([target_size, target_size])
+ crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
+
+ original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return CogView4PipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_output.py b/src/diffusers/pipelines/cogview4/pipeline_output.py
new file mode 100644
index 000000000000..4efec1310845
--- /dev/null
+++ b/src/diffusers/pipelines/cogview4/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class CogView4PipelineOutput(BaseOutput):
+ """
+ Output class for CogView3 pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 875dbed38c4d..e7a84d4b6dfb 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -650,7 +650,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -658,7 +658,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index 38e63f56b2f3..948728d56afc 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -743,7 +743,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -751,7 +751,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
@@ -1644,7 +1644,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
index 56f6c9149c6e..98769f247737 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
@@ -726,7 +726,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -734,7 +734,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index bfc96eeb8dab..23b16d8c2452 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -507,7 +507,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -515,7 +515,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index 2be8e75973ef..ed5b08a03cb7 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -485,7 +485,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -493,7 +493,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py
index dc0071a494e3..8ea5eb7dd575 100644
--- a/src/diffusers/pipelines/free_noise_utils.py
+++ b/src/diffusers/pipelines/free_noise_utils.py
@@ -341,9 +341,9 @@ def _encode_prompt_free_noise(
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
- negative_prompt_interpolation_embeds[
- start_frame : end_frame + 1
- ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
+ negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
+ self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
+ )
prompt_embeds = prompt_interpolation_embeds
negative_prompt_embeds = negative_prompt_interpolation_embeds
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index e653b8266f19..5f8db26eef54 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
"""
_load_connected_pipes = True
- model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
+ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
_exclude_from_cpu_offload = ["prior_prior"]
def __init__(
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index cce5f0b3d5bc..769c834ec3cc 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -579,7 +579,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
index bc7a4b57affd..6d89f16765a3 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
@@ -604,7 +604,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -612,7 +612,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
@@ -1340,7 +1340,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
index 33abfb0be89f..db652989cfc1 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
@@ -683,7 +683,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -691,7 +691,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1191,7 +1191,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index fdf3df2f4d6a..8b06bdc9c969 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -737,7 +737,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -745,7 +745,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1509,7 +1509,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 55a9f47145a2..288f269a6563 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -575,7 +575,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index 23f1279e203d..62c2a57161cf 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -285,9 +285,7 @@ def maybe_raise_or_warn(
model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj):
- raise ValueError(
- f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
- )
+ raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 527724d1de1a..8dfac6e74276 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -1395,8 +1395,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
if load_components_from_hub and not trust_remote_code:
raise ValueError(
- f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
- f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
+ f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
+ f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py
index 9d9f9d9b2ab1..dd25945590cd 100644
--- a/src/diffusers/pipelines/shap_e/renderer.py
+++ b/src/diffusers/pipelines/shap_e/renderer.py
@@ -983,9 +983,9 @@ def decode_to_mesh(
fields = torch.cat(fields, dim=1)
fields = fields.float()
- assert (
- len(fields.shape) == 3 and fields.shape[-1] == 1
- ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
+ assert len(fields.shape) == 3 and fields.shape[-1] == 1, (
+ f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
+ )
fields = fields.reshape(1, *([grid_size] * 3))
@@ -1039,9 +1039,9 @@ def decode_to_mesh(
textures = textures.float()
# 3.3 augument the mesh with texture data
- assert len(textures.shape) == 3 and textures.shape[-1] == len(
- texture_channels
- ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
+ assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), (
+ f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
+ )
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
index 5d773b614a5c..1b87c02df029 100644
--- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
@@ -584,7 +584,7 @@ def __call__(
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
raise ValueError(
- f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
+ f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
)
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index abcba926160a..dd659306e002 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -335,7 +335,7 @@ def _generate(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index ddd2e27dedaf..f2e1d87be87e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -475,7 +475,7 @@ def __call__(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 6f4e7f358952..0f7be1a1bbcd 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -660,7 +660,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -668,7 +668,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1226,7 +1226,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index 7857bc58a8ad..e0748943ffff 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -401,7 +401,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index c6967bc393b5..42db88b03049 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -600,7 +600,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index dae4540ebe00..f9b6dcbf5ad2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -740,7 +740,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index 67791c17a74b..9ede52153a2e 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -1152,7 +1152,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input."
)
elif num_channels_transformer != 16:
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 920caf4d24a1..835c0af800da 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -741,7 +741,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -749,7 +749,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1509,7 +1509,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py
index 6ec3885fe373..db57db70d0d4 100644
--- a/src/diffusers/quantizers/base.py
+++ b/src/diffusers/quantizers/base.py
@@ -215,19 +215,15 @@ def _dequantize(self, model):
)
@abstractmethod
- def _process_model_before_weight_loading(self, model, **kwargs):
- ...
+ def _process_model_before_weight_loading(self, model, **kwargs): ...
@abstractmethod
- def _process_model_after_weight_loading(self, model, **kwargs):
- ...
+ def _process_model_after_weight_loading(self, model, **kwargs): ...
@property
@abstractmethod
- def is_serializable(self):
- ...
+ def is_serializable(self): ...
@property
@abstractmethod
- def is_trainable(self):
- ...
+ def is_trainable(self): ...
diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py
index 653171638ccf..c946fa1681c0 100644
--- a/src/diffusers/schedulers/scheduling_consistency_models.py
+++ b/src/diffusers/schedulers/scheduling_consistency_models.py
@@ -203,8 +203,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index eb40d79b9f60..3a4eaf4e5a72 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -279,8 +279,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index 20ad7a4c927d..64195be141f6 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -289,8 +289,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index 686b686f6870..2a0cce7bf146 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -413,8 +413,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py
index 5d60383142a4..77770ab2066c 100644
--- a/src/diffusers/schedulers/scheduling_tcd.py
+++ b/src/diffusers/schedulers/scheduling_tcd.py
@@ -431,8 +431,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index 2474ed5c2114..868bb6b15e0a 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -241,7 +241,7 @@ def _set_state_dict_into_text_encoder(
"""
text_encoder_state_dict = {
- f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
+ f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
}
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
@@ -576,7 +576,7 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
if self.temp_stored_params is None:
- raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py
index f482deddd2f4..4f001b3047d6 100644
--- a/src/diffusers/utils/deprecation_utils.py
+++ b/src/diffusers/utils/deprecation_utils.py
@@ -40,7 +40,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
line_number = call_frame.lineno
function = call_frame.function
key, value = next(iter(deprecated_kwargs.items()))
- raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
+ raise TypeError(f"{function} in {filename} line {line_number - 1} got an unexpected keyword argument `{key}`")
if len(values) == 0:
return
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 9b36be9e0604..bc466046c998 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class CogView4Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index 3014efebc82e..d0874cca0c86 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -228,7 +228,7 @@
_wandb_available = importlib.util.find_spec("wandb") is not None
try:
_wandb_version = importlib_metadata.version("wandb")
- logger.debug(f"Successfully imported wandb version {_wandb_version }")
+ logger.debug(f"Successfully imported wandb version {_wandb_version}")
except importlib_metadata.PackageNotFoundError:
_wandb_available = False
diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py
index 6f93450c410c..b96e0e222cb1 100644
--- a/src/diffusers/utils/logging.py
+++ b/src/diffusers/utils/logging.py
@@ -60,8 +60,7 @@ def _get_default_logging_level() -> int:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
- f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
- f"has to be one of: { ', '.join(log_levels.keys()) }"
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
)
return _default_log_level
diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py
index 62b114ba67e3..8efd6e6df51e 100644
--- a/src/diffusers/utils/state_dict_utils.py
+++ b/src/diffusers/utils/state_dict_utils.py
@@ -329,7 +329,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
- alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index 3ae74cddcbbf..9e3527650cc8 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -796,7 +796,7 @@ def pytest_terminal_summary_main(tr, id):
f.write("slowest durations\n")
for i, rep in enumerate(dlist):
if rep.duration < durations_min:
- f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
+ f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
break
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
@@ -941,7 +941,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
process.join(timeout=timeout)
if results["error"] is not None:
- test_case.fail(f'{results["error"]}')
+ test_case.fail(f"{results['error']}")
class CaptureLogger:
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 4fc14804475a..fabd5952a710 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -277,9 +277,9 @@ def test_one_request_upon_cached(self):
)
download_requests = [r.method for r in m.request_history]
- assert (
- download_requests.count("HEAD") == 3
- ), "3 HEAD requests one for config, one for model, and one for shard index file."
+ assert download_requests.count("HEAD") == 3, (
+ "3 HEAD requests one for config, one for model, and one for shard index file."
+ )
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
@@ -291,9 +291,9 @@ def test_one_request_upon_cached(self):
)
cache_requests = [r.method for r in m.request_history]
- assert (
- "HEAD" == cache_requests[0] and len(cache_requests) == 2
- ), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
+ assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
+ "We should call only `model_info` to check for commit hash and knowing if shard index is present."
+ )
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py
index 2531381dc7c8..aef08c1f3b68 100644
--- a/tests/models/transformers/test_models_transformer_sd3.py
+++ b/tests/models/transformers/test_models_transformer_sd3.py
@@ -91,9 +91,9 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()
- assert (
- model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
- ), "xformers is not enabled"
+ assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
+ "xformers is not enabled"
+ )
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
@@ -165,9 +165,9 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()
- assert (
- model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
- ), "xformers is not enabled"
+ assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
+ "xformers is not enabled"
+ )
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index 57f6e4ee440b..804b01a26971 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -651,22 +651,22 @@ def test_model_xattn_mask(self, mask_dtype):
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
- assert full_cond_keepallmask_out.allclose(
- full_cond_out, rtol=1e-05, atol=1e-05
- ), "a 'keep all' mask should give the same result as no mask"
+ assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
+ "a 'keep all' mask should give the same result as no mask"
+ )
trunc_cond = cond[:, :-1, :]
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
- assert not trunc_cond_out.allclose(
- full_cond_out, rtol=1e-05, atol=1e-05
- ), "discarding the last token from our cond should change the result"
+ assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
+ "discarding the last token from our cond should change the result"
+ )
batch, tokens, _ = cond.shape
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
- assert masked_cond_out.allclose(
- trunc_cond_out, rtol=1e-05, atol=1e-05
- ), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
+ assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), (
+ "masking the last token from our cond should be equivalent to truncating that token out of the condition"
+ )
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
@@ -694,9 +694,9 @@ def test_model_xattn_padding(self):
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
- assert trunc_mask_out.allclose(
- keeplast_out
- ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
+ assert trunc_mask_out.allclose(keeplast_out), (
+ "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
+ )
def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing
@@ -1111,12 +1111,12 @@ def test_load_attn_procs_raise_warning(self):
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
- assert not torch.allclose(
- non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
- ), "LoRA injected UNet should produce different results."
- assert torch.allclose(
- lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
- ), "Loading from a saved checkpoint should produce identical results."
+ assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
+ "LoRA injected UNet should produce different results."
+ )
+ assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
+ "Loading from a saved checkpoint should produce identical results."
+ )
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py
index 3397ca9e394a..071194c59ead 100644
--- a/tests/others/test_image_processor.py
+++ b/tests/others/test_image_processor.py
@@ -65,9 +65,9 @@ def test_vae_image_processor_pt(self):
)
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -78,9 +78,9 @@ def test_vae_image_processor_np(self):
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -93,9 +93,9 @@ def test_vae_image_processor_pil(self):
for i, o in zip(input_pil, out):
in_np = np.array(i)
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_preprocess_input_3d(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
@@ -293,9 +293,9 @@ def test_vae_image_processor_resize_pt(self):
scale = 2
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
exp_pt_shape = (b, c, h // scale, w // scale)
- assert (
- out_pt.shape == exp_pt_shape
- ), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
+ assert out_pt.shape == exp_pt_shape, (
+ f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
+ )
def test_vae_image_processor_resize_np(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
@@ -305,6 +305,6 @@ def test_vae_image_processor_resize_np(self):
input_np = self.to_np(input_pt)
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
exp_np_shape = (b, h // scale, w // scale, c)
- assert (
- out_np.shape == exp_np_shape
- ), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
+ assert out_np.shape == exp_np_shape, (
+ f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
+ )
diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py
index f28d8708d309..f348008ae4de 100644
--- a/tests/pipelines/amused/test_amused.py
+++ b/tests/pipelines/amused/test_amused.py
@@ -124,8 +124,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
+ def test_inference_batch_single_identical(self): ...
@slow
diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py
index 2699bbe7f56f..942735f15707 100644
--- a/tests/pipelines/amused/test_amused_img2img.py
+++ b/tests/pipelines/amused/test_amused_img2img.py
@@ -126,8 +126,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
+ def test_inference_batch_single_identical(self): ...
@slow
diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py
index 645379a7eab1..541b988f1798 100644
--- a/tests/pipelines/amused/test_amused_inpaint.py
+++ b/tests/pipelines/amused/test_amused_inpaint.py
@@ -130,8 +130,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
+ def test_inference_batch_single_identical(self): ...
@slow
diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
index 14bc588df905..05ed1d5b1864 100644
--- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
+++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
@@ -138,9 +138,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -154,15 +154,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@unittest.skip("xformers attention processor does not exist for AuraFlow")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
index 7e85cef65129..9d4e8df170cf 100644
--- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py
+++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
@@ -193,6 +193,6 @@ def test_blipdiffusion(self):
[0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index 884ddfb2a95a..b1e27f67c796 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -293,9 +293,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -309,15 +309,15 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
index 2a51fc65798c..d767de23f840 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
@@ -297,9 +297,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -313,12 +313,12 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
index f7e1fe7fd6c7..32983814738b 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
@@ -316,9 +316,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -332,15 +332,15 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
index 4d836cb5e2a4..b1ac8cbd90ed 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
@@ -298,9 +298,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -314,12 +314,12 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
index 99a238caf53a..0563d9eb2277 100644
--- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
+++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
@@ -217,6 +217,6 @@ def test_blipdiffusion_controlnet(self):
assert image.shape == (1, 16, 16, 4)
expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index 5e856b125f32..b0a99b0dbbca 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -173,9 +173,9 @@ def test_controlnet_flux(self):
[0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
index 02270d7fbd00..6c0d947c5266 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
@@ -194,9 +194,9 @@ def test_fused_qkv_projections(self):
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -210,15 +210,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
index 5500c7bd1c81..6fbaf0e33f54 100644
--- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
+++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
@@ -155,9 +155,9 @@ def test_controlnet_hunyuandit(self):
[0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
index 2cd57ce56d52..d9f5dcad7d61 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
@@ -194,9 +194,9 @@ def test_controlnet_inpaint_sd3(self):
[0.51708984, 0.7421875, 0.4580078, 0.6435547, 0.65625, 0.43603516, 0.5151367, 0.65722656, 0.60839844]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 7527d17af32a..4fc8f07f3ee7 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -199,9 +199,9 @@ def run_pipe(self, components, use_sd35=False):
else:
expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
def test_controlnet_sd3(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py
index 30883ac4a63d..18732c0058de 100644
--- a/tests/pipelines/dit/test_dit.py
+++ b/tests/pipelines/dit/test_dit.py
@@ -149,8 +149,7 @@ def test_dit_512(self):
for word, image in zip(words, images):
expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- f"/dit/{word}_512.npy"
+ f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy"
)
assert np.abs((expected_image - image).max()) < 1e-1
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index addc29e14670..045b2fdd3306 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -169,9 +169,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -185,15 +185,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py
index 2bd511db3d65..e9eddaa80bfb 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control.py
@@ -162,9 +162,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -178,15 +178,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
index c5ff02a525f2..37ebf4493595 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -174,9 +174,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -190,15 +190,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
index 653cb41e4bc4..dc7fdb932fac 100644
--- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
+++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
@@ -269,9 +269,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -287,15 +287,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py
index 8553ed96e9e1..cf15fbba854a 100644
--- a/tests/pipelines/kandinsky/test_kandinsky.py
+++ b/tests/pipelines/kandinsky/test_kandinsky.py
@@ -237,12 +237,12 @@ def test_kandinsky(self):
expected_slice = np.array([1.0000, 1.0000, 0.2766, 1.0000, 0.5447, 0.1737, 1.0000, 0.4316, 0.9024])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index a7f861565cc9..69b204665139 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -96,12 +96,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.2893, 0.1464, 0.4603, 0.3529, 0.4612, 0.7701, 0.4027, 0.3051, 0.5155])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
@@ -202,12 +202,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.4852, 0.4136, 0.4539, 0.4781, 0.4680, 0.5217, 0.4973, 0.4089, 0.4977])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
@@ -312,12 +312,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
index ea289c5ccd71..81b52a05391e 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
@@ -258,12 +258,12 @@ def test_kandinsky_img2img(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5816, 0.5872, 0.4634, 0.5982, 0.4767, 0.4710, 0.4669, 0.4717, 0.4966])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
@@ -318,7 +318,7 @@ def test_kandinsky_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
@@ -384,7 +384,7 @@ def test_kandinsky_img2img_ddpm(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/frog.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png"
)
prompt = "A red cartoon frog, 4k"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
index 740046678744..22c967bd4404 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
@@ -253,12 +253,12 @@ def test_kandinsky_inpaint(self):
expected_slice = np.array([0.8222, 0.8896, 0.4373, 0.8088, 0.4905, 0.2609, 0.6816, 0.4291, 0.5129])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -316,7 +316,7 @@ def test_kandinsky_inpaint(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
mask[:250, 250:-250] = 1
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py
index cbd9166efada..728a1d67a464 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky.py
@@ -208,13 +208,13 @@ def test_kandinsky(self):
expected_slice = np.array([0.3420, 0.9505, 0.3919, 1.0000, 0.5188, 0.3109, 0.6139, 0.5624, 0.6811])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
index dbba0831397b..b697c46ef361 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
@@ -101,12 +101,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.3076, 0.2729, 0.5668, 0.0522, 0.3384, 0.7028, 0.4908, 0.3659, 0.6243])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
@@ -223,12 +223,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.4445, 0.4287, 0.4596, 0.3919, 0.3730, 0.5039, 0.4834, 0.4269, 0.5521])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
@@ -344,12 +344,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.5039, 0.4926, 0.4898, 0.4978, 0.4838, 0.4942, 0.4738, 0.4702, 0.4816])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
index 1f3219e0d69e..10a95d6177b2 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
@@ -210,13 +210,13 @@ def test_kandinsky_controlnet(self):
[0.6959826, 0.868279, 0.7558092, 0.68769467, 0.85805804, 0.65977496, 0.44885302, 0.5959111, 0.4251595]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
index 20944aa3d6f8..58fbbecc0569 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
@@ -218,12 +218,12 @@ def test_kandinsky_controlnet_img2img(self):
expected_slice = np.array(
[0.54985034, 0.55509365, 0.52561504, 0.5570494, 0.5593818, 0.5263979, 0.50285643, 0.5069846, 0.51196736]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1.75e-3)
@@ -254,7 +254,7 @@ def test_kandinsky_controlnet_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
init_image = init_image.resize((512, 512))
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
index 26d8b45cf900..34f089fcf1e7 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
@@ -226,12 +226,12 @@ def test_kandinsky_img2img(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5712, 0.5443, 0.4725, 0.6195, 0.5184, 0.4651, 0.4473, 0.4590, 0.5016])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=2e-1)
@@ -259,7 +259,7 @@ def test_kandinsky_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
index 25cf4bbed456..be2d90ea9c53 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
@@ -233,12 +233,12 @@ def test_kandinsky_inpaint(self):
[0.50775903, 0.49527195, 0.48824543, 0.50192237, 0.48644906, 0.49373814, 0.4780598, 0.47234827, 0.48327848]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -313,7 +313,7 @@ def test_kandinsky_inpaint(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
mask[:250, 250:-250] = 1
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py
index 941ef9093361..e80d5c61fd72 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3.py
@@ -155,9 +155,9 @@ def test_kandinsky3(self):
expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
index 8c817df32e0c..79468077ecff 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
@@ -180,9 +180,9 @@ def test_kandinsky3_img2img(self):
[0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py
index 59ce9cc0a987..902958ce4121 100644
--- a/tests/pipelines/pag/test_pag_animatediff.py
+++ b/tests/pipelines/pag/test_pag_animatediff.py
@@ -450,9 +450,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).frames[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd.py b/tests/pipelines/pag/test_pag_controlnet_sd.py
index 8a7eb6f0c675..e59b6e676676 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd.py
@@ -171,9 +171,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
index 0a7413e99926..969737f22ee4 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
@@ -168,9 +168,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl.py b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
index 6400cc2b7cab..5323bad37217 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
@@ -189,9 +189,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
index b02f4d8b4561..992de5cdbae8 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
@@ -191,9 +191,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py
index db0e257760ed..26852744f9e0 100644
--- a/tests/pipelines/pag/test_pag_hunyuan_dit.py
+++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py
@@ -271,15 +271,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -292,9 +292,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py
index 8cfb2c3fd16a..825da0f7b8ac 100644
--- a/tests/pipelines/pag/test_pag_kolors.py
+++ b/tests/pipelines/pag/test_pag_kolors.py
@@ -136,9 +136,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py
index 7de19e0f00fc..072dd80a4da0 100644
--- a/tests/pipelines/pag/test_pag_pixart_sigma.py
+++ b/tests/pipelines/pag/test_pag_pixart_sigma.py
@@ -120,9 +120,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
+ )
out = pipe(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py
index 12addabeb0a8..608e023b0d33 100644
--- a/tests/pipelines/pag/test_pag_sana.py
+++ b/tests/pipelines/pag/test_pag_sana.py
@@ -266,9 +266,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py
index 17e3f7038439..711945308d37 100644
--- a/tests/pipelines/pag/test_pag_sd.py
+++ b/tests/pipelines/pag/test_pag_sd.py
@@ -155,9 +155,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -322,9 +322,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -339,6 +339,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd3.py b/tests/pipelines/pag/test_pag_sd3.py
index 627d613ee20d..5183756913c2 100644
--- a/tests/pipelines/pag/test_pag_sd3.py
+++ b/tests/pipelines/pag/test_pag_sd3.py
@@ -203,9 +203,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -219,15 +219,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -240,9 +240,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd3_img2img.py b/tests/pipelines/pag/test_pag_sd3_img2img.py
index bffcd254e2c5..694a86577dbf 100644
--- a/tests/pipelines/pag/test_pag_sd3_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd3_img2img.py
@@ -148,9 +148,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
@@ -253,9 +253,9 @@ def test_pag_cfg(self):
0.17822266,
]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(
@@ -271,6 +271,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.1508789, 0.16210938, 0.17138672, 0.16210938, 0.17089844, 0.16137695, 0.16235352, 0.16430664, 0.16455078]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py
index f44204f82486..d540a2257140 100644
--- a/tests/pipelines/pag/test_pag_sd_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd_img2img.py
@@ -160,9 +160,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -259,9 +259,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -277,6 +277,6 @@ def test_pag_uncond(self):
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py
index a528b66cc72a..00d7e9f9c29d 100644
--- a/tests/pipelines/pag/test_pag_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sd_inpaint.py
@@ -296,9 +296,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -313,6 +313,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.3876953, 0.40356445, 0.4934082, 0.39697266, 0.41674805, 0.41015625, 0.375, 0.36914062, 0.40649414]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py
index 589573385677..c2e10a6325b2 100644
--- a/tests/pipelines/pag/test_pag_sdxl.py
+++ b/tests/pipelines/pag/test_pag_sdxl.py
@@ -168,9 +168,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -331,9 +331,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.3123679, 0.31725878, 0.32026544, 0.327533, 0.3266391, 0.3303998, 0.33544615, 0.34181812, 0.34102726]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -348,6 +348,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.47400922, 0.48650584, 0.4839625, 0.4724013, 0.4890427, 0.49544555, 0.51707107, 0.54299414, 0.5224372]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py
index 7e5fc5fa28b9..83f6fac40ff0 100644
--- a/tests/pipelines/pag/test_pag_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py
@@ -214,9 +214,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -314,9 +314,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.20301354, 0.21078318, 0.2021082, 0.20277798, 0.20681083, 0.19562206, 0.20121682, 0.21562952, 0.21277016]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -331,6 +331,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.21303111, 0.22188407, 0.2124992, 0.21365267, 0.18823743, 0.17569828, 0.21113116, 0.19419771, 0.18919235]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
index efc37abd0682..3fead15e6e9b 100644
--- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
@@ -219,9 +219,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -320,9 +320,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.41385046, 0.39608297, 0.4360491, 0.26872507, 0.32187328, 0.4242474, 0.2603805, 0.34167895, 0.46561807]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -337,6 +337,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.41597816, 0.39302617, 0.44287828, 0.2687074, 0.28315824, 0.40582314, 0.20877528, 0.2380802, 0.39447647]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py
index a92e99366ee3..d5c7e78af2a3 100644
--- a/tests/pipelines/pixart_sigma/test_pixart.py
+++ b/tests/pipelines/pixart_sigma/test_pixart.py
@@ -327,9 +327,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -343,15 +343,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index f3661355e9dd..8bd25a722cf3 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -264,7 +264,7 @@ def tearDown(self):
def test_shap_e_img2img(self):
input_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/shap_e/corgi.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/corgi.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
index d256deed376c..ad09b9ce8292 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
@@ -198,12 +198,12 @@ def test_stable_cascade(self):
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index ccd5567106d2..acacf3e11880 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -287,15 +287,15 @@ def test_stable_diffusion_ays(self):
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
- assert (
- np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
- ), "ays timesteps and ays sigmas should have the same outputs"
- assert (
- np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
- ), "use ays timesteps should have different outputs"
- assert (
- np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
- ), "use ays sigmas should have different outputs"
+ assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
+ "ays timesteps and ays sigmas should have the same outputs"
+ )
+ assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
+ "use ays timesteps should have different outputs"
+ )
+ assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
+ "use ays sigmas should have different outputs"
+ )
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
@@ -728,9 +728,9 @@ def test_freeu_enabled(self):
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
- assert not np.allclose(
- output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
- ), "Enabling of FreeU should lead to different results."
+ assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
+ "Enabling of FreeU should lead to different results."
+ )
def test_freeu_disabled(self):
components = self.get_dummy_components()
@@ -753,9 +753,9 @@ def test_freeu_disabled(self):
prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)
).images
- assert np.allclose(
- output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
- ), "Disabling of FreeU should lead to results similar to the default pipeline results."
+ assert np.allclose(output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]), (
+ "Disabling of FreeU should lead to results similar to the default pipeline results."
+ )
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -778,15 +778,15 @@ def test_fused_qkv_projections(self):
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index a6f718ae4fbb..d079eed6e1c3 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -201,9 +201,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -217,15 +217,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_skip_guidance_layers(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index 8550f258045e..230a9edc0a18 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -242,15 +242,15 @@ def test_stable_diffusion_ays(self):
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
- assert (
- np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
- ), "ays timesteps and ays sigmas should have the same outputs"
- assert (
- np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
- ), "use ays timesteps should have different outputs"
- assert (
- np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
- ), "use ays sigmas should have different outputs"
+ assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
+ "ays timesteps and ays sigmas should have the same outputs"
+ )
+ assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
+ "use ays timesteps should have different outputs"
+ )
+ assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
+ "use ays sigmas should have different outputs"
+ )
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
@@ -855,9 +855,9 @@ def new_step(self, *args, **kwargs):
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
latents = pipe_1(**inputs_1).images[0]
- assert (
- expected_steps_1 == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps_1 == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
with self.assertRaises(ValueError) as cm:
inputs_2 = {
@@ -884,9 +884,9 @@ def new_step(self, *args, **kwargs):
pipe_3(**inputs_3).images[0]
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
- assert (
- expected_steps == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
index 964c7123dd32..a07a2a8d8a84 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
@@ -576,9 +576,9 @@ def new_step(self, *args, **kwargs):
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
latents = pipe_1(**inputs_1).images[0]
- assert (
- expected_steps_1 == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps_1 == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
inputs_2 = {
**inputs,
@@ -592,9 +592,9 @@ def new_step(self, *args, **kwargs):
pipe_3(**inputs_3).images[0]
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
- assert (
- expected_steps == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 423c82e0602e..aac5503074bf 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -160,9 +160,9 @@ def test_one_request_upon_cached(self):
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 15, "15 calls to files"
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
- assert (
- len(download_requests) == 32
- ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
+ assert len(download_requests) == 32, (
+ "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
+ )
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
@@ -172,9 +172,9 @@ def test_one_request_upon_cached(self):
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
- assert (
- len(cache_requests) == 2
- ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ assert len(cache_requests) == 2, (
+ "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ )
def test_less_downloads_passed_object(self):
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -210,9 +210,9 @@ def test_less_downloads_passed_object_calls(self):
assert download_requests.count("HEAD") == 13, "13 calls to files"
# 17 - 2 because no call to config or model file for `safety_checker`
assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json"
- assert (
- len(download_requests) == 28
- ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
+ assert len(download_requests) == 28, (
+ "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
+ )
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
@@ -222,9 +222,9 @@ def test_less_downloads_passed_object_calls(self):
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
- assert (
- len(cache_requests) == 2
- ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ assert len(cache_requests) == 2, (
+ "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ )
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index f5494fbade2e..d7dcd86a9507 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -180,12 +180,12 @@ def test_freeu(self):
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0]
- assert not np.allclose(
- output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
- ), "Enabling of FreeU should lead to different results."
- assert np.allclose(
- output, output_no_freeu, atol=1e-2
- ), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
+ assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
+ "Enabling of FreeU should lead to different results."
+ )
+ assert np.allclose(output, output_no_freeu, atol=1e-2), (
+ f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
+ )
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -206,12 +206,12 @@ def test_fused_qkv_projections(self):
and hasattr(component, "original_attn_processors")
and component.original_attn_processors is not None
):
- assert check_qkv_fusion_processors_exist(
- component
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- component, component.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
+ assert check_qkv_fusion_processors_exist(component), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
+ assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), (
+ "Something wrong with the attention processors concerning the fused QKV projections."
+ )
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
@@ -224,15 +224,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
class IPAdapterTesterMixin:
@@ -857,9 +857,9 @@ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3):
for component in pipe_original.components.values():
if hasattr(component, "attn_processors"):
- assert all(
- type(proc) == AttnProcessor for proc in component.attn_processors.values()
- ), "`from_pipe` changed the attention processor in original pipeline."
+ assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), (
+ "`from_pipe` changed the attention processor in original pipeline."
+ )
@require_accelerator
@require_accelerate_version_greater("0.14.0")
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
index a0e6e1417e67..1c9790807fa8 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
@@ -191,12 +191,12 @@ def test_wuerstchen(self):
expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py
index 55b3202ad0be..28c354709dc9 100644
--- a/tests/schedulers/test_scheduler_dpm_multi.py
+++ b/tests/schedulers/test_scheduler_dpm_multi.py
@@ -357,9 +357,9 @@ def test_custom_timesteps(self):
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py
index 7cbaa5cc5e8d..0756a5ed71ff 100644
--- a/tests/schedulers/test_scheduler_dpm_single.py
+++ b/tests/schedulers/test_scheduler_dpm_single.py
@@ -345,9 +345,9 @@ def test_custom_timesteps(self):
lower_order_final=lower_order_final,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
index e97d64ec5f1d..8525ce61c40d 100644
--- a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
+++ b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
@@ -188,9 +188,9 @@ def test_solver_order_and_type(self):
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
- assert (
- not torch.isnan(sample).any()
- ), f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
+ assert not torch.isnan(sample).any(), (
+ f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
+ )
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py
index 4c7e02442cd0..01e173a631cd 100644
--- a/tests/schedulers/test_scheduler_euler.py
+++ b/tests/schedulers/test_scheduler_euler.py
@@ -245,9 +245,9 @@ def test_custom_timesteps(self):
interpolation_type=interpolation_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_custom_sigmas(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
@@ -260,9 +260,9 @@ def test_custom_sigmas(self):
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py
index 9e060c6d476f..90012f5525ab 100644
--- a/tests/schedulers/test_scheduler_heun.py
+++ b/tests/schedulers/test_scheduler_heun.py
@@ -216,9 +216,9 @@ def test_custom_timesteps(self):
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 4e7bc0af6842..4e1713c9ceb1 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -72,9 +72,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
- assert isinstance(
- component, pipe.components[component_name].__class__
- ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ assert isinstance(component, pipe.components[component_name].__class__), (
+ f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ )
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
@@ -85,9 +85,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
- assert (
- pipe.components[component_name].config[param_name] == param_value
- ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ assert pipe.components[component_name].config[param_name] == param_value, (
+ f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ )
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
@@ -253,9 +253,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
- assert isinstance(
- component, pipe.components[component_name].__class__
- ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ assert isinstance(component, pipe.components[component_name].__class__), (
+ f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ )
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
@@ -266,9 +266,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
- assert (
- pipe.components[component_name].config[param_name] == param_value
- ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ assert pipe.components[component_name].config[param_name] == param_value, (
+ f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ )
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py
index b1faeb78776b..31b2eb6e36b0 100644
--- a/tests/single_file/test_model_autoencoder_dc_single_file.py
+++ b/tests/single_file/test_model_autoencoder_dc_single_file.py
@@ -87,9 +87,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter
@@ -106,9 +106,9 @@ def test_single_file_in_type_variant_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_mix_type_variant_components(self):
repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"
@@ -121,6 +121,6 @@ def test_single_file_mix_type_variant_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
diff --git a/tests/single_file/test_model_controlnet_single_file.py b/tests/single_file/test_model_controlnet_single_file.py
index bfcb802380a6..3580d73531a3 100644
--- a/tests/single_file/test_model_controlnet_single_file.py
+++ b/tests/single_file/test_model_controlnet_single_file.py
@@ -58,9 +58,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path)
diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py
index 0ec97db26a9e..bf11faaa9c0e 100644
--- a/tests/single_file/test_model_flux_transformer_single_file.py
+++ b/tests/single_file/test_model_flux_transformer_single_file.py
@@ -58,9 +58,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
diff --git a/tests/single_file/test_model_motion_adapter_single_file.py b/tests/single_file/test_model_motion_adapter_single_file.py
index b195f25d094b..a747f16dc1db 100644
--- a/tests/single_file/test_model_motion_adapter_single_file.py
+++ b/tests/single_file/test_model_motion_adapter_single_file.py
@@ -40,9 +40,9 @@ def test_single_file_components_version_v1_5(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_v1_5_2(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt"
@@ -55,9 +55,9 @@ def test_single_file_components_version_v1_5_2(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_v1_5_3(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt"
@@ -70,9 +70,9 @@ def test_single_file_components_version_v1_5_3(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_sdxl_beta(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt"
@@ -85,6 +85,6 @@ def test_single_file_components_version_sdxl_beta(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
diff --git a/tests/single_file/test_model_sd_cascade_unet_single_file.py b/tests/single_file/test_model_sd_cascade_unet_single_file.py
index 08b04e3cd7e8..92b371c3fb41 100644
--- a/tests/single_file/test_model_sd_cascade_unet_single_file.py
+++ b/tests/single_file/test_model_sd_cascade_unet_single_file.py
@@ -60,9 +60,9 @@ def test_single_file_components_stage_b(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_b_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -77,9 +77,9 @@ def test_single_file_components_stage_b_lite(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_c(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -94,9 +94,9 @@ def test_single_file_components_stage_c(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_c_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -111,6 +111,6 @@ def test_single_file_components_stage_c_lite(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py
index 9db4cddb3c9d..bba1726ae380 100644
--- a/tests/single_file/test_model_vae_single_file.py
+++ b/tests/single_file/test_model_vae_single_file.py
@@ -91,9 +91,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
diff --git a/utils/log_reports.py b/utils/log_reports.py
index dd1b258519d7..5575c9ba8415 100644
--- a/utils/log_reports.py
+++ b/utils/log_reports.py
@@ -35,7 +35,7 @@ def main(slack_channel_name=None):
if line.get("nodeid", "") != "":
test = line["nodeid"]
if line.get("duration", None) is not None:
- duration = f'{line["duration"]:.4f}'
+ duration = f"{line['duration']:.4f}"
if line.get("outcome", "") == "failed":
section_num_failed += 1
failed.append([test, duration, log.name.split("_")[0]])
diff --git a/utils/update_metadata.py b/utils/update_metadata.py
index 103a2b9ab0cc..54fce1edd5d9 100644
--- a/utils/update_metadata.py
+++ b/utils/update_metadata.py
@@ -104,8 +104,7 @@ def update_metadata(commit_sha: str):
if commit_sha is not None:
commit_message = (
- f"Update with commit {commit_sha}\n\nSee: "
- f"https://github.com/huggingface/diffusers/commit/{commit_sha}"
+ f"Update with commit {commit_sha}\n\nSee: https://github.com/huggingface/diffusers/commit/{commit_sha}"
)
else:
commit_message = "Update"
From 61636799cf424eca9fd7e8ce0728c8e4b6b2d2c3 Mon Sep 17 00:00:00 2001
From: zRzRzRzRzRzRzR <2448370773@qq.com>
Date: Wed, 15 Jan 2025 00:18:23 +0800
Subject: [PATCH 02/68] encode with glm
---
.../pipelines/cogview4/pipeline_cogview4.py | 46 ++++++++++++-------
1 file changed, 30 insertions(+), 16 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 48d411349703..522636a32c0d 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -28,7 +28,6 @@
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView4PipelineOutput
-
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -38,7 +37,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
EXAMPLE_DOC_STRING = """
Examples:
```python
@@ -180,7 +178,7 @@ def _get_glm_embeds(
text_inputs = self.tokenizer(
prompt,
- padding="max_length",
+ padding="longest", # not use max length
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
@@ -188,19 +186,26 @@ def _get_glm_embeds(
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
+ current_length = text_input_ids.shape[1]
+ pad_length = (16 - (current_length % 16)) % 16
+ if pad_length > 0:
+ pad_ids = torch.full(
+ (text_input_ids.shape[0], pad_length),
+ fill_value=151329, # <|endoftext|> of glm-4
+ dtype=text_input_ids.dtype,
+ device=text_input_ids.device,
+ )
+ text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
- prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = self.text_encoder.model.embed_tokens(text_input_ids)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
-
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- _, seq_len, _ = prompt_embeds.shape
+ seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
@@ -208,11 +213,12 @@ def _get_glm_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
- max_sequence_length: int = 224,
+ max_sequence_length: int = 1024,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
@@ -222,6 +228,10 @@ def encode_prompt(
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_images_per_prompt (`int`, *optional*, defaults to 1):
@@ -233,7 +243,7 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- max_sequence_length (`int`, defaults to `224`):
+ max_sequence_length (`int`, defaults to `1024`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
device: (`torch.device`, *optional*):
torch device
@@ -249,7 +259,7 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds = self._get_t5_prompt_embeds(
+ prompt_embeds = self._get_glm_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
@@ -258,7 +268,13 @@ def encode_prompt(
)
if do_classifier_free_guidance and negative_prompt is None:
- negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
+ negative_prompt_embeds = self._get_glm_embeds(
+ prompt="",
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
@@ -275,14 +291,13 @@ def encode_prompt(
" the batch size of `prompt`."
)
- negative_prompt_embeds = self._get_t5_prompt_embeds(
+ negative_prompt_embeds = self._get_glm_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
-
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -422,7 +437,7 @@ def __call__(
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- CogView4Pipeline: int = 224,
+ max_sequence_length: int = 1024,
) -> Union[CogView4PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -543,7 +558,6 @@ def __call__(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
-
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
From 6090ea7f439a7a6c3d3f1d7caf3d38da8782261f Mon Sep 17 00:00:00 2001
From: zRzRzRzRzRzRzR <2448370773@qq.com>
Date: Wed, 15 Jan 2025 22:16:35 +0800
Subject: [PATCH 03/68] draft schedule
---
scripts/convert_cogview4_to_diffusers.py | 35 +-
src/diffusers/__init__.py | 2 +
.../pipelines/cogview4/pipeline_cogview4.py | 19 +-
src/diffusers/schedulers/__init__.py | 2 +
.../schedulers/scheduling_ddim_cogview4.py | 543 ++++++++++++++++++
src/diffusers/utils/dummy_pt_objects.py | 15 +
6 files changed, 589 insertions(+), 27 deletions(-)
create mode 100644 src/diffusers/schedulers/scheduling_ddim_cogview4.py
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index e99562898b52..06df7ce53e2c 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -31,11 +31,10 @@
from accelerate import init_empty_weights
from transformers import PreTrainedTokenizerFast, GlmForCausalLM
-from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView4Pipeline, CogView3PlusTransformer2DModel
+from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView3PlusTransformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
-
CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
@@ -170,16 +169,16 @@ def main(args):
args.transformer_checkpoint_path
)
transformer = CogView3PlusTransformer2DModel(
- patch_size = 2,
- in_channels = 16,
- num_layers = 28,
- attention_head_dim= 128,
- num_attention_heads = 32,
- out_channels = 16,
- text_embed_dim= 4096,
- time_embed_dim = 512,
- condition_dim= 256,
- pos_embed_max_size = 128,
+ patch_size=2,
+ in_channels=16,
+ num_layers=28,
+ attention_head_dim=128,
+ num_attention_heads=32,
+ out_channels=16,
+ text_embed_dim=4096,
+ time_embed_dim=512,
+ condition_dim=256,
+ pos_embed_max_size=128,
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
@@ -210,16 +209,20 @@ def main(args):
if dtype is not None:
vae = vae.to(dtype=dtype)
- text_encoder_id = 'THUDM/glm-4-9b-hf'
+ text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
- text_encoder = GlmForCausalLM.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir, torch_dtype=torch.bfloat16 if dtype=="bf16" else torch.float32)
+ text_encoder = GlmForCausalLM.from_pretrained(
+ text_encoder_id,
+ cache_dir=args.text_encoder_cache_dir,
+ torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
+ )
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
- scheduler = CogVideoXDDIMScheduler.from_config(
+ scheduler = CogView4DDIMScheduler.from_config(
{
- "snr_shift_scale": 4.0,
+ "shift_scale": 1.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 1b19f9161ca1..206763a45278 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -175,6 +175,7 @@
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
+ "CogView4DDIMScheduler",
"DDIMInverseScheduler",
"DDIMParallelScheduler",
"DDIMScheduler",
@@ -684,6 +685,7 @@
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
+ CogView4DDIMScheduler,
DDIMInverseScheduler,
DDIMParallelScheduler,
DDIMScheduler,
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 522636a32c0d..3759adadd582 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -23,7 +23,7 @@
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
-from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from ...schedulers import CogView4DDIMScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView4PipelineOutput
@@ -151,7 +151,7 @@ def __init__(
text_encoder: GlmModel,
vae: AutoencoderKL,
transformer: CogView3PlusTransformer2DModel,
- scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ scheduler: CogView4DDIMScheduler,
):
super().__init__()
@@ -318,7 +318,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
-
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@@ -517,8 +516,8 @@ def __call__(
Examples:
Returns:
- [`~pipelines.cogview3.pipeline_CogView4.CogView3PipelineOutput`] or `tuple`:
- [`~pipelines.cogview3.pipeline_CogView4.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
+ [`~pipelines.cogview4.pipeline_CogView4.CogView3PipelineOutput`] or `tuple`:
+ [`~pipelines.cogview4.pipeline_CogView4.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
@@ -640,15 +639,13 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
- if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ if not isinstance(self.scheduler, CogView4DDIMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
- noise_pred,
- old_pred_original_sample,
- t,
- timesteps[i - 1] if i > 0 else None,
- latents,
+ model_output=noise_pred,
+ timestep=t,
+ sample=latents,
**extra_step_kwargs,
return_dict=False,
)
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index bb9088538653..512d28d95c09 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -44,6 +44,7 @@
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
+ _import_structure["scheduling_ddim_cogview4"] = ["CogView4DDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
@@ -144,6 +145,7 @@
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
+ from .scheduling_ddim_cogview4 import CogView4DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogview4.py b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
new file mode 100644
index 000000000000..7c79b50fba6d
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
@@ -0,0 +1,543 @@
+# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from ..utils.torch_utils import randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.Tensor
+ pred_original_sample: Optional[torch.Tensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class CogView4DDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ shift_scale: int = 1.0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ # Check if the requested number of steps is valid
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.num_train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ # Set the current number of inference steps
+ self.num_inference_steps = num_inference_steps
+
+ # Generate timesteps according to the specified spacing method
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+
+ # Convert the numpy array of timesteps into a PyTorch tensor
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ # ===== change for cogview4 ====
+ # The new dynamic shifting code starts here.
+
+ # Convert integer timesteps to float for further manipulation
+ times_float = self.timesteps.float() / float(self.config.num_train_timesteps)
+
+ # Apply the shift_scale factor
+ times_float = self.config.shift_scale * times_float
+
+ # Convert the shifted floats back to integer indices for timesteps
+ new_timesteps = (times_float * self.config.num_train_timesteps).round().long().clamp_min(0)
+
+ # Ensure the timesteps are in descending order and unique
+ new_timesteps = new_timesteps.unique().flip(0)
+ if len(new_timesteps) == 0:
+ # If all values somehow got collapsed, fallback to a single timestep
+ new_timesteps = torch.zeros(1, dtype=torch.long, device=device)
+
+ # Overwrite the original timesteps with our newly shifted timesteps
+ self.timesteps = new_timesteps
+ # =====
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ if variance_noise is not None and generator is not None:
+ raise ValueError(
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
+ " `variance_noise` stays `None`."
+ )
+
+ if variance_noise is None:
+ variance_noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
+ )
+ variance = std_dev_t * variance_noise
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (
+ prev_sample,
+ pred_original_sample,
+ )
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
+ # for the subsequent add_noise calls
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 4b6ac10385cf..f90744e2d977 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1275,6 +1275,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class CogView4DDIMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class DDIMInverseScheduler(metaclass=DummyObject):
_backends = ["torch"]
From c7d1227d843e62434ec88ed2fa7426924ff36550 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com>
Date: Thu, 16 Jan 2025 13:58:41 +0800
Subject: [PATCH 04/68] feat(scheduler): Add CogView scheduler implementation
---
.../schedulers/scheduling_cogview.py | 332 ++++++++++++++++++
1 file changed, 332 insertions(+)
create mode 100644 src/diffusers/schedulers/scheduling_cogview.py
diff --git a/src/diffusers/schedulers/scheduling_cogview.py b/src/diffusers/schedulers/scheduling_cogview.py
new file mode 100644
index 000000000000..103706360a8f
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_cogview.py
@@ -0,0 +1,332 @@
+# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils.torch_utils import randn_tensor
+from .scheduling_ddim import DDIMSchedulerOutput
+from .scheduling_utils import SchedulerMixin
+
+
+class CogViewScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `CogViewScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.00085):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.012):
+ The final `beta` value.
+ prediction_type (`str`, defaults to `v_prediction`):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ timestep_spacing (`str`, defaults to `leading`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ num_inference_steps (`int`, defaults to 50):
+ The number of inference steps to use.
+ scale_factor (`float`, defaults to 1.0):
+ Scaling factor to apply to the model input.
+ snr_shift_scale (`float`, defaults to 1.0):
+ Scale factor for shifting the signal-to-noise ratio.
+ zero_snr (`bool`, defaults to True):
+ Whether to adjust the alphas to achieve zero terminal SNR.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085,
+ beta_end: float = 0.012,
+ prediction_type: str = "v_prediction",
+ timestep_spacing: str = "leading",
+ steps_offset: int = 0,
+ num_inference_steps: int = 50,
+ scale_factor: float = 1.0,
+ snr_shift_scale: float = 1.0,
+ zero_snr: bool = True,
+ ):
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ # SNR shift
+ self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
+ sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
+ if zero_snr:
+ sqrt_alphas_cumprod_0 = sqrt_alphas_cumprod[0]
+ sqrt_alphas_cumprod_T_1 = sqrt_alphas_cumprod[-1]
+ sqrt_alphas_cumprod -= sqrt_alphas_cumprod_T_1
+ sqrt_alphas_cumprod *= sqrt_alphas_cumprod_0 / (sqrt_alphas_cumprod_0 - sqrt_alphas_cumprod_T_1)
+ self.sqrt_alphas_cumprod = sqrt_alphas_cumprod
+ self.sigmas = torch.sqrt(1 - sqrt_alphas_cumprod**2)
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample * self.scale_factor
+
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device] = None,
+ timesteps: Optional[List[int]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
+ `num_inference_steps` must be `None`.
+
+ """
+ if num_inference_steps is not None and timesteps is not None:
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
+
+ if timesteps is not None:
+ for i in range(1, len(timesteps)):
+ if timesteps[i] >= timesteps[i - 1]:
+ raise ValueError("`custom_timesteps` must be in descending order.")
+
+ if timesteps[0] >= self.config.num_train_timesteps:
+ raise ValueError(
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
+ )
+
+ timesteps = np.array(timesteps, dtype=np.int64)
+ self.custom_timesteps = True
+ else:
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ self.custom_timesteps = False
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 1.0,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else 1.0
+ sigma_t = eta * torch.sqrt(
+ (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+ )
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5) * pred_epsilon
+
+ # 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ if variance_noise is not None and generator is not None:
+ raise ValueError(
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
+ " `variance_noise` stays `None`."
+ )
+
+ if variance_noise is None:
+ variance_noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
+ )
+ variance = sigma_t * variance_noise
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (
+ prev_sample,
+ pred_original_sample,
+ )
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ apply_scale: bool = True,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
+ # for the subsequent add_noise calls
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
+ self.sigmas = self.sigmas.to(dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = self.sqrt_alphas_cumprod[timesteps]
+ sigmas = self.sigmas[timesteps]
+ assert sqrt_alpha_prod.dim() == 1, f"sqrt_alpha_prod must be a 1D tensor, got {sqrt_alpha_prod.dim()}D"
+ assert sqrt_alpha_prod.shape == sigmas.shape, (
+ f"sigmas and sqrt_alpha_prod must have the same shape, got {sigmas.shape} and {sqrt_alpha_prod.shape}"
+ )
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+ sigmas = sigmas.unsqueeze(-1)
+
+ if apply_scale:
+ original_samples = original_samples * self.scale_factor
+
+ # scale noise and original samples
+ noise = noise * sigmas
+ original_samples = original_samples * sqrt_alpha_prod
+
+ noisy_samples = noise + original_samples
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
From f4457fbbf23eaa92dff3bc08f8bba7e193b9d3ea Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com>
Date: Fri, 17 Jan 2025 13:17:04 +0800
Subject: [PATCH 05/68] feat(embeddings): add CogView 2D rotary positional
embedding
---
src/diffusers/models/embeddings.py | 75 ++++++++++++++++++++++++++++++
1 file changed, 75 insertions(+)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index c64b9587be77..5f7441796748 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -2611,3 +2611,78 @@ def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds.append(image_embed)
return projected_image_embeds
+
+
+class CogViewRotary2DEmbedding(nn.Module):
+ def __init__(
+ self,
+ kv_channels: int,
+ rotary_percent: float,
+ max_h: int = 128,
+ max_w: int = 128,
+ rotary_interleaved: bool = False,
+ seq_len_interpolation_factor: float = None,
+ inner_interp: bool = False,
+ rotary_base: int = 10000,
+ ) -> None:
+ super().__init__()
+
+ dim = kv_channels
+ if rotary_percent < 1.0:
+ dim = int(dim * rotary_percent)
+ self.rotary_interleaved = rotary_interleaved
+
+ self.seq_len_interpolation_factor = seq_len_interpolation_factor
+ self.inner_interp = inner_interp
+
+ dim_h = kv_channels // 2
+ dim_w = kv_channels // 2
+
+ device = torch.cuda.current_device()
+ h_inv_freq = 1.0 / (
+ rotary_base
+ ** (torch.arange(0, dim_h, 2, dtype=torch.float32, device=device)[: (dim_h // 2)].float() / dim_h)
+ )
+ w_inv_freq = 1.0 / (
+ rotary_base
+ ** (torch.arange(0, dim_w, 2, dtype=torch.float32, device=device)[: (dim_w // 2)].float() / dim_w)
+ )
+
+ h_seq = torch.arange(max_h, device=device, dtype=h_inv_freq.dtype)
+ w_seq = torch.arange(max_w, device=device, dtype=w_inv_freq.dtype)
+
+ self.freqs_h = torch.outer(h_seq, h_inv_freq)
+ self.freqs_w = torch.outer(w_seq, w_inv_freq)
+ self.max_h = max_h
+ self.max_w = max_w
+
+ def forward(
+ self,
+ h_idx: torch.Tensor,
+ w_idx: torch.Tensor,
+ target_h: torch.Tensor = None,
+ target_w: torch.Tensor = None,
+ mask: torch.Tensor = None,
+ ) -> torch.Tensor:
+ if self.inner_interp:
+ inner_h_idx = (h_idx * self.max_h) // target_h
+ inner_w_idx = (w_idx * self.max_w) // target_w
+
+ h_emb = self.freqs_h[inner_h_idx]
+ w_emb = self.freqs_w[inner_w_idx]
+
+ else:
+ h_emb = self.freqs_h[h_idx]
+ w_emb = self.freqs_w[w_idx]
+
+ mask = (mask == 1).unsqueeze(-1)
+
+ emb = torch.cat([h_emb, w_emb], dim=-1) * mask
+
+ assert emb.ndim == 2, f"expected emb to have 2 dimensions, got {emb.ndim}"
+ if not self.rotary_interleaved:
+ emb = torch.repeat_interleave(emb, 2, dim=0)
+ else:
+ emb = torch.repeat_interleave(emb, 2, dim=1)
+
+ return emb
From 9a9321843b2cf005f8992a938674f33aa89fb19c Mon Sep 17 00:00:00 2001
From: zRzRzRzRzRzRzR <2448370773@qq.com>
Date: Fri, 17 Jan 2025 16:45:08 +0800
Subject: [PATCH 06/68] 1
---
src/diffusers/pipelines/cogview4/pipeline_cogview4.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 3759adadd582..5184a90a81d3 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -1,3 +1,4 @@
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
From ca000dd61f28083f601a527c3418024c0995c9a2 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Fri, 17 Jan 2025 20:43:40 +0800
Subject: [PATCH 07/68] Update pipeline_cogview4.py
---
src/diffusers/pipelines/cogview4/pipeline_cogview4.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 5184a90a81d3..3759adadd582 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -1,4 +1,3 @@
-
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
From 7ab4a3fbfcdbfc82b512ab6da70f34efbd6a1469 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Sat, 18 Jan 2025 23:59:24 +0800
Subject: [PATCH 08/68] fix the timestep init and sigma
---
scripts/convert_cogview4_to_diffusers.py | 2 +-
.../pipelines/cogview4/pipeline_cogview4.py | 117 +++++++++++++++---
.../schedulers/scheduling_ddim_cogview4.py | 28 +----
3 files changed, 101 insertions(+), 46 deletions(-)
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index 06df7ce53e2c..8ac3e5854c84 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -231,7 +231,7 @@ def main(args):
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
- "timestep_spacing": "trailing",
+ "timestep_spacing": "linspace",
}
)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 3759adadd582..1690882ed886 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -16,6 +16,7 @@
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
+import math
import torch
from transformers import GlmModel
@@ -53,7 +54,19 @@
"""
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def calculate_shift(
+ image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** shift_sigma)
+
+
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -203,7 +216,7 @@ def _get_glm_embeds(
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
- prompt_embeds = self.text_encoder.model.embed_tokens(text_input_ids)[0]
+ prompt_embeds = self.text_encoder.model.embed_tokens(text_input_ids.to(self.text_encoder.model.device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -573,6 +586,16 @@ def __call__(
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
+ self.transformer.config.patch_size ** 2
+ )
+ mu = calculate_shift(image_seq_len)
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
+
+ self.sigmas = time_shift(mu, 1.0, sigmas) # This is for noisy contr
+
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
@@ -611,17 +634,81 @@ def __call__(
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
+ # for i, t in enumerate(timesteps):
+ # if self.interrupt:
+ # continue
+ #
+ # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ #
+ # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ # timestep = t.expand(latent_model_input.shape[0])
+ #
+ # # predict noise model_output
+ # noise_pred = self.transformer(
+ # hidden_states=latent_model_input,
+ # encoder_hidden_states=prompt_embeds,
+ # timestep=timestep,
+ # original_size=original_size,
+ # target_size=target_size,
+ # crop_coords=crops_coords_top_left,
+ # return_dict=False,
+ # )[0]
+ # noise_pred = noise_pred.float()
+ #
+ # # perform guidance
+ # if self.do_classifier_free_guidance:
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ # noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ #
+ # # compute the previous noisy sample x_t -> x_t-1
+ # if not isinstance(self.scheduler, CogView4DDIMScheduler):
+ # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ # else:
+ # latents, old_pred_original_sample = self.scheduler.step(
+ # model_output=noise_pred,
+ # timestep=t,
+ # sample=latents,
+ # **extra_step_kwargs,
+ # return_dict=False,
+ # )
+ # latents = latents.to(prompt_embeds.dtype)
+ #
+ # # call the callback, if provided
+ # if callback_on_step_end is not None:
+ # callback_kwargs = {}
+ # for k in callback_on_step_end_tensor_inputs:
+ # callback_kwargs[k] = locals()[k]
+ # callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ #
+ # latents = callback_outputs.pop("latents", latents)
+ # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ # negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ #
+ # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ # progress_bar.update()
+ #
+ # if XLA_AVAILABLE:
+ # xm.mark_step()
+ # 假设 sigmas 已经计算好了,和之前的步骤一样
for i, t in enumerate(timesteps):
if self.interrupt:
continue
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # 获取当前的 sigma 和下一个时间步的 sigma
+ sigma = sigmas[i]
+ sigma_next = sigmas[i + 1] if i + 1 < len(sigmas) else sigma # 防止越界
+
+ # 根据 sigmas 修改 latent 模型输入
+ latent_model_input = latents * sigma # 使用当前 sigma 调整 latents
+ latent_model_input = torch.cat(
+ [latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ # 广播到 batch 维度,以便与 ONNX/Core ML 兼容
timestep = t.expand(latent_model_input.shape[0])
- # predict noise model_output
+ # 预测噪声
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
@@ -633,25 +720,18 @@ def __call__(
)[0]
noise_pred = noise_pred.float()
- # perform guidance
+ # 执行引导
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
- # compute the previous noisy sample x_t -> x_t-1
- if not isinstance(self.scheduler, CogView4DDIMScheduler):
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- else:
- latents, old_pred_original_sample = self.scheduler.step(
- model_output=noise_pred,
- timestep=t,
- sample=latents,
- **extra_step_kwargs,
- return_dict=False,
- )
+ # 根据预测的噪声和 sigmas 更新 latents
+ latents = latents + (sigma_next - sigma) * noise_pred # 使用 sigmas 计算新的 latents
+
+ # 或者使用更新后的 latents 进行下一步计算
latents = latents.to(prompt_embeds.dtype)
- # call the callback, if provided
+ # 如果有回调,执行回调
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
@@ -667,7 +747,6 @@ def __call__(
if XLA_AVAILABLE:
xm.mark_step()
-
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogview4.py b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
index 7c79b50fba6d..72ab2c12c454 100644
--- a/src/diffusers/schedulers/scheduling_ddim_cogview4.py
+++ b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
@@ -318,10 +318,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
# Generate timesteps according to the specified spacing method
if self.config.timestep_spacing == "linspace":
timesteps = (
- np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
- .round()[::-1]
- .copy()
- .astype(np.int64)
+ np.linspace(self.config.num_train_timesteps, 1, num_inference_steps)
+ .astype(np.int64) # Only for CogView4
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
@@ -339,28 +337,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
# Convert the numpy array of timesteps into a PyTorch tensor
self.timesteps = torch.from_numpy(timesteps).to(device)
- # ===== change for cogview4 ====
- # The new dynamic shifting code starts here.
-
- # Convert integer timesteps to float for further manipulation
- times_float = self.timesteps.float() / float(self.config.num_train_timesteps)
-
- # Apply the shift_scale factor
- times_float = self.config.shift_scale * times_float
-
- # Convert the shifted floats back to integer indices for timesteps
- new_timesteps = (times_float * self.config.num_train_timesteps).round().long().clamp_min(0)
-
- # Ensure the timesteps are in descending order and unique
- new_timesteps = new_timesteps.unique().flip(0)
- if len(new_timesteps) == 0:
- # If all values somehow got collapsed, fallback to a single timestep
- new_timesteps = torch.zeros(1, dtype=torch.long, device=device)
-
- # Overwrite the original timesteps with our newly shifted timesteps
- self.timesteps = new_timesteps
- # =====
-
def step(
self,
model_output: torch.Tensor,
From 56ceaa6af3451cac18747cbb8bf0f190a586de70 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Sun, 19 Jan 2025 21:11:13 +0800
Subject: [PATCH 09/68] update latent
---
.../pipelines/cogview4/pipeline_cogview4.py | 108 +++++-------------
1 file changed, 27 insertions(+), 81 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 1690882ed886..424324e84583 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -582,19 +582,18 @@ def __call__(
device=device,
)
if self.do_classifier_free_guidance:
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=1)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
-
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
- self.transformer.config.patch_size ** 2
+ self.transformer.config.patch_size**2
)
mu = calculate_shift(image_seq_len)
sigmas = timesteps / self.scheduler.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
- self.sigmas = time_shift(mu, 1.0, sigmas) # This is for noisy contr
+ self.sigmas = time_shift(mu, 1.0, sigmas).to(torch.long).to("cpu") # This is for noisy control of cogview4
self._num_timesteps = len(timesteps)
@@ -630,89 +629,26 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
- # for i, t in enumerate(timesteps):
- # if self.interrupt:
- # continue
- #
- # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
- # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
- #
- # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- # timestep = t.expand(latent_model_input.shape[0])
- #
- # # predict noise model_output
- # noise_pred = self.transformer(
- # hidden_states=latent_model_input,
- # encoder_hidden_states=prompt_embeds,
- # timestep=timestep,
- # original_size=original_size,
- # target_size=target_size,
- # crop_coords=crops_coords_top_left,
- # return_dict=False,
- # )[0]
- # noise_pred = noise_pred.float()
- #
- # # perform guidance
- # if self.do_classifier_free_guidance:
- # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- # noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
- #
- # # compute the previous noisy sample x_t -> x_t-1
- # if not isinstance(self.scheduler, CogView4DDIMScheduler):
- # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- # else:
- # latents, old_pred_original_sample = self.scheduler.step(
- # model_output=noise_pred,
- # timestep=t,
- # sample=latents,
- # **extra_step_kwargs,
- # return_dict=False,
- # )
- # latents = latents.to(prompt_embeds.dtype)
- #
- # # call the callback, if provided
- # if callback_on_step_end is not None:
- # callback_kwargs = {}
- # for k in callback_on_step_end_tensor_inputs:
- # callback_kwargs[k] = locals()[k]
- # callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
- #
- # latents = callback_outputs.pop("latents", latents)
- # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- # negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- #
- # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- # progress_bar.update()
- #
- # if XLA_AVAILABLE:
- # xm.mark_step()
- # 假设 sigmas 已经计算好了,和之前的步骤一样
for i, t in enumerate(timesteps):
if self.interrupt:
continue
- # 获取当前的 sigma 和下一个时间步的 sigma
- sigma = sigmas[i]
- sigma_next = sigmas[i + 1] if i + 1 < len(sigmas) else sigma # 防止越界
-
- # 根据 sigmas 修改 latent 模型输入
- latent_model_input = latents * sigma # 使用当前 sigma 调整 latents
- latent_model_input = torch.cat(
- [latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input
+ # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = latents # For CogView4 concat the text embed and only use prompt
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
- # 广播到 batch 维度,以便与 ONNX/Core ML 兼容
- timestep = t.expand(latent_model_input.shape[0])
+ # Use sigma instead of timestep directly
+ sigma = self.sigmas[i] # Get the corresponding sigma value
+ timestep = sigma.expand(latent_model_input.shape[0]).to(device) # Use sigma to scale the timestep
- # 预测噪声
+ # predict noise model_output using sigma
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
- timestep=timestep,
+ timestep=timestep, # Pass sigma as timestep for noise prediction
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
@@ -720,23 +656,32 @@ def __call__(
)[0]
noise_pred = noise_pred.float()
- # 执行引导
+ # perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
- # 根据预测的噪声和 sigmas 更新 latents
- latents = latents + (sigma_next - sigma) * noise_pred # 使用 sigmas 计算新的 latents
-
- # 或者使用更新后的 latents 进行下一步计算
+ # compute the previous noisy sample x_t -> x_t-1 using sigma (not timestep)
+ if not isinstance(self.scheduler, CogView4DDIMScheduler):
+ latents = self.scheduler.step(noise_pred, sigma, latents, **extra_step_kwargs, return_dict=False)[
+ 0
+ ]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ model_output=noise_pred,
+ timestep=sigma, # Use sigma here as timestep
+ sample=latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
latents = latents.to(prompt_embeds.dtype)
- # 如果有回调,执行回调
+ # call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ callback_outputs = callback_on_step_end(self, i, sigma, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
@@ -747,6 +692,7 @@ def __call__(
if XLA_AVAILABLE:
xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
From a7179a21c488a0e28bb0c53ae1605c60f55a7b9e Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Sun, 19 Jan 2025 23:54:15 +0800
Subject: [PATCH 10/68] draft patch(not work)
---
src/diffusers/models/embeddings.py | 69 +++++++++++++++++++
.../transformers/transformer_cogview3plus.py | 14 +++-
.../pipelines/cogview4/pipeline_cogview4.py | 26 +++++--
3 files changed, 100 insertions(+), 9 deletions(-)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 5f7441796748..5d02b4b710fd 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -812,6 +812,75 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
return (hidden_states + pos_embed).to(hidden_states.dtype)
+class CogView4PatchEmbed(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 16,
+ hidden_size: int = 2560,
+ patch_size: int = 2,
+ text_hidden_size: int = 4096,
+ pos_embed_max_size: int = 128,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_size = hidden_size
+ self.patch_size = patch_size
+ self.text_hidden_size = text_hidden_size
+ self.pos_embed_max_size = pos_embed_max_size
+ # Linear projection for image patches
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
+
+ # Linear projection for text embeddings
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
+ #TODO:这里需要改成RotaryEmbed
+ pos_embed = get_2d_sincos_pos_embed(
+ hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
+ )
+ pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
+ self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, channel, height, width = hidden_states.shape
+
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
+ raise ValueError("Height and width must be divisible by patch size")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
+ hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
+
+ # Project the patches
+ hidden_states = self.proj(hidden_states)
+ prompt_encoder_hidden_states = []
+ negative_prompt_encoder_hidden_states = []
+
+ for i in range(0, batch_size, 2):
+ prompt_embeds = encoder_hidden_states[i, :, :] # [seq_len, hidden_size]
+ negative_embeds = encoder_hidden_states[i + 1, :, :] # [seq_len, hidden_size]
+ mask = negative_embeds.abs().sum(dim=-1) > 0
+ seq_len_neg = mask.sum().item() # 非零部分的数量
+ negative_embeds_valid = negative_embeds[:seq_len_neg, :] # [seq_len_neg, hidden_size]
+ prompt_encoder_hidden_states.append(prompt_embeds)
+ negative_prompt_encoder_hidden_states.append(negative_embeds_valid)
+ prompt_encoder_hidden_states = torch.stack(prompt_encoder_hidden_states, dim=0)
+ negative_prompt_encoder_hidden_states = torch.stack(negative_prompt_encoder_hidden_states, dim=0)
+ prompt_text_length = prompt_encoder_hidden_states.shape[1]
+ negative_prompt_text_length = negative_prompt_encoder_hidden_states.shape[1]
+ image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
+ prompt_text_pos_embed = torch.zeros(
+ (prompt_text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
+ )
+ negative_prompt_text_pos_embed = torch.zeros(
+ (negative_prompt_text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
+ )
+ prompt_pos_embed = torch.cat([prompt_text_pos_embed, image_pos_embed], dim=0)[None, ...]
+ negative_prompt_pos_embed = torch.cat([negative_prompt_text_pos_embed, image_pos_embed], dim=0)[None, ...]
+ # TODO: 拼接哼一个完整的 pos_embed 以及拼接 Rope Embed
+ pos_embed = torch.cat([prompt_pos_embed, negative_prompt_pos_embed], dim=0)
+ hidden_states = hidden_states + pos_embed.to(hidden_states.dtype)
+ return hidden_states
def get_3d_rotary_pos_embed(
embed_dim,
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 369509a3a35e..fb62ac88f974 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -28,7 +28,7 @@
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
-from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
+from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed, CogView4PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
@@ -166,7 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
- _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
+ _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed", "CogView4PlusPatchEmbed"]
@register_to_config
def __init__(
@@ -191,7 +191,15 @@ def __init__(
# Each of these are sincos embeddings of shape 2 * condition_dim
self.pooled_projection_dim = 3 * 2 * condition_dim
- self.patch_embed = CogView3PlusPatchEmbed(
+ # self.patch_embed = CogView3PlusPatchEmbed(
+ # in_channels=in_channels,
+ # hidden_size=self.inner_dim,
+ # patch_size=patch_size,
+ # text_hidden_size=text_embed_dim,
+ # pos_embed_max_size=pos_embed_max_size,
+ # )
+ # TODO: 兼容性适配
+ self.patch_embed = CogView4PatchEmbed(
in_channels=in_channels,
hidden_size=self.inner_dim,
patch_size=patch_size,
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 424324e84583..20ddca725510 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -311,6 +311,24 @@ def encode_prompt(
device=device,
dtype=dtype,
)
+
+ #TODO: 先pad 0 ,后续再处理不同长度的问题
+ seq_len_prompt = prompt_embeds.shape[1]
+ seq_len_neg = negative_prompt_embeds.shape[1]
+ if seq_len_neg < seq_len_prompt:
+ # 创建一个新的张量,大小为 [batch_size, seq_len_prompt, hidden_size]
+ batch_size = negative_prompt_embeds.shape[0]
+ hidden_size = negative_prompt_embeds.shape[2]
+ # 填充后的张量
+ padded_negative_prompt_embeds = torch.zeros(
+ batch_size,
+ seq_len_prompt,
+ hidden_size,
+ dtype=negative_prompt_embeds.dtype,
+ device=negative_prompt_embeds.device
+ )
+ padded_negative_prompt_embeds[:, :seq_len_neg, :] = negative_prompt_embeds
+ negative_prompt_embeds = padded_negative_prompt_embeds
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -582,7 +600,7 @@ def __call__(
device=device,
)
if self.do_classifier_free_guidance:
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=1)
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -594,7 +612,6 @@ def __call__(
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
self.sigmas = time_shift(mu, 1.0, sigmas).to(torch.long).to("cpu") # This is for noisy control of cogview4
-
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
@@ -635,11 +652,8 @@ def __call__(
for i, t in enumerate(timesteps):
if self.interrupt:
continue
-
- # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
- latent_model_input = latents # For CogView4 concat the text embed and only use prompt
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-
# Use sigma instead of timestep directly
sigma = self.sigmas[i] # Get the corresponding sigma value
timestep = sigma.expand(latent_model_input.shape[0]).to(device) # Use sigma to scale the timestep
From e6b89078484651fba18b85fe6385fae387b0a41e Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Wed, 22 Jan 2025 18:35:45 +0800
Subject: [PATCH 11/68] fix
---
src/diffusers/__init__.py | 1 +
src/diffusers/pipelines/__init__.py | 1 +
src/diffusers/pipelines/auto_pipeline.py | 3 +++
3 files changed, 5 insertions(+)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 520db1a3fde9..1463296728bd 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -176,6 +176,7 @@
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
+ "CogView4DDIMScheduler",
"DDIMInverseScheduler",
"DDIMParallelScheduler",
"DDIMScheduler",
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index ce291e5ceb45..285348e5885d 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -154,6 +154,7 @@
"CogVideoXFunControlPipeline",
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
+ _import_structure["cogview4"] = ["CogView4Pipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index a19329431b05..353be8635649 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -22,6 +22,8 @@
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
+from .cogview4 import CogView4Pipeline
+
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
@@ -136,6 +138,7 @@
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline),
+ ("cogview4", CogView4Pipeline),
]
)
From 0ab726066e4a00c5221e95fa5dbf66b974ce4b5b Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Thu, 23 Jan 2025 06:54:44 +0000
Subject: [PATCH 12/68] [WIP][cogview4]: implement initial CogView4 pipeline
Implement the basic CogView4 pipeline structure with the following changes:
- Add CogView4 pipeline implementation
- Implement DDIM scheduler for CogView4
- Add CogView3Plus transformer architecture
- Update embedding models
Current limitations:
- CFG implementation uses padding for sequence length alignment
- Need to verify transformer inference alignment with Megatron
TODO:
- Consider separate forward passes for condition/uncondition
instead of padding approach
---
src/diffusers/models/embeddings.py | 63 ++---
.../transformers/transformer_cogview3plus.py | 100 +++++--
.../pipelines/cogview4/pipeline_cogview4.py | 255 +++++++++---------
.../schedulers/scheduling_ddim_cogview4.py | 2 +-
4 files changed, 233 insertions(+), 187 deletions(-)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 5d02b4b710fd..7c40887c7970 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -812,6 +812,7 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
return (hidden_states + pos_embed).to(hidden_states.dtype)
+
class CogView4PatchEmbed(nn.Module):
def __init__(
self,
@@ -832,55 +833,35 @@ def __init__(
# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
- #TODO:这里需要改成RotaryEmbed
- pos_embed = get_2d_sincos_pos_embed(
- hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
- )
- pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
- self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
- def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self, hidden_states: torch.Tensor, prompt_embeds: torch.Tensor, negative_prompt_embeds: torch.Tensor | None
+ ) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
if height % self.patch_size != 0 or width % self.patch_size != 0:
raise ValueError("Height and width must be divisible by patch size")
- height = height // self.patch_size
- width = width // self.patch_size
- hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
- hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
- hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
+ patch_height = height // self.patch_size
+ patch_width = width // self.patch_size
- # Project the patches
- hidden_states = self.proj(hidden_states)
- prompt_encoder_hidden_states = []
- negative_prompt_encoder_hidden_states = []
-
- for i in range(0, batch_size, 2):
- prompt_embeds = encoder_hidden_states[i, :, :] # [seq_len, hidden_size]
- negative_embeds = encoder_hidden_states[i + 1, :, :] # [seq_len, hidden_size]
- mask = negative_embeds.abs().sum(dim=-1) > 0
- seq_len_neg = mask.sum().item() # 非零部分的数量
- negative_embeds_valid = negative_embeds[:seq_len_neg, :] # [seq_len_neg, hidden_size]
- prompt_encoder_hidden_states.append(prompt_embeds)
- negative_prompt_encoder_hidden_states.append(negative_embeds_valid)
- prompt_encoder_hidden_states = torch.stack(prompt_encoder_hidden_states, dim=0)
- negative_prompt_encoder_hidden_states = torch.stack(negative_prompt_encoder_hidden_states, dim=0)
- prompt_text_length = prompt_encoder_hidden_states.shape[1]
- negative_prompt_text_length = negative_prompt_encoder_hidden_states.shape[1]
- image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
- prompt_text_pos_embed = torch.zeros(
- (prompt_text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
+ # b, c, h, w -> b, c, patch_height, patch_size, patch_width, patch_size
+ # -> b, patch_height, patch_width, c, patch_size, patch_size
+ # -> b, patch_height * patch_width, c * patch_size * patch_size
+ hidden_states = (
+ hidden_states.reshape(batch_size, channel, patch_height, self.patch_size, patch_width, self.patch_size)
+ .permute(0, 2, 4, 1, 3, 5)
+ .reshape(batch_size, patch_height * patch_width, channel * self.patch_size * self.patch_size)
)
- negative_prompt_text_pos_embed = torch.zeros(
- (negative_prompt_text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
- )
- prompt_pos_embed = torch.cat([prompt_text_pos_embed, image_pos_embed], dim=0)[None, ...]
- negative_prompt_pos_embed = torch.cat([negative_prompt_text_pos_embed, image_pos_embed], dim=0)[None, ...]
- # TODO: 拼接哼一个完整的 pos_embed 以及拼接 Rope Embed
- pos_embed = torch.cat([prompt_pos_embed, negative_prompt_pos_embed], dim=0)
- hidden_states = hidden_states + pos_embed.to(hidden_states.dtype)
- return hidden_states
+
+ # project
+ hidden_states = self.proj(hidden_states) # embed_dim: 64 -> 4096
+ prompt_embeds = self.text_proj(prompt_embeds) # embed_dim: 4096 -> 4096
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = self.text_proj(negative_prompt_embeds) # embed_dim: 4096 -> 4096
+
+ return hidden_states, prompt_embeds, negative_prompt_embeds
+
def get_3d_rotary_pos_embed(
embed_dim,
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index fb62ac88f974..d59faf30613e 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -84,6 +84,7 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
+ **kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@@ -103,7 +104,7 @@ def forward(
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **kwargs
)
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
@@ -191,14 +192,15 @@ def __init__(
# Each of these are sincos embeddings of shape 2 * condition_dim
self.pooled_projection_dim = 3 * 2 * condition_dim
- # self.patch_embed = CogView3PlusPatchEmbed(
- # in_channels=in_channels,
- # hidden_size=self.inner_dim,
- # patch_size=patch_size,
- # text_hidden_size=text_embed_dim,
- # pos_embed_max_size=pos_embed_max_size,
- # )
- # TODO: 兼容性适配
+ self.max_h = 256
+ self.max_w = 256
+ self.rope = self.prepare_rope(
+ embed_dim=self.config.attention_head_dim,
+ max_h=self.max_h,
+ max_w=self.max_w,
+ rotary_base=10000
+ )
+
self.patch_embed = CogView4PatchEmbed(
in_channels=in_channels,
hidden_size=self.inner_dim,
@@ -300,10 +302,55 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
+ @staticmethod
+ def prepare_rope(embed_dim, max_h, max_w, rotary_base):
+ dim_h = embed_dim // 2
+ dim_w = embed_dim // 2
+ h_inv_freq = 1.0 / (
+ rotary_base ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
+ )
+ w_inv_freq = 1.0 / (
+ rotary_base ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
+ )
+ h_seq = torch.arange(max_h, dtype=h_inv_freq.dtype)
+ w_seq = torch.arange(max_w, dtype=w_inv_freq.dtype)
+ freqs_h = torch.outer(h_seq, h_inv_freq)
+ freqs_w = torch.outer(w_seq, w_inv_freq)
+ return (freqs_h, freqs_w)
+
+ def get_rope_embedding(self, height, width, target_h, target_w, device):
+ # Get pre-computed frequencies
+ freqs_h, freqs_w = self.rope
+
+ h_idx = torch.arange(height)
+ w_idx = torch.arange(width)
+ inner_h_idx = (h_idx * self.max_h) // target_h
+ inner_w_idx = (w_idx * self.max_w) // target_w
+
+ freqs_h = freqs_h[inner_h_idx].to(device)
+ freqs_w = freqs_w[inner_w_idx].to(device)
+
+ # Create position matrices for height and width
+ # [height, 1, dim//4] and [1, width, dim//4]
+ freqs_h = freqs_h.unsqueeze(1)
+ freqs_w = freqs_w.unsqueeze(0)
+ # Broadcast freqs_h and freqs_w to [height, width, dim//4]
+ freqs_h = freqs_h.expand(height, width, -1)
+ freqs_w = freqs_w.expand(height, width, -1)
+
+ # Concatenate along last dimension to get [height, width, dim//2]
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1)
+
+ freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
+ freqs = freqs.reshape(height*width, -1)
+
+ return freqs.cos(), freqs.sin()
+
def forward(
self,
hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
timestep: torch.LongTensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
@@ -338,16 +385,27 @@ def forward(
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
The denoised latents using provided inputs as conditioning.
"""
- height, width = hidden_states.shape[-2:]
- text_seq_length = encoder_hidden_states.shape[1]
+ batch_size, channel, height, width = hidden_states.shape
+ patch_height, patch_width = height // self.config.patch_size, width // self.config.patch_size
+ do_cfg = negative_prompt_embeds is not None
- hidden_states = self.patch_embed(
- hidden_states, encoder_hidden_states
- ) # takes care of adding positional embeddings too.
- emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
+ if do_cfg:
+ assert batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0], "batch size mismatch in CFG mode"
+ else:
+ assert batch_size == prompt_embeds.shape[0], "batch size mismatch in non-CFG mode"
+
+ hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
+ hidden_states, prompt_embeds, negative_prompt_embeds
+ )
- encoder_hidden_states = hidden_states[:, :text_seq_length]
- hidden_states = hidden_states[:, text_seq_length:]
+ encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
+
+ # prepare image_rotary__emb
+ image_rotary_emb = self.get_rope_embedding(
+ patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
+ )
+
+ emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -363,7 +421,8 @@ def custom_forward(*inputs):
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
- emb,
+ emb=emb,
+ image_rotary_emb=image_rotary_emb,
**ckpt_kwargs,
)
else:
@@ -371,9 +430,10 @@ def custom_forward(*inputs):
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
emb=emb,
+ image_rotary_emb=image_rotary_emb,
)
- hidden_states = self.norm_out(hidden_states, emb)
+ hidden_states = self.norm_out(hidden_states, emb) # 结果对应于megatron里的final_layer_input
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
# unpatchify
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 20ddca725510..1b68c8f38412 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -14,9 +14,9 @@
# limitations under the License.
import inspect
+import math
from typing import Callable, Dict, List, Optional, Tuple, Union
-import math
import torch
from transformers import GlmModel
@@ -29,6 +29,7 @@
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView4PipelineOutput
+
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -67,63 +68,63 @@ def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** shift_sigma)
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- sigmas: Optional[List[float]] = None,
- **kwargs,
-):
- r"""
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
- must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
- `num_inference_steps` and `sigmas` must be `None`.
- sigmas (`List[float]`, *optional*):
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
- `num_inference_steps` and `timesteps` must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None and sigmas is not None:
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- elif sigmas is not None:
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accept_sigmas:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" sigmas schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
+# def retrieve_timesteps(
+# scheduler,
+# num_inference_steps: Optional[int] = None,
+# device: Optional[Union[str, torch.device]] = None,
+# timesteps: Optional[List[int]] = None,
+# sigmas: Optional[List[float]] = None,
+# **kwargs,
+# ):
+# r"""
+# Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+# custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+# Args:
+# scheduler (`SchedulerMixin`):
+# The scheduler to get timesteps from.
+# num_inference_steps (`int`):
+# The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+# must be `None`.
+# device (`str` or `torch.device`, *optional*):
+# The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+# timesteps (`List[int]`, *optional*):
+# Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+# `num_inference_steps` and `sigmas` must be `None`.
+# sigmas (`List[float]`, *optional*):
+# Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+# `num_inference_steps` and `timesteps` must be `None`.
+
+# Returns:
+# `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+# second element is the number of inference steps.
+# """
+# if timesteps is not None and sigmas is not None:
+# raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+# if timesteps is not None:
+# accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+# if not accepts_timesteps:
+# raise ValueError(
+# f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+# f" timestep schedules. Please check whether you are using the correct scheduler."
+# )
+# scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+# timesteps = scheduler.timesteps
+# num_inference_steps = len(timesteps)
+# elif sigmas is not None:
+# accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+# if not accept_sigmas:
+# raise ValueError(
+# f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+# f" sigmas schedules. Please check whether you are using the correct scheduler."
+# )
+# scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+# timesteps = scheduler.timesteps
+# num_inference_steps = len(timesteps)
+# else:
+# scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+# timesteps = scheduler.timesteps
+# return timesteps, num_inference_steps
class CogView4Pipeline(DiffusionPipeline):
@@ -172,6 +173,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_factor = 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -210,7 +212,7 @@ def _get_glm_embeds(
if pad_length > 0:
pad_ids = torch.full(
(text_input_ids.shape[0], pad_length),
- fill_value=151329, # <|endoftext|> of glm-4
+ fill_value=self.tokenizer.pad_token_id,
dtype=text_input_ids.dtype,
device=text_input_ids.device,
)
@@ -312,23 +314,23 @@ def encode_prompt(
dtype=dtype,
)
- #TODO: 先pad 0 ,后续再处理不同长度的问题
+ # TODO: 先pad 0 ,后续再处理不同长度的问题 (lhy: 这里改为pad padding token试试)
seq_len_prompt = prompt_embeds.shape[1]
seq_len_neg = negative_prompt_embeds.shape[1]
if seq_len_neg < seq_len_prompt:
- # 创建一个新的张量,大小为 [batch_size, seq_len_prompt, hidden_size]
- batch_size = negative_prompt_embeds.shape[0]
- hidden_size = negative_prompt_embeds.shape[2]
- # 填充后的张量
- padded_negative_prompt_embeds = torch.zeros(
- batch_size,
- seq_len_prompt,
- hidden_size,
- dtype=negative_prompt_embeds.dtype,
- device=negative_prompt_embeds.device
- )
- padded_negative_prompt_embeds[:, :seq_len_neg, :] = negative_prompt_embeds
- negative_prompt_embeds = padded_negative_prompt_embeds
+ # 创建一个新的张量,大小为 [batch_size, seq_len_prompt, hidden_size]
+ batch_size, seq_len, hidden_size = negative_prompt_embeds.shape
+ # 填充后的张量
+ padded_negative_prompt = torch.full(
+ (batch_size, seq_len_prompt - seq_len_neg),
+ fill_value=self.tokenizer.pad_token_id,
+ device=negative_prompt_embeds.device,
+ )
+ padded_negative_prompt_embeds = self.text_encoder.model.embed_tokens(
+ padded_negative_prompt.to(self.text_encoder.model.device)
+ )
+ negative_prompt_embeds = torch.cat([padded_negative_prompt_embeds, negative_prompt_embeds], dim=1)
+ assert negative_prompt_embeds.shape == prompt_embeds.shape
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -382,8 +384,15 @@ def check_inputs(
prompt_embeds=None,
negative_prompt_embeds=None,
):
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % self.image_factor != 0 or width % self.image_factor != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.image_factor} but are {height} and {width}."
+ )
+
+ if height < 512 or height > 2048 or width < 512 or width > 2048:
+ raise ValueError(
+ f"`height` and `width` must be between 512 and 2048, but got height={height} and width={width}."
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -561,7 +570,7 @@ def __call__(
original_size = original_size or (height, width)
target_size = (height, width)
- # 1. Check inputs. Raise error if not correct
+ # Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
@@ -574,7 +583,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._interrupt = False
- # 2. Default call parameters
+ # Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -587,34 +596,20 @@ def __call__(
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
- # 3. Encode input prompt
+ do_classifier_free_guidance = self.do_classifier_free_guidance
+ # Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
- if self.do_classifier_free_guidance:
- prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
-
- # 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
- image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
- self.transformer.config.patch_size**2
- )
- mu = calculate_shift(image_seq_len)
- sigmas = timesteps / self.scheduler.config.num_train_timesteps
- sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
- self.sigmas = time_shift(mu, 1.0, sigmas).to(torch.long).to("cpu") # This is for noisy control of cogview4
- self._num_timesteps = len(timesteps)
-
- # 5. Prepare latents.
+ # Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
@@ -627,15 +622,12 @@ def __call__(
latents,
)
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
-
- # 7. Prepare additional timestep conditions
+ # Prepare additional timestep conditions
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
original_size = torch.cat([original_size, original_size])
target_size = torch.cat([target_size, target_size])
crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
@@ -644,50 +636,63 @@ def __call__(
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
- # 8. Denoising loop
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device) # 记得确认把scheduler.config的timestep_spacing是linspace
+ timesteps = self.scheduler.timesteps
+ image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
+ self.transformer.config.patch_size**2
+ )
+ mu = calculate_shift(image_seq_len)
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
+
+ self.sigmas = time_shift(mu, 1.0, sigmas).to("cpu") # This is for noisy control of cogview4
+ self._num_timesteps = len(timesteps)
+
+ # Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
- # for DPM-solver++
- old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ timestep = t.reshape((1,))
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ timestep = torch.cat([timestep] * 2) if do_classifier_free_guidance else t
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
- # Use sigma instead of timestep directly
- sigma = self.sigmas[i] # Get the corresponding sigma value
- timestep = sigma.expand(latent_model_input.shape[0]).to(device) # Use sigma to scale the timestep
# predict noise model_output using sigma
noise_pred = self.transformer(
hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
timestep=timestep, # Pass sigma as timestep for noise prediction
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
+
noise_pred = noise_pred.float()
# perform guidance
- if self.do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1 using sigma (not timestep)
- if not isinstance(self.scheduler, CogView4DDIMScheduler):
- latents = self.scheduler.step(noise_pred, sigma, latents, **extra_step_kwargs, return_dict=False)[
- 0
- ]
- else:
- latents, old_pred_original_sample = self.scheduler.step(
- model_output=noise_pred,
- timestep=sigma, # Use sigma here as timestep
- sample=latents,
- **extra_step_kwargs,
- return_dict=False,
- )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred_guided = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ ###########################
+ # Get the corresponding sigma value
+ # 这一部分应该放到schduler中(包括self.sigmas的计算也是)
+ # 最后应该调用self.scheduler.step(),只需要传入当前的t,返回下一步的latents即可
+ sigma = self.sigmas[i]
+ sigma_next = self.sigmas[i + 1]
+ dt = sigma_next - sigma
+
+ latents = latents + dt * noise_pred_guided
+ ##############################
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogview4.py b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
index 72ab2c12c454..012b43dbbad2 100644
--- a/src/diffusers/schedulers/scheduling_ddim_cogview4.py
+++ b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
@@ -197,7 +197,7 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: str = "linspace",
rescale_betas_zero_snr: bool = False,
shift_scale: int = 1.0,
):
From f608f82a8b0a5dabdd7a77059941cb27b3cd9591 Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Thu, 23 Jan 2025 07:27:41 +0000
Subject: [PATCH 13/68] [WIP][cogview4][refactor]: Split condition/uncondition
forward pass in CogView4 pipeline
Split the forward pass for conditional and unconditional predictions in the CogView4 pipeline to match the original implementation. The noise prediction is now done separately for each case before combining them for guidance. However, the results still need improvement.
This is a work in progress as the generated images are not yet matching expected quality.
---
.../transformers/transformer_cogview3plus.py | 65 ++++++++++---------
.../pipelines/cogview4/pipeline_cogview4.py | 19 +-----
2 files changed, 35 insertions(+), 49 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index d59faf30613e..3b6c2bb8e55d 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -397,59 +397,62 @@ def forward(
hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
hidden_states, prompt_embeds, negative_prompt_embeds
)
+ emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
- encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
+ encoder_hidden_states_cond = prompt_embeds
+ encoder_hidden_states_uncond = negative_prompt_embeds
+ hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
+ emb_cond, emb_uncond = emb.chunk(2)
# prepare image_rotary__emb
image_rotary_emb = self.get_rope_embedding(
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
)
- emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
-
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- encoder_hidden_states,
- emb=emb,
+ ...
+ else:
+ hidden_states_cond, encoder_hidden_states_cond = block(
+ hidden_states=hidden_states_cond,
+ encoder_hidden_states=encoder_hidden_states_cond,
+ emb=emb_cond, # refactor later
image_rotary_emb=image_rotary_emb,
- **ckpt_kwargs,
)
- else:
- hidden_states, encoder_hidden_states = block(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- emb=emb,
+ hidden_states_uncond, encoder_hidden_states_uncond = block(
+ hidden_states=hidden_states_uncond,
+ encoder_hidden_states=encoder_hidden_states_uncond,
+ emb=emb_uncond, # refactor later
image_rotary_emb=image_rotary_emb,
)
- hidden_states = self.norm_out(hidden_states, emb) # 结果对应于megatron里的final_layer_input
- hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
+ hidden_states_cond = self.norm_out(hidden_states_cond, emb) # 结果对应于megatron里的final_layer_input
+ hidden_states_uncond = self.norm_out(hidden_states_uncond, emb) # 结果对应于megatron里的final_layer_input
+ hidden_states_cond = self.proj_out(hidden_states_cond) # (batch_size, height*width, patch_size*patch_size*out_channels)
+ hidden_states_uncond = self.proj_out(hidden_states_uncond) # (batch_size, height*width, patch_size*patch_size*out_channels)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
- hidden_states = hidden_states.reshape(
- shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
+ hidden_states_cond = hidden_states_cond.reshape(
+ shape=(hidden_states_cond.shape[0], height, width, self.out_channels, patch_size, patch_size)
+ )
+ hidden_states_cond = torch.einsum("nhwcpq->nchpwq", hidden_states_cond)
+ output_cond = hidden_states_cond.reshape(
+ shape=(hidden_states_cond.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ )
+
+ hidden_states_uncond = hidden_states_uncond.reshape(
+ shape=(hidden_states_uncond.shape[0], height, width, self.out_channels, patch_size, patch_size)
)
- hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
- output = hidden_states.reshape(
- shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ hidden_states_uncond = torch.einsum("nhwcpq->nchpwq", hidden_states_uncond)
+ output_uncond = hidden_states_uncond.reshape(
+ shape=(hidden_states_uncond.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
- return (output,)
+ return (output_cond, output_uncond)
- return Transformer2DModelOutput(sample=output)
+ return Transformer2DModelOutput(sample=output_cond), Transformer2DModelOutput(sample=output_uncond)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 1b68c8f38412..ce68d124fe72 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -314,23 +314,6 @@ def encode_prompt(
dtype=dtype,
)
- # TODO: 先pad 0 ,后续再处理不同长度的问题 (lhy: 这里改为pad padding token试试)
- seq_len_prompt = prompt_embeds.shape[1]
- seq_len_neg = negative_prompt_embeds.shape[1]
- if seq_len_neg < seq_len_prompt:
- # 创建一个新的张量,大小为 [batch_size, seq_len_prompt, hidden_size]
- batch_size, seq_len, hidden_size = negative_prompt_embeds.shape
- # 填充后的张量
- padded_negative_prompt = torch.full(
- (batch_size, seq_len_prompt - seq_len_neg),
- fill_value=self.tokenizer.pad_token_id,
- device=negative_prompt_embeds.device,
- )
- padded_negative_prompt_embeds = self.text_encoder.model.embed_tokens(
- padded_negative_prompt.to(self.text_encoder.model.device)
- )
- negative_prompt_embeds = torch.cat([padded_negative_prompt_embeds, negative_prompt_embeds], dim=1)
- assert negative_prompt_embeds.shape == prompt_embeds.shape
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -680,7 +663,7 @@ def __call__(
# perform guidance
if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred_cond, noise_pred_uncond = noise_pred
noise_pred_guided = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
###########################
From b86bfd4cbe9bfe4e37b478ff5349947565540edc Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Thu, 23 Jan 2025 22:10:02 +0800
Subject: [PATCH 14/68] use with -2 hidden state
---
src/diffusers/pipelines/cogview4/pipeline_cogview4.py | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index ce68d124fe72..64bd64f379bb 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -174,7 +174,7 @@ def __init__(
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_factor = 16
-
+ self.text_projector = torch.nn.Linear(4096, 4096)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _get_glm_embeds(
@@ -217,10 +217,12 @@ def _get_glm_embeds(
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
-
- prompt_embeds = self.text_encoder.model.embed_tokens(text_input_ids.to(self.text_encoder.model.device))[0]
+ prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True).hidden_states[-2]
+ self.text_projector.to(dtype=dtype, device=device)
+ prompt_embeds = self.text_projector(prompt_embeds)
+ breakpoint()
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- seq_len, _ = prompt_embeds.shape
+ _, seq_len, _= prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
From c4d1e69cbb3bef92703c7b8f603d5973aaf236ce Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Thu, 23 Jan 2025 22:17:43 +0800
Subject: [PATCH 15/68] remove text_projector
---
src/diffusers/models/embeddings.py | 2 +-
src/diffusers/pipelines/cogview4/pipeline_cogview4.py | 4 ----
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 7c40887c7970..a3a058c63fee 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -859,7 +859,7 @@ def forward(
prompt_embeds = self.text_proj(prompt_embeds) # embed_dim: 4096 -> 4096
if negative_prompt_embeds is not None:
negative_prompt_embeds = self.text_proj(negative_prompt_embeds) # embed_dim: 4096 -> 4096
-
+ breakpoint()
return hidden_states, prompt_embeds, negative_prompt_embeds
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 64bd64f379bb..713411564544 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -174,7 +174,6 @@ def __init__(
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_factor = 16
- self.text_projector = torch.nn.Linear(4096, 4096)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _get_glm_embeds(
@@ -218,9 +217,6 @@ def _get_glm_embeds(
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True).hidden_states[-2]
- self.text_projector.to(dtype=dtype, device=device)
- prompt_embeds = self.text_projector(prompt_embeds)
- breakpoint()
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _= prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
From 79161400555feef238ec4f25e16fe6d38ba2d58e Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Thu, 23 Jan 2025 22:17:51 +0800
Subject: [PATCH 16/68] 1
---
src/diffusers/models/embeddings.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index a3a058c63fee..084a34cb9c83 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -859,7 +859,6 @@ def forward(
prompt_embeds = self.text_proj(prompt_embeds) # embed_dim: 4096 -> 4096
if negative_prompt_embeds is not None:
negative_prompt_embeds = self.text_proj(negative_prompt_embeds) # embed_dim: 4096 -> 4096
- breakpoint()
return hidden_states, prompt_embeds, negative_prompt_embeds
From f8945ce71f08bbb33fcba05a4cd78f95ad951b5a Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Fri, 24 Jan 2025 11:08:12 +0000
Subject: [PATCH 17/68] [WIP] Add tensor-reload to align input from transformer
block
---
src/diffusers/models/normalization.py | 9 ++++
.../transformers/transformer_cogview3plus.py | 42 ++++++++++++++++---
.../pipelines/cogview4/pipeline_cogview4.py | 7 +++-
3 files changed, 51 insertions(+), 7 deletions(-)
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 7db4d3d17d2f..a298001b3569 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -333,9 +333,18 @@ def __init__(
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+
+ ####################################
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
+ # emb = self.linear(conditioning_embedding).to(x.dtype)
+ ####################################
+
scale, shift = torch.chunk(emb, 2, dim=1)
+
+ ############################
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ # x = x * (1 + scale)[:, None, :] + shift[:, None, :]
+ ############################
return x
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 3b6c2bb8e55d..3fb270bac7dc 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -232,7 +232,8 @@ def __init__(
embedding_dim=self.inner_dim,
conditioning_embedding_dim=time_embed_dim,
elementwise_affine=False,
- eps=1e-6,
+ # eps=1e-6,
+ eps=1e-5,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
@@ -399,8 +400,6 @@ def forward(
)
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
- encoder_hidden_states_cond = prompt_embeds
- encoder_hidden_states_uncond = negative_prompt_embeds
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
emb_cond, emb_uncond = emb.chunk(2)
@@ -409,6 +408,22 @@ def forward(
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
)
+ ######################
+ # prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
+ # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
+ prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
+ negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_uncondition_16_32.pt")[None, ::]
+
+ hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")
+ hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")
+
+ emb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")
+ emb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")
+ ######################
+
+ encoder_hidden_states_cond = prompt_embeds
+ encoder_hidden_states_uncond = negative_prompt_embeds
+
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
...
@@ -418,16 +433,31 @@ def forward(
encoder_hidden_states=encoder_hidden_states_cond,
emb=emb_cond, # refactor later
image_rotary_emb=image_rotary_emb,
+ # image_rotary_emb=None,
)
+ ###########################
+ # hidden_states_cond, encoder_hidden_states_cond = (
+ # self.norm_out.norm(hidden_states_cond),
+ # self.norm_out.norm(encoder_hidden_states_cond),
+ # )
+ ###########################
+
hidden_states_uncond, encoder_hidden_states_uncond = block(
hidden_states=hidden_states_uncond,
encoder_hidden_states=encoder_hidden_states_uncond,
emb=emb_uncond, # refactor later
image_rotary_emb=image_rotary_emb,
+ # image_rotary_emb=None,
)
-
- hidden_states_cond = self.norm_out(hidden_states_cond, emb) # 结果对应于megatron里的final_layer_input
- hidden_states_uncond = self.norm_out(hidden_states_uncond, emb) # 结果对应于megatron里的final_layer_input
+ ###########################
+ # hidden_states_uncond, encoder_hidden_states_uncond = (
+ # self.norm_out.norm(hidden_states_uncond),
+ # self.norm_out.norm(encoder_hidden_states_uncond),
+ # )
+ ###########################
+
+ hidden_states_cond = self.norm_out(hidden_states_cond, emb_cond) # 结果对应于megatron里的final_layer_input
+ hidden_states_uncond = self.norm_out(hidden_states_uncond, emb_uncond) # 结果对应于megatron里的final_layer_input
hidden_states_cond = self.proj_out(hidden_states_cond) # (batch_size, height*width, patch_size*patch_size*out_channels)
hidden_states_uncond = self.proj_out(hidden_states_uncond) # (batch_size, height*width, patch_size*patch_size*out_channels)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 713411564544..447938e0b1c8 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -216,7 +216,9 @@ def _get_glm_embeds(
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
- prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True).hidden_states[-2]
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
+ ).hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _= prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -592,6 +594,7 @@ def __call__(
# Prepare latents.
latent_channels = self.transformer.config.in_channels
+ #########################
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
@@ -602,6 +605,8 @@ def __call__(
generator,
latents,
)
+ latents = torch.ones_like(latents)
+ #########################
# Prepare additional timestep conditions
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
From bf7f3225e8fc26195969f658345734db6df03a0e Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Sat, 25 Jan 2025 00:22:21 +0800
Subject: [PATCH 18/68] [WIP] for older glm
---
scripts/convert_cogview4_to_diffusers.py | 17 +++++++++++++----
src/diffusers/models/normalization.py | 9 ---------
2 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index 8ac3e5854c84..583608705ab5 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -8,7 +8,7 @@
python scripts/convert_cogview4_to_diffusers.py \
--transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
- --output_path "/raid/yiyi/CogBiew4-6B" \
+ --output_path "THUDM/CogView4-6B" \
--dtype "bf16"
Arguments:
@@ -209,12 +209,21 @@ def main(args):
if dtype is not None:
vae = vae.to(dtype=dtype)
- text_encoder_id = "THUDM/glm-4-9b-hf"
- tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
- text_encoder = GlmForCausalLM.from_pretrained(
+ # text_encoder_id = "THUDM/glm-4-9b-hf"
+ # tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
+ # text_encoder = GlmForCausalLM.from_pretrained(
+ # text_encoder_id,
+ # cache_dir=args.text_encoder_cache_dir,
+ # torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
+ # )
+ from transformers import AutoTokenizer,AutoModel
+ text_encoder_id = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/glm-4-9b"
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_id,trust_remote_code=True)
+ text_encoder = AutoModel.from_pretrained(
text_encoder_id,
cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
+ trust_remote_code = True
)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index a298001b3569..7db4d3d17d2f 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -333,18 +333,9 @@ def __init__(
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
-
- ####################################
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
- # emb = self.linear(conditioning_embedding).to(x.dtype)
- ####################################
-
scale, shift = torch.chunk(emb, 2, dim=1)
-
- ############################
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
- # x = x * (1 + scale)[:, None, :] + shift[:, None, :]
- ############################
return x
From dd6568bf0b1b3226f9e39853cfd208b4498e21e6 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Sat, 25 Jan 2025 13:37:26 +0800
Subject: [PATCH 19/68] use with cogview4 transformers forward twice of u and
uc
---
docs/source/en/_toctree.yml | 2 +
.../en/api/models/cogview4_transformer2d.md | 30 ++
scripts/convert_cogview4_to_diffusers.py | 34 +-
src/diffusers/__init__.py | 2 +
src/diffusers/models/__init__.py | 2 +
src/diffusers/models/transformers/__init__.py | 1 +
.../transformers/transformer_cogview3plus.py | 184 ++-----
.../transformers/transformer_cogview4.py | 470 ++++++++++++++++++
.../pipelines/cogview4/pipeline_cogview4.py | 40 +-
src/diffusers/utils/dummy_pt_objects.py | 14 +
10 files changed, 600 insertions(+), 179 deletions(-)
create mode 100644 docs/source/en/api/models/cogview4_transformer2d.md
create mode 100644 src/diffusers/models/transformers/transformer_cogview4.py
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index fc3022cf7b35..09ab490783cb 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -276,6 +276,8 @@
title: ConsisIDTransformer3DModel
- local: api/models/cogview3plus_transformer2d
title: CogView3PlusTransformer2DModel
+ - local: api/models/cogview4_transformer2d
+ title: CogView4Transformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/flux_transformer
diff --git a/docs/source/en/api/models/cogview4_transformer2d.md b/docs/source/en/api/models/cogview4_transformer2d.md
new file mode 100644
index 000000000000..e6c976e64253
--- /dev/null
+++ b/docs/source/en/api/models/cogview4_transformer2d.md
@@ -0,0 +1,30 @@
+
+
+# CogView4Transformer2DModel
+
+A Diffusion Transformer model for 2D data from [CogView4]()
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import CogView3PlusTransformer2DModel
+
+transformer = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+```
+
+## CogView4Transformer2DModel
+
+[[autodoc]] CogView4Transformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index 583608705ab5..1371a16d2eec 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -31,7 +31,7 @@
from accelerate import init_empty_weights
from transformers import PreTrainedTokenizerFast, GlmForCausalLM
-from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView3PlusTransformer2DModel
+from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
@@ -168,7 +168,7 @@ def main(args):
converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
- transformer = CogView3PlusTransformer2DModel(
+ transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=16,
num_layers=28,
@@ -209,23 +209,27 @@ def main(args):
if dtype is not None:
vae = vae.to(dtype=dtype)
- # text_encoder_id = "THUDM/glm-4-9b-hf"
- # tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
- # text_encoder = GlmForCausalLM.from_pretrained(
- # text_encoder_id,
- # cache_dir=args.text_encoder_cache_dir,
- # torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
- # )
- from transformers import AutoTokenizer,AutoModel
- text_encoder_id = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/glm-4-9b"
- tokenizer = AutoTokenizer.from_pretrained(text_encoder_id,trust_remote_code=True)
- text_encoder = AutoModel.from_pretrained(
+ text_encoder_id = "/share/home/zyx/Models/glm-4-9b-hf"
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
+ text_encoder = GlmForCausalLM.from_pretrained(
text_encoder_id,
cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
- trust_remote_code = True
)
- # Apparently, the conversion does not work anymore without this :shrug:
+
+ # TODO: This is for Older GLM-4 as https://huggingface.co/THUDM/glm-4-9b, will use https://huggingface.co/THUDM/glm-4-9b-hf for new transformers version format.
+ # TODO: Remove it later
+
+ # from transformers import AutoTokenizer,AutoModel
+ # text_encoder_id = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/glm-4-9b"
+ # tokenizer = AutoTokenizer.from_pretrained(text_encoder_id,trust_remote_code=True)
+ # text_encoder = AutoModel.from_pretrained(
+ # text_encoder_id,
+ # cache_dir=args.text_encoder_cache_dir,
+ # torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
+ # trust_remote_code = True
+ # )
+
for param in text_encoder.parameters():
param.data = param.data.contiguous()
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 1463296728bd..5e9cec1728f5 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -92,6 +92,7 @@
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
+ "CogView4Transformer2DModel",
"ConsisIDTransformer3DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
@@ -606,6 +607,7 @@
AutoencoderTiny,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
+ CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
ConsistencyDecoderVAE,
ControlNetModel,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index e3f291ce2dc7..fdb206a6e674 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -68,6 +68,7 @@
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
+ _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
@@ -130,6 +131,7 @@
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
+ CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 77e1698b8fc2..04716b9ed39e 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -18,6 +18,7 @@
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
+ from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 3fb270bac7dc..0376cc2fd70d 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -28,7 +28,7 @@
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
-from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed, CogView4PatchEmbed
+from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
@@ -84,7 +84,6 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
- **kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@@ -104,7 +103,7 @@ def forward(
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **kwargs
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
@@ -167,7 +166,8 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
- _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed", "CogView4PlusPatchEmbed"]
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
+ _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config
def __init__(
@@ -192,16 +192,7 @@ def __init__(
# Each of these are sincos embeddings of shape 2 * condition_dim
self.pooled_projection_dim = 3 * 2 * condition_dim
- self.max_h = 256
- self.max_w = 256
- self.rope = self.prepare_rope(
- embed_dim=self.config.attention_head_dim,
- max_h=self.max_h,
- max_w=self.max_w,
- rotary_base=10000
- )
-
- self.patch_embed = CogView4PatchEmbed(
+ self.patch_embed = CogView3PlusPatchEmbed(
in_channels=in_channels,
hidden_size=self.inner_dim,
patch_size=patch_size,
@@ -232,8 +223,7 @@ def __init__(
embedding_dim=self.inner_dim,
conditioning_embedding_dim=time_embed_dim,
elementwise_affine=False,
- # eps=1e-6,
- eps=1e-5,
+ eps=1e-6,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
@@ -303,55 +293,10 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
- @staticmethod
- def prepare_rope(embed_dim, max_h, max_w, rotary_base):
- dim_h = embed_dim // 2
- dim_w = embed_dim // 2
- h_inv_freq = 1.0 / (
- rotary_base ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
- )
- w_inv_freq = 1.0 / (
- rotary_base ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
- )
- h_seq = torch.arange(max_h, dtype=h_inv_freq.dtype)
- w_seq = torch.arange(max_w, dtype=w_inv_freq.dtype)
- freqs_h = torch.outer(h_seq, h_inv_freq)
- freqs_w = torch.outer(w_seq, w_inv_freq)
- return (freqs_h, freqs_w)
-
- def get_rope_embedding(self, height, width, target_h, target_w, device):
- # Get pre-computed frequencies
- freqs_h, freqs_w = self.rope
-
- h_idx = torch.arange(height)
- w_idx = torch.arange(width)
- inner_h_idx = (h_idx * self.max_h) // target_h
- inner_w_idx = (w_idx * self.max_w) // target_w
-
- freqs_h = freqs_h[inner_h_idx].to(device)
- freqs_w = freqs_w[inner_w_idx].to(device)
-
- # Create position matrices for height and width
- # [height, 1, dim//4] and [1, width, dim//4]
- freqs_h = freqs_h.unsqueeze(1)
- freqs_w = freqs_w.unsqueeze(0)
- # Broadcast freqs_h and freqs_w to [height, width, dim//4]
- freqs_h = freqs_h.expand(height, width, -1)
- freqs_w = freqs_w.expand(height, width, -1)
-
- # Concatenate along last dimension to get [height, width, dim//2]
- freqs = torch.cat([freqs_h, freqs_w], dim=-1)
-
- freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
- freqs = freqs.reshape(height*width, -1)
-
- return freqs.cos(), freqs.sin()
-
def forward(
self,
hidden_states: torch.Tensor,
- prompt_embeds: torch.Tensor,
- negative_prompt_embeds: torch.Tensor | None,
+ encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
@@ -386,103 +331,58 @@ def forward(
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
The denoised latents using provided inputs as conditioning.
"""
- batch_size, channel, height, width = hidden_states.shape
- patch_height, patch_width = height // self.config.patch_size, width // self.config.patch_size
- do_cfg = negative_prompt_embeds is not None
-
- if do_cfg:
- assert batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0], "batch size mismatch in CFG mode"
- else:
- assert batch_size == prompt_embeds.shape[0], "batch size mismatch in non-CFG mode"
+ height, width = hidden_states.shape[-2:]
+ text_seq_length = encoder_hidden_states.shape[1]
- hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
- hidden_states, prompt_embeds, negative_prompt_embeds
- )
+ hidden_states = self.patch_embed(
+ hidden_states, encoder_hidden_states
+ ) # takes care of adding positional embeddings too.
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
- hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
- emb_cond, emb_uncond = emb.chunk(2)
-
- # prepare image_rotary__emb
- image_rotary_emb = self.get_rope_embedding(
- patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
- )
-
- ######################
- # prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
- # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
- prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
- negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_uncondition_16_32.pt")[None, ::]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
- hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")
- hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
- emb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")
- emb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")
- ######################
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
- encoder_hidden_states_cond = prompt_embeds
- encoder_hidden_states_uncond = negative_prompt_embeds
+ return custom_forward
- for index_block, block in enumerate(self.transformer_blocks):
- if torch.is_grad_enabled() and self.gradient_checkpointing:
- ...
- else:
- hidden_states_cond, encoder_hidden_states_cond = block(
- hidden_states=hidden_states_cond,
- encoder_hidden_states=encoder_hidden_states_cond,
- emb=emb_cond, # refactor later
- image_rotary_emb=image_rotary_emb,
- # image_rotary_emb=None,
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ **ckpt_kwargs,
)
- ###########################
- # hidden_states_cond, encoder_hidden_states_cond = (
- # self.norm_out.norm(hidden_states_cond),
- # self.norm_out.norm(encoder_hidden_states_cond),
- # )
- ###########################
-
- hidden_states_uncond, encoder_hidden_states_uncond = block(
- hidden_states=hidden_states_uncond,
- encoder_hidden_states=encoder_hidden_states_uncond,
- emb=emb_uncond, # refactor later
- image_rotary_emb=image_rotary_emb,
- # image_rotary_emb=None,
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=emb,
)
- ###########################
- # hidden_states_uncond, encoder_hidden_states_uncond = (
- # self.norm_out.norm(hidden_states_uncond),
- # self.norm_out.norm(encoder_hidden_states_uncond),
- # )
- ###########################
-
- hidden_states_cond = self.norm_out(hidden_states_cond, emb_cond) # 结果对应于megatron里的final_layer_input
- hidden_states_uncond = self.norm_out(hidden_states_uncond, emb_uncond) # 结果对应于megatron里的final_layer_input
- hidden_states_cond = self.proj_out(hidden_states_cond) # (batch_size, height*width, patch_size*patch_size*out_channels)
- hidden_states_uncond = self.proj_out(hidden_states_uncond) # (batch_size, height*width, patch_size*patch_size*out_channels)
+
+ hidden_states = self.norm_out(hidden_states, emb)
+ hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
- hidden_states_cond = hidden_states_cond.reshape(
- shape=(hidden_states_cond.shape[0], height, width, self.out_channels, patch_size, patch_size)
- )
- hidden_states_cond = torch.einsum("nhwcpq->nchpwq", hidden_states_cond)
- output_cond = hidden_states_cond.reshape(
- shape=(hidden_states_cond.shape[0], self.out_channels, height * patch_size, width * patch_size)
- )
-
- hidden_states_uncond = hidden_states_uncond.reshape(
- shape=(hidden_states_uncond.shape[0], height, width, self.out_channels, patch_size, patch_size)
+ hidden_states = hidden_states.reshape(
+ shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
)
- hidden_states_uncond = torch.einsum("nhwcpq->nchpwq", hidden_states_uncond)
- output_uncond = hidden_states_uncond.reshape(
- shape=(hidden_states_uncond.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
- return (output_cond, output_uncond)
+ return (output,)
- return Transformer2DModelOutput(sample=output_cond), Transformer2DModelOutput(sample=output_uncond)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
new file mode 100644
index 000000000000..fe4adc2c09b7
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -0,0 +1,470 @@
+# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.attention import FeedForward
+from ...models.attention_processor import (
+ Attention,
+ AttentionProcessor,
+ CogVideoXAttnProcessor2_0,
+)
+from ...models.modeling_utils import ModelMixin
+from ...models.normalization import AdaLayerNormContinuous
+from ...utils import is_torch_version, logging
+from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView4PatchEmbed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CogView4TransformerBlock(nn.Module):
+ r"""
+ Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
+
+ Args:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ """
+
+ def __init__(
+ self,
+ dim: int = 2560,
+ num_attention_heads: int = 64,
+ attention_head_dim: int = 40,
+ time_embed_dim: int = 512,
+ ):
+ super().__init__()
+
+ self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ out_dim=dim,
+ bias=True,
+ qk_norm="layer_norm",
+ elementwise_affine=False,
+ eps=1e-6,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ emb: torch.Tensor,
+ **kwargs,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ (
+ norm_hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ norm_encoder_hidden_states,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ ) = self.norm1(hidden_states, encoder_hidden_states, emb)
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **kwargs
+ )
+
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
+
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+ return hidden_states, encoder_hidden_states
+
+class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
+ r"""
+ Args:
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ attention_head_dim (`int`, defaults to `40`):
+ The number of channels in each head.
+ num_attention_heads (`int`, defaults to `64`):
+ The number of heads to use for multi-head attention.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ condition_dim (`int`, defaults to `256`):
+ The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
+ crop_coords).
+ pos_embed_max_size (`int`, defaults to `128`):
+ The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
+ to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
+ means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
+ patch_size => 128 * 8 * 2 => 2048`.
+ sample_size (`int`, defaults to `128`):
+ The base resolution of input latents. If height/width is not provided during generation, this value is used
+ to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["CogView4TransformerBlock", "CogView4PatchEmbed", "CogView4PatchEmbed"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ num_layers: int = 30,
+ attention_head_dim: int = 40,
+ num_attention_heads: int = 64,
+ out_channels: int = 16,
+ text_embed_dim: int = 4096,
+ time_embed_dim: int = 512,
+ condition_dim: int = 256,
+ pos_embed_max_size: int = 128,
+ sample_size: int = 128,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
+ # Each of these are sincos embeddings of shape 2 * condition_dim
+ self.pooled_projection_dim = 3 * 2 * condition_dim
+
+ self.max_h = 256
+ self.max_w = 256
+ self.rope = self.prepare_rope(
+ embed_dim=self.config.attention_head_dim, max_h=self.max_h, max_w=self.max_w, rotary_base=10000
+ )
+
+ self.patch_embed = CogView4PatchEmbed(
+ in_channels=in_channels,
+ hidden_size=self.inner_dim,
+ patch_size=patch_size,
+ text_hidden_size=text_embed_dim,
+ pos_embed_max_size=pos_embed_max_size,
+ )
+
+ self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
+ embedding_dim=time_embed_dim,
+ condition_dim=condition_dim,
+ pooled_projection_dim=self.pooled_projection_dim,
+ timesteps_dim=self.inner_dim,
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogView4TransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(
+ embedding_dim=self.inner_dim,
+ conditioning_embedding_dim=time_embed_dim,
+ elementwise_affine=False,
+ # eps=1e-6, # For CogView4 is 1e-5
+ )
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @staticmethod
+ def prepare_rope(embed_dim, max_h, max_w, rotary_base):
+ dim_h = embed_dim // 2
+ dim_w = embed_dim // 2
+ h_inv_freq = 1.0 / (
+ rotary_base ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
+ )
+ w_inv_freq = 1.0 / (
+ rotary_base ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
+ )
+ h_seq = torch.arange(max_h, dtype=h_inv_freq.dtype)
+ w_seq = torch.arange(max_w, dtype=w_inv_freq.dtype)
+ freqs_h = torch.outer(h_seq, h_inv_freq)
+ freqs_w = torch.outer(w_seq, w_inv_freq)
+ return (freqs_h, freqs_w)
+
+ def get_rope_embedding(self, height, width, target_h, target_w, device):
+ # Get pre-computed frequencies
+ freqs_h, freqs_w = self.rope
+
+ h_idx = torch.arange(height)
+ w_idx = torch.arange(width)
+ inner_h_idx = (h_idx * self.max_h) // target_h
+ inner_w_idx = (w_idx * self.max_w) // target_w
+
+ freqs_h = freqs_h[inner_h_idx].to(device)
+ freqs_w = freqs_w[inner_w_idx].to(device)
+
+ # Create position matrices for height and width
+ # [height, 1, dim//4] and [1, width, dim//4]
+ freqs_h = freqs_h.unsqueeze(1)
+ freqs_w = freqs_w.unsqueeze(0)
+ # Broadcast freqs_h and freqs_w to [height, width, dim//4]
+ freqs_h = freqs_h.expand(height, width, -1)
+ freqs_w = freqs_w.expand(height, width, -1)
+
+ # Concatenate along last dimension to get [height, width, dim//2]
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1)
+
+ freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
+ freqs = freqs.reshape(height * width, -1)
+
+ return freqs.cos(), freqs.sin()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ timestep: torch.LongTensor,
+ original_size: torch.Tensor,
+ target_size: torch.Tensor,
+ crop_coords: torch.Tensor,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`CogView3PlusTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input `hidden_states` of shape `(batch size, channel, height, width)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
+ `(batch_size, sequence_len, text_embed_dim)`
+ timestep (`torch.LongTensor`):
+ Used to indicate denoising step.
+ original_size (`torch.Tensor`):
+ CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`torch.Tensor`):
+ CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crop_coords (`torch.Tensor`):
+ CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
+ The denoised latents using provided inputs as conditioning.
+ """
+ batch_size, channel, height, width = hidden_states.shape
+ patch_height, patch_width = height // self.config.patch_size, width // self.config.patch_size
+ do_cfg = negative_prompt_embeds is not None
+
+ if do_cfg:
+ assert batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0], (
+ "batch size mismatch in CFG mode"
+ )
+ else:
+ assert batch_size == prompt_embeds.shape[0], "batch size mismatch in non-CFG mode"
+
+ # 1. RoPE
+ image_rotary_emb = self.get_rope_embedding(
+ patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
+ )
+
+ # 2. Conditional embeddings
+ temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
+ temb_cond, temb_uncond = temb.chunk(2)
+ hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
+ hidden_states, prompt_embeds, negative_prompt_embeds
+ )
+ hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
+ encoder_hidden_states_cond = prompt_embeds
+ encoder_hidden_states_uncond = negative_prompt_embeds
+
+ ######################
+ # prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
+ # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
+ # prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
+ # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_uncondition_16_32.pt")[None, ::]
+ #
+ # hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")
+ # hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")
+ #
+ # emb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")
+ # emb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")
+ ######################
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ # TODO 微调使用
+ ...
+ else:
+ hidden_states_cond, encoder_hidden_states_cond = block(
+ hidden_states=hidden_states_cond,
+ encoder_hidden_states=encoder_hidden_states_cond,
+ emb=temb_cond,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states_uncond, encoder_hidden_states_uncond = block(
+ hidden_states=hidden_states_uncond,
+ encoder_hidden_states=encoder_hidden_states_uncond,
+ emb=temb_uncond,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states_cond, encoder_hidden_states_cond = (
+ self.norm_out(hidden_states_cond, temb_cond),
+ self.norm_out(encoder_hidden_states_cond, temb_cond),
+ )
+ hidden_states_uncond, encoder_hidden_states_uncond = (
+ self.norm_out(hidden_states_uncond, temb_uncond),
+ self.norm_out(encoder_hidden_states_uncond, temb_uncond),
+ )
+ hidden_states_cond = self.proj_out(hidden_states_cond)
+ hidden_states_uncond = self.proj_out(hidden_states_uncond)
+
+ # unpatchify
+ patch_size = self.config.patch_size
+ height = height // patch_size
+ width = width // patch_size
+
+ hidden_states_cond = hidden_states_cond.reshape(
+ shape=(hidden_states_cond.shape[0], height, width, self.out_channels, patch_size, patch_size)
+ )
+ hidden_states_cond = torch.einsum("nhwcpq->nchpwq", hidden_states_cond)
+ output_cond = hidden_states_cond.reshape(
+ shape=(hidden_states_cond.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ )
+
+ hidden_states_uncond = hidden_states_uncond.reshape(
+ shape=(hidden_states_uncond.shape[0], height, width, self.out_channels, patch_size, patch_size)
+ )
+ hidden_states_uncond = torch.einsum("nhwcpq->nchpwq", hidden_states_uncond)
+ output_uncond = hidden_states_uncond.reshape(
+ shape=(hidden_states_uncond.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ )
+ if not return_dict:
+ return (output_cond, output_uncond)
+ return Transformer2DModelOutput(sample=output_cond), Transformer2DModelOutput(sample=output_uncond)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 447938e0b1c8..78a4a36ce82b 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -22,7 +22,7 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
-from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
+from ...models import AutoencoderKL, CogView4Transformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogView4DDIMScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -164,7 +164,7 @@ def __init__(
tokenizer: GlmModel,
text_encoder: GlmModel,
vae: AutoencoderKL,
- transformer: CogView3PlusTransformer2DModel,
+ transformer: CogView4Transformer2DModel,
scheduler: CogView4DDIMScheduler,
):
super().__init__()
@@ -219,8 +219,15 @@ def _get_glm_embeds(
prompt_embeds = self.text_encoder(
text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
).hidden_states[-2]
+
+ # TODO: This is for Older GLM-4 as https://huggingface.co/THUDM/glm-4-9b, will use https://huggingface.co/THUDM/glm-4-9b-hf for new transformers version format.
+ # TODO: Remove it later
+ # prompt_embeds = self.text_encoder(
+ # text_input_ids.to(self.text_encoder.transformer.device), output_hidden_states=True
+ # ).hidden_states[-2]
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- _, seq_len, _= prompt_embeds.shape
+ _, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
@@ -539,8 +546,8 @@ def __call__(
Examples:
Returns:
- [`~pipelines.cogview4.pipeline_CogView4.CogView3PipelineOutput`] or `tuple`:
- [`~pipelines.cogview4.pipeline_CogView4.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
@@ -592,9 +599,8 @@ def __call__(
device=device,
)
- # Prepare latents.
+ # 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
- #########################
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
@@ -605,8 +611,6 @@ def __call__(
generator,
latents,
)
- latents = torch.ones_like(latents)
- #########################
# Prepare additional timestep conditions
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
@@ -626,7 +630,7 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Prepare timesteps
- self.scheduler.set_timesteps(num_inference_steps, device) # 记得确认把scheduler.config的timestep_spacing是linspace
+ self.scheduler.set_timesteps(num_inference_steps, device)
timesteps = self.scheduler.timesteps
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
@@ -634,23 +638,19 @@ def __call__(
mu = calculate_shift(image_seq_len)
sigmas = timesteps / self.scheduler.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
-
- self.sigmas = time_shift(mu, 1.0, sigmas).to("cpu") # This is for noisy control of cogview4
+ self.sigmas = time_shift(mu, 1.0, sigmas).to("cpu")
self._num_timesteps = len(timesteps)
- # Denoising loop
+ # 6. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
timestep = t.reshape((1,))
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
timestep = torch.cat([timestep] * 2) if do_classifier_free_guidance else t
-
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-
- # predict noise model_output using sigma
noise_pred = self.transformer(
hidden_states=latent_model_input,
prompt_embeds=prompt_embeds,
@@ -660,10 +660,7 @@ def __call__(
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
- )[0]
-
- noise_pred = noise_pred.float()
-
+ )
# perform guidance
if do_classifier_free_guidance:
noise_pred_cond, noise_pred_uncond = noise_pred
@@ -687,7 +684,6 @@ def __call__(
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, sigma, callback_kwargs)
-
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index d967e99d8370..68341674108a 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -226,6 +226,20 @@ def from_config(cls, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class CogView4Transformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
class ConsisIDTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
From 9e5b991c3ce934ad10be948e628b775e84888d8b Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Sat, 25 Jan 2025 14:06:49 +0800
Subject: [PATCH 20/68] Update convert_cogview4_to_diffusers.py
---
scripts/convert_cogview4_to_diffusers.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index 1371a16d2eec..4405a40fb761 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -209,7 +209,7 @@ def main(args):
if dtype is not None:
vae = vae.to(dtype=dtype)
- text_encoder_id = "/share/home/zyx/Models/glm-4-9b-hf"
+ text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmForCausalLM.from_pretrained(
text_encoder_id,
From 36b1682ec1f907c8b10395968bcb10292302b50f Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Sun, 26 Jan 2025 16:02:26 +0800
Subject: [PATCH 21/68] remove this
---
src/diffusers/models/transformers/transformer_cogview4.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index fe4adc2c09b7..2611cb4b5e88 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -224,7 +224,6 @@ def __init__(
embedding_dim=self.inner_dim,
conditioning_embedding_dim=time_embed_dim,
elementwise_affine=False,
- # eps=1e-6, # For CogView4 is 1e-5
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
From 16c2397c5a8e0f24b5a317df08ce381e293eeba1 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 20:30:41 +0800
Subject: [PATCH 22/68] use main example
---
examples/community/README.md | 92 +------------------
examples/community/matryoshka.py | 79 +++++++++++++++-
examples/dreambooth/README.md | 26 ------
.../dreambooth/train_dreambooth_lora_sana.py | 15 +--
.../pixart/controlnet_pixart_alpha.py | 20 +++-
.../text_to_image/train_text_to_image_lora.py | 4 +
6 files changed, 100 insertions(+), 136 deletions(-)
mode change 100644 => 100755 examples/community/README.md
diff --git a/examples/community/README.md b/examples/community/README.md
old mode 100644
new mode 100755
index 4c593a004893..c7c40c46ef2d
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,7 +77,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
-| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -4586,8 +4585,8 @@ image = pipe(
```
|  |  |  |
-| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
-| Gradient | Input | Output |
+| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
+| Gradient | Input | Output |
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
@@ -4635,93 +4634,6 @@ make_image_grid(image, rows=1, cols=len(image))
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
-### Stable Diffusion XL Attentive Eraser Pipeline
-
-
-**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
-
-#### Key features
-
-- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
-- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
-- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
-
-#### Usage example
-To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
-```py
-import torch
-from diffusers import DDIMScheduler, DiffusionPipeline
-from diffusers.utils import load_image
-import torch.nn.functional as F
-from torchvision.transforms.functional import to_tensor, gaussian_blur
-
-dtype = torch.float16
-device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
-
-scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
- scheduler=scheduler,
- variant="fp16",
- use_safetensors=True,
- torch_dtype=dtype,
-).to(device)
-
-
-def preprocess_image(image_path, device):
- image = to_tensor((load_image(image_path)))
- image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
- if image.shape[1] != 3:
- image = image.expand(-1, 3, -1, -1)
- image = F.interpolate(image, (1024, 1024))
- image = image.to(dtype).to(device)
- return image
-
-def preprocess_mask(mask_path, device):
- mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
- mask = mask.unsqueeze_(0).float() # 0 or 1
- mask = F.interpolate(mask, (1024, 1024))
- mask = gaussian_blur(mask, kernel_size=(77, 77))
- mask[mask < 0.1] = 0
- mask[mask >= 0.1] = 1
- mask = mask.to(dtype).to(device)
- return mask
-
-prompt = "" # Set prompt to null
-seed=123
-generator = torch.Generator(device=device).manual_seed(seed)
-source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
-mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
-source_image = preprocess_image(source_image_path, device)
-mask = preprocess_mask(mask_path, device)
-
-image = pipeline(
- prompt=prompt,
- image=source_image,
- mask_image=mask,
- height=1024,
- width=1024,
- AAS=True, # enable AAS
- strength=0.8, # inpainting strength
- rm_guidance_scale=9, # removal guidance scale
- ss_steps = 9, # similarity suppression steps
- ss_scale = 0.3, # similarity suppression scale
- AAS_start_step=0, # AAS start step
- AAS_start_layer=34, # AAS start layer
- AAS_end_layer=70, # AAS end layer
- num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
- generator=generator,
- guidance_scale=1,
-).images[0]
-image.save('./removed_img.png')
-print("Object removal completed")
-```
-
-| Source Image | Mask | Output |
-| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
-|  |  |  |
-
# Perturbed-Attention Guidance
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py
index 4895bd150114..1d7a367ecc60 100644
--- a/examples/community/matryoshka.py
+++ b/examples/community/matryoshka.py
@@ -80,6 +80,7 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
+ is_torch_version,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -868,7 +869,23 @@ def forward(
for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1013,6 +1030,17 @@ def forward(
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1021,7 +1049,12 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
- hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
else:
hidden_states = attn(
hidden_states,
@@ -1159,7 +1192,23 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1233,6 +1282,10 @@ def __init__(
]
)
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -1312,8 +1365,19 @@ def forward(
# Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
- block,
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -1321,6 +1385,7 @@ def forward(
timestep,
cross_attention_kwargs,
class_labels,
+ **ckpt_kwargs,
)
else:
hidden_states = block(
@@ -2659,6 +2724,10 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md
index eed0575c322d..f97a4d0cd0f4 100644
--- a/examples/dreambooth/README.md
+++ b/examples/dreambooth/README.md
@@ -742,29 +742,3 @@ accelerate launch train_dreambooth.py \
## Stable Diffusion XL
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
-
-## Dataset
-
-We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
-
-The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
-
-We need to create a file `metadata.jsonl` in the directory with our images:
-
-```
-{"file_name": "01.jpg", "prompt": "prompt 01"}
-{"file_name": "02.jpg", "prompt": "prompt 02"}
-```
-
-If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
-
-```sh
-python convert_to_imagefolder.py --path my_dataset/
-```
-
-We use `--dataset_name` and `--caption_column` with training scripts.
-
-```
---dataset_name=my_dataset/
---caption_column=prompt
-```
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
index a6baea9967a2..ce8e768f7b5b 100644
--- a/examples/dreambooth/train_dreambooth_lora_sana.py
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -63,7 +63,6 @@
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
-from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -75,9 +74,6 @@
logger = get_logger(__name__)
-if is_torch_npu_available():
- torch.npu.config.allow_internal_format = False
-
def save_model_card(
repo_id: str,
@@ -605,7 +601,6 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
- parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -929,7 +924,8 @@ def main(args):
image.save(image_filename)
del pipeline
- free_memory()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
# Handle the repository creation
if accelerator.is_main_process:
@@ -992,13 +988,6 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16)
- if args.enable_npu_flash_attention:
- if is_torch_npu_available():
- logger.info("npu flash attention enabled.")
- transformer.enable_npu_flash_attention()
- else:
- raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
-
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py
index 8f2eb974398d..f825719a1364 100644
--- a/examples/research_projects/pixart/controlnet_pixart_alpha.py
+++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py
@@ -8,6 +8,7 @@
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils.torch_utils import is_torch_version
class PixArtControlNetAdapterBlock(nn.Module):
@@ -150,6 +151,10 @@ def __init__(
self.transformer = transformer
self.controlnet = controlnet
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -215,8 +220,18 @@ def forward(
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
- hidden_states = self._gradient_checkpointing_func(
- block,
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -224,6 +239,7 @@ def forward(
timestep,
cross_attention_kwargs,
None,
+ **ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 82c395c685f8..e7f2f5c4c881 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -515,6 +515,10 @@ def main():
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
+ # Freeze the unet parameters before adding adapters
+ for param in unet.parameters():
+ param.requires_grad_(False)
+
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
From 601696d0be99a4364f317000dc1f72fd782de2e5 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 20:33:53 +0800
Subject: [PATCH 23/68] change back
---
.../train_dreambooth_lora_flux_advanced.py | 8 +-
.../train_dreambooth_lora_sd15_advanced.py | 16 +--
.../train_dreambooth_lora_sdxl_advanced.py | 16 +--
examples/amused/train_amused.py | 2 +-
.../train_cogvideox_image_to_video_lora.py | 2 +-
examples/cogvideo/train_cogvideox_lora.py | 2 +-
examples/community/README.md | 92 +++++++++++++-
.../community/adaptive_mask_inpainting.py | 2 +-
examples/community/hd_painter.py | 2 +-
examples/community/img2img_inpainting.py | 2 +-
examples/community/llm_grounded_diffusion.py | 4 +-
examples/community/lpw_stable_diffusion_xl.py | 2 +-
examples/community/matryoshka.py | 79 +-----------
.../pipeline_flux_differential_img2img.py | 4 +-
examples/community/pipeline_prompt2prompt.py | 12 +-
.../community/pipeline_sdxl_style_aligned.py | 2 +-
...pipeline_stable_diffusion_upscale_ldm3d.py | 2 +-
...diffusion_xl_controlnet_adapter_inpaint.py | 2 +-
examples/community/scheduling_ufogen.py | 3 +-
.../train_lcm_distill_lora_sd_wds.py | 2 +-
.../train_lcm_distill_lora_sdxl.py | 2 +-
.../train_lcm_distill_lora_sdxl_wds.py | 2 +-
examples/custom_diffusion/retrieve.py | 8 +-
.../train_custom_diffusion.py | 24 ++--
examples/dreambooth/README.md | 26 ++++
examples/dreambooth/train_dreambooth.py | 2 +-
examples/dreambooth/train_dreambooth_lora.py | 2 +-
.../dreambooth/train_dreambooth_lora_flux.py | 2 +-
.../dreambooth/train_dreambooth_lora_sana.py | 17 ++-
.../dreambooth/train_dreambooth_lora_sd3.py | 2 +-
.../dreambooth/train_dreambooth_lora_sdxl.py | 4 +-
.../flux-control/train_control_lora_flux.py | 8 +-
.../colossalai/train_dreambooth_colossalai.py | 2 +-
.../controlnet/train_controlnet_webdataset.py | 7 +-
.../diffusion_dpo/train_diffusion_dpo.py | 2 +-
.../diffusion_dpo/train_diffusion_dpo_sdxl.py | 2 +-
.../train_diffusion_orpo_sdxl_lora.py | 4 +-
.../train_diffusion_orpo_sdxl_lora_wds.py | 4 +-
.../train_dreambooth_lora_flux_miniature.py | 2 +-
examples/research_projects/gligen/demo.ipynb | 20 +--
.../train_instruct_pix2pix_lora.py | 4 +-
.../train_multi_subject_dreambooth.py | 12 +-
.../textual_inversion.py | 6 +-
.../textual_inversion/textual_inversion.py | 6 +-
.../pixart/controlnet_pixart_alpha.py | 20 +--
.../pipeline_prompt_diffusion.py | 3 +-
.../text_to_image/train_text_to_image_xla.py | 4 +-
.../dreambooth/train_dreambooth.py | 2 +-
.../dreambooth/train_dreambooth_lora.py | 2 +-
.../dreambooth/train_dreambooth_lora_sdxl.py | 4 +-
.../train_text_to_image_lora_sdxl.py | 2 +-
.../train_dreambooth_lora_sd3_miniature.py | 2 +-
.../text_to_image/train_text_to_image_lora.py | 4 -
.../train_text_to_image_lora_sdxl.py | 2 +-
.../textual_inversion/textual_inversion.py | 6 +-
.../textual_inversion_sdxl.py | 12 +-
examples/vqgan/test_vqgan.py | 6 +-
examples/vqgan/train_vqgan.py | 12 +-
scripts/convert_amused.py | 2 +-
scripts/convert_consistency_to_diffusers.py | 4 +-
.../convert_dance_diffusion_to_diffusers.py | 12 +-
scripts/convert_diffusers_to_original_sdxl.py | 18 +--
..._diffusers_to_original_stable_diffusion.py | 20 +--
...vert_hunyuandit_controlnet_to_diffusers.py | 6 +-
scripts/convert_hunyuandit_to_diffusers.py | 9 +-
scripts/convert_k_upscaler_to_diffusers.py | 10 +-
scripts/convert_mochi_to_diffusers.py | 118 +++++++++---------
...convert_original_audioldm2_to_diffusers.py | 2 +-
.../convert_original_audioldm_to_diffusers.py | 2 +-
.../convert_original_musicldm_to_diffusers.py | 2 +-
scripts/convert_stable_audio.py | 18 +--
scripts/convert_svd_to_diffusers.py | 12 +-
scripts/convert_vq_diffusion_to_diffusers.py | 24 ++--
73 files changed, 404 insertions(+), 362 deletions(-)
mode change 100755 => 100644 examples/community/README.md
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
index 0298d3a6bfe1..235113d6a348 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
@@ -818,9 +818,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(isinstance(tok, str) for tok in inserting_toks), (
- "All elements in inserting_toks should be strings."
- )
+ assert all(
+ isinstance(tok, str) for tok in inserting_toks
+ ), "All elements in inserting_toks should be strings."
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -1683,7 +1683,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 865c93bf6e87..86891d5d7f0c 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -200,7 +200,7 @@ def save_model_card(
"diffusers",
"diffusers-training",
lora,
- "template:sd-lorastable-diffusion",
+ "template:sd-lora" "stable-diffusion",
"stable-diffusion-diffusers",
]
model_card = populate_model_card(model_card, tags=tags)
@@ -724,9 +724,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(isinstance(tok, str) for tok in inserting_toks), (
- "All elements in inserting_toks should be strings."
- )
+ assert all(
+ isinstance(tok, str) for tok in inserting_toks
+ ), "All elements in inserting_toks should be strings."
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -746,9 +746,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
.to(dtype=self.dtype)
* std_token_embedding
)
- self.embeddings_settings[f"original_embeddings_{idx}"] = (
- text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
- )
+ self.embeddings_settings[
+ f"original_embeddings_{idx}"
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1322,7 +1322,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index 71ccfb1ee6e9..6e4f40c22df9 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -116,7 +116,7 @@ def save_model_card(
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
- - text: '{validation_prompt if validation_prompt else " "}'
+ - text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
@@ -891,9 +891,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(isinstance(tok, str) for tok in inserting_toks), (
- "All elements in inserting_toks should be strings."
- )
+ assert all(
+ isinstance(tok, str) for tok in inserting_toks
+ ), "All elements in inserting_toks should be strings."
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -913,9 +913,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
.to(dtype=self.dtype)
* std_token_embedding
)
- self.embeddings_settings[f"original_embeddings_{idx}"] = (
- text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
- )
+ self.embeddings_settings[
+ f"original_embeddings_{idx}"
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1648,7 +1648,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py
index d71d9ccbb83e..df44a0a63aeb 100644
--- a/examples/amused/train_amused.py
+++ b/examples/amused/train_amused.py
@@ -720,7 +720,7 @@ def load_model_hook(models, input_dir):
# Train!
logger.info("***** Running training *****")
logger.info(f" Num training steps = {args.max_train_steps}")
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
index 86f2965636f3..aaee133680ea 100644
--- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py
+++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
@@ -1138,7 +1138,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
index 59e42fcb80d7..01ea59c593a9 100644
--- a/examples/cogvideo/train_cogvideox_lora.py
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -1159,7 +1159,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/community/README.md b/examples/community/README.md
old mode 100755
new mode 100644
index c7c40c46ef2d..4c593a004893
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -4585,8 +4586,8 @@ image = pipe(
```
|  |  |  |
-| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
-| Gradient | Input | Output |
+| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
+| Gradient | Input | Output |
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
@@ -4634,6 +4635,93 @@ make_image_grid(image, rows=1, cols=len(image))
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
+### Stable Diffusion XL Attentive Eraser Pipeline
+
+
+**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
+
+#### Key features
+
+- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
+- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
+- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
+
+#### Usage example
+To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
+```py
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers.utils import load_image
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+dtype = torch.float16
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+ scheduler=scheduler,
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=dtype,
+).to(device)
+
+
+def preprocess_image(image_path, device):
+ image = to_tensor((load_image(image_path)))
+ image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+ if image.shape[1] != 3:
+ image = image.expand(-1, 3, -1, -1)
+ image = F.interpolate(image, (1024, 1024))
+ image = image.to(dtype).to(device)
+ return image
+
+def preprocess_mask(mask_path, device):
+ mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+ mask = mask.unsqueeze_(0).float() # 0 or 1
+ mask = F.interpolate(mask, (1024, 1024))
+ mask = gaussian_blur(mask, kernel_size=(77, 77))
+ mask[mask < 0.1] = 0
+ mask[mask >= 0.1] = 1
+ mask = mask.to(dtype).to(device)
+ return mask
+
+prompt = "" # Set prompt to null
+seed=123
+generator = torch.Generator(device=device).manual_seed(seed)
+source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+source_image = preprocess_image(source_image_path, device)
+mask = preprocess_mask(mask_path, device)
+
+image = pipeline(
+ prompt=prompt,
+ image=source_image,
+ mask_image=mask,
+ height=1024,
+ width=1024,
+ AAS=True, # enable AAS
+ strength=0.8, # inpainting strength
+ rm_guidance_scale=9, # removal guidance scale
+ ss_steps = 9, # similarity suppression steps
+ ss_scale = 0.3, # similarity suppression scale
+ AAS_start_step=0, # AAS start step
+ AAS_start_layer=34, # AAS start layer
+ AAS_end_layer=70, # AAS end layer
+ num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+ generator=generator,
+ guidance_scale=1,
+).images[0]
+image.save('./removed_img.png')
+print("Object removal completed")
+```
+
+| Source Image | Mask | Output |
+| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
+|  |  |  |
+
# Perturbed-Attention Guidance
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py
index 81f9527b4703..df736956485b 100644
--- a/examples/community/adaptive_mask_inpainting.py
+++ b/examples/community/adaptive_mask_inpainting.py
@@ -1103,7 +1103,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `default_mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/hd_painter.py b/examples/community/hd_painter.py
index 9d7b95b62c6e..91ebe076104a 100644
--- a/examples/community/hd_painter.py
+++ b/examples/community/hd_painter.py
@@ -686,7 +686,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py
index 001e4cc5b2cf..292c9aa2bc47 100644
--- a/examples/community/img2img_inpainting.py
+++ b/examples/community/img2img_inpainting.py
@@ -362,7 +362,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py
index 814694f1e366..129793dae6b0 100644
--- a/examples/community/llm_grounded_diffusion.py
+++ b/examples/community/llm_grounded_diffusion.py
@@ -1120,7 +1120,7 @@ def latent_lmd_guidance(
if verbose:
logger.info(
- f"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
+ f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
)
try:
@@ -1184,7 +1184,7 @@ def latent_lmd_guidance(
if verbose:
logger.info(
- f"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
+ f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
)
finally:
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index af1082e8410b..4bcef10f97c2 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -1773,7 +1773,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py
index 1d7a367ecc60..4895bd150114 100644
--- a/examples/community/matryoshka.py
+++ b/examples/community/matryoshka.py
@@ -80,7 +80,6 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
- is_torch_version,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -869,23 +868,7 @@ def forward(
for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1030,17 +1013,6 @@ def forward(
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
@@ -1192,23 +1159,7 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1282,10 +1233,6 @@ def __init__(
]
)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -1365,19 +1312,8 @@ def forward(
# Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -1385,7 +1321,6 @@ def custom_forward(*inputs):
timestep,
cross_attention_kwargs,
class_labels,
- **ckpt_kwargs,
)
else:
hidden_states = block(
@@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py
index 33eaa9de04cd..a66e2b1c7c8a 100644
--- a/examples/community/pipeline_flux_differential_img2img.py
+++ b/examples/community/pipeline_flux_differential_img2img.py
@@ -488,7 +488,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -496,7 +496,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index b9985542ccf7..736f00799eae 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -907,12 +907,12 @@ def create_controller(
# reweight
if edit_type == "reweight":
- assert equalizer_words is not None and equalizer_strengths is not None, (
- "To use reweight edit, please specify equalizer_words and equalizer_strengths."
- )
- assert len(equalizer_words) == len(equalizer_strengths), (
- "equalizer_words and equalizer_strengths must be of same length."
- )
+ assert (
+ equalizer_words is not None and equalizer_strengths is not None
+ ), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
+ assert len(equalizer_words) == len(
+ equalizer_strengths
+ ), "equalizer_words and equalizer_strengths must be of same length."
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
return AttentionReweight(
prompts,
diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py
index 6aebb6c18df7..9377caf7ba2e 100644
--- a/examples/community/pipeline_sdxl_style_aligned.py
+++ b/examples/community/pipeline_sdxl_style_aligned.py
@@ -1738,7 +1738,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
index 6c63f53e815c..8a709ab46757 100644
--- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
+++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
@@ -689,7 +689,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents + num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index 6a0ed3523dab..8480117866cc 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -1578,7 +1578,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/scheduling_ufogen.py b/examples/community/scheduling_ufogen.py
index 0b832394cf97..4b1b92ff183a 100644
--- a/examples/community/scheduling_ufogen.py
+++ b/examples/community/scheduling_ufogen.py
@@ -288,7 +288,8 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index 28fc7c73e6eb..2045e7809310 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
if "lora_down" in kohya_key:
- alpha_key = f"{kohya_key.split('.')[0]}.alpha"
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
index 61d883fdfb78..38fe94ed3fe5 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -901,7 +901,7 @@ def load_model_hook(models, input_dir):
unet_ = accelerator.unwrap_model(unet)
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
unet_state_dict = {
- f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index 4324f81b9695..fdb789c21628 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
if "lora_down" in kohya_key:
- alpha_key = f"{kohya_key.split('.')[0]}.alpha"
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict
diff --git a/examples/custom_diffusion/retrieve.py b/examples/custom_diffusion/retrieve.py
index 27f4b4e0dc60..a28fe344d93b 100644
--- a/examples/custom_diffusion/retrieve.py
+++ b/examples/custom_diffusion/retrieve.py
@@ -50,11 +50,9 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
total = 0
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
- with (
- open(f"{class_data_dir}/caption.txt", "w") as f1,
- open(f"{class_data_dir}/urls.txt", "w") as f2,
- open(f"{class_data_dir}/images.txt", "w") as f3,
- ):
+ with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
+ f"{class_data_dir}/images.txt", "w"
+ ) as f3:
while total < num_class_images:
images = class_images[count]
count += 1
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index 140e64a0e075..dc21746cb159 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -731,18 +731,18 @@ def main(args):
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True)
if args.real_prior:
- assert (class_images_dir / "images").exists(), (
- f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
- )
- assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
- f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
- )
- assert (class_images_dir / "caption.txt").exists(), (
- f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
- )
- assert (class_images_dir / "images.txt").exists(), (
- f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
- )
+ assert (
+ class_images_dir / "images"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ len(list((class_images_dir / "images").iterdir())) == args.num_class_images
+ ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ class_images_dir / "caption.txt"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ class_images_dir / "images.txt"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
args.concepts_list[i] = concept
diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md
index f97a4d0cd0f4..eed0575c322d 100644
--- a/examples/dreambooth/README.md
+++ b/examples/dreambooth/README.md
@@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \
## Stable Diffusion XL
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
+
+## Dataset
+
+We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
+
+The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
+
+We need to create a file `metadata.jsonl` in the directory with our images:
+
+```
+{"file_name": "01.jpg", "prompt": "prompt 01"}
+{"file_name": "02.jpg", "prompt": "prompt 02"}
+```
+
+If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
+
+```sh
+python convert_to_imagefolder.py --path my_dataset/
+```
+
+We use `--dataset_name` and `--caption_column` with training scripts.
+
+```
+--dataset_name=my_dataset/
+--caption_column=prompt
+```
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 43e680610ee5..b863f5641233 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -1014,7 +1014,7 @@ def load_model_hook(models, input_dir):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
- f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 6a817cf09b63..83a24b778083 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -982,7 +982,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index 2278784f896d..91e028251a1d 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -1275,7 +1275,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
index ce8e768f7b5b..9e69bd6a668b 100644
--- a/examples/dreambooth/train_dreambooth_lora_sana.py
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -63,6 +63,7 @@
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -74,6 +75,9 @@
logger = get_logger(__name__)
+if is_torch_npu_available():
+ torch.npu.config.allow_internal_format = False
+
def save_model_card(
repo_id: str,
@@ -601,6 +605,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -924,8 +929,7 @@ def main(args):
image.save(image_filename)
del pipeline
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
+ free_memory()
# Handle the repository creation
if accelerator.is_main_process:
@@ -988,6 +992,13 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16)
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ transformer.enable_npu_flash_attention()
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1052,7 +1063,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 191dbfbb37a3..65e7dac26bdd 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -1355,7 +1355,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 8115dd61483c..35704c574f28 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
-# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
+# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
@@ -1271,7 +1271,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index 2a9bfd949cde..56c5f2a89a3a 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
torch_dtype=weight_dtype,
)
pipeline.load_lora_weights(args.output_dir)
- assert pipeline.transformer.config.in_channels == initial_channels * 2, (
- f"{pipeline.transformer.config.in_channels=}"
- )
+ assert (
+ pipeline.transformer.config.in_channels == initial_channels * 2
+ ), f"{pipeline.transformer.config.in_channels=}"
pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
@@ -954,7 +954,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
transformer_lora_state_dict = {
- f"{k.replace('transformer.', '')}": v
+ f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items()
if k.startswith("transformer.") and "lora" in k
}
diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py
index 4e541b8d3a02..10c8e095a696 100644
--- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py
+++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py
@@ -619,7 +619,7 @@ def collate_fn(examples):
optimizer.step()
lr_scheduler.step()
- logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0])
+ logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py
index e820c34e6fcf..765bb495062e 100644
--- a/examples/research_projects/controlnet/train_controlnet_webdataset.py
+++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -805,20 +805,21 @@ def parse_args(input_args=None):
"--control_type",
type=str,
default="canny",
- help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
+ help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."),
)
parser.add_argument(
"--transformer_layers_per_block",
type=str,
default=None,
- help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
+ help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."),
)
parser.add_argument(
"--old_style_controlnet",
action="store_true",
default=False,
help=(
- "Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
+ "Use the old style controlnet, which is a single transformer layer with"
+ " a single head. Defaults to False."
),
)
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
index 0b9c248ed004..ab88d4967766 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
index f0afa12e9ceb..0297a06f5b2c 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
index 12eb67d4a7bb..ed245e9cef7d 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -683,7 +683,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
index a5d89f77d687..66a7a3652947 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -790,7 +790,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
index cc535bbaaa85..ccaf3164a00c 100644
--- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
+++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
@@ -783,7 +783,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb
index b467ba3a87bc..571f1a0323a2 100644
--- a/examples/research_projects/gligen/demo.ipynb
+++ b/examples/research_projects/gligen/demo.ipynb
@@ -48,12 +48,16 @@
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
- "pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
+ "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
- "text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
- "vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
+ "text_encoder = CLIPTextModel.from_pretrained(\n",
+ " pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
+ ")\n",
+ "vae = AutoencoderKL.from_pretrained(\n",
+ " pretrained_model_name_or_path, subfolder=\"vae\"\n",
+ ")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n",
@@ -67,7 +71,9 @@
"metadata": {},
"outputs": [],
"source": [
- "unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
+ "unet = UNet2DConditionModel.from_pretrained(\n",
+ " '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
+ ")"
]
},
{
@@ -111,8 +117,8 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n",
- "prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
- "gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
+ "prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
+ "gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"\n",
@@ -160,7 +166,7 @@
"metadata": {},
"outputs": [],
"source": [
- "diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
+ "diffusers.utils.make_image_grid(images, 4, len(images)//4)"
]
},
{
diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
index 57910f969876..070cdad15564 100644
--- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
+++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -15,8 +15,8 @@
# limitations under the License.
"""
-Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
-Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
+ Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
+ Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
import argparse
diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
index 57c555e43fd8..0f507b26d6a8 100644
--- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
+++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",")
- assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
- "Instance data dir and prompt inputs are not of the same length."
- )
+ assert all(
+ x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
+ ), "Instance data dir and prompt inputs are not of the same length."
if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts
- assert num_of_validation_prompts == len(negative_validation_prompts), (
- "The length of negative prompts for validation is greater than the number of validation prompts."
- )
+ assert num_of_validation_prompts == len(
+ negative_validation_prompts
+ ), "The length of negative prompts for validation is greater than the number of validation prompts."
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
index 75dcfccbd5b8..19432142f541 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -830,9 +830,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
- orig_embeds_params[index_no_updates]
- )
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
index a881b06a94dc..7f5dc8ece9fc 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
+++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -886,9 +886,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
- orig_embeds_params[index_no_updates]
- )
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py
index f825719a1364..8f2eb974398d 100644
--- a/examples/research_projects/pixart/controlnet_pixart_alpha.py
+++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py
@@ -8,7 +8,6 @@
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
-from diffusers.utils.torch_utils import is_torch_version
class PixArtControlNetAdapterBlock(nn.Module):
@@ -151,10 +150,6 @@ def __init__(
self.transformer = transformer
self.controlnet = controlnet
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -220,18 +215,8 @@ def forward(
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -239,7 +224,6 @@ def custom_forward(*inputs):
timestep,
cross_attention_kwargs,
None,
- **ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index 51668a61cdc2..19c1f30d82da 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -663,7 +663,8 @@ def check_inputs(
self.check_image(image, prompt, prompt_embeds)
else:
raise ValueError(
- f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
+ f"You have passed a list of images of length {len(image_pair)}."
+ f"Make sure the list size equals to two."
)
# Check `controlnet_conditioning_scale`
diff --git a/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
index 6ae1a9a6c611..9719585d3dfb 100644
--- a/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
+++ b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
@@ -173,7 +173,7 @@ def print_loss_closure(step, loss):
if not dataloader_exception:
xm.wait_device_ops()
total_time = time.time() - last_time
- print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
+ print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
else:
print("dataloader exception happen, skip result")
return
@@ -622,7 +622,7 @@ def collate_fn(examples):
num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal():
print("***** Running training *****")
- print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
+ print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
)
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
index 043f913893b1..26caba5a42c1 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -1057,7 +1057,7 @@ def load_model_hook(models, input_dir):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
- f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
index 393f991387d6..410cd74a5b7b 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -1021,7 +1021,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
index 01ef67a55da4..c02a59a0077a 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
-# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
+# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
@@ -1336,7 +1336,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
index c87f50e27245..abc439912664 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
@@ -750,7 +750,7 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
index ebb9b129db7e..f5bee58d4534 100644
--- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
+++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
@@ -765,7 +765,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index e7f2f5c4c881..82c395c685f8 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -515,10 +515,6 @@ def main():
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
- # Freeze the unet parameters before adding adapters
- for param in unet.parameters():
- param.requires_grad_(False)
-
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index e0408de4cfd5..f71e4a71bb90 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -767,7 +767,7 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 51e220828cdf..757a12045f10 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -910,9 +910,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
- orig_embeds_params[index_no_updates]
- )
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index f32c729195b0..11463943c448 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -965,12 +965,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
- orig_embeds_params[index_no_updates]
- )
- accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
- orig_embeds_params_2[index_no_updates_2]
- )
+ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
+ index_no_updates_2
+ ] = orig_embeds_params_2[index_no_updates_2]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py
index d13e102e7816..aa5d4c67b642 100644
--- a/examples/vqgan/test_vqgan.py
+++ b/examples/vqgan/test_vqgan.py
@@ -177,7 +177,7 @@ def test_vqmodel_checkpointing(self):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
- --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
+ --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--output_dir {tmpdir}
--seed=0
""".split()
@@ -262,7 +262,7 @@ def test_vqmodel_checkpointing_use_ema(self):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
- --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
+ --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--output_dir {tmpdir}
--use_ema
--seed=0
@@ -377,7 +377,7 @@ def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoi
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
--checkpointing_steps=2
- --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
+ --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--checkpoints_total_limit=2
--seed=0
""".split()
diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py
index 33d234da52d7..992722fa7a78 100644
--- a/examples/vqgan/train_vqgan.py
+++ b/examples/vqgan/train_vqgan.py
@@ -653,15 +653,15 @@ def main():
try:
# Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1]
- assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
- f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
- )
+ assert isinstance(
+ timm_centercrop_transform, transforms.CenterCrop
+ ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization
timm_model_normalization = timm_transform.transforms[-1]
- assert isinstance(timm_model_normalization, transforms.Normalize), (
- f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
- )
+ assert isinstance(
+ timm_model_normalization, transforms.Normalize
+ ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
except AssertionError as e:
raise NotImplementedError(e)
# Enable flash attention if asked
diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py
index ddd1bf508b6d..21be29dfdb99 100644
--- a/scripts/convert_amused.py
+++ b/scripts/convert_amused.py
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
# assert (old_output == new_output).all()
print("skipping full vae equivalence check")
- print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
+ print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
return new_vae
diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py
index 2b918280ca05..0f8b4ddca8ef 100644
--- a/scripts/convert_consistency_to_diffusers.py
+++ b/scripts/convert_consistency_to_diffusers.py
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
- old_prefix = f"output_blocks.{current_layer - 1}.1"
+ old_prefix = f"output_blocks.{current_layer-1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1):
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
- old_prefix = f"output_blocks.{current_layer - 1}.2"
+ old_prefix = f"output_blocks.{current_layer-1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
diff --git a/scripts/convert_dance_diffusion_to_diffusers.py b/scripts/convert_dance_diffusion_to_diffusers.py
index 3d64a77fae7d..ce69bfe2bfc8 100755
--- a/scripts/convert_dance_diffusion_to_diffusers.py
+++ b/scripts/convert_dance_diffusion_to_diffusers.py
@@ -260,9 +260,9 @@ def main(args):
model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path):
- assert model_name == args.model_path, (
- f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
- )
+ assert (
+ model_name == args.model_path
+ ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"]
@@ -289,9 +289,9 @@ def main(args):
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items():
- assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
- f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
- )
+ assert (
+ diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
+ ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
if key == "time_proj.weight":
value = value.squeeze()
diff --git a/scripts/convert_diffusers_to_original_sdxl.py b/scripts/convert_diffusers_to_original_sdxl.py
index 1aa792b3f06a..648d0376f72e 100644
--- a/scripts/convert_diffusers_to_original_sdxl.py
+++ b/scripts/convert_diffusers_to_original_sdxl.py
@@ -52,18 +52,18 @@
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i > 0:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(4):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i < 2:
@@ -75,12 +75,12 @@
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
@@ -89,7 +89,7 @@
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2 * j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -137,20 +137,20 @@ def convert_unet_state_dict(unet_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"up.{3 - i}.upsample."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
- sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
- sd_mid_res_prefix = f"mid.block_{i + 1}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py
index 049dda7d42a7..d1b7df070c43 100644
--- a/scripts/convert_diffusers_to_original_stable_diffusion.py
+++ b/scripts/convert_diffusers_to_original_stable_diffusion.py
@@ -47,36 +47,36 @@
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
@@ -85,7 +85,7 @@
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2 * j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -133,20 +133,20 @@ def convert_unet_state_dict(unet_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"up.{3 - i}.upsample."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
- sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
- sd_mid_res_prefix = f"mid.block_{i + 1}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
diff --git a/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
index 5cef46c98983..1c8383690890 100644
--- a/scripts/convert_hunyuandit_controlnet_to_diffusers.py
+++ b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
@@ -21,9 +21,9 @@ def main(args):
model_config = HunyuanDiT2DControlNetModel.load_config(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
)
- model_config["use_style_cond_and_image_meta_size"] = (
- args.use_style_cond_and_image_meta_size
- ) ### version <= v1.1: True; version >= v1.2: False
+ model_config[
+ "use_style_cond_and_image_meta_size"
+ ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
print(model_config)
for key in state_dict:
diff --git a/scripts/convert_hunyuandit_to_diffusers.py b/scripts/convert_hunyuandit_to_diffusers.py
index 65fcccb22a1a..da3af8333ee3 100644
--- a/scripts/convert_hunyuandit_to_diffusers.py
+++ b/scripts/convert_hunyuandit_to_diffusers.py
@@ -13,14 +13,15 @@ def main(args):
state_dict = state_dict[args.load_key]
except KeyError:
raise KeyError(
- f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
+ f"{args.load_key} not found in the checkpoint."
+ f"Please load from the following keys:{state_dict.keys()}"
)
device = "cuda"
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
- model_config["use_style_cond_and_image_meta_size"] = (
- args.use_style_cond_and_image_meta_size
- ) ### version <= v1.1: True; version >= v1.2: False
+ model_config[
+ "use_style_cond_and_image_meta_size"
+ ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
# input_size -> sample_size, text_dim -> cross_attention_dim
for key in state_dict:
diff --git a/scripts/convert_k_upscaler_to_diffusers.py b/scripts/convert_k_upscaler_to_diffusers.py
index cff845ef8099..62abedd73785 100644
--- a/scripts/convert_k_upscaler_to_diffusers.py
+++ b/scripts/convert_k_upscaler_to_diffusers.py
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}"
- cross_attention_prefix = f"{block_prefix}.{idx}"
+ cross_attention_prefix = f"{block_prefix}.{idx }"
cross_attention_index = 1 if not attention.add_self_attention else 2
idx = (
n * attention_idx + cross_attention_index
if block_type == "up"
else n * attention_idx + cross_attention_index + 1
)
- cross_attention_prefix = f"{block_prefix}.{idx}"
+ cross_attention_prefix = f"{block_prefix}.{idx }"
diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint(
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
block_out_channels = original_config["channels"]
- assert len(set(original_config["depths"])) == 1, (
- "UNet2DConditionModel currently do not support blocks with different number of layers"
- )
+ assert (
+ len(set(original_config["depths"])) == 1
+ ), "UNet2DConditionModel currently do not support blocks with different number of layers"
layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"]
diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py
index 64e4f69eac17..9727deeb6b0c 100644
--- a/scripts/convert_mochi_to_diffusers.py
+++ b/scripts/convert_mochi_to_diffusers.py
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i + 1}.stack.0.weight"
+ f"blocks.0.{i+1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i + 1}.stack.0.bias"
+ f"blocks.0.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i + 1}.stack.2.weight"
+ f"blocks.0.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i + 1}.stack.2.bias"
+ f"blocks.0.{i+1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i + 1}.stack.3.weight"
+ f"blocks.0.{i+1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i + 1}.stack.3.bias"
+ f"blocks.0.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i + 1}.stack.5.weight"
+ f"blocks.0.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i + 1}.stack.5.bias"
+ f"blocks.0.{i+1}.stack.5.bias"
)
# Convert up_blocks (MochiUpBlock3D)
@@ -197,35 +197,33 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
for block in range(3):
for i in range(down_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
+ f"blocks.{block+1}.blocks.{i}.stack.0.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
+ f"blocks.{block+1}.blocks.{i}.stack.0.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
+ f"blocks.{block+1}.blocks.{i}.stack.2.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
+ f"blocks.{block+1}.blocks.{i}.stack.2.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
+ f"blocks.{block+1}.blocks.{i}.stack.3.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
+ f"blocks.{block+1}.blocks.{i}.stack.3.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
+ f"blocks.{block+1}.blocks.{i}.stack.5.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
+ f"blocks.{block+1}.blocks.{i}.stack.5.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.proj.weight"
- )
- new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
- f"blocks.{block + 1}.proj.bias"
+ f"blocks.{block+1}.proj.weight"
)
+ new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
@@ -269,133 +267,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i + 1}.stack.0.weight"
+ f"layers.{i+1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i + 1}.stack.0.bias"
+ f"layers.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i + 1}.stack.2.weight"
+ f"layers.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i + 1}.stack.2.bias"
+ f"layers.{i+1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i + 1}.stack.3.weight"
+ f"layers.{i+1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i + 1}.stack.3.bias"
+ f"layers.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i + 1}.stack.5.weight"
+ f"layers.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i + 1}.stack.5.bias"
+ f"layers.{i+1}.stack.5.bias"
)
# Convert down_blocks (MochiDownBlock3D)
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
for block in range(3):
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.0.weight"
+ f"layers.{block+4}.layers.0.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.0.bias"
+ f"layers.{block+4}.layers.0.bias"
)
for i in range(down_block_layers[block]):
# Convert resnets
- new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
- encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
- )
+ new_state_dict[
+ f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
+ ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
+ f"layers.{block+4}.layers.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
+ f"layers.{block+4}.layers.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
- )
- new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
- encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
+ f"layers.{block+4}.layers.{i+1}.stack.2.bias"
)
+ new_state_dict[
+ f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
+ ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
+ f"layers.{block+4}.layers.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
+ f"layers.{block+4}.layers.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
+ f"layers.{block+4}.layers.{i+1}.stack.5.bias"
)
# Convert attentions
- qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
+ qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
+ f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
+ f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
+ f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
+ f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
)
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
# Convert resnets
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i + 7}.stack.0.weight"
+ f"layers.{i+7}.stack.0.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i + 7}.stack.0.bias"
+ f"layers.{i+7}.stack.0.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i + 7}.stack.2.weight"
+ f"layers.{i+7}.stack.2.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i + 7}.stack.2.bias"
+ f"layers.{i+7}.stack.2.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i + 7}.stack.3.weight"
+ f"layers.{i+7}.stack.3.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i + 7}.stack.3.bias"
+ f"layers.{i+7}.stack.3.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i + 7}.stack.5.weight"
+ f"layers.{i+7}.stack.5.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i + 7}.stack.5.bias"
+ f"layers.{i+7}.stack.5.bias"
)
# Convert attentions
- qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
+ qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
- f"layers.{i + 7}.attn_block.attn.out.weight"
+ f"layers.{i+7}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
- f"layers.{i + 7}.attn_block.attn.out.bias"
+ f"layers.{i+7}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i + 7}.attn_block.norm.weight"
+ f"layers.{i+7}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i + 7}.attn_block.norm.bias"
+ f"layers.{i+7}.attn_block.norm.bias"
)
# Convert output layers
diff --git a/scripts/convert_original_audioldm2_to_diffusers.py b/scripts/convert_original_audioldm2_to_diffusers.py
index 2c0695ce5595..1dc7d739ea76 100644
--- a/scripts/convert_original_audioldm2_to_diffusers.py
+++ b/scripts/convert_original_audioldm2_to_diffusers.py
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py
index 44183f1aea29..4f8e4f8f9f80 100644
--- a/scripts/convert_original_audioldm_to_diffusers.py
+++ b/scripts/convert_original_audioldm_to_diffusers.py
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_original_musicldm_to_diffusers.py b/scripts/convert_original_musicldm_to_diffusers.py
index 00836fde2592..61e5d16eea9e 100644
--- a/scripts/convert_original_musicldm_to_diffusers.py
+++ b/scripts/convert_original_musicldm_to_diffusers.py
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py
index b33c8b0608e7..a0f9d0f87d90 100644
--- a/scripts/convert_stable_audio.py
+++ b/scripts/convert_stable_audio.py
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
# get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
- new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
+ new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
if "encoder" in new_key:
for i in range(3):
- new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
- new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
- new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
+ new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
+ new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
+ new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
else:
for i in range(2, 5):
- new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
- new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
- new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
+ new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
+ new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
+ new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1:
- new_key = new_key.replace(f"block.{idx - 1}", "snake1")
+ new_key = new_key.replace(f"block.{idx-1}", "snake1")
elif idx == num_autoencoder_layers + 2:
- new_key = new_key.replace(f"block.{idx - 1}", "conv2")
+ new_key = new_key.replace(f"block.{idx-1}", "conv2")
else:
new_key = new_key
diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py
index e46410ccb3bd..3243ce294b26 100644
--- a/scripts/convert_svd_to_diffusers.py
+++ b/scripts/convert_svd_to_diffusers.py
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
- unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
- )
+ new_checkpoint[
+ f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
+ ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
if len(attentions):
paths = renew_attention_paths(attentions)
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
- new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
- unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
- )
+ new_checkpoint[
+ f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
+ ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py
index fe62d18faff0..7da6b4094986 100644
--- a/scripts/convert_vq_diffusion_to_diffusers.py
+++ b/scripts/convert_vq_diffusion_to_diffusers.py
@@ -51,9 +51,9 @@
def vqvae_model_from_original_config(original_config):
- assert original_config["target"] in PORTED_VQVAES, (
- f"{original_config['target']} has not yet been ported to diffusers."
- )
+ assert (
+ original_config["target"] in PORTED_VQVAES
+ ), f"{original_config['target']} has not yet been ported to diffusers."
original_config = original_config["params"]
@@ -464,15 +464,15 @@ def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_p
def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config
):
- assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
- f"{original_diffusion_config['target']} has not yet been ported to diffusers."
- )
- assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
- f"{original_transformer_config['target']} has not yet been ported to diffusers."
- )
- assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
- f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
- )
+ assert (
+ original_diffusion_config["target"] in PORTED_DIFFUSIONS
+ ), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
+ assert (
+ original_transformer_config["target"] in PORTED_TRANSFORMERS
+ ), f"{original_transformer_config['target']} has not yet been ported to diffusers."
+ assert (
+ original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
+ ), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"]
From 84115dc1df74781698505cf3e73f7f0470a19454 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 20:35:15 +0800
Subject: [PATCH 24/68] reset
---
tests/models/test_modeling_common.py | 12 ++---
.../test_models_transformer_sd3.py | 12 ++---
.../unets/test_models_unet_2d_condition.py | 36 +++++++-------
tests/others/test_image_processor.py | 30 ++++++------
tests/pipelines/amused/test_amused.py | 3 +-
tests/pipelines/amused/test_amused_img2img.py | 3 +-
tests/pipelines/amused/test_amused_inpaint.py | 3 +-
.../aura_flow/test_pipeline_aura_flow.py | 24 +++++-----
.../blipdiffusion/test_blipdiffusion.py | 6 +--
tests/pipelines/cogvideo/test_cogvideox.py | 24 +++++-----
.../cogvideo/test_cogvideox_fun_control.py | 24 +++++-----
.../cogvideo/test_cogvideox_image2video.py | 24 +++++-----
.../cogvideo/test_cogvideox_video2video.py | 24 +++++-----
.../test_controlnet_blip_diffusion.py | 6 +--
.../controlnet_flux/test_controlnet_flux.py | 6 +--
.../test_controlnet_flux_img2img.py | 24 +++++-----
.../test_controlnet_hunyuandit.py | 6 +--
.../test_controlnet_inpaint_sd3.py | 6 +--
.../controlnet_sd3/test_controlnet_sd3.py | 6 +--
tests/pipelines/dit/test_dit.py | 3 +-
tests/pipelines/flux/test_pipeline_flux.py | 24 +++++-----
.../flux/test_pipeline_flux_control.py | 24 +++++-----
.../test_pipeline_flux_control_inpaint.py | 24 +++++-----
.../pipelines/hunyuan_dit/test_hunyuan_dit.py | 24 +++++-----
tests/pipelines/kandinsky/test_kandinsky.py | 12 ++---
.../kandinsky/test_kandinsky_combined.py | 36 +++++++-------
.../kandinsky/test_kandinsky_img2img.py | 16 +++----
.../kandinsky/test_kandinsky_inpaint.py | 14 +++---
.../pipelines/kandinsky2_2/test_kandinsky.py | 12 ++---
.../kandinsky2_2/test_kandinsky_combined.py | 36 +++++++-------
.../kandinsky2_2/test_kandinsky_controlnet.py | 12 ++---
.../test_kandinsky_controlnet_img2img.py | 14 +++---
.../kandinsky2_2/test_kandinsky_img2img.py | 14 +++---
.../kandinsky2_2/test_kandinsky_inpaint.py | 14 +++---
tests/pipelines/kandinsky3/test_kandinsky3.py | 6 +--
.../kandinsky3/test_kandinsky3_img2img.py | 6 +--
tests/pipelines/pag/test_pag_animatediff.py | 6 +--
tests/pipelines/pag/test_pag_controlnet_sd.py | 6 +--
.../pag/test_pag_controlnet_sd_inpaint.py | 6 +--
.../pipelines/pag/test_pag_controlnet_sdxl.py | 6 +--
.../pag/test_pag_controlnet_sdxl_img2img.py | 6 +--
tests/pipelines/pag/test_pag_hunyuan_dit.py | 24 +++++-----
tests/pipelines/pag/test_pag_kolors.py | 6 +--
tests/pipelines/pag/test_pag_pixart_sigma.py | 6 +--
tests/pipelines/pag/test_pag_sana.py | 6 +--
tests/pipelines/pag/test_pag_sd.py | 18 +++----
tests/pipelines/pag/test_pag_sd3.py | 30 ++++++------
tests/pipelines/pag/test_pag_sd3_img2img.py | 18 +++----
tests/pipelines/pag/test_pag_sd_img2img.py | 18 +++----
tests/pipelines/pag/test_pag_sd_inpaint.py | 12 ++---
tests/pipelines/pag/test_pag_sdxl.py | 18 +++----
tests/pipelines/pag/test_pag_sdxl_img2img.py | 18 +++----
tests/pipelines/pag/test_pag_sdxl_inpaint.py | 18 +++----
tests/pipelines/pixart_sigma/test_pixart.py | 24 +++++-----
tests/pipelines/shap_e/test_shap_e_img2img.py | 2 +-
.../test_stable_cascade_combined.py | 12 ++---
.../stable_diffusion/test_stable_diffusion.py | 48 +++++++++----------
.../test_pipeline_stable_diffusion_3.py | 24 +++++-----
.../test_stable_diffusion_xl.py | 30 ++++++------
.../test_stable_diffusion_xl_inpaint.py | 12 ++---
tests/pipelines/test_pipelines.py | 24 +++++-----
tests/pipelines/test_pipelines_common.py | 48 +++++++++----------
.../wuerstchen/test_wuerstchen_combined.py | 12 ++---
tests/schedulers/test_scheduler_dpm_multi.py | 6 +--
tests/schedulers/test_scheduler_dpm_single.py | 6 +--
.../test_scheduler_edm_dpmsolver_multistep.py | 6 +--
tests/schedulers/test_scheduler_euler.py | 12 ++---
tests/schedulers/test_scheduler_heun.py | 6 +--
.../single_file/single_file_testing_utils.py | 24 +++++-----
.../test_model_autoencoder_dc_single_file.py | 18 +++----
.../test_model_controlnet_single_file.py | 6 +--
...test_model_flux_transformer_single_file.py | 6 +--
.../test_model_motion_adapter_single_file.py | 24 +++++-----
.../test_model_sd_cascade_unet_single_file.py | 24 +++++-----
.../single_file/test_model_vae_single_file.py | 6 +--
utils/log_reports.py | 2 +-
utils/update_metadata.py | 3 +-
77 files changed, 591 insertions(+), 586 deletions(-)
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 9b53b44bb9bf..c3cb082b0ef1 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -292,9 +292,9 @@ def test_one_request_upon_cached(self):
)
download_requests = [r.method for r in m.request_history]
- assert download_requests.count("HEAD") == 3, (
- "3 HEAD requests one for config, one for model, and one for shard index file."
- )
+ assert (
+ download_requests.count("HEAD") == 3
+ ), "3 HEAD requests one for config, one for model, and one for shard index file."
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
@@ -306,9 +306,9 @@ def test_one_request_upon_cached(self):
)
cache_requests = [r.method for r in m.request_history]
- assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
- "We should call only `model_info` to check for commit hash and knowing if shard index is present."
- )
+ assert (
+ "HEAD" == cache_requests[0] and len(cache_requests) == 2
+ ), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py
index aef08c1f3b68..2531381dc7c8 100644
--- a/tests/models/transformers/test_models_transformer_sd3.py
+++ b/tests/models/transformers/test_models_transformer_sd3.py
@@ -91,9 +91,9 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()
- assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
- "xformers is not enabled"
- )
+ assert (
+ model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
+ ), "xformers is not enabled"
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
@@ -165,9 +165,9 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()
- assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
- "xformers is not enabled"
- )
+ assert (
+ model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
+ ), "xformers is not enabled"
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index 804b01a26971..57f6e4ee440b 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -651,22 +651,22 @@ def test_model_xattn_mask(self, mask_dtype):
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
- assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
- "a 'keep all' mask should give the same result as no mask"
- )
+ assert full_cond_keepallmask_out.allclose(
+ full_cond_out, rtol=1e-05, atol=1e-05
+ ), "a 'keep all' mask should give the same result as no mask"
trunc_cond = cond[:, :-1, :]
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
- assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
- "discarding the last token from our cond should change the result"
- )
+ assert not trunc_cond_out.allclose(
+ full_cond_out, rtol=1e-05, atol=1e-05
+ ), "discarding the last token from our cond should change the result"
batch, tokens, _ = cond.shape
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
- assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), (
- "masking the last token from our cond should be equivalent to truncating that token out of the condition"
- )
+ assert masked_cond_out.allclose(
+ trunc_cond_out, rtol=1e-05, atol=1e-05
+ ), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
@@ -694,9 +694,9 @@ def test_model_xattn_padding(self):
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
- assert trunc_mask_out.allclose(keeplast_out), (
- "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
- )
+ assert trunc_mask_out.allclose(
+ keeplast_out
+ ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing
@@ -1111,12 +1111,12 @@ def test_load_attn_procs_raise_warning(self):
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
- assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
- "LoRA injected UNet should produce different results."
- )
- assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
- "Loading from a saved checkpoint should produce identical results."
- )
+ assert not torch.allclose(
+ non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
+ ), "LoRA injected UNet should produce different results."
+ assert torch.allclose(
+ lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
+ ), "Loading from a saved checkpoint should produce identical results."
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py
index 071194c59ead..3397ca9e394a 100644
--- a/tests/others/test_image_processor.py
+++ b/tests/others/test_image_processor.py
@@ -65,9 +65,9 @@ def test_vae_image_processor_pt(self):
)
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
- assert np.abs(in_np - out_np).max() < 1e-6, (
- f"decoded output does not match input for output_type {output_type}"
- )
+ assert (
+ np.abs(in_np - out_np).max() < 1e-6
+ ), f"decoded output does not match input for output_type {output_type}"
def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -78,9 +78,9 @@ def test_vae_image_processor_np(self):
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
- assert np.abs(in_np - out_np).max() < 1e-6, (
- f"decoded output does not match input for output_type {output_type}"
- )
+ assert (
+ np.abs(in_np - out_np).max() < 1e-6
+ ), f"decoded output does not match input for output_type {output_type}"
def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -93,9 +93,9 @@ def test_vae_image_processor_pil(self):
for i, o in zip(input_pil, out):
in_np = np.array(i)
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
- assert np.abs(in_np - out_np).max() < 1e-6, (
- f"decoded output does not match input for output_type {output_type}"
- )
+ assert (
+ np.abs(in_np - out_np).max() < 1e-6
+ ), f"decoded output does not match input for output_type {output_type}"
def test_preprocess_input_3d(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
@@ -293,9 +293,9 @@ def test_vae_image_processor_resize_pt(self):
scale = 2
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
exp_pt_shape = (b, c, h // scale, w // scale)
- assert out_pt.shape == exp_pt_shape, (
- f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
- )
+ assert (
+ out_pt.shape == exp_pt_shape
+ ), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
def test_vae_image_processor_resize_np(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
@@ -305,6 +305,6 @@ def test_vae_image_processor_resize_np(self):
input_np = self.to_np(input_pt)
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
exp_np_shape = (b, h // scale, w // scale, c)
- assert out_np.shape == exp_np_shape, (
- f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
- )
+ assert (
+ out_np.shape == exp_np_shape
+ ), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py
index 4d950f90f773..2dfc36a6ce45 100644
--- a/tests/pipelines/amused/test_amused.py
+++ b/tests/pipelines/amused/test_amused.py
@@ -125,7 +125,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
+ def test_inference_batch_single_identical(self):
+ ...
@slow
diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py
index 942735f15707..2699bbe7f56f 100644
--- a/tests/pipelines/amused/test_amused_img2img.py
+++ b/tests/pipelines/amused/test_amused_img2img.py
@@ -126,7 +126,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
+ def test_inference_batch_single_identical(self):
+ ...
@slow
diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py
index 541b988f1798..645379a7eab1 100644
--- a/tests/pipelines/amused/test_amused_inpaint.py
+++ b/tests/pipelines/amused/test_amused_inpaint.py
@@ -130,7 +130,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
+ def test_inference_batch_single_identical(self):
+ ...
@slow
diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
index 21b436135725..bee905f9ae13 100644
--- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
+++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
@@ -139,9 +139,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -155,15 +155,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
@unittest.skip("xformers attention processor does not exist for AuraFlow")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
index 2a86a38b27ae..6d422745ce5a 100644
--- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py
+++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
@@ -195,6 +195,6 @@ def test_blipdiffusion(self):
[0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index 81591cd0874d..750f20f8fbe5 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -295,9 +295,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -311,15 +311,15 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
@slow
diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
index f5123d385749..c936bad4c3d5 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
@@ -298,9 +298,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -314,12 +314,12 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
index ec4e51bd1bad..cac47f1a83d4 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
@@ -317,9 +317,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -333,15 +333,15 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
@slow
diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
index b1ac8cbd90ed..4d836cb5e2a4 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
@@ -298,9 +298,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -314,12 +314,12 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
index 94c87433d882..b4d3e3aaa8ed 100644
--- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
+++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
@@ -219,6 +219,6 @@ def test_blipdiffusion_controlnet(self):
assert image.shape == (1, 16, 16, 4)
expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index 10ad4ff1580a..8b9852dbec6e 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -174,9 +174,9 @@ def test_controlnet_flux(self):
[0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f"Expected: {expected_slice}, got: {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
index 6c0d947c5266..02270d7fbd00 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
@@ -194,9 +194,9 @@ def test_fused_qkv_projections(self):
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -210,15 +210,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
index ea611fb68acb..5c6054ccb605 100644
--- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
+++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
@@ -157,9 +157,9 @@ def test_controlnet_hunyuandit(self):
[0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f"Expected: {expected_slice}, got: {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
index d9f5dcad7d61..2cd57ce56d52 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
@@ -194,9 +194,9 @@ def test_controlnet_inpaint_sd3(self):
[0.51708984, 0.7421875, 0.4580078, 0.6435547, 0.65625, 0.43603516, 0.5151367, 0.65722656, 0.60839844]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f"Expected: {expected_slice}, got: {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 7e5bd28bebb4..e1894d555c3c 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -200,9 +200,9 @@ def run_pipe(self, components, use_sd35=False):
else:
expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f"Expected: {expected_slice}, got: {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
def test_controlnet_sd3(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py
index 18732c0058de..30883ac4a63d 100644
--- a/tests/pipelines/dit/test_dit.py
+++ b/tests/pipelines/dit/test_dit.py
@@ -149,7 +149,8 @@ def test_dit_512(self):
for word, image in zip(words, images):
expected_image = load_numpy(
- f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ f"/dit/{word}_512.npy"
)
assert np.abs((expected_image - image).max()) < 1e-1
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index 7417a59cd9cf..bab343a5954c 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -173,9 +173,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -189,15 +189,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py
index 3ffad261a8a5..7fdb19327213 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control.py
@@ -163,9 +163,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -179,15 +179,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
index 37ebf4493595..c5ff02a525f2 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -174,9 +174,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -190,15 +190,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
index fd203b3758be..6c9117a55c36 100644
--- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
+++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
@@ -270,9 +270,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -288,15 +288,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
@slow
diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py
index 8d4ae7046a28..1a13ec75d082 100644
--- a/tests/pipelines/kandinsky/test_kandinsky.py
+++ b/tests/pipelines/kandinsky/test_kandinsky.py
@@ -239,12 +239,12 @@ def test_kandinsky(self):
expected_slice = np.array([1.0000, 1.0000, 0.2766, 1.0000, 0.5447, 0.1737, 1.0000, 0.4316, 0.9024])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index c1a1b3ab4522..3c8767a708d4 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -98,12 +98,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.2893, 0.1464, 0.4603, 0.3529, 0.4612, 0.7701, 0.4027, 0.3051, 0.5155])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
@@ -206,12 +206,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.4852, 0.4136, 0.4539, 0.4781, 0.4680, 0.5217, 0.4973, 0.4089, 0.4977])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
@@ -318,12 +318,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
index 0d5c74ecd0e6..23f13ffee223 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
@@ -260,12 +260,12 @@ def test_kandinsky_img2img(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5816, 0.5872, 0.4634, 0.5982, 0.4767, 0.4710, 0.4669, 0.4717, 0.4966])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
@@ -320,7 +320,7 @@ def test_kandinsky_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
@@ -386,7 +386,7 @@ def test_kandinsky_img2img_ddpm(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/frog.png"
)
prompt = "A red cartoon frog, 4k"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
index 8cdc93a7c82e..ebb1a4d88739 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
@@ -255,12 +255,12 @@ def test_kandinsky_inpaint(self):
expected_slice = np.array([0.8222, 0.8896, 0.4373, 0.8088, 0.4905, 0.2609, 0.6816, 0.4291, 0.5129])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -318,7 +318,7 @@ def test_kandinsky_inpaint(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
mask[:250, 250:-250] = 1
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py
index 728a1d67a464..cbd9166efada 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky.py
@@ -208,13 +208,13 @@ def test_kandinsky(self):
expected_slice = np.array([0.3420, 0.9505, 0.3919, 1.0000, 0.5188, 0.3109, 0.6139, 0.5624, 0.6811])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
index 413e84dd8d15..bbf2f08a7b08 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
@@ -103,12 +103,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.3076, 0.2729, 0.5668, 0.0522, 0.3384, 0.7028, 0.4908, 0.3659, 0.6243])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
@@ -227,12 +227,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.4445, 0.4287, 0.4596, 0.3919, 0.3730, 0.5039, 0.4834, 0.4269, 0.5521])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
@@ -350,12 +350,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.5039, 0.4926, 0.4898, 0.4978, 0.4838, 0.4942, 0.4738, 0.4702, 0.4816])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
index 10a95d6177b2..1f3219e0d69e 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
@@ -210,13 +210,13 @@ def test_kandinsky_controlnet(self):
[0.6959826, 0.868279, 0.7558092, 0.68769467, 0.85805804, 0.65977496, 0.44885302, 0.5959111, 0.4251595]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
index 58fbbecc0569..20944aa3d6f8 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
@@ -218,12 +218,12 @@ def test_kandinsky_controlnet_img2img(self):
expected_slice = np.array(
[0.54985034, 0.55509365, 0.52561504, 0.5570494, 0.5593818, 0.5263979, 0.50285643, 0.5069846, 0.51196736]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1.75e-3)
@@ -254,7 +254,7 @@ def test_kandinsky_controlnet_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
)
init_image = init_image.resize((512, 512))
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
index 34f089fcf1e7..26d8b45cf900 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
@@ -226,12 +226,12 @@ def test_kandinsky_img2img(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5712, 0.5443, 0.4725, 0.6195, 0.5184, 0.4651, 0.4473, 0.4590, 0.5016])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=2e-1)
@@ -259,7 +259,7 @@ def test_kandinsky_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
index be2d90ea9c53..25cf4bbed456 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
@@ -233,12 +233,12 @@ def test_kandinsky_inpaint(self):
[0.50775903, 0.49527195, 0.48824543, 0.50192237, 0.48644906, 0.49373814, 0.4780598, 0.47234827, 0.48327848]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -313,7 +313,7 @@ def test_kandinsky_inpaint(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
mask[:250, 250:-250] = 1
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py
index e80d5c61fd72..941ef9093361 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3.py
@@ -155,9 +155,9 @@ def test_kandinsky3(self):
expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
index 79468077ecff..8c817df32e0c 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
@@ -180,9 +180,9 @@ def test_kandinsky3_img2img(self):
[0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py
index 902958ce4121..59ce9cc0a987 100644
--- a/tests/pipelines/pag/test_pag_animatediff.py
+++ b/tests/pipelines/pag/test_pag_animatediff.py
@@ -450,9 +450,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).frames[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd.py b/tests/pipelines/pag/test_pag_controlnet_sd.py
index e59b6e676676..8a7eb6f0c675 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd.py
@@ -171,9 +171,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
index 969737f22ee4..0a7413e99926 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
@@ -168,9 +168,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl.py b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
index 5323bad37217..6400cc2b7cab 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
@@ -189,9 +189,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
index 992de5cdbae8..b02f4d8b4561 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
@@ -191,9 +191,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py
index 26852744f9e0..db0e257760ed 100644
--- a/tests/pipelines/pag/test_pag_hunyuan_dit.py
+++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py
@@ -271,15 +271,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -292,9 +292,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py
index 5c50d8eba5ae..cf9466988d85 100644
--- a/tests/pipelines/pag/test_pag_kolors.py
+++ b/tests/pipelines/pag/test_pag_kolors.py
@@ -138,9 +138,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py
index 072dd80a4da0..7de19e0f00fc 100644
--- a/tests/pipelines/pag/test_pag_pixart_sigma.py
+++ b/tests/pipelines/pag/test_pag_pixart_sigma.py
@@ -120,9 +120,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
out = pipe(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py
index ee1e359383e9..a2c657297860 100644
--- a/tests/pipelines/pag/test_pag_sana.py
+++ b/tests/pipelines/pag/test_pag_sana.py
@@ -268,9 +268,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py
index 711945308d37..17e3f7038439 100644
--- a/tests/pipelines/pag/test_pag_sd.py
+++ b/tests/pipelines/pag/test_pag_sd.py
@@ -155,9 +155,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -322,9 +322,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -339,6 +339,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sd3.py b/tests/pipelines/pag/test_pag_sd3.py
index 5183756913c2..627d613ee20d 100644
--- a/tests/pipelines/pag/test_pag_sd3.py
+++ b/tests/pipelines/pag/test_pag_sd3.py
@@ -203,9 +203,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -219,15 +219,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -240,9 +240,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd3_img2img.py b/tests/pipelines/pag/test_pag_sd3_img2img.py
index 694a86577dbf..bffcd254e2c5 100644
--- a/tests/pipelines/pag/test_pag_sd3_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd3_img2img.py
@@ -148,9 +148,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
@@ -253,9 +253,9 @@ def test_pag_cfg(self):
0.17822266,
]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(
@@ -271,6 +271,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.1508789, 0.16210938, 0.17138672, 0.16210938, 0.17089844, 0.16137695, 0.16235352, 0.16430664, 0.16455078]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py
index d540a2257140..f44204f82486 100644
--- a/tests/pipelines/pag/test_pag_sd_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd_img2img.py
@@ -160,9 +160,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -259,9 +259,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -277,6 +277,6 @@ def test_pag_uncond(self):
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py
index 00d7e9f9c29d..a528b66cc72a 100644
--- a/tests/pipelines/pag/test_pag_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sd_inpaint.py
@@ -296,9 +296,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -313,6 +313,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.3876953, 0.40356445, 0.4934082, 0.39697266, 0.41674805, 0.41015625, 0.375, 0.36914062, 0.40649414]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py
index c2e10a6325b2..589573385677 100644
--- a/tests/pipelines/pag/test_pag_sdxl.py
+++ b/tests/pipelines/pag/test_pag_sdxl.py
@@ -168,9 +168,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -331,9 +331,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.3123679, 0.31725878, 0.32026544, 0.327533, 0.3266391, 0.3303998, 0.33544615, 0.34181812, 0.34102726]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -348,6 +348,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.47400922, 0.48650584, 0.4839625, 0.4724013, 0.4890427, 0.49544555, 0.51707107, 0.54299414, 0.5224372]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py
index 6006f8de0aa8..33bd47bfee10 100644
--- a/tests/pipelines/pag/test_pag_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py
@@ -216,9 +216,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -316,9 +316,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.20301354, 0.21078318, 0.2021082, 0.20277798, 0.20681083, 0.19562206, 0.20121682, 0.21562952, 0.21277016]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -333,6 +333,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.21303111, 0.22188407, 0.2124992, 0.21365267, 0.18823743, 0.17569828, 0.21113116, 0.19419771, 0.18919235]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
index 1ff6c66a8830..8378b07e9f74 100644
--- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
@@ -221,9 +221,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
- f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
- )
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -322,9 +322,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.41385046, 0.39608297, 0.4360491, 0.26872507, 0.32187328, 0.4242474, 0.2603805, 0.34167895, 0.46561807]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -339,6 +339,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.41597816, 0.39302617, 0.44287828, 0.2687074, 0.28315824, 0.40582314, 0.20877528, 0.2380802, 0.39447647]
)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
- f"output is different from expected, {image_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py
index 3208be54f464..6e265b9d5eb8 100644
--- a/tests/pipelines/pixart_sigma/test_pixart.py
+++ b/tests/pipelines/pixart_sigma/test_pixart.py
@@ -328,9 +328,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -344,15 +344,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
@slow
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index 72eee3e35eb1..ac7096874b31 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -266,7 +266,7 @@ def tearDown(self):
def test_shap_e_img2img(self):
input_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/corgi.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/shap_e/corgi.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
index ad09b9ce8292..d256deed376c 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
@@ -198,12 +198,12 @@ def test_stable_cascade(self):
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index 44157d040484..1e700bed03f8 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -288,15 +288,15 @@ def test_stable_diffusion_ays(self):
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
- assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
- "ays timesteps and ays sigmas should have the same outputs"
- )
- assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
- "use ays timesteps should have different outputs"
- )
- assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
- "use ays sigmas should have different outputs"
- )
+ assert (
+ np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
+ ), "ays timesteps and ays sigmas should have the same outputs"
+ assert (
+ np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
+ ), "use ays timesteps should have different outputs"
+ assert (
+ np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
+ ), "use ays sigmas should have different outputs"
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
@@ -729,9 +729,9 @@ def test_freeu_enabled(self):
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
- assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
- "Enabling of FreeU should lead to different results."
- )
+ assert not np.allclose(
+ output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
+ ), "Enabling of FreeU should lead to different results."
def test_freeu_disabled(self):
components = self.get_dummy_components()
@@ -754,9 +754,9 @@ def test_freeu_disabled(self):
prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)
).images
- assert np.allclose(output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]), (
- "Disabling of FreeU should lead to results similar to the default pipeline results."
- )
+ assert np.allclose(
+ output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
+ ), "Disabling of FreeU should lead to results similar to the default pipeline results."
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -779,15 +779,15 @@ def test_fused_qkv_projections(self):
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index b555637ad9bf..df37090eeba2 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -202,9 +202,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -218,15 +218,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
def test_skip_guidance_layers(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index 17dd17ac7f56..f1422022a7aa 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -243,15 +243,15 @@ def test_stable_diffusion_ays(self):
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
- assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
- "ays timesteps and ays sigmas should have the same outputs"
- )
- assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
- "use ays timesteps should have different outputs"
- )
- assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
- "use ays sigmas should have different outputs"
- )
+ assert (
+ np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
+ ), "ays timesteps and ays sigmas should have the same outputs"
+ assert (
+ np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
+ ), "use ays timesteps should have different outputs"
+ assert (
+ np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
+ ), "use ays sigmas should have different outputs"
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
@@ -856,9 +856,9 @@ def new_step(self, *args, **kwargs):
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
latents = pipe_1(**inputs_1).images[0]
- assert expected_steps_1 == done_steps, (
- f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
- )
+ assert (
+ expected_steps_1 == done_steps
+ ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
with self.assertRaises(ValueError) as cm:
inputs_2 = {
@@ -885,9 +885,9 @@ def new_step(self, *args, **kwargs):
pipe_3(**inputs_3).images[0]
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
- assert expected_steps == done_steps, (
- f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
- )
+ assert (
+ expected_steps == done_steps
+ ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
index 551dec8f1cb5..c759f4c112d9 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
@@ -578,9 +578,9 @@ def new_step(self, *args, **kwargs):
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
latents = pipe_1(**inputs_1).images[0]
- assert expected_steps_1 == done_steps, (
- f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
- )
+ assert (
+ expected_steps_1 == done_steps
+ ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
inputs_2 = {
**inputs,
@@ -594,9 +594,9 @@ def new_step(self, *args, **kwargs):
pipe_3(**inputs_3).images[0]
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
- assert expected_steps == done_steps, (
- f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
- )
+ assert (
+ expected_steps == done_steps
+ ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 2df025fc2bc1..6ce7c5d604f4 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -163,9 +163,9 @@ def test_one_request_upon_cached(self):
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 15, "15 calls to files"
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
- assert len(download_requests) == 32, (
- "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
- )
+ assert (
+ len(download_requests) == 32
+ ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
@@ -175,9 +175,9 @@ def test_one_request_upon_cached(self):
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
- assert len(cache_requests) == 2, (
- "We should call only `model_info` to check for _commit hash and `send_telemetry`"
- )
+ assert (
+ len(cache_requests) == 2
+ ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
def test_less_downloads_passed_object(self):
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -213,9 +213,9 @@ def test_less_downloads_passed_object_calls(self):
assert download_requests.count("HEAD") == 13, "13 calls to files"
# 17 - 2 because no call to config or model file for `safety_checker`
assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json"
- assert len(download_requests) == 28, (
- "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
- )
+ assert (
+ len(download_requests) == 28
+ ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
@@ -225,9 +225,9 @@ def test_less_downloads_passed_object_calls(self):
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
- assert len(cache_requests) == 2, (
- "We should call only `model_info` to check for _commit hash and `send_telemetry`"
- )
+ assert (
+ len(cache_requests) == 2
+ ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 41d0f84fec4e..de5faa185c2f 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -184,12 +184,12 @@ def test_freeu(self):
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0]
- assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
- "Enabling of FreeU should lead to different results."
- )
- assert np.allclose(output, output_no_freeu, atol=1e-2), (
- f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
- )
+ assert not np.allclose(
+ output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
+ ), "Enabling of FreeU should lead to different results."
+ assert np.allclose(
+ output, output_no_freeu, atol=1e-2
+ ), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -210,12 +210,12 @@ def test_fused_qkv_projections(self):
and hasattr(component, "original_attn_processors")
and component.original_attn_processors is not None
):
- assert check_qkv_fusion_processors_exist(component), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- )
- assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), (
- "Something wrong with the attention processors concerning the fused QKV projections."
- )
+ assert check_qkv_fusion_processors_exist(
+ component
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ component, component.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
@@ -228,15 +228,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
- "Fusion of QKV projections shouldn't affect the outputs."
- )
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- )
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
- )
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
class IPAdapterTesterMixin:
@@ -861,9 +861,9 @@ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3):
for component in pipe_original.components.values():
if hasattr(component, "attn_processors"):
- assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), (
- "`from_pipe` changed the attention processor in original pipeline."
- )
+ assert all(
+ type(proc) == AttnProcessor for proc in component.attn_processors.values()
+ ), "`from_pipe` changed the attention processor in original pipeline."
@require_accelerator
@require_accelerate_version_greater("0.14.0")
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
index 1c9790807fa8..a0e6e1417e67 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
@@ -191,12 +191,12 @@ def test_wuerstchen(self):
expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py
index 28c354709dc9..55b3202ad0be 100644
--- a/tests/schedulers/test_scheduler_dpm_multi.py
+++ b/tests/schedulers/test_scheduler_dpm_multi.py
@@ -357,9 +357,9 @@ def test_custom_timesteps(self):
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
- assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
- f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
- )
+ assert (
+ torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
+ ), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py
index 0756a5ed71ff..7cbaa5cc5e8d 100644
--- a/tests/schedulers/test_scheduler_dpm_single.py
+++ b/tests/schedulers/test_scheduler_dpm_single.py
@@ -345,9 +345,9 @@ def test_custom_timesteps(self):
lower_order_final=lower_order_final,
final_sigmas_type=final_sigmas_type,
)
- assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
- f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
- )
+ assert (
+ torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
+ ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
index 8525ce61c40d..e97d64ec5f1d 100644
--- a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
+++ b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
@@ -188,9 +188,9 @@ def test_solver_order_and_type(self):
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
- assert not torch.isnan(sample).any(), (
- f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
- )
+ assert (
+ not torch.isnan(sample).any()
+ ), f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py
index 01e173a631cd..4c7e02442cd0 100644
--- a/tests/schedulers/test_scheduler_euler.py
+++ b/tests/schedulers/test_scheduler_euler.py
@@ -245,9 +245,9 @@ def test_custom_timesteps(self):
interpolation_type=interpolation_type,
final_sigmas_type=final_sigmas_type,
)
- assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
- f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
- )
+ assert (
+ torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
+ ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
def test_custom_sigmas(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
@@ -260,9 +260,9 @@ def test_custom_sigmas(self):
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
- assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
- f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
- )
+ assert (
+ torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
+ ), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py
index 90012f5525ab..9e060c6d476f 100644
--- a/tests/schedulers/test_scheduler_heun.py
+++ b/tests/schedulers/test_scheduler_heun.py
@@ -216,9 +216,9 @@ def test_custom_timesteps(self):
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
- assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
- f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
- )
+ assert (
+ torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
+ ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 4e1713c9ceb1..4e7bc0af6842 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -72,9 +72,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
- assert isinstance(component, pipe.components[component_name].__class__), (
- f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
- )
+ assert isinstance(
+ component, pipe.components[component_name].__class__
+ ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
@@ -85,9 +85,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
- assert pipe.components[component_name].config[param_name] == param_value, (
- f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
- )
+ assert (
+ pipe.components[component_name].config[param_name] == param_value
+ ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
@@ -253,9 +253,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
- assert isinstance(component, pipe.components[component_name].__class__), (
- f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
- )
+ assert isinstance(
+ component, pipe.components[component_name].__class__
+ ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
@@ -266,9 +266,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
- assert pipe.components[component_name].config[param_name] == param_value, (
- f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
- )
+ assert (
+ pipe.components[component_name].config[param_name] == param_value
+ ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py
index 31b2eb6e36b0..b1faeb78776b 100644
--- a/tests/single_file/test_model_autoencoder_dc_single_file.py
+++ b/tests/single_file/test_model_autoencoder_dc_single_file.py
@@ -87,9 +87,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter
@@ -106,9 +106,9 @@ def test_single_file_in_type_variant_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
def test_single_file_mix_type_variant_components(self):
repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"
@@ -121,6 +121,6 @@ def test_single_file_mix_type_variant_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
diff --git a/tests/single_file/test_model_controlnet_single_file.py b/tests/single_file/test_model_controlnet_single_file.py
index 3580d73531a3..bfcb802380a6 100644
--- a/tests/single_file/test_model_controlnet_single_file.py
+++ b/tests/single_file/test_model_controlnet_single_file.py
@@ -58,9 +58,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path)
diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py
index bf11faaa9c0e..0ec97db26a9e 100644
--- a/tests/single_file/test_model_flux_transformer_single_file.py
+++ b/tests/single_file/test_model_flux_transformer_single_file.py
@@ -58,9 +58,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
diff --git a/tests/single_file/test_model_motion_adapter_single_file.py b/tests/single_file/test_model_motion_adapter_single_file.py
index a747f16dc1db..b195f25d094b 100644
--- a/tests/single_file/test_model_motion_adapter_single_file.py
+++ b/tests/single_file/test_model_motion_adapter_single_file.py
@@ -40,9 +40,9 @@ def test_single_file_components_version_v1_5(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
def test_single_file_components_version_v1_5_2(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt"
@@ -55,9 +55,9 @@ def test_single_file_components_version_v1_5_2(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
def test_single_file_components_version_v1_5_3(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt"
@@ -70,9 +70,9 @@ def test_single_file_components_version_v1_5_3(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
def test_single_file_components_version_sdxl_beta(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt"
@@ -85,6 +85,6 @@ def test_single_file_components_version_sdxl_beta(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
diff --git a/tests/single_file/test_model_sd_cascade_unet_single_file.py b/tests/single_file/test_model_sd_cascade_unet_single_file.py
index 92b371c3fb41..08b04e3cd7e8 100644
--- a/tests/single_file/test_model_sd_cascade_unet_single_file.py
+++ b/tests/single_file/test_model_sd_cascade_unet_single_file.py
@@ -60,9 +60,9 @@ def test_single_file_components_stage_b(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_components_stage_b_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -77,9 +77,9 @@ def test_single_file_components_stage_b_lite(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_components_stage_c(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -94,9 +94,9 @@ def test_single_file_components_stage_c(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_components_stage_c_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -111,6 +111,6 @@ def test_single_file_components_stage_c_lite(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py
index bba1726ae380..9db4cddb3c9d 100644
--- a/tests/single_file/test_model_vae_single_file.py
+++ b/tests/single_file/test_model_vae_single_file.py
@@ -91,9 +91,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
diff --git a/utils/log_reports.py b/utils/log_reports.py
index 5575c9ba8415..dd1b258519d7 100644
--- a/utils/log_reports.py
+++ b/utils/log_reports.py
@@ -35,7 +35,7 @@ def main(slack_channel_name=None):
if line.get("nodeid", "") != "":
test = line["nodeid"]
if line.get("duration", None) is not None:
- duration = f"{line['duration']:.4f}"
+ duration = f'{line["duration"]:.4f}'
if line.get("outcome", "") == "failed":
section_num_failed += 1
failed.append([test, duration, log.name.split("_")[0]])
diff --git a/utils/update_metadata.py b/utils/update_metadata.py
index 4fde581d4170..a97e65801c5f 100644
--- a/utils/update_metadata.py
+++ b/utils/update_metadata.py
@@ -104,7 +104,8 @@ def update_metadata(commit_sha: str):
if commit_sha is not None:
commit_message = (
- f"Update with commit {commit_sha}\n\nSee: https://github.com/huggingface/diffusers/commit/{commit_sha}"
+ f"Update with commit {commit_sha}\n\nSee: "
+ f"https://github.com/huggingface/diffusers/commit/{commit_sha}"
)
else:
commit_message = "Update"
From 95a103ffb55efdc0fc66cb55db0da949e3cc861f Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 20:41:50 +0800
Subject: [PATCH 25/68] setback
---
src/diffusers/loaders/ip_adapter.py | 6 +-
.../loaders/lora_conversion_utils.py | 66 +++++++++----------
src/diffusers/models/model_loading_utils.py | 2 +-
.../models/transformers/transformer_2d.py | 6 +-
.../controlnet/pipeline_controlnet_inpaint.py | 4 +-
.../pipeline_controlnet_inpaint_sd_xl.py | 6 +-
...pipeline_controlnet_union_inpaint_sd_xl.py | 4 +-
.../pipeline_flux_controlnet_inpainting.py | 4 +-
.../pipelines/flux/pipeline_flux_inpaint.py | 4 +-
.../kandinsky/pipeline_kandinsky_combined.py | 2 +-
.../kandinsky/pipeline_kandinsky_inpaint.py | 2 +-
.../pag/pipeline_pag_controlnet_sd_inpaint.py | 6 +-
.../pipelines/pag/pipeline_pag_sd_inpaint.py | 6 +-
.../pag/pipeline_pag_sd_xl_inpaint.py | 6 +-
.../pipeline_paint_by_example.py | 2 +-
.../pipeline_flax_stable_diffusion_inpaint.py | 2 +-
.../pipeline_onnx_stable_diffusion_inpaint.py | 2 +-
.../pipeline_stable_diffusion_inpaint.py | 6 +-
...eline_stable_diffusion_instruct_pix2pix.py | 2 +-
...ipeline_stable_diffusion_latent_upscale.py | 2 +-
.../pipeline_stable_diffusion_upscale.py | 2 +-
.../pipeline_stable_diffusion_3_inpaint.py | 2 +-
.../pipeline_stable_diffusion_xl_inpaint.py | 6 +-
src/diffusers/quantizers/base.py | 12 ++--
src/diffusers/utils/deprecation_utils.py | 2 +-
src/diffusers/utils/dummy_pt_objects.py | 29 --------
.../dummy_torch_and_transformers_objects.py | 15 -----
src/diffusers/utils/import_utils.py | 2 +-
src/diffusers/utils/logging.py | 3 +-
src/diffusers/utils/state_dict_utils.py | 2 +-
src/diffusers/utils/testing_utils.py | 4 +-
31 files changed, 91 insertions(+), 128 deletions(-)
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index 0870f059e8f0..7b691d1fe16e 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -292,7 +292,8 @@ def set_ip_adapter_scale(self, scale):
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
- f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
+ f"Cannot assign {len(scale_configs)} scale_configs to "
+ f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
@@ -591,7 +592,8 @@ def LinearStrengthModel(start, finish, size):
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
- f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
+ f"Cannot assign {len(scale_configs)} scale_configs to "
+ f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index fecf5170a489..e064aeba43b6 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -177,9 +177,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Store DoRA scale if present.
if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
- unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
- state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
- )
+ unet_state_dict[
+ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -199,13 +199,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
- te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
- state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
- )
+ te_state_dict[
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith("lora_te2_"):
- te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
- state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
- )
+ te2_state_dict[
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Store alpha if present.
if lora_name_alpha in state_dict:
@@ -684,21 +684,21 @@ def swap_scale_shift(weight):
for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in
- converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
- original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
- )
+ converted_state_dict[
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
- original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
- )
+ converted_state_dict[
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
- converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
- original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
- )
+ converted_state_dict[
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
- original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
- )
+ converted_state_dict[
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
## time_text_embed.text_embedder <- vector_in
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
@@ -720,21 +720,21 @@ def swap_scale_shift(weight):
# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
- converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
- original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
- )
+ converted_state_dict[
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
- original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
- )
+ converted_state_dict[
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
- converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
- original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
- )
+ converted_state_dict[
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
- original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
- )
+ converted_state_dict[
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
# context_embedder
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index d38898c34383..7e7445ef1239 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -181,7 +181,7 @@ def load_state_dict(
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
)
diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py
index 5515a7885098..a88ee6c9c9b8 100644
--- a/src/diffusers/models/transformers/transformer_2d.py
+++ b/src/diffusers/models/transformers/transformer_2d.py
@@ -211,9 +211,9 @@ def _init_continuous_input(self, norm_type):
def _init_vectorized_inputs(self, norm_type):
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
- assert self.config.num_vector_embeds is not None, (
- "Transformer2DModel over discrete input must provide num_embed"
- )
+ assert (
+ self.config.num_vector_embeds is not None
+ ), "Transformer2DModel over discrete input must provide num_embed"
self.height = self.config.sample_size
self.width = self.config.sample_size
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index e7a84d4b6dfb..875dbed38c4d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -650,7 +650,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -658,7 +658,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index 948728d56afc..38e63f56b2f3 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -743,7 +743,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -751,7 +751,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
@@ -1644,7 +1644,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
index c2006862280b..1ee63e5f7db6 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
@@ -726,7 +726,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -734,7 +734,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index ded22a5467a6..05fcb9449cfe 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -507,7 +507,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -515,7 +515,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index ed5b08a03cb7..2be8e75973ef 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -485,7 +485,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -493,7 +493,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index 5f8db26eef54..e653b8266f19 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
"""
_load_connected_pipes = True
- model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
+ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
_exclude_from_cpu_offload = ["prior_prior"]
def __init__(
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index 769c834ec3cc..cce5f0b3d5bc 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -579,7 +579,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
index 6d89f16765a3..bc7a4b57affd 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
@@ -604,7 +604,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -612,7 +612,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
@@ -1340,7 +1340,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
index db652989cfc1..33abfb0be89f 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
@@ -683,7 +683,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -691,7 +691,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1191,7 +1191,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index 8b06bdc9c969..fdf3df2f4d6a 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -737,7 +737,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -745,7 +745,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1509,7 +1509,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 288f269a6563..55a9f47145a2 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -575,7 +575,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index dd659306e002..abcba926160a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -335,7 +335,7 @@ def _generate(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index f2e1d87be87e..ddd2e27dedaf 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -475,7 +475,7 @@ def __call__(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 0f7be1a1bbcd..6f4e7f358952 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -660,7 +660,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -668,7 +668,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1226,7 +1226,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index e0748943ffff..7857bc58a8ad 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -401,7 +401,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents + num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index 42db88b03049..c6967bc393b5 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -600,7 +600,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents + num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index f9b6dcbf5ad2..dae4540ebe00 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -740,7 +740,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents + num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index 05a8757039cf..de9842913e98 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -1258,7 +1258,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input."
)
elif num_channels_transformer != 16:
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 835c0af800da..920caf4d24a1 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -741,7 +741,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -749,7 +749,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1509,7 +1509,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py
index fa9ba98e6d0d..1c75b5bef933 100644
--- a/src/diffusers/quantizers/base.py
+++ b/src/diffusers/quantizers/base.py
@@ -215,15 +215,19 @@ def _dequantize(self, model):
)
@abstractmethod
- def _process_model_before_weight_loading(self, model, **kwargs): ...
+ def _process_model_before_weight_loading(self, model, **kwargs):
+ ...
@abstractmethod
- def _process_model_after_weight_loading(self, model, **kwargs): ...
+ def _process_model_after_weight_loading(self, model, **kwargs):
+ ...
@property
@abstractmethod
- def is_serializable(self): ...
+ def is_serializable(self):
+ ...
@property
@abstractmethod
- def is_trainable(self): ...
+ def is_trainable(self):
+ ...
diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py
index 4f001b3047d6..f482deddd2f4 100644
--- a/src/diffusers/utils/deprecation_utils.py
+++ b/src/diffusers/utils/deprecation_utils.py
@@ -40,7 +40,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
line_number = call_frame.lineno
function = call_frame.function
key, value = next(iter(deprecated_kwargs.items()))
- raise TypeError(f"{function} in {filename} line {line_number - 1} got an unexpected keyword argument `{key}`")
+ raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
if len(values) == 0:
return
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 6e77873987ac..6a1978944c9f 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -275,20 +275,6 @@ def from_config(cls, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class CogView4Transformer2DModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
class ConsisIDTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1353,21 +1339,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class CogView4DDIMScheduler(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
class DDIMInverseScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index bd45aa3c20ea..b899915c3046 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -377,21 +377,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogView4Pipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index 92bb2c1bdaa9..37535366ed44 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -236,7 +236,7 @@
_wandb_available = importlib.util.find_spec("wandb") is not None
try:
_wandb_version = importlib_metadata.version("wandb")
- logger.debug(f"Successfully imported wandb version {_wandb_version}")
+ logger.debug(f"Successfully imported wandb version {_wandb_version }")
except importlib_metadata.PackageNotFoundError:
_wandb_available = False
diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py
index b96e0e222cb1..6f93450c410c 100644
--- a/src/diffusers/utils/logging.py
+++ b/src/diffusers/utils/logging.py
@@ -60,7 +60,8 @@ def _get_default_logging_level() -> int:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
- f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return _default_log_level
diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py
index 8efd6e6df51e..62b114ba67e3 100644
--- a/src/diffusers/utils/state_dict_utils.py
+++ b/src/diffusers/utils/state_dict_utils.py
@@ -329,7 +329,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
- alpha_key = f"{kohya_key.split('.')[0]}.alpha"
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index 0401da7c6044..7eda13716025 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -813,7 +813,7 @@ def pytest_terminal_summary_main(tr, id):
f.write("slowest durations\n")
for i, rep in enumerate(dlist):
if rep.duration < durations_min:
- f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
+ f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
break
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
@@ -958,7 +958,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
process.join(timeout=timeout)
if results["error"] is not None:
- test_case.fail(f"{results['error']}")
+ test_case.fail(f'{results["error"]}')
class CaptureLogger:
From d932f670ea6ea614f1692fbb52e0c79930145397 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 20:43:39 +0800
Subject: [PATCH 26/68] back
---
.../pipelines/audioldm2/pipeline_audioldm2.py | 2 +-
src/diffusers/pipelines/shap_e/renderer.py | 12 ++++++------
.../pipelines/stable_audio/pipeline_stable_audio.py | 2 +-
src/diffusers/schedulers/__init__.py | 2 --
.../schedulers/scheduling_consistency_models.py | 3 ++-
src/diffusers/schedulers/scheduling_ddpm.py | 3 ++-
src/diffusers/schedulers/scheduling_ddpm_parallel.py | 3 ++-
src/diffusers/schedulers/scheduling_lcm.py | 3 ++-
src/diffusers/schedulers/scheduling_tcd.py | 3 ++-
src/diffusers/training_utils.py | 4 ++--
10 files changed, 20 insertions(+), 17 deletions(-)
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index e36e36304bd8..b8b5d07af529 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -788,7 +788,7 @@ def check_inputs(
if transcription is None:
if self.text_encoder_2.config.model_type == "vits":
- raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
+ raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
elif transcription is not None and (
not isinstance(transcription, str) and not isinstance(transcription, list)
):
diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py
index dd25945590cd..9d9f9d9b2ab1 100644
--- a/src/diffusers/pipelines/shap_e/renderer.py
+++ b/src/diffusers/pipelines/shap_e/renderer.py
@@ -983,9 +983,9 @@ def decode_to_mesh(
fields = torch.cat(fields, dim=1)
fields = fields.float()
- assert len(fields.shape) == 3 and fields.shape[-1] == 1, (
- f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
- )
+ assert (
+ len(fields.shape) == 3 and fields.shape[-1] == 1
+ ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
fields = fields.reshape(1, *([grid_size] * 3))
@@ -1039,9 +1039,9 @@ def decode_to_mesh(
textures = textures.float()
# 3.3 augument the mesh with texture data
- assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), (
- f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
- )
+ assert len(textures.shape) == 3 and textures.shape[-1] == len(
+ texture_channels
+ ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
index 1b87c02df029..5d773b614a5c 100644
--- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
@@ -584,7 +584,7 @@ def __call__(
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
raise ValueError(
- f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
+ f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
)
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 512d28d95c09..bb9088538653 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -44,7 +44,6 @@
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
- _import_structure["scheduling_ddim_cogview4"] = ["CogView4DDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
@@ -145,7 +144,6 @@
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
- from .scheduling_ddim_cogview4 import CogView4DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py
index c946fa1681c0..653171638ccf 100644
--- a/src/diffusers/schedulers/scheduling_consistency_models.py
+++ b/src/diffusers/schedulers/scheduling_consistency_models.py
@@ -203,7 +203,8 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index f9eb9c365acd..624d5a5cd4f3 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -279,7 +279,8 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index 64195be141f6..20ad7a4c927d 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -289,7 +289,8 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index 2a0cce7bf146..686b686f6870 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -413,7 +413,8 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py
index 77770ab2066c..5d60383142a4 100644
--- a/src/diffusers/schedulers/scheduling_tcd.py
+++ b/src/diffusers/schedulers/scheduling_tcd.py
@@ -431,7 +431,8 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index 660a2042e18d..082640f37a17 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -241,7 +241,7 @@ def _set_state_dict_into_text_encoder(
"""
text_encoder_state_dict = {
- f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
+ f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
}
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
@@ -578,7 +578,7 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
if self.temp_stored_params is None:
- raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
From b04f15d7dc469c4f3f3a9f87c5e5b794d897a8b7 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 20:45:39 +0800
Subject: [PATCH 27/68] back 4
---
src/diffusers/pipelines/free_noise_utils.py | 6 +++---
src/diffusers/pipelines/pipeline_loading_utils.py | 4 +++-
src/diffusers/pipelines/pipeline_utils.py | 4 ++--
3 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py
index 8ea5eb7dd575..dc0071a494e3 100644
--- a/src/diffusers/pipelines/free_noise_utils.py
+++ b/src/diffusers/pipelines/free_noise_utils.py
@@ -341,9 +341,9 @@ def _encode_prompt_free_noise(
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
- negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
- self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
- )
+ negative_prompt_interpolation_embeds[
+ start_frame : end_frame + 1
+ ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
prompt_embeds = prompt_interpolation_embeds
negative_prompt_embeds = negative_prompt_interpolation_embeds
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index c26fc89fb1f0..4173c49524dd 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -287,7 +287,9 @@ def maybe_raise_or_warn(
model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj):
- raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
+ )
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 63a05352a6a1..0c1371c7556f 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -1449,8 +1449,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
if load_components_from_hub and not trust_remote_code:
raise ValueError(
- f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
- f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
+ f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
+ f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
From 5d33f3f411ba9770f113a8a4d340e5b4760fbc70 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 21:30:38 +0800
Subject: [PATCH 28/68] Fix qkv conversion logic for CogView4 to Diffusers
format
---
scripts/convert_cogview4_to_diffusers.py | 35 ++++++++++++++++++++----
1 file changed, 29 insertions(+), 6 deletions(-)
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index 4405a40fb761..bf5d8dc675aa 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -103,17 +103,40 @@ def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
+ # qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
+ # qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
+ # q, k, v = qkv_weight.chunk(3, dim=0)
+ # q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
+ #
+ # new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ # new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
+ # new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ # new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
+ # new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ # new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
+
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
+
+ num_heads = 32
+ hidden_dim = 4096
+ head_dim = qkv_weight.shape[0] // (3 * num_heads)
+ qkv_weight = qkv_weight.view(num_heads, 3, head_dim, hidden_dim)
+ qkv_bias = qkv_bias.view(num_heads, 3, head_dim)
+
+ qkv_weight = qkv_weight.permute(1, 0, 2, 3) # (3, num_heads, head_dim, hidden_dim)
+ qkv_bias = qkv_bias.permute(1, 0, 2) # (3, num_heads, head_dim)
+
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
- new_state_dict[block_prefix + "attn1.to_q.weight"] = q
- new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
- new_state_dict[block_prefix + "attn1.to_k.weight"] = k
- new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
- new_state_dict[block_prefix + "attn1.to_v.weight"] = v
- new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias.squeeze(0).reshape(num_heads * head_dim)
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias.squeeze(0).reshape(num_heads * head_dim)
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias.squeeze(0).reshape(num_heads * head_dim)
+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
From b889b37eb8a2ef5c79d11702ccedb9d56dff171e Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 21:30:49 +0800
Subject: [PATCH 29/68] back5
---
src/diffusers/schedulers/__init__.py | 2 ++
src/diffusers/utils/dummy_pt_objects.py | 13 +++++++++++++
2 files changed, 15 insertions(+)
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index bb9088538653..512d28d95c09 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -44,6 +44,7 @@
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
+ _import_structure["scheduling_ddim_cogview4"] = ["CogView4DDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
@@ -144,6 +145,7 @@
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
+ from .scheduling_ddim_cogview4 import CogView4DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 6a1978944c9f..e465f75f7129 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1338,6 +1338,19 @@ def from_config(cls, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class CogView4DDIMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
class DDIMInverseScheduler(metaclass=DummyObject):
_backends = ["torch"]
From e239c3cd54b3f609c14dbf8fa1d75028309bb6bb Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 22:18:22 +0800
Subject: [PATCH 30/68] revert to sat to cogview4 version
---
scripts/convert_cogview4_to_diffusers.py | 35 ++++--------------------
1 file changed, 6 insertions(+), 29 deletions(-)
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index bf5d8dc675aa..4405a40fb761 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -103,40 +103,17 @@ def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
- # qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
- # qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
- # q, k, v = qkv_weight.chunk(3, dim=0)
- # q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
- #
- # new_state_dict[block_prefix + "attn1.to_q.weight"] = q
- # new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
- # new_state_dict[block_prefix + "attn1.to_k.weight"] = k
- # new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
- # new_state_dict[block_prefix + "attn1.to_v.weight"] = v
- # new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
-
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
-
- num_heads = 32
- hidden_dim = 4096
- head_dim = qkv_weight.shape[0] // (3 * num_heads)
- qkv_weight = qkv_weight.view(num_heads, 3, head_dim, hidden_dim)
- qkv_bias = qkv_bias.view(num_heads, 3, head_dim)
-
- qkv_weight = qkv_weight.permute(1, 0, 2, 3) # (3, num_heads, head_dim, hidden_dim)
- qkv_bias = qkv_bias.permute(1, 0, 2) # (3, num_heads, head_dim)
-
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
- new_state_dict[block_prefix + "attn1.to_q.weight"] = q.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
- new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias.squeeze(0).reshape(num_heads * head_dim)
- new_state_dict[block_prefix + "attn1.to_k.weight"] = k.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
- new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias.squeeze(0).reshape(num_heads * head_dim)
- new_state_dict[block_prefix + "attn1.to_v.weight"] = v.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
- new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias.squeeze(0).reshape(num_heads * head_dim)
-
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
From 310da291c851fdb58aaacb71c1db8b7185b07e8b Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Tue, 28 Jan 2025 23:58:44 +0800
Subject: [PATCH 31/68] update a new convert from megatron
---
check_same_sat_megatron_diffusers.py | 264 +++++++++++++++++++++++++++
1 file changed, 264 insertions(+)
create mode 100644 check_same_sat_megatron_diffusers.py
diff --git a/check_same_sat_megatron_diffusers.py b/check_same_sat_megatron_diffusers.py
new file mode 100644
index 000000000000..0495a55f4d91
--- /dev/null
+++ b/check_same_sat_megatron_diffusers.py
@@ -0,0 +1,264 @@
+import torch
+from collections import OrderedDict
+from diffusers import CogView4Transformer2DModel
+
+def load_state_dict_sat(file_path):
+ """Load the SAT state dictionary from a given file path."""
+ # Typically, the stored SAT ckpt is in the format: {'module': {...}}
+ ckpt = torch.load(file_path, map_location="cuda")
+ return ckpt["module"]
+
+
+def extract_qkv_from_sat(state_dict, layer_idx):
+ """
+ Extract QKV weights and biases from a SAT state_dict.
+ Expects keys like:
+ model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value
+ """
+ prefix = f"model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value"
+ w = state_dict[f"{prefix}.weight"].clone()
+ b = state_dict[f"{prefix}.bias"].clone()
+ return (w, b)
+
+
+def load_state_dict_cogview(cogview_path):
+ """
+ Loads the CogView4 model from diffusers and returns its state_dict().
+ NOTE: You should adjust 'torch_dtype' and 'device_map' as appropriate.
+ """
+ cogview_model = CogView4Transformer2DModel.from_pretrained(
+ cogview_path, torch_dtype=torch.bfloat16, device_map="auto"
+ )
+ return cogview_model.state_dict()
+
+
+def extract_qkv_from_cogview(state_dict, layer_idx, num_heads, head_dim, hidden_dim):
+ """
+ Extract Q, K, V from CogView4 checkpoint and reshape them into the same shape as SAT’s QKV.
+ For each layer i:
+ Q prefix: transformer_blocks.{layer_idx}.attn1.to_q
+ K prefix: transformer_blocks.{layer_idx}.attn1.to_k
+ V prefix: transformer_blocks.{layer_idx}.attn1.to_v
+ Final shape must match SAT's [3*hidden_dim, hidden_dim] for weight, and [3*hidden_dim] for bias.
+ """
+ q_prefix = f"transformer_blocks.{layer_idx}.attn1.to_q"
+ k_prefix = f"transformer_blocks.{layer_idx}.attn1.to_k"
+ v_prefix = f"transformer_blocks.{layer_idx}.attn1.to_v"
+
+ # Extract
+ q_weight = state_dict[f"{q_prefix}.weight"].clone()
+ k_weight = state_dict[f"{k_prefix}.weight"].clone()
+ v_weight = state_dict[f"{v_prefix}.weight"].clone()
+
+ q_bias = state_dict[f"{q_prefix}.bias"].clone()
+ k_bias = state_dict[f"{k_prefix}.bias"].clone()
+ v_bias = state_dict[f"{v_prefix}.bias"].clone()
+
+ # Reshape weights: [hidden_dim, hidden_dim] -> [num_heads, head_dim, hidden_dim]
+ # Then concat along the first dimension (which will become 3*num_heads*head_dim)
+ q_weight = q_weight.view(num_heads, head_dim, hidden_dim)
+ k_weight = k_weight.view(num_heads, head_dim, hidden_dim)
+ v_weight = v_weight.view(num_heads, head_dim, hidden_dim)
+
+ qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) # shape: (3*num_heads, head_dim, hidden_dim)
+ qkv_weight = qkv_weight.view(3 * num_heads * head_dim, hidden_dim) # flatten
+
+ # Reshape biases: [hidden_dim] -> [num_heads, head_dim]
+ q_bias = q_bias.view(num_heads, head_dim)
+ k_bias = k_bias.view(num_heads, head_dim)
+ v_bias = v_bias.view(num_heads, head_dim)
+
+ qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) # (3*num_heads, head_dim)
+ qkv_bias = qkv_bias.view(3 * num_heads * head_dim)
+
+ return (qkv_weight, qkv_bias)
+
+def create_sat_state_dict_from_megatron(megatron_ckpt_dict, num_layers=48, num_heads=32, hidden_size=3072):
+ """
+ Convert a loaded Megatron checkpoint's 'model' dictionary into the same
+ format used by SAT. This returns something like {'module': {...}} for
+ easy comparison with SAT.
+
+ The code below is adapted from your 'create_sat_state_dict' function,
+ but we rename it here to keep it direct.
+ """
+ from tqdm import tqdm
+
+ hidden_size_per_head = hidden_size // num_heads
+ mega_weight = megatron_ckpt_dict["model"]
+ sat_weight = {}
+
+ # --- patch_embed ---
+ sat_weight["model.diffusion_model.mixins.patch_embed.proj.weight"] = \
+ mega_weight["encoder_expand_linear.weight"].reshape(hidden_size, 64).clone()
+ sat_weight["model.diffusion_model.mixins.patch_embed.proj.bias"] = \
+ mega_weight["encoder_expand_linear.bias"].clone()
+
+ sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.weight"] = \
+ mega_weight["text_projector.weight"].clone()
+ sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.bias"] = \
+ mega_weight["text_projector.bias"].clone()
+
+ # --- time embedding ---
+ sat_weight["model.diffusion_model.time_embed.0.weight"] = \
+ mega_weight["time_embedding.time_embed.0.weight"].clone()
+ sat_weight["model.diffusion_model.time_embed.0.bias"] = \
+ mega_weight["time_embedding.time_embed.0.bias"].clone()
+ sat_weight["model.diffusion_model.time_embed.2.weight"] = \
+ mega_weight["time_embedding.time_embed.2.weight"].clone()
+ sat_weight["model.diffusion_model.time_embed.2.bias"] = \
+ mega_weight["time_embedding.time_embed.2.bias"].clone()
+
+ # --- label embedding ---
+ sat_weight["model.diffusion_model.label_emb.0.0.weight"] = \
+ mega_weight["label_embedding.label_embed.0.weight"].clone()
+ sat_weight["model.diffusion_model.label_emb.0.0.bias"] = \
+ mega_weight["label_embedding.label_embed.0.bias"].clone()
+ sat_weight["model.diffusion_model.label_emb.0.2.weight"] = \
+ mega_weight["label_embedding.label_embed.2.weight"].clone()
+ sat_weight["model.diffusion_model.label_emb.0.2.bias"] = \
+ mega_weight["label_embedding.label_embed.2.bias"].clone()
+
+ # --- layers ---
+ for i in tqdm(range(num_layers), desc="Converting Megatron->SAT"):
+ # attention output
+ sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.weight"] = \
+ mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.weight"].clone()
+ sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.bias"] = \
+ mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.bias"].clone()
+
+ # QKV
+ qkv_weight = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.weight"].clone()
+ qkv_bias = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.bias"].clone()
+
+ # Reshape QKV from Megatron format into SAT format
+ # qkv_weight: [3*hidden_size, hidden_size] -> [num_heads, 3, hidden_size_per_head, hidden_size] -> ...
+ sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.weight"] = \
+ qkv_weight.view(num_heads, 3, hidden_size_per_head, hidden_size) \
+ .permute(1, 0, 2, 3) \
+ .reshape(3 * hidden_size, hidden_size).clone()
+ sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.bias"] = \
+ qkv_bias.view(num_heads, 3, hidden_size_per_head) \
+ .permute(1, 0, 2) \
+ .reshape(3 * hidden_size) \
+ .clone()
+
+ # MLP
+ sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.weight"] = \
+ mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.weight"].clone()
+ sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.bias"] = \
+ mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.bias"].clone()
+
+ sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.weight"] = \
+ mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.weight"].clone()
+ sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.bias"] = \
+ mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.bias"].clone()
+
+ # AdaLN
+ adaln_weight = mega_weight[f"decoder.layers.{i}.adaln.weight"].clone()
+ adaln_bias = mega_weight[f"decoder.layers.{i}.adaln.bias"].clone()
+
+ sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.weight"] = adaln_weight.clone()
+ sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.bias"] = adaln_bias.clone()
+
+ # --- final layers ---
+ sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.weight"] = \
+ mega_weight["adaln_final.weight"].clone()
+ sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.bias"] = \
+ mega_weight["adaln_final.bias"].clone()
+ sat_weight["model.diffusion_model.mixins.final_layer.linear.weight"] = \
+ mega_weight["output_projector.weight"].clone()
+ sat_weight["model.diffusion_model.mixins.final_layer.linear.bias"] = \
+ mega_weight["output_projector.bias"].clone()
+
+ return OrderedDict(sat_weight)
+
+
+def load_state_dict_megatron_and_convert_to_sat(megatron_ckpt_path, num_layers, num_heads, hidden_size):
+ """
+ Load a Megatron checkpoint from , then convert it into
+ an SAT-style OrderedDict for direct QKV comparison.
+
+ Typically, = ".../iter_0287500/mp_rank_00/model_optim_rng.pt"
+ """
+ ckpt = torch.load(megatron_ckpt_path, map_location="cuda")
+ # Convert to SAT
+ sat_like_weight = create_sat_state_dict_from_megatron(
+ ckpt, num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size
+ )
+ return sat_like_weight
+
+def compute_l2_difference(tensor1, tensor2):
+ """Compute L2 norm of the difference between two tensors."""
+ return torch.norm(tensor1 - tensor2, p=2).item()
+
+
+def compare_qkv(qkv1, qkv2, name1="Model1", name2="Model2", atol=1e-6):
+ """
+ Compare QKV from two different sources (each is a tuple of (weight, bias)).
+ Returns (weight_match, bias_match, weight_l2, bias_l2).
+ """
+ w1, b1 = qkv1
+ w2, b2 = qkv2
+
+ weight_match = torch.allclose(w1, w2, atol=atol)
+ bias_match = torch.allclose(b1, b2, atol=atol)
+ weight_l2_diff = compute_l2_difference(w1, w2)
+ bias_l2_diff = compute_l2_difference(b1, b2)
+
+ if not (weight_match and bias_match):
+ print(f"[QKV Mismatch] {name1} vs {name2}")
+ print(f" Weight L2: {weight_l2_diff:.6f}, Bias L2: {bias_l2_diff:.6f}")
+ else:
+ # If everything matches well:
+ print(f"[QKV Match] {name1} vs {name2} (Weight L2={weight_l2_diff:.6f}, Bias L2={bias_l2_diff:.6f})")
+
+ return weight_match, bias_match, weight_l2_diff, bias_l2_diff
+
+if __name__ == "__main__":
+ num_layers = 28
+ num_heads = 32
+ hidden_dim = 4096
+ head_dim = hidden_dim // num_heads
+
+ sat_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_sat/0287500/mp_rank_00_model_states.pt"
+ sat_state_dict = load_state_dict_sat(sat_ckpt_path)
+
+ cogview_path = "/share/zyx/CogView4-6B-0128/transformer" # directory containing model index for diffusers
+ cogview_state_dict = load_state_dict_cogview(cogview_path)
+
+ megatron_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_ema/iter_0287500/mp_rank_00/model_optim_rng.pt"
+ mega_as_sat_state_dict = load_state_dict_megatron_and_convert_to_sat(
+ megatron_ckpt_path,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ hidden_size=hidden_dim
+ )
+
+ print("\n==== Start QKV Comparison ====\n")
+ for layer_idx in range(num_layers):
+ print(f"--- Layer {layer_idx} ---")
+
+ # Extract QKV from SAT
+ sat_qkv = extract_qkv_from_sat(sat_state_dict, layer_idx)
+
+ # Extract QKV from CogView
+ cogview_qkv = extract_qkv_from_cogview(
+ cogview_state_dict, layer_idx, num_heads, head_dim, hidden_dim
+ )
+
+ # Extract QKV from Megatron->SAT
+ mega_qkv = extract_qkv_from_sat(mega_as_sat_state_dict, layer_idx)
+
+ # Compare: SAT vs CogView
+ compare_qkv(sat_qkv, cogview_qkv, name1="SAT", name2="CogView4")
+
+ # Compare: SAT vs Megatron
+ compare_qkv(sat_qkv, mega_qkv, name1="SAT", name2="Megatron")
+
+ # Compare: CogView vs Megatron (optional)
+ compare_qkv(cogview_qkv, mega_qkv, name1="CogView4", name2="Megatron")
+
+ print()
+
+ print("=== Done ===")
From 3bd6d30a0f01261a2ead0cec64075cfdc8b52b0d Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Tue, 28 Jan 2025 16:22:01 +0000
Subject: [PATCH 32/68] [WIP][cogview4]: implement CogView4 attention processor
Add CogView4AttnProcessor class for implementing scaled dot-product attention
with rotary embeddings for the CogVideoX model. This processor concatenates
encoder and hidden states, applies QKV projections and RoPE, but does not
include spatial normalization.
TODO:
- Fix incorrect QKV projection weights
- Resolve ~25% error in RoPE implementation compared to Megatron
---
src/diffusers/models/attention_processor.py | 104 +++++++++++++++++++-
src/diffusers/models/embeddings.py | 34 +++++++
2 files changed, 135 insertions(+), 3 deletions(-)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 26625753e4b6..44f7ada5b56e 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2802,6 +2802,105 @@ def __call__(
return hidden_states
+class CogView4AttnProcessor:
+ """
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogView4AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ ###############################################3
+ # TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
+ linear_qkv_weight = torch.load("/home/lhy/code/cogview/linear_qkv_weight.pt")
+ linear_qkv_bias = torch.load("/home/lhy/code/cogview/linear_qkv_bias.pt")
+
+ qkv = torch.matmul(hidden_states, linear_qkv_weight.T) + linear_qkv_bias
+ qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
+ query, key, value = qkv.chunk(3, dim=-1)
+
+
+ # TODO: 校验rope是否apply正确(目前有25%的误差)
+ ###############################################3
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb_megatron
+
+ query[:, :, text_seq_length:, :] = apply_rotary_emb_megatron(
+ query[:, :, text_seq_length:, :], image_rotary_emb
+ )
+ key[:, :, text_seq_length:, :] = apply_rotary_emb_megatron(
+ key[:, :, text_seq_length:, :], image_rotary_emb
+ )
+
+ ##########################################
+ query = torch.load("/home/lhy/code/cogview/query_after_rope.pt")
+ key = torch.load("/home/lhy/code/cogview/key_after_rope.pt")
+ value = torch.load("/home/lhy/code/cogview/value_after_rope.pt")
+ query = query[None, :16+4096, ...]
+ key = key[None, :16+4096, ...]
+ value = value[None, :16+4096, ...]
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+ ##########################################
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -2824,9 +2923,7 @@ def __call__(
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
+ batch_size, sequence_length, _ = hidden_states.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
@@ -6174,6 +6271,7 @@ def __call__(
FusedFluxAttnProcessor2_0,
FusedFluxAttnProcessor2_0_NPU,
CogVideoXAttnProcessor2_0,
+ CogView4AttnProcessor,
FusedCogVideoXAttnProcessor2_0,
XFormersAttnAddedKVProcessor,
XFormersAttnProcessor,
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index fd5d95051f11..73088dee6011 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1283,6 +1283,40 @@ def apply_1d_rope(tokens, pos, cos, sin):
x = torch.cat([t, h, w], dim=-1)
return x
+def apply_rotary_emb_megatron(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
+ """Apply rotary position embeddings to input tensor.
+
+ Args:
+ x: Input tensor of shape [seq_len, batch_size, n_heads, head_dim]
+ freqs: Frequency tensor of shape [seq_len, 1, 1, head_dim//2]
+
+ Returns:
+ Tensor with rotary position embeddings applied
+ """
+ batch_size, n_heads, seq_len, rot_dim = x.shape
+
+ # Reshape x to have rot_dim as the last dimension
+ x_rot, x_pass = x.chunk(2, dim=-1)
+
+ # Apply rotary embeddings
+ # First calculate cos and sin
+ cos, sin = freqs.chunk(2, dim=-1)
+ cos, sin = torch.cos(cos), torch.sin(sin)
+
+ # Rotate x_rot
+ x_rot_cos = x_rot * cos
+ # Create rotated version of x_rot by shifting rot_dim/2 positions
+ x_rot_shifted = torch.cat([-x_rot[..., rot_dim//2:], x_rot[..., :rot_dim//2]], dim=-1)
+ x_rot_sin = x_rot_shifted * sin
+
+ # Combine
+ x_rot = x_rot_cos + x_rot_sin
+
+ # Concatenate back with x_pass
+ x_out = torch.cat([x_rot, x_pass], dim=-1)
+
+ return x_out
+
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
From f826aec021c1e5da31f5710647f9d48996f9901e Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Tue, 28 Jan 2025 16:23:57 +0000
Subject: [PATCH 33/68] [cogview4] implement CogView4 transformer block
Implement CogView4 transformer block following the Megatron architecture:
- Add multi-modulate and multi-gate mechanisms for adaptive layer normalization
- Implement dual-stream attention with encoder-decoder structure
- Add feed-forward network with GELU activation
- Support rotary position embeddings for image tokens
The implementation follows the original CogView4 architecture while adapting
it to work within the diffusers framework.
---
.../transformers/transformer_cogview4.py | 181 +++++++++++-------
1 file changed, 116 insertions(+), 65 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 2611cb4b5e88..8c50610e8f9a 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -17,13 +17,14 @@
import torch
import torch.nn as nn
+import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
- CogVideoXAttnProcessor2_0,
+ CogView4AttnProcessor,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
@@ -32,6 +33,7 @@
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -60,6 +62,8 @@ def __init__(
super().__init__()
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
+ self.adaln = self.norm1.linear
+ self.layernorm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.attn1 = Attention(
query_dim=dim,
@@ -69,66 +73,109 @@ def __init__(
bias=True,
qk_norm="layer_norm",
elementwise_affine=False,
- eps=1e-6,
- processor=CogVideoXAttnProcessor2_0(),
+ eps=1e-5,
+ processor=CogView4AttnProcessor(),
)
- self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
- self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
-
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+ def multi_modulate(self, hidden_states, encoder_hidden_states, factors) -> torch.Tensor:
+ n_sample, n_type, h = factors[0].shape
+ shift_factor, scale_factor = factors[0].view(-1, h), factors[1].view(-1, h)
+
+ shift_factor_hidden_states, shift_factor_encoder_hidden_states = shift_factor.chunk(2, dim=0)
+ scale_factor_hidden_states, scale_factor_encoder_hidden_states = scale_factor.chunk(2, dim=0)
+
+ hidden_states = torch.addcmul(shift_factor_hidden_states, hidden_states, (1 + scale_factor_hidden_states))
+ encoder_hidden_states = torch.addcmul(
+ shift_factor_encoder_hidden_states, encoder_hidden_states, (1 + scale_factor_encoder_hidden_states)
+ )
+
+ return hidden_states, encoder_hidden_states
+
+ def multi_gate(self, hidden_states, encoder_hidden_states, factor):
+ batch_size, seq_len, hidden_dim = hidden_states.shape
+ gate_factor = factor.view(-1, hidden_dim)
+ gate_factor_hidden_states, gate_factor_encoder_hidden_states = gate_factor.chunk(2, dim=0)
+ hidden_states = gate_factor_hidden_states * hidden_states
+ encoder_hidden_states = gate_factor_encoder_hidden_states * encoder_hidden_states
+ return hidden_states, encoder_hidden_states
+
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
- emb: torch.Tensor,
+ time_embedding: torch.Tensor = None,
+ image_rotary_emb: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- # norm & modulate
- (
- norm_hidden_states,
- gate_msa,
- shift_mlp,
- scale_mlp,
- gate_mlp,
- norm_encoder_hidden_states,
- c_gate_msa,
- c_shift_mlp,
- c_scale_mlp,
- c_gate_mlp,
- ) = self.norm1(hidden_states, encoder_hidden_states, emb)
-
- # attention
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **kwargs
- )
-
- hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
- encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
+ batch_size, encoder_hidden_states_len, hidden_dim = encoder_hidden_states.shape
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- # norm & modulate
- norm_hidden_states = self.norm2(hidden_states)
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ residual = hidden_states
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ # time_embedding embedding, [n_sample, h]
+ assert time_embedding is not None
- # feed-forward
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
- ff_output = self.ff(norm_hidden_states)
+ layernorm_factor = (
+ self.adaln(time_embedding)
+ .view(
+ time_embedding.shape[0],
+ 6,
+ 2,
+ hidden_states.shape[-1],
+ )
+ .permute(1, 2, 0, 3)
+ .contiguous()
+ )
- hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
+ ##############################################################
+ # Optional Input Layer norm
+ hidden_states = self.layernorm(hidden_states)
+ hidden_states, encoder_hidden_states = self.multi_modulate(
+ hidden_states=hidden_states[:, encoder_hidden_states_len:],
+ encoder_hidden_states=hidden_states[:, :encoder_hidden_states_len],
+ factors=(layernorm_factor[0], layernorm_factor[1]),
+ )
+ hidden_states, encoder_hidden_states = self.attn1(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states, encoder_hidden_states = self.multi_gate(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ factor=layernorm_factor[2],
+ )
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states += residual
+
+ residual = hidden_states
+ ##############################################################
+ hidden_states = self.layernorm(hidden_states)
+ hidden_states, encoder_hidden_states = self.multi_modulate(
+ hidden_states=hidden_states[:, encoder_hidden_states_len:],
+ encoder_hidden_states=hidden_states[:, :encoder_hidden_states_len],
+ factors=(layernorm_factor[3], layernorm_factor[4]),
+ )
+ hidden_states = self.ff(hidden_states)
+ encoder_hidden_states = self.ff(encoder_hidden_states)
+ hidden_states, encoder_hidden_states = self.multi_gate(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ factor=layernorm_factor[5],
+ )
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states += residual
- if hidden_states.dtype == torch.float16:
- hidden_states = hidden_states.clip(-65504, 65504)
- if encoder_hidden_states.dtype == torch.float16:
- encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+ ##############################################################
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, :encoder_hidden_states_len],
+ hidden_states[:, encoder_hidden_states_len:],
+ )
return hidden_states, encoder_hidden_states
+
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
r"""
Args:
@@ -335,7 +382,8 @@ def get_rope_embedding(self, height, width, target_h, target_w, device):
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
freqs = freqs.reshape(height * width, -1)
- return freqs.cos(), freqs.sin()
+ return freqs
+ # return freqs.cos(), freqs.sin()
def forward(
self,
@@ -391,28 +439,31 @@ def forward(
image_rotary_emb = self.get_rope_embedding(
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
)
+ # image_rotary_emb = torch.load("/home/lhy/code/cogview/rotary_pos_emb.pt")
+ # image_rotary_emb = image_rotary_emb[16:16+4096, 0, 0, :]
+ ######################
# 2. Conditional embeddings
- temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
- temb_cond, temb_uncond = temb.chunk(2)
- hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
- hidden_states, prompt_embeds, negative_prompt_embeds
- )
- hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
+ # temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
+ # temb_cond, temb_uncond = temb.chunk(2)
+ # hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
+ # hidden_states, prompt_embeds, negative_prompt_embeds
+ # )
+ # hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
+ # encoder_hidden_states_cond = prompt_embeds
+ # encoder_hidden_states_uncond = negative_prompt_embeds
+
+ prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
+ negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
+
+ hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")[None, ::]
+ hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")[None, ::]
+
+ temb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")[None, ::]
+ temb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")[None, ::]
+
encoder_hidden_states_cond = prompt_embeds
encoder_hidden_states_uncond = negative_prompt_embeds
-
- ######################
- # prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
- # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
- # prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
- # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_uncondition_16_32.pt")[None, ::]
- #
- # hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")
- # hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")
- #
- # emb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")
- # emb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")
######################
for index_block, block in enumerate(self.transformer_blocks):
@@ -423,13 +474,13 @@ def forward(
hidden_states_cond, encoder_hidden_states_cond = block(
hidden_states=hidden_states_cond,
encoder_hidden_states=encoder_hidden_states_cond,
- emb=temb_cond,
+ time_embedding=temb_cond,
image_rotary_emb=image_rotary_emb,
)
hidden_states_uncond, encoder_hidden_states_uncond = block(
hidden_states=hidden_states_uncond,
encoder_hidden_states=encoder_hidden_states_uncond,
- emb=temb_uncond,
+ time_embedding=temb_uncond,
image_rotary_emb=image_rotary_emb,
)
From bf1fdc8c81af94438b53adcd2170c54044e6e20b Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Wed, 29 Jan 2025 00:47:02 +0800
Subject: [PATCH 34/68] with new attn
---
.../convert_cogview4_to_diffusers_megatron.py | 384 ++++++++++++++++++
src/diffusers/models/attention_processor.py | 36 +-
.../transformers/transformer_cogview4.py | 32 +-
3 files changed, 417 insertions(+), 35 deletions(-)
create mode 100644 scripts/convert_cogview4_to_diffusers_megatron.py
diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py
new file mode 100644
index 000000000000..0c93cb099744
--- /dev/null
+++ b/scripts/convert_cogview4_to_diffusers_megatron.py
@@ -0,0 +1,384 @@
+"""
+Convert a CogView4 checkpoint to the Diffusers format.
+
+This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
+with the Diffusers library.
+
+Example usage:
+ python scripts/convert_cogview4_to_diffusers.py \
+ --transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
+ --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
+ --output_path "THUDM/CogView4-6B" \
+ --dtype "bf16"
+
+Arguments:
+ --transformer_checkpoint_path: Path to Transformer state dict.
+ --vae_checkpoint_path: Path to VAE state dict.
+ --output_path: The path to save the converted model.
+ --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
+ --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used.
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
+
+ Default is "bf16" because CogView4 uses bfloat16 for training.
+
+Note: You must provide either --transformer_checkpoint_path or --vae_checkpoint_path.
+"""
+
+import argparse
+from contextlib import nullcontext
+import torch
+from transformers import PreTrainedTokenizerFast, GlmForCausalLM
+from tqdm import tqdm
+
+from diffusers import (
+ AutoencoderKL,
+ CogView4DDIMScheduler,
+ CogView4Pipeline,
+ CogView4Transformer2DModel,
+)
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--transformer_checkpoint_path",
+ default=None,
+ type=str,
+ help="Path to Megatron (not SAT) Transformer checkpoint, e.g., 'model_optim_rng.pt'.",
+)
+parser.add_argument(
+ "--vae_checkpoint_path",
+ default=None,
+ type=str,
+ help="(Optional) Path to VAE checkpoint, e.g., 'imagekl_ch16.pt'.",
+)
+parser.add_argument(
+ "--output_path",
+ required=True,
+ type=str,
+ help="Directory to save the final Diffusers format pipeline.",
+)
+parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ default=False,
+ help="Whether to push the converted model to the HuggingFace Hub.",
+)
+parser.add_argument(
+ "--text_encoder_cache_dir",
+ type=str,
+ default=None,
+ help="Specify the cache directory for the text encoder.",
+)
+parser.add_argument(
+ "--dtype",
+ type=str,
+ default="bf16",
+ choices=["fp16", "bf16", "fp32"],
+ help="Data type to save the model in.",
+)
+
+parser.add_argument(
+ "--num_layers",
+ type=int,
+ default=28,
+ help="Number of Transformer layers (e.g., 28, 48...).",
+)
+parser.add_argument(
+ "--num_heads",
+ type=int,
+ default=32,
+ help="Number of attention heads.",
+)
+parser.add_argument(
+ "--hidden_size",
+ type=int,
+ default=4096,
+ help="Transformer hidden dimension size.",
+)
+parser.add_argument(
+ "--attention_head_dim",
+ type=int,
+ default=128,
+ help="Dimension of each attention head.",
+)
+parser.add_argument(
+ "--time_embed_dim",
+ type=int,
+ default=512,
+ help="Dimension of time embeddings.",
+)
+parser.add_argument(
+ "--condition_dim",
+ type=int,
+ default=256,
+ help="Dimension of condition embeddings.",
+)
+parser.add_argument(
+ "--pos_embed_max_size",
+ type=int,
+ default=128,
+ help="Maximum size for positional embeddings.",
+)
+
+args = parser.parse_args()
+
+
+def swap_scale_shift(weight, dim):
+ """
+ Swap the scale and shift components in the weight tensor.
+
+ Args:
+ weight (torch.Tensor): The original weight tensor.
+ dim (int): The dimension along which to split.
+
+ Returns:
+ torch.Tensor: The modified weight tensor with scale and shift swapped.
+ """
+ shift, scale = weight.chunk(2, dim=dim)
+ new_weight = torch.cat([scale, shift], dim=dim)
+ return new_weight
+
+
+def convert_megatron_transformer_checkpoint_to_diffusers(
+ ckpt_path: str,
+ num_layers: int,
+ num_heads: int,
+ hidden_size: int,
+):
+ """
+ Convert a Megatron Transformer checkpoint to Diffusers format.
+
+ Args:
+ ckpt_path (str): Path to the Megatron Transformer checkpoint.
+ num_layers (int): Number of Transformer layers.
+ num_heads (int): Number of attention heads.
+ hidden_size (int): Hidden size of the Transformer.
+
+ Returns:
+ dict: The converted state dictionary compatible with Diffusers.
+ """
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+ mega = ckpt["model"]
+
+ new_state_dict = {}
+
+ # Patch Embedding
+ new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
+ new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
+ new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
+ new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
+
+ # Time Condition Embedding
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = mega[
+ "time_embedding.time_embed.0.weight"
+ ]
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = mega["time_embedding.time_embed.0.bias"]
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = mega[
+ "time_embedding.time_embed.2.weight"
+ ]
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = mega["time_embedding.time_embed.2.bias"]
+
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = mega[
+ "label_embedding.label_embed.0.weight"
+ ]
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = mega[
+ "label_embedding.label_embed.0.bias"
+ ]
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = mega[
+ "label_embedding.label_embed.2.weight"
+ ]
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = mega[
+ "label_embedding.label_embed.2.bias"
+ ]
+
+ # Convert each Transformer layer
+ for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"):
+ block_prefix = f"transformer_blocks.{i}."
+
+ # AdaLayerNorm
+ new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
+ mega[f"decoder.layers.{i}.adaln.weight"], dim=0
+ )
+ new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
+ mega[f"decoder.layers.{i}.adaln.bias"], dim=0
+ )
+
+ # QKV
+ qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
+ qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
+
+ # Reshape to match SAT logic
+ qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)
+ qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)
+
+ qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)
+ qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)
+
+ # Assign to Diffusers keys
+ q, k, v = torch.chunk(qkv_weight, 3, dim=0)
+ qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)
+
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = qb
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = kb
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = vb
+
+ # Attention Output
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
+ f"decoder.layers.{i}.self_attention.linear_proj.weight"
+ ].T
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
+ f"decoder.layers.{i}.self_attention.linear_proj.bias"
+ ]
+
+ # MLP
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.weight"]
+ new_state_dict[block_prefix + "ff.net.0.proj.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.bias"]
+ new_state_dict[block_prefix + "ff.net.2.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.weight"]
+ new_state_dict[block_prefix + "ff.net.2.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.bias"]
+
+ # Final Layers
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(mega["adaln_final.weight"], dim=0)
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(mega["adaln_final.bias"], dim=0)
+ new_state_dict["proj_out.weight"] = mega["output_projector.weight"]
+ new_state_dict["proj_out.bias"] = mega["output_projector.bias"]
+
+ return new_state_dict
+
+
+def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
+ """
+ Convert a CogView4 VAE checkpoint to Diffusers format.
+
+ Args:
+ ckpt_path (str): Path to the VAE checkpoint.
+ vae_config (dict): Configuration dictionary for the VAE.
+
+ Returns:
+ dict: The converted VAE state dictionary compatible with Diffusers.
+ """
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+ return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
+
+
+def main(args):
+ """
+ Main function to convert CogView4 checkpoints to Diffusers format.
+
+ Args:
+ args (argparse.Namespace): Parsed command-line arguments.
+ """
+ # Determine the desired data type
+ if args.dtype == "fp16":
+ dtype = torch.float16
+ elif args.dtype == "bf16":
+ dtype = torch.bfloat16
+ elif args.dtype == "fp32":
+ dtype = torch.float32
+ else:
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
+
+ transformer = None
+ vae = None
+
+ # Convert Transformer checkpoint if provided
+ if args.transformer_checkpoint_path is not None:
+ converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(
+ ckpt_path=args.transformer_checkpoint_path,
+ num_layers=args.num_layers,
+ num_heads=args.num_heads,
+ hidden_size=args.hidden_size,
+ )
+ transformer = CogView4Transformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ num_layers=args.num_layers,
+ attention_head_dim=args.attention_head_dim,
+ num_attention_heads=args.num_heads,
+ out_channels=16,
+ text_embed_dim=args.hidden_size,
+ time_embed_dim=args.time_embed_dim,
+ condition_dim=args.condition_dim,
+ pos_embed_max_size=args.pos_embed_max_size,
+ )
+
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ # Convert to the specified dtype
+ if dtype is not None:
+ transformer = transformer.to(dtype=dtype)
+
+ # Convert VAE checkpoint if provided
+ if args.vae_checkpoint_path is not None:
+ vae_config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ("DownEncoderBlock2D",) * 4,
+ "up_block_types": ("UpDecoderBlock2D",) * 4,
+ "block_out_channels": (128, 512, 1024, 1024),
+ "layers_per_block": 3,
+ "act_fn": "silu",
+ "latent_channels": 16,
+ "norm_num_groups": 32,
+ "sample_size": 1024,
+ "scaling_factor": 1.0,
+ "force_upcast": True,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "mid_block_add_attention": False,
+ }
+ converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ if dtype is not None:
+ vae = vae.to(dtype=dtype)
+
+ # Load the text encoder and tokenizer
+ text_encoder_id = "/share/home/zyx/Models/glm-4-9b-hf"
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
+ text_encoder = GlmForCausalLM.from_pretrained(
+ text_encoder_id,
+ cache_dir=args.text_encoder_cache_dir,
+ torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
+ )
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ # Initialize the scheduler
+ scheduler = CogView4DDIMScheduler.from_config(
+ {
+ "shift_scale": 1.0,
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": False,
+ "num_train_timesteps": 1000,
+ "prediction_type": "v_prediction",
+ "rescale_betas_zero_snr": True,
+ "set_alpha_to_one": True,
+ "timestep_spacing": "linspace",
+ }
+ )
+
+ # Create the pipeline
+ pipe = CogView4Pipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ # Save the converted pipeline
+ pipe.save_pretrained(
+ args.output_path,
+ safe_serialization=True,
+ max_shard_size="5GB",
+ push_to_hub=args.push_to_hub,
+ )
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 44f7ada5b56e..0e66bb4b8c30 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2837,18 +2837,18 @@ def __call__(
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
###############################################3
# TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
- linear_qkv_weight = torch.load("/home/lhy/code/cogview/linear_qkv_weight.pt")
- linear_qkv_bias = torch.load("/home/lhy/code/cogview/linear_qkv_bias.pt")
-
- qkv = torch.matmul(hidden_states, linear_qkv_weight.T) + linear_qkv_bias
- qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
- query, key, value = qkv.chunk(3, dim=-1)
+ # linear_qkv_weight = torch.load("/home/lhy/code/cogview/linear_qkv_weight.pt")
+ # linear_qkv_bias = torch.load("/home/lhy/code/cogview/linear_qkv_bias.pt")
+ #
+ # qkv = torch.matmul(hidden_states, linear_qkv_weight.T) + linear_qkv_bias
+ # qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
+ # query, key, value = qkv.chunk(3, dim=-1)
# TODO: 校验rope是否apply正确(目前有25%的误差)
@@ -2875,15 +2875,15 @@ def __call__(
)
##########################################
- query = torch.load("/home/lhy/code/cogview/query_after_rope.pt")
- key = torch.load("/home/lhy/code/cogview/key_after_rope.pt")
- value = torch.load("/home/lhy/code/cogview/value_after_rope.pt")
- query = query[None, :16+4096, ...]
- key = key[None, :16+4096, ...]
- value = value[None, :16+4096, ...]
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
+ # query = torch.load("/home/lhy/code/cogview/query_after_rope.pt")
+ # key = torch.load("/home/lhy/code/cogview/key_after_rope.pt")
+ # value = torch.load("/home/lhy/code/cogview/value_after_rope.pt")
+ # query = query[None, :16+4096, ...]
+ # key = key[None, :16+4096, ...]
+ # value = value[None, :16+4096, ...]
+ # query = query.transpose(1, 2)
+ # key = key.transpose(1, 2)
+ # value = value.transpose(1, 2)
##########################################
hidden_states = F.scaled_dot_product_attention(
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 8c50610e8f9a..c0f557b7017b 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -444,23 +444,21 @@ def forward(
######################
# 2. Conditional embeddings
- # temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
- # temb_cond, temb_uncond = temb.chunk(2)
- # hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
- # hidden_states, prompt_embeds, negative_prompt_embeds
- # )
- # hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
- # encoder_hidden_states_cond = prompt_embeds
- # encoder_hidden_states_uncond = negative_prompt_embeds
-
- prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
- negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
-
- hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")[None, ::]
- hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")[None, ::]
-
- temb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")[None, ::]
- temb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")[None, ::]
+ temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
+ temb_cond, temb_uncond = temb.chunk(2)
+ hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
+ hidden_states, prompt_embeds, negative_prompt_embeds
+ )
+ hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
+
+ # prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
+ # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
+ #
+ # hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")[None, ::]
+ # hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")[None, ::]
+ #
+ # temb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")[None, ::]
+ # temb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")[None, ::]
encoder_hidden_states_cond = prompt_embeds
encoder_hidden_states_uncond = negative_prompt_embeds
From 6a3a07fcee38a767ddd4a1955d4c8e9a59ec0981 Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Tue, 28 Jan 2025 17:36:24 +0000
Subject: [PATCH 35/68] [bugfix] fix dimension mismatch in CogView4 attention
---
src/diffusers/models/attention_processor.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 44f7ada5b56e..6e86e76d1873 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2837,9 +2837,9 @@ def __call__(
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
###############################################3
# TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
@@ -2850,7 +2850,6 @@ def __call__(
qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
query, key, value = qkv.chunk(3, dim=-1)
-
# TODO: 校验rope是否apply正确(目前有25%的误差)
###############################################3
From de274f39a0dea91d7bf6a069f423dcc231e52ca2 Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Tue, 28 Jan 2025 17:37:20 +0000
Subject: [PATCH 36/68] [cogview4][WIP]: update final normalization in CogView4
transformer
Refactored the final normalization layer in CogView4 transformer to use separate layernorm and AdaLN operations instead of combined AdaLayerNormContinuous. This matches the original implementation but needs validation.
Needs verification against reference implementation.
---
.../transformers/transformer_cogview4.py | 38 ++++++++++++++-----
1 file changed, 29 insertions(+), 9 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 8c50610e8f9a..7f73548d6c4c 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -170,8 +170,8 @@ def forward(
##############################################################
hidden_states, encoder_hidden_states = (
- hidden_states[:, :encoder_hidden_states_len],
hidden_states[:, encoder_hidden_states_len:],
+ hidden_states[:, :encoder_hidden_states_len],
)
return hidden_states, encoder_hidden_states
@@ -240,6 +240,8 @@ def __init__(
embed_dim=self.config.attention_head_dim, max_h=self.max_h, max_w=self.max_w, rotary_base=10000
)
+ self.layernorm = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-5)
+
self.patch_embed = CogView4PatchEmbed(
in_channels=in_channels,
hidden_size=self.inner_dim,
@@ -267,11 +269,15 @@ def __init__(
]
)
+ ######################################
self.norm_out = AdaLayerNormContinuous(
embedding_dim=self.inner_dim,
conditioning_embedding_dim=time_embed_dim,
elementwise_affine=False,
)
+ self.adaln_final = self.norm_out.linear
+ ######################################
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@@ -484,14 +490,28 @@ def forward(
image_rotary_emb=image_rotary_emb,
)
- hidden_states_cond, encoder_hidden_states_cond = (
- self.norm_out(hidden_states_cond, temb_cond),
- self.norm_out(encoder_hidden_states_cond, temb_cond),
- )
- hidden_states_uncond, encoder_hidden_states_uncond = (
- self.norm_out(hidden_states_uncond, temb_uncond),
- self.norm_out(encoder_hidden_states_uncond, temb_uncond),
- )
+ #################################################
+ # hidden_states_cond, encoder_hidden_states_cond = (
+ # self.norm_out(hidden_states_cond, temb_cond),
+ # self.norm_out(encoder_hidden_states_cond, temb_cond),
+ # )
+ # hidden_states_uncond, encoder_hidden_states_uncond = (
+ # self.norm_out(hidden_states_uncond, temb_uncond),
+ # self.norm_out(encoder_hidden_states_uncond, temb_uncond),
+ # )
+
+ hidden_states_cond = self.layernorm(hidden_states_cond)
+ hidden_states_uncond = self.layernorm(hidden_states_uncond)
+ encoder_hidden_states_cond = self.layernorm(encoder_hidden_states_cond)
+ encoder_hidden_states_uncond = self.layernorm(encoder_hidden_states_uncond)
+
+ shift_cond, scale_cond = self.adaln_final(temb_cond).chunk(2, dim=-1)
+ shift_uncond, scale_uncond = self.adaln_final(temb_uncond).chunk(2, dim=-1)
+
+ hidden_states_cond = hidden_states_cond * (1 + scale_cond) + shift_cond
+ hidden_states_uncond = hidden_states_uncond * (1 + scale_uncond) + shift_uncond
+ #################################################
+
hidden_states_cond = self.proj_out(hidden_states_cond)
hidden_states_uncond = self.proj_out(hidden_states_uncond)
From 46277b2d251a006e04b9f3ec64efa36144376b60 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Wed, 5 Feb 2025 14:26:18 +0800
Subject: [PATCH 37/68] 1
---
check_same_sat_megatron_diffusers.py | 264 ---------------------------
1 file changed, 264 deletions(-)
delete mode 100644 check_same_sat_megatron_diffusers.py
diff --git a/check_same_sat_megatron_diffusers.py b/check_same_sat_megatron_diffusers.py
deleted file mode 100644
index 0495a55f4d91..000000000000
--- a/check_same_sat_megatron_diffusers.py
+++ /dev/null
@@ -1,264 +0,0 @@
-import torch
-from collections import OrderedDict
-from diffusers import CogView4Transformer2DModel
-
-def load_state_dict_sat(file_path):
- """Load the SAT state dictionary from a given file path."""
- # Typically, the stored SAT ckpt is in the format: {'module': {...}}
- ckpt = torch.load(file_path, map_location="cuda")
- return ckpt["module"]
-
-
-def extract_qkv_from_sat(state_dict, layer_idx):
- """
- Extract QKV weights and biases from a SAT state_dict.
- Expects keys like:
- model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value
- """
- prefix = f"model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value"
- w = state_dict[f"{prefix}.weight"].clone()
- b = state_dict[f"{prefix}.bias"].clone()
- return (w, b)
-
-
-def load_state_dict_cogview(cogview_path):
- """
- Loads the CogView4 model from diffusers and returns its state_dict().
- NOTE: You should adjust 'torch_dtype' and 'device_map' as appropriate.
- """
- cogview_model = CogView4Transformer2DModel.from_pretrained(
- cogview_path, torch_dtype=torch.bfloat16, device_map="auto"
- )
- return cogview_model.state_dict()
-
-
-def extract_qkv_from_cogview(state_dict, layer_idx, num_heads, head_dim, hidden_dim):
- """
- Extract Q, K, V from CogView4 checkpoint and reshape them into the same shape as SAT’s QKV.
- For each layer i:
- Q prefix: transformer_blocks.{layer_idx}.attn1.to_q
- K prefix: transformer_blocks.{layer_idx}.attn1.to_k
- V prefix: transformer_blocks.{layer_idx}.attn1.to_v
- Final shape must match SAT's [3*hidden_dim, hidden_dim] for weight, and [3*hidden_dim] for bias.
- """
- q_prefix = f"transformer_blocks.{layer_idx}.attn1.to_q"
- k_prefix = f"transformer_blocks.{layer_idx}.attn1.to_k"
- v_prefix = f"transformer_blocks.{layer_idx}.attn1.to_v"
-
- # Extract
- q_weight = state_dict[f"{q_prefix}.weight"].clone()
- k_weight = state_dict[f"{k_prefix}.weight"].clone()
- v_weight = state_dict[f"{v_prefix}.weight"].clone()
-
- q_bias = state_dict[f"{q_prefix}.bias"].clone()
- k_bias = state_dict[f"{k_prefix}.bias"].clone()
- v_bias = state_dict[f"{v_prefix}.bias"].clone()
-
- # Reshape weights: [hidden_dim, hidden_dim] -> [num_heads, head_dim, hidden_dim]
- # Then concat along the first dimension (which will become 3*num_heads*head_dim)
- q_weight = q_weight.view(num_heads, head_dim, hidden_dim)
- k_weight = k_weight.view(num_heads, head_dim, hidden_dim)
- v_weight = v_weight.view(num_heads, head_dim, hidden_dim)
-
- qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) # shape: (3*num_heads, head_dim, hidden_dim)
- qkv_weight = qkv_weight.view(3 * num_heads * head_dim, hidden_dim) # flatten
-
- # Reshape biases: [hidden_dim] -> [num_heads, head_dim]
- q_bias = q_bias.view(num_heads, head_dim)
- k_bias = k_bias.view(num_heads, head_dim)
- v_bias = v_bias.view(num_heads, head_dim)
-
- qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) # (3*num_heads, head_dim)
- qkv_bias = qkv_bias.view(3 * num_heads * head_dim)
-
- return (qkv_weight, qkv_bias)
-
-def create_sat_state_dict_from_megatron(megatron_ckpt_dict, num_layers=48, num_heads=32, hidden_size=3072):
- """
- Convert a loaded Megatron checkpoint's 'model' dictionary into the same
- format used by SAT. This returns something like {'module': {...}} for
- easy comparison with SAT.
-
- The code below is adapted from your 'create_sat_state_dict' function,
- but we rename it here to keep it direct.
- """
- from tqdm import tqdm
-
- hidden_size_per_head = hidden_size // num_heads
- mega_weight = megatron_ckpt_dict["model"]
- sat_weight = {}
-
- # --- patch_embed ---
- sat_weight["model.diffusion_model.mixins.patch_embed.proj.weight"] = \
- mega_weight["encoder_expand_linear.weight"].reshape(hidden_size, 64).clone()
- sat_weight["model.diffusion_model.mixins.patch_embed.proj.bias"] = \
- mega_weight["encoder_expand_linear.bias"].clone()
-
- sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.weight"] = \
- mega_weight["text_projector.weight"].clone()
- sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.bias"] = \
- mega_weight["text_projector.bias"].clone()
-
- # --- time embedding ---
- sat_weight["model.diffusion_model.time_embed.0.weight"] = \
- mega_weight["time_embedding.time_embed.0.weight"].clone()
- sat_weight["model.diffusion_model.time_embed.0.bias"] = \
- mega_weight["time_embedding.time_embed.0.bias"].clone()
- sat_weight["model.diffusion_model.time_embed.2.weight"] = \
- mega_weight["time_embedding.time_embed.2.weight"].clone()
- sat_weight["model.diffusion_model.time_embed.2.bias"] = \
- mega_weight["time_embedding.time_embed.2.bias"].clone()
-
- # --- label embedding ---
- sat_weight["model.diffusion_model.label_emb.0.0.weight"] = \
- mega_weight["label_embedding.label_embed.0.weight"].clone()
- sat_weight["model.diffusion_model.label_emb.0.0.bias"] = \
- mega_weight["label_embedding.label_embed.0.bias"].clone()
- sat_weight["model.diffusion_model.label_emb.0.2.weight"] = \
- mega_weight["label_embedding.label_embed.2.weight"].clone()
- sat_weight["model.diffusion_model.label_emb.0.2.bias"] = \
- mega_weight["label_embedding.label_embed.2.bias"].clone()
-
- # --- layers ---
- for i in tqdm(range(num_layers), desc="Converting Megatron->SAT"):
- # attention output
- sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.weight"] = \
- mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.weight"].clone()
- sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.bias"] = \
- mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.bias"].clone()
-
- # QKV
- qkv_weight = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.weight"].clone()
- qkv_bias = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.bias"].clone()
-
- # Reshape QKV from Megatron format into SAT format
- # qkv_weight: [3*hidden_size, hidden_size] -> [num_heads, 3, hidden_size_per_head, hidden_size] -> ...
- sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.weight"] = \
- qkv_weight.view(num_heads, 3, hidden_size_per_head, hidden_size) \
- .permute(1, 0, 2, 3) \
- .reshape(3 * hidden_size, hidden_size).clone()
- sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.bias"] = \
- qkv_bias.view(num_heads, 3, hidden_size_per_head) \
- .permute(1, 0, 2) \
- .reshape(3 * hidden_size) \
- .clone()
-
- # MLP
- sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.weight"] = \
- mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.weight"].clone()
- sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.bias"] = \
- mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.bias"].clone()
-
- sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.weight"] = \
- mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.weight"].clone()
- sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.bias"] = \
- mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.bias"].clone()
-
- # AdaLN
- adaln_weight = mega_weight[f"decoder.layers.{i}.adaln.weight"].clone()
- adaln_bias = mega_weight[f"decoder.layers.{i}.adaln.bias"].clone()
-
- sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.weight"] = adaln_weight.clone()
- sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.bias"] = adaln_bias.clone()
-
- # --- final layers ---
- sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.weight"] = \
- mega_weight["adaln_final.weight"].clone()
- sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.bias"] = \
- mega_weight["adaln_final.bias"].clone()
- sat_weight["model.diffusion_model.mixins.final_layer.linear.weight"] = \
- mega_weight["output_projector.weight"].clone()
- sat_weight["model.diffusion_model.mixins.final_layer.linear.bias"] = \
- mega_weight["output_projector.bias"].clone()
-
- return OrderedDict(sat_weight)
-
-
-def load_state_dict_megatron_and_convert_to_sat(megatron_ckpt_path, num_layers, num_heads, hidden_size):
- """
- Load a Megatron checkpoint from , then convert it into
- an SAT-style OrderedDict for direct QKV comparison.
-
- Typically, = ".../iter_0287500/mp_rank_00/model_optim_rng.pt"
- """
- ckpt = torch.load(megatron_ckpt_path, map_location="cuda")
- # Convert to SAT
- sat_like_weight = create_sat_state_dict_from_megatron(
- ckpt, num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size
- )
- return sat_like_weight
-
-def compute_l2_difference(tensor1, tensor2):
- """Compute L2 norm of the difference between two tensors."""
- return torch.norm(tensor1 - tensor2, p=2).item()
-
-
-def compare_qkv(qkv1, qkv2, name1="Model1", name2="Model2", atol=1e-6):
- """
- Compare QKV from two different sources (each is a tuple of (weight, bias)).
- Returns (weight_match, bias_match, weight_l2, bias_l2).
- """
- w1, b1 = qkv1
- w2, b2 = qkv2
-
- weight_match = torch.allclose(w1, w2, atol=atol)
- bias_match = torch.allclose(b1, b2, atol=atol)
- weight_l2_diff = compute_l2_difference(w1, w2)
- bias_l2_diff = compute_l2_difference(b1, b2)
-
- if not (weight_match and bias_match):
- print(f"[QKV Mismatch] {name1} vs {name2}")
- print(f" Weight L2: {weight_l2_diff:.6f}, Bias L2: {bias_l2_diff:.6f}")
- else:
- # If everything matches well:
- print(f"[QKV Match] {name1} vs {name2} (Weight L2={weight_l2_diff:.6f}, Bias L2={bias_l2_diff:.6f})")
-
- return weight_match, bias_match, weight_l2_diff, bias_l2_diff
-
-if __name__ == "__main__":
- num_layers = 28
- num_heads = 32
- hidden_dim = 4096
- head_dim = hidden_dim // num_heads
-
- sat_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_sat/0287500/mp_rank_00_model_states.pt"
- sat_state_dict = load_state_dict_sat(sat_ckpt_path)
-
- cogview_path = "/share/zyx/CogView4-6B-0128/transformer" # directory containing model index for diffusers
- cogview_state_dict = load_state_dict_cogview(cogview_path)
-
- megatron_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_ema/iter_0287500/mp_rank_00/model_optim_rng.pt"
- mega_as_sat_state_dict = load_state_dict_megatron_and_convert_to_sat(
- megatron_ckpt_path,
- num_layers=num_layers,
- num_heads=num_heads,
- hidden_size=hidden_dim
- )
-
- print("\n==== Start QKV Comparison ====\n")
- for layer_idx in range(num_layers):
- print(f"--- Layer {layer_idx} ---")
-
- # Extract QKV from SAT
- sat_qkv = extract_qkv_from_sat(sat_state_dict, layer_idx)
-
- # Extract QKV from CogView
- cogview_qkv = extract_qkv_from_cogview(
- cogview_state_dict, layer_idx, num_heads, head_dim, hidden_dim
- )
-
- # Extract QKV from Megatron->SAT
- mega_qkv = extract_qkv_from_sat(mega_as_sat_state_dict, layer_idx)
-
- # Compare: SAT vs CogView
- compare_qkv(sat_qkv, cogview_qkv, name1="SAT", name2="CogView4")
-
- # Compare: SAT vs Megatron
- compare_qkv(sat_qkv, mega_qkv, name1="SAT", name2="Megatron")
-
- # Compare: CogView vs Megatron (optional)
- compare_qkv(cogview_qkv, mega_qkv, name1="CogView4", name2="Megatron")
-
- print()
-
- print("=== Done ===")
From ebbaa5bc53182aeb160d53cacb032361411868a2 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Wed, 5 Feb 2025 14:28:19 +0800
Subject: [PATCH 38/68] put back
---
src/diffusers/pipelines/__init__.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 285348e5885d..048dfabb0923 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -497,6 +497,7 @@
CogVideoXVideoToVideoPipeline,
)
from .cogview3 import CogView3PlusPipeline
+ from .consisid import ConsisIDPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
From f1ccdd2c134912ca3de830a73b49e51048b78ea3 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Wed, 5 Feb 2025 23:17:05 +0800
Subject: [PATCH 39/68] Update transformer_cogview4.py
---
.../transformers/transformer_cogview4.py | 43 +++++++++----------
1 file changed, 21 insertions(+), 22 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index a6d73e76c606..10f00b4b9c6c 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -176,6 +176,12 @@ def forward(
return hidden_states, encoder_hidden_states
+def swap_scale_shift(weight, dim):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
r"""
Args:
@@ -276,7 +282,10 @@ def __init__(
elementwise_affine=False,
)
self.adaln_final = self.norm_out.linear
- ######################################
+ # with torch.no_grad():
+ # w = self.norm_out.linear.weight.data.clone()
+ # w_swapped = swap_scale_shift(w, dim=0)
+ # self.adaln_final.weight.data.copy_(w_swapped)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
@@ -445,6 +454,7 @@ def forward(
image_rotary_emb = self.get_rope_embedding(
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
)
+ ## TODO: @Oleehy Remove it after debugging
# image_rotary_emb = torch.load("/home/lhy/code/cogview/rotary_pos_emb.pt")
# image_rotary_emb = image_rotary_emb[16:16+4096, 0, 0, :]
@@ -457,6 +467,7 @@ def forward(
)
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
+ # Todo: @Oleehy Remove it after debugging
# prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
# negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
#
@@ -488,27 +499,15 @@ def forward(
image_rotary_emb=image_rotary_emb,
)
- #################################################
- # hidden_states_cond, encoder_hidden_states_cond = (
- # self.norm_out(hidden_states_cond, temb_cond),
- # self.norm_out(encoder_hidden_states_cond, temb_cond),
- # )
- # hidden_states_uncond, encoder_hidden_states_uncond = (
- # self.norm_out(hidden_states_uncond, temb_uncond),
- # self.norm_out(encoder_hidden_states_uncond, temb_uncond),
- # )
-
- hidden_states_cond = self.layernorm(hidden_states_cond)
- hidden_states_uncond = self.layernorm(hidden_states_uncond)
- encoder_hidden_states_cond = self.layernorm(encoder_hidden_states_cond)
- encoder_hidden_states_uncond = self.layernorm(encoder_hidden_states_uncond)
-
- shift_cond, scale_cond = self.adaln_final(temb_cond).chunk(2, dim=-1)
- shift_uncond, scale_uncond = self.adaln_final(temb_uncond).chunk(2, dim=-1)
-
- hidden_states_cond = hidden_states_cond * (1 + scale_cond) + shift_cond
- hidden_states_uncond = hidden_states_uncond * (1 + scale_uncond) + shift_uncond
- #################################################
+ # Todo: @Oleehy Check if this is the right implementation
+ hidden_states_cond, encoder_hidden_states_cond = (
+ self.norm_out(hidden_states_cond, temb_cond),
+ self.norm_out(encoder_hidden_states_cond, temb_cond),
+ )
+ hidden_states_uncond, encoder_hidden_states_uncond = (
+ self.norm_out(hidden_states_uncond, temb_uncond),
+ self.norm_out(encoder_hidden_states_uncond, temb_uncond),
+ )
hidden_states_cond = self.proj_out(hidden_states_cond)
hidden_states_uncond = self.proj_out(hidden_states_uncond)
From 030a467e93ec459c18c115e8846f25e759a4203f Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Thu, 6 Feb 2025 12:23:26 +0800
Subject: [PATCH 40/68] change time_shift
---
.../pipelines/cogview4/pipeline_cogview4.py | 26 ++++++++++++-------
1 file changed, 17 insertions(+), 9 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 78a4a36ce82b..4abd8918b973 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -55,18 +55,25 @@
"""
+def time_shift(self, mu: float, shift_sigma: float, sigmas: torch.Tensor):
+ return mu / (mu + (1 / sigmas - 1) ** shift_sigma)
+
+
def calculate_shift(
- image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15
+ self,
+ image_seq_len,
+ base_seq_len: int = 256,
):
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
- b = base_shift - m * base_seq_len
- mu = image_seq_len * m + b
- return mu
+ if isinstance(image_seq_len, int):
+ mu = math.sqrt(image_seq_len / base_seq_len)
+ elif isinstance(image_seq_len, torch.Tensor):
+ mu = torch.sqrt(image_seq_len / base_seq_len)
+ else:
+ raise ValueError(f'Invalid type for image_seq_len: {type(image_seq_len)}')
+ mu = mu * 0.75 + 0.25
-def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
- return math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** shift_sigma)
-
+ return mu
# def retrieve_timesteps(
# scheduler,
@@ -598,7 +605,8 @@ def __call__(
max_sequence_length=max_sequence_length,
device=device,
)
-
+ torch.save(prompt_embeds, '/share/home/zyx/prompt_embeds.pt')
+ torch.save(negative_prompt_embeds, '/share/home/zyx/negative_prompt_embeds.pt')
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
From ad405752d0de98e856cbb1d28f5ab0b144d94d1a Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Thu, 6 Feb 2025 12:23:59 +0800
Subject: [PATCH 41/68] Update pipeline_cogview4.py
---
src/diffusers/pipelines/cogview4/pipeline_cogview4.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 4abd8918b973..91210c21f039 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -55,12 +55,11 @@
"""
-def time_shift(self, mu: float, shift_sigma: float, sigmas: torch.Tensor):
+def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
return mu / (mu + (1 / sigmas - 1) ** shift_sigma)
def calculate_shift(
- self,
image_seq_len,
base_seq_len: int = 256,
):
From 81d39eea92e38e3e9aa4171d4691b15edbd5fd62 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Thu, 6 Feb 2025 17:04:53 +0800
Subject: [PATCH 42/68] change timesteps
---
.../pipelines/cogview4/pipeline_cogview4.py | 59 -------------------
1 file changed, 59 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 91210c21f039..0ec386e44bfd 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -74,65 +74,6 @@ def calculate_shift(
return mu
-# def retrieve_timesteps(
-# scheduler,
-# num_inference_steps: Optional[int] = None,
-# device: Optional[Union[str, torch.device]] = None,
-# timesteps: Optional[List[int]] = None,
-# sigmas: Optional[List[float]] = None,
-# **kwargs,
-# ):
-# r"""
-# Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
-# custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
-# Args:
-# scheduler (`SchedulerMixin`):
-# The scheduler to get timesteps from.
-# num_inference_steps (`int`):
-# The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
-# must be `None`.
-# device (`str` or `torch.device`, *optional*):
-# The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
-# timesteps (`List[int]`, *optional*):
-# Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
-# `num_inference_steps` and `sigmas` must be `None`.
-# sigmas (`List[float]`, *optional*):
-# Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
-# `num_inference_steps` and `timesteps` must be `None`.
-
-# Returns:
-# `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
-# second element is the number of inference steps.
-# """
-# if timesteps is not None and sigmas is not None:
-# raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
-# if timesteps is not None:
-# accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
-# if not accepts_timesteps:
-# raise ValueError(
-# f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
-# f" timestep schedules. Please check whether you are using the correct scheduler."
-# )
-# scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
-# timesteps = scheduler.timesteps
-# num_inference_steps = len(timesteps)
-# elif sigmas is not None:
-# accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
-# if not accept_sigmas:
-# raise ValueError(
-# f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
-# f" sigmas schedules. Please check whether you are using the correct scheduler."
-# )
-# scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
-# timesteps = scheduler.timesteps
-# num_inference_steps = len(timesteps)
-# else:
-# scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
-# timesteps = scheduler.timesteps
-# return timesteps, num_inference_steps
-
-
class CogView4Pipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using CogView4.
From 45f9e88d4f628320f211a04ff7609fdf182b1b86 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Thu, 6 Feb 2025 17:06:15 +0800
Subject: [PATCH 43/68] fix
---
docs/source/en/api/models/cogview4_transformer2d.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/source/en/api/models/cogview4_transformer2d.md b/docs/source/en/api/models/cogview4_transformer2d.md
index e6c976e64253..4bf14bdd4991 100644
--- a/docs/source/en/api/models/cogview4_transformer2d.md
+++ b/docs/source/en/api/models/cogview4_transformer2d.md
@@ -16,9 +16,9 @@ A Diffusion Transformer model for 2D data from [CogView4]()
The model can be loaded with the following code snippet.
```python
-from diffusers import CogView3PlusTransformer2DModel
+from diffusers import CogView4Transformer2DModel
-transformer = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+transformer = CogView4Transformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## CogView4Transformer2DModel
From 1dbeaa83c789642c4491c38819ab78aa38edd7c7 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Thu, 6 Feb 2025 17:07:45 +0800
Subject: [PATCH 44/68] change text_encoder_id
---
scripts/convert_cogview4_to_diffusers_megatron.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py
index 0c93cb099744..b9e08bb4ae4f 100644
--- a/scripts/convert_cogview4_to_diffusers_megatron.py
+++ b/scripts/convert_cogview4_to_diffusers_megatron.py
@@ -336,7 +336,7 @@ def main(args):
vae = vae.to(dtype=dtype)
# Load the text encoder and tokenizer
- text_encoder_id = "/share/home/zyx/Models/glm-4-9b-hf"
+ text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmForCausalLM.from_pretrained(
text_encoder_id,
From f20960021bf3a1f0002195c2daee48b690552cac Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Thu, 6 Feb 2025 10:24:42 +0000
Subject: [PATCH 45/68] [cogview4][rope] align RoPE implementation with
Megatron
- Implement apply_rope method in attention processor to match Megatron's implementation
- Update position embeddings to ensure compatibility with Megatron-style rotary embeddings
- Ensure consistent rotary position encoding across attention layers
This change improves compatibility with Megatron-based models and provides
better alignment with the original implementation's positional encoding approach.
---
src/diffusers/models/attention_processor.py | 50 ++++++++++++++++++++-
src/diffusers/models/embeddings.py | 26 +++--------
2 files changed, 56 insertions(+), 20 deletions(-)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 964b165f4835..10581afe8e0a 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2846,7 +2846,7 @@ def __call__(
# TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
# linear_qkv_weight = torch.load("/home/lhy/code/cogview/linear_qkv_weight.pt")
# linear_qkv_bias = torch.load("/home/lhy/code/cogview/linear_qkv_bias.pt")
- #
+
# qkv = torch.matmul(hidden_states, linear_qkv_weight.T) + linear_qkv_bias
# qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
# query, key, value = qkv.chunk(3, dim=-1)
@@ -2867,6 +2867,54 @@ def __call__(
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb_megatron
+ ########################## check tensor
+ # def apply_rope_megatron_tmp(input, freqs, transpose_output=False):
+ # # 这个实现目前似乎存在0.2%的误差
+ # s, b, h, d = input.shape
+ # d2 = freqs.shape[3] # d2 corresponds to the frequency dimension size
+
+ # # Initialize output tensor with the correct shape
+ # if transpose_output:
+ # output = torch.empty((b, s, h, d), device=input.device)
+ # output = output.transpose(0, 1) # Transpose to (s, b, h, d)
+ # else:
+ # output = torch.empty((s, b, h, d), device=input.device)
+
+ # # Apply the ROPE transformation for each element
+ # for s_id in range(s):
+ # for b_id in range(b):
+ # for h_id in range(h):
+ # for d_id in range(d2):
+ # v_cos, v_sin = torch.cos(freqs[s_id, 0, 0, d_id]), torch.sin(freqs[s_id, 0, 0, d_id])
+ # input_val = input[s_id, b_id, h_id, d_id]
+ # if d_id + d2 // 2 < d2:
+ # input_val_rotate = -input[s_id, b_id, h_id, d_id + d2 // 2]
+ # else:
+ # input_val_rotate = input[s_id, b_id, h_id, d_id + (d2 // 2 - d2)]
+ # output[s_id, b_id, h_id, d_id] = input_val * v_cos + input_val_rotate * v_sin
+
+ # return output
+
+ # query_before_rope_megatron = torch.load("/home/lhy/code/cogview/query_before_rope.pt")
+ # query_after_rope_megatron = torch.load("/home/lhy/code/cogview/query_after_rope.pt")
+
+ # q_pos_emb = torch.load("/home/lhy/code/cogview/q_pos_emb.pt")[16:16+4096][:, 0, 0, :]
+ # k_pos_emb = torch.load("/home/lhy/code/cogview/k_pos_emb.pt")[16:16+4096][:, 0, 0, :]
+
+ # diff_query_before_rope = torch.norm(query_before_rope_megatron[:4112, ...] - query.transpose(1, 2)[0])
+ # diff_q_emb = torch.norm(q_pos_emb - image_rotary_emb)
+ # diff_k_emb = torch.norm(k_pos_emb - image_rotary_emb)
+
+ # input = query.permute(2, 0, 1, 3)[16:, ...]
+ # freqs = image_rotary_emb[:, None, None, :]
+ # output = apply_rope_megatron_tmp(input.to("cpu"), freqs.to("cpu"), transpose_output=True)
+
+ # out_foo = apply_rotary_emb_megatron(
+ # query[:, :, text_seq_length:, :], image_rotary_emb
+ # )
+ # diff_after_rope = torch.norm(query_after_rope_megatron[16:16+4096].transpose(0, 1) - out_foo[0])
+ ##########################
+
query[:, :, text_seq_length:, :] = apply_rotary_emb_megatron(
query[:, :, text_seq_length:, :], image_rotary_emb
)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 73088dee6011..0c457dd0f2d6 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1287,33 +1287,21 @@ def apply_rotary_emb_megatron(x: torch.Tensor, freqs: torch.Tensor) -> torch.Ten
"""Apply rotary position embeddings to input tensor.
Args:
- x: Input tensor of shape [seq_len, batch_size, n_heads, head_dim]
- freqs: Frequency tensor of shape [seq_len, 1, 1, head_dim//2]
+ x: Input tensor of shape [batch_size, n_heads, seq_len, head_dim]
+ freqs: Frequency tensor of shape [seq_len, head_dim]
Returns:
Tensor with rotary position embeddings applied
"""
batch_size, n_heads, seq_len, rot_dim = x.shape
+ assert rot_dim % 2 == 0 and rot_dim == freqs.shape[-1]
- # Reshape x to have rot_dim as the last dimension
- x_rot, x_pass = x.chunk(2, dim=-1)
+ x_dim_first_half, x_dim_second_half = x.chunk(2, dim=-1)
+ x_rot_shifted = torch.cat([-x_dim_second_half, x_dim_first_half], dim=-1)
- # Apply rotary embeddings
- # First calculate cos and sin
- cos, sin = freqs.chunk(2, dim=-1)
- cos, sin = torch.cos(cos), torch.sin(sin)
+ cos, sin = torch.cos(freqs), torch.sin(freqs)
- # Rotate x_rot
- x_rot_cos = x_rot * cos
- # Create rotated version of x_rot by shifting rot_dim/2 positions
- x_rot_shifted = torch.cat([-x_rot[..., rot_dim//2:], x_rot[..., :rot_dim//2]], dim=-1)
- x_rot_sin = x_rot_shifted * sin
-
- # Combine
- x_rot = x_rot_cos + x_rot_sin
-
- # Concatenate back with x_pass
- x_out = torch.cat([x_rot, x_pass], dim=-1)
+ x_out = cos * x + sin * x_rot_shifted
return x_out
From 992f5a3c866fd86ed420c0b381fc4d8b0856e3b3 Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Thu, 6 Feb 2025 10:26:14 +0000
Subject: [PATCH 46/68] [cogview4][bugfix] apply silu activation to time
embeddings in CogView4
Applied silu activation to time embeddings before splitting into conditional
and unconditional parts in CogView4Transformer2DModel. This matches the
original implementation and helps ensure correct time conditioning behavior.
---
.../transformers/transformer_cogview4.py | 31 ++++++++++---------
1 file changed, 17 insertions(+), 14 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index a6d73e76c606..37389ff84575 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -448,27 +448,32 @@ def forward(
# image_rotary_emb = torch.load("/home/lhy/code/cogview/rotary_pos_emb.pt")
# image_rotary_emb = image_rotary_emb[16:16+4096, 0, 0, :]
- ######################
# 2. Conditional embeddings
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
+ temb = F.silu(temb)
temb_cond, temb_uncond = temb.chunk(2)
hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
hidden_states, prompt_embeds, negative_prompt_embeds
)
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
+ ######################
+ # reload for debug
+ ## 这里大概有2%~4%的误差
# prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
# negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
- #
+
+ ## 这里0误差
# hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")[None, ::]
# hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")[None, ::]
- #
+
+ ## 目前temb部分有很大的误差
# temb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")[None, ::]
# temb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")[None, ::]
+ ######################
encoder_hidden_states_cond = prompt_embeds
encoder_hidden_states_uncond = negative_prompt_embeds
- ######################
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -480,35 +485,33 @@ def forward(
encoder_hidden_states=encoder_hidden_states_cond,
time_embedding=temb_cond,
image_rotary_emb=image_rotary_emb,
+ # image_rotary_emb=None,
)
hidden_states_uncond, encoder_hidden_states_uncond = block(
hidden_states=hidden_states_uncond,
encoder_hidden_states=encoder_hidden_states_uncond,
time_embedding=temb_uncond,
image_rotary_emb=image_rotary_emb,
+ # image_rotary_emb=None,
)
- #################################################
- # hidden_states_cond, encoder_hidden_states_cond = (
- # self.norm_out(hidden_states_cond, temb_cond),
- # self.norm_out(encoder_hidden_states_cond, temb_cond),
- # )
- # hidden_states_uncond, encoder_hidden_states_uncond = (
- # self.norm_out(hidden_states_uncond, temb_uncond),
- # self.norm_out(encoder_hidden_states_uncond, temb_uncond),
- # )
hidden_states_cond = self.layernorm(hidden_states_cond)
hidden_states_uncond = self.layernorm(hidden_states_uncond)
encoder_hidden_states_cond = self.layernorm(encoder_hidden_states_cond)
encoder_hidden_states_uncond = self.layernorm(encoder_hidden_states_uncond)
+ #################################################
+ # reload weight&bias for debug
+ # self.adaln_final.weight = torch.load("/home/lhy/code/cogview/adaln_final_weight.pt")
+ # self.adaln_final.bias = torch.load("/home/lhy/code/cogview/adaln_final_bias.pt")
+ #################################################
+
shift_cond, scale_cond = self.adaln_final(temb_cond).chunk(2, dim=-1)
shift_uncond, scale_uncond = self.adaln_final(temb_uncond).chunk(2, dim=-1)
hidden_states_cond = hidden_states_cond * (1 + scale_cond) + shift_cond
hidden_states_uncond = hidden_states_uncond * (1 + scale_uncond) + shift_uncond
- #################################################
hidden_states_cond = self.proj_out(hidden_states_cond)
hidden_states_uncond = self.proj_out(hidden_states_uncond)
From 03a1c3b05d30ada44c719db8483e692ce3608b36 Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Thu, 6 Feb 2025 10:32:38 +0000
Subject: [PATCH 47/68] [cogview4][chore] clean up pipeline code
- Remove commented out code and debug statements
- Remove unused retrieve_timesteps function
- Clean up code formatting and documentation
This commit focuses on code cleanup in the CogView4 pipeline implementation, removing unnecessary commented code and improving readability without changing functionality.
---
src/diffusers/models/attention_processor.py | 72 -------------------
.../transformers/transformer_cogview4.py | 24 +------
2 files changed, 2 insertions(+), 94 deletions(-)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 10581afe8e0a..fbe96869f7a9 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2842,18 +2842,6 @@ def __call__(
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
- ###############################################3
- # TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
- # linear_qkv_weight = torch.load("/home/lhy/code/cogview/linear_qkv_weight.pt")
- # linear_qkv_bias = torch.load("/home/lhy/code/cogview/linear_qkv_bias.pt")
-
- # qkv = torch.matmul(hidden_states, linear_qkv_weight.T) + linear_qkv_bias
- # qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
- # query, key, value = qkv.chunk(3, dim=-1)
-
- # TODO: 校验rope是否apply正确(目前有25%的误差)
- ###############################################3
-
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
@@ -2867,54 +2855,6 @@ def __call__(
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb_megatron
- ########################## check tensor
- # def apply_rope_megatron_tmp(input, freqs, transpose_output=False):
- # # 这个实现目前似乎存在0.2%的误差
- # s, b, h, d = input.shape
- # d2 = freqs.shape[3] # d2 corresponds to the frequency dimension size
-
- # # Initialize output tensor with the correct shape
- # if transpose_output:
- # output = torch.empty((b, s, h, d), device=input.device)
- # output = output.transpose(0, 1) # Transpose to (s, b, h, d)
- # else:
- # output = torch.empty((s, b, h, d), device=input.device)
-
- # # Apply the ROPE transformation for each element
- # for s_id in range(s):
- # for b_id in range(b):
- # for h_id in range(h):
- # for d_id in range(d2):
- # v_cos, v_sin = torch.cos(freqs[s_id, 0, 0, d_id]), torch.sin(freqs[s_id, 0, 0, d_id])
- # input_val = input[s_id, b_id, h_id, d_id]
- # if d_id + d2 // 2 < d2:
- # input_val_rotate = -input[s_id, b_id, h_id, d_id + d2 // 2]
- # else:
- # input_val_rotate = input[s_id, b_id, h_id, d_id + (d2 // 2 - d2)]
- # output[s_id, b_id, h_id, d_id] = input_val * v_cos + input_val_rotate * v_sin
-
- # return output
-
- # query_before_rope_megatron = torch.load("/home/lhy/code/cogview/query_before_rope.pt")
- # query_after_rope_megatron = torch.load("/home/lhy/code/cogview/query_after_rope.pt")
-
- # q_pos_emb = torch.load("/home/lhy/code/cogview/q_pos_emb.pt")[16:16+4096][:, 0, 0, :]
- # k_pos_emb = torch.load("/home/lhy/code/cogview/k_pos_emb.pt")[16:16+4096][:, 0, 0, :]
-
- # diff_query_before_rope = torch.norm(query_before_rope_megatron[:4112, ...] - query.transpose(1, 2)[0])
- # diff_q_emb = torch.norm(q_pos_emb - image_rotary_emb)
- # diff_k_emb = torch.norm(k_pos_emb - image_rotary_emb)
-
- # input = query.permute(2, 0, 1, 3)[16:, ...]
- # freqs = image_rotary_emb[:, None, None, :]
- # output = apply_rope_megatron_tmp(input.to("cpu"), freqs.to("cpu"), transpose_output=True)
-
- # out_foo = apply_rotary_emb_megatron(
- # query[:, :, text_seq_length:, :], image_rotary_emb
- # )
- # diff_after_rope = torch.norm(query_after_rope_megatron[16:16+4096].transpose(0, 1) - out_foo[0])
- ##########################
-
query[:, :, text_seq_length:, :] = apply_rotary_emb_megatron(
query[:, :, text_seq_length:, :], image_rotary_emb
)
@@ -2922,18 +2862,6 @@ def __call__(
key[:, :, text_seq_length:, :], image_rotary_emb
)
- ##########################################
- # query = torch.load("/home/lhy/code/cogview/query_after_rope.pt")
- # key = torch.load("/home/lhy/code/cogview/key_after_rope.pt")
- # value = torch.load("/home/lhy/code/cogview/value_after_rope.pt")
- # query = query[None, :16+4096, ...]
- # key = key[None, :16+4096, ...]
- # value = value[None, :16+4096, ...]
- # query = query.transpose(1, 2)
- # key = key.transpose(1, 2)
- # value = value.transpose(1, 2)
- ##########################################
-
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 37389ff84575..9e064524a929 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -445,8 +445,6 @@ def forward(
image_rotary_emb = self.get_rope_embedding(
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
)
- # image_rotary_emb = torch.load("/home/lhy/code/cogview/rotary_pos_emb.pt")
- # image_rotary_emb = image_rotary_emb[16:16+4096, 0, 0, :]
# 2. Conditional embeddings
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
@@ -457,21 +455,6 @@ def forward(
)
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
- ######################
- # reload for debug
- ## 这里大概有2%~4%的误差
- # prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
- # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
-
- ## 这里0误差
- # hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")[None, ::]
- # hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")[None, ::]
-
- ## 目前temb部分有很大的误差
- # temb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")[None, ::]
- # temb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")[None, ::]
- ######################
-
encoder_hidden_states_cond = prompt_embeds
encoder_hidden_states_uncond = negative_prompt_embeds
@@ -485,17 +468,14 @@ def forward(
encoder_hidden_states=encoder_hidden_states_cond,
time_embedding=temb_cond,
image_rotary_emb=image_rotary_emb,
- # image_rotary_emb=None,
)
hidden_states_uncond, encoder_hidden_states_uncond = block(
hidden_states=hidden_states_uncond,
encoder_hidden_states=encoder_hidden_states_uncond,
time_embedding=temb_uncond,
image_rotary_emb=image_rotary_emb,
- # image_rotary_emb=None,
)
-
hidden_states_cond = self.layernorm(hidden_states_cond)
hidden_states_uncond = self.layernorm(hidden_states_uncond)
encoder_hidden_states_cond = self.layernorm(encoder_hidden_states_cond)
@@ -503,8 +483,8 @@ def forward(
#################################################
# reload weight&bias for debug
- # self.adaln_final.weight = torch.load("/home/lhy/code/cogview/adaln_final_weight.pt")
- # self.adaln_final.bias = torch.load("/home/lhy/code/cogview/adaln_final_bias.pt")
+ self.adaln_final.weight = torch.load("/home/lhy/code/cogview/adaln_final_weight.pt")
+ self.adaln_final.bias = torch.load("/home/lhy/code/cogview/adaln_final_bias.pt")
#################################################
shift_cond, scale_cond = self.adaln_final(temb_cond).chunk(2, dim=-1)
From 3dab073a62dedc909f78ffc12592ac26e4a0fedf Mon Sep 17 00:00:00 2001
From: OleehyO
Date: Thu, 6 Feb 2025 14:19:52 +0000
Subject: [PATCH 48/68] [cogview4][scheduler] Implement CogView4 scheduler and
pipeline
---
.../pipelines/cogview4/pipeline_cogview4.py | 47 +-
.../schedulers/scheduling_ddim_cogview4.py | 409 ++++--------------
2 files changed, 80 insertions(+), 376 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 0ec386e44bfd..57c45ebdfeca 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -55,25 +55,6 @@
"""
-def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
- return mu / (mu + (1 / sigmas - 1) ** shift_sigma)
-
-
-def calculate_shift(
- image_seq_len,
- base_seq_len: int = 256,
-):
- if isinstance(image_seq_len, int):
- mu = math.sqrt(image_seq_len / base_seq_len)
- elif isinstance(image_seq_len, torch.Tensor):
- mu = torch.sqrt(image_seq_len / base_seq_len)
- else:
- raise ValueError(f'Invalid type for image_seq_len: {type(image_seq_len)}')
-
- mu = mu * 0.75 + 0.25
-
- return mu
-
class CogView4Pipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using CogView4.
@@ -545,9 +526,7 @@ def __call__(
max_sequence_length=max_sequence_length,
device=device,
)
- torch.save(prompt_embeds, '/share/home/zyx/prompt_embeds.pt')
- torch.save(negative_prompt_embeds, '/share/home/zyx/negative_prompt_embeds.pt')
- # 5. Prepare latents.
+ # Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
@@ -578,18 +557,13 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Prepare timesteps
- self.scheduler.set_timesteps(num_inference_steps, device)
- timesteps = self.scheduler.timesteps
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
)
- mu = calculate_shift(image_seq_len)
- sigmas = timesteps / self.scheduler.config.num_train_timesteps
- sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
- self.sigmas = time_shift(mu, 1.0, sigmas).to("cpu")
- self._num_timesteps = len(timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, image_seq_len, device)
+ timesteps = self.scheduler.timesteps
- # 6. Denoising loop
+ # Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -614,16 +588,7 @@ def __call__(
noise_pred_cond, noise_pred_uncond = noise_pred
noise_pred_guided = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
- ###########################
- # Get the corresponding sigma value
- # 这一部分应该放到schduler中(包括self.sigmas的计算也是)
- # 最后应该调用self.scheduler.step(),只需要传入当前的t,返回下一步的latents即可
- sigma = self.sigmas[i]
- sigma_next = self.sigmas[i + 1]
- dt = sigma_next - sigma
-
- latents = latents + dt * noise_pred_guided
- ##############################
+ latents = self.scheduler.step(noise_pred_guided, latents, t).prev_sample
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
@@ -631,7 +596,7 @@ def __call__(
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, sigma, callback_kwargs)
+ callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogview4.py b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
index 012b43dbbad2..6924c30ec286 100644
--- a/src/diffusers/schedulers/scheduling_ddim_cogview4.py
+++ b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
@@ -1,4 +1,5 @@
-# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,154 +29,37 @@
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+# Copied from diffusers.schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput
@dataclass
-# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
-class DDIMSchedulerOutput(BaseOutput):
+class CogView4DDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
- prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
- pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
- The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
- `pred_original_sample` can be used to preview progress or for guidance.
"""
- prev_sample: torch.Tensor
- pred_original_sample: Optional[torch.Tensor] = None
-
-
-# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
-def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
- """
- Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
- (1-beta) over time from t = [0,1].
-
- Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
- to that part of the diffusion process.
-
-
- Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
-
- Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
- """
- if alpha_transform_type == "cosine":
-
- def alpha_bar_fn(t):
- return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
-
- elif alpha_transform_type == "exp":
-
- def alpha_bar_fn(t):
- return math.exp(t * -12.0)
-
- else:
- raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
-
- betas = []
- for i in range(num_diffusion_timesteps):
- t1 = i / num_diffusion_timesteps
- t2 = (i + 1) / num_diffusion_timesteps
- betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
- return torch.tensor(betas, dtype=torch.float32)
-
-
-def rescale_zero_terminal_snr(betas):
- """
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
-
- Args:
- betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
-
- Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
- """
- # Convert betas to alphas_bar_sqrt
- alphas = 1.0 - betas
- alphas_cumprod = torch.cumprod(alphas, dim=0)
- alphas_bar_sqrt = alphas_cumprod.sqrt()
-
- # Store old values.
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
-
- # Shift so the last timestep is zero.
- alphas_bar_sqrt -= alphas_bar_sqrt_T
-
- # Scale so the first timestep is back to the old value.
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
-
- # Convert alphas_bar_sqrt to betas
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
- alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
- alphas = torch.cat([alphas_bar[0:1], alphas])
- betas = 1 - alphas
-
- return betas
+ prev_sample: torch.FloatTensor
class CogView4DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
- `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
- non-Markovian guidance.
+ CogView4 DDIM Scheduler.
+
+ This scheduler is a modified version of the DDIM scheduler specifically designed for use with the CogView4 model.
+ It implements the denoising process using a deterministic approach based on the DDIM (Denoising Diffusion Implicit Models)
+ framework.
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
- methods the library implements for all schedulers such as loading and saving.
+ The scheduler maintains the core DDIM functionality while being optimized for the CogView4 architecture and its specific
+ requirements for image generation tasks.
Args:
- num_train_timesteps (`int`, defaults to 1000):
- The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
- The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
- The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
- trained_betas (`np.ndarray`, *optional*):
- Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- clip_sample (`bool`, defaults to `True`):
- Clip the predicted sample for numerical stability.
- clip_sample_range (`float`, defaults to 1.0):
- The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
- set_alpha_to_one (`bool`, defaults to `True`):
- Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
- there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
- otherwise it uses the alpha value at step 0.
- steps_offset (`int`, defaults to 0):
- An offset added to the inference steps, as required by some model families.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
- thresholding (`bool`, defaults to `False`):
- Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
- as Stable Diffusion.
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
- The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
- sample_max_value (`float`, defaults to 1.0):
- The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
- timestep_spacing (`str`, defaults to `"leading"`):
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- rescale_betas_zero_snr (`bool`, defaults to `False`):
- Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
- dark samples instead of limiting it to samples with medium brightness. Loosely related to
- [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ num_train_timesteps (int, optional): The number of diffusion steps to train the model. Defaults to 1000.
+ beta_start (float, optional): The starting value of beta for the noise schedule. Defaults to 0.0001.
+ beta_end (float, optional): The ending value of beta for the noise schedule. Defaults to 0.02.
+ set_alpha_to_one (bool, optional): Whether to set the final alpha cumprod value to 1. Defaults to True.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -187,44 +71,13 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- clip_sample: bool = True,
set_alpha_to_one: bool = True,
- steps_offset: int = 0,
- prediction_type: str = "epsilon",
- thresholding: bool = False,
- dynamic_thresholding_ratio: float = 0.995,
- clip_sample_range: float = 1.0,
- sample_max_value: float = 1.0,
- timestep_spacing: str = "linspace",
- rescale_betas_zero_snr: bool = False,
- shift_scale: int = 1.0,
):
- if trained_betas is not None:
- self.betas = torch.tensor(trained_betas, dtype=torch.float32)
- elif beta_schedule == "linear":
- self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
- elif beta_schedule == "scaled_linear":
- # this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
- elif beta_schedule == "squaredcos_cap_v2":
- # Glide cosine schedule
- self.betas = betas_for_alpha_bar(num_train_timesteps)
- else:
- raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
-
- # Rescale for zero SNR
- if rescale_betas_zero_snr:
- self.betas = rescale_zero_terminal_snr(self.betas)
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
- # At every step in ddim, we are looking into the previous alphas_cumprod
- # For the final step, there is no previous alphas_cumprod because we are already at 0
- # `set_alpha_to_one` decides whether we set this parameter simply to one or
- # whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
@@ -232,7 +85,6 @@ def __init__(
# setable values
self.num_inference_steps = None
- self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
@@ -251,57 +103,37 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
"""
return sample
- def _get_variance(self, timestep, prev_timestep):
- alpha_prod_t = self.alphas_cumprod[timestep]
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
- beta_prod_t = 1 - alpha_prod_t
- beta_prod_t_prev = 1 - alpha_prod_t_prev
-
- variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
-
- return variance
-
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
- def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
- """
- "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
- prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
- s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
- pixels from saturation at each step. We find that dynamic thresholding results in significantly better
- photorealism as well as better image-text alignment, especially when using very large guidance weights."
-
- https://arxiv.org/abs/2205.11487
- """
- dtype = sample.dtype
- batch_size, channels, *remaining_dims = sample.shape
-
- if dtype not in (torch.float32, torch.float64):
- sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
-
- # Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
-
- abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+ @staticmethod
+ def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ ):
+ if isinstance(image_seq_len, int):
+ mu = math.sqrt(image_seq_len / base_seq_len)
+ elif isinstance(image_seq_len, torch.Tensor):
+ mu = torch.sqrt(image_seq_len / base_seq_len)
+ else:
+ raise ValueError(f"Invalid type for image_seq_len: {type(image_seq_len)}")
- s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
- s = torch.clamp(
- s, min=1, max=self.config.sample_max_value
- ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
- s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
- sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+ mu = mu * 0.75 + 0.25
- sample = sample.reshape(batch_size, channels, *remaining_dims)
- sample = sample.to(dtype)
+ return mu
- return sample
+ @staticmethod
+ def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
+ return mu / (mu + (1 / sigmas - 1) ** shift_sigma)
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(self, num_inference_steps: int, image_seq_len: int, device: Union[str, torch.device] = None):
"""
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Sets the discrete timesteps used for the diffusion chain. Supporting to be called in every batch.
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
+ image_seq_len (`int`):
+ The length of the image sequence.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# Check if the requested number of steps is valid
@@ -315,158 +147,65 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
# Set the current number of inference steps
self.num_inference_steps = num_inference_steps
- # Generate timesteps according to the specified spacing method
- if self.config.timestep_spacing == "linspace":
- timesteps = (
- np.linspace(self.config.num_train_timesteps, 1, num_inference_steps)
- .astype(np.int64) # Only for CogView4
- )
- elif self.config.timestep_spacing == "leading":
- step_ratio = self.config.num_train_timesteps // self.num_inference_steps
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
- timesteps += self.config.steps_offset
- elif self.config.timestep_spacing == "trailing":
- step_ratio = self.config.num_train_timesteps / self.num_inference_steps
- timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
- timesteps -= 1
- else:
- raise ValueError(
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
- )
+ timesteps = np.linspace(self.config.num_train_timesteps, 1, num_inference_steps).astype(np.int64)
+ self.timestep2idx = {timestep: i for i, timestep in enumerate(timesteps)}
# Convert the numpy array of timesteps into a PyTorch tensor
- self.timesteps = torch.from_numpy(timesteps).to(device)
+ timesteps = torch.from_numpy(timesteps).to(device)
+
+ mu = self.calculate_shift(image_seq_len)
+ sigmas = timesteps / self.config.num_train_timesteps
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
+
+ self.timesteps = timesteps
+ self.sigmas = self.time_shift(mu, 1.0, sigmas).to("cpu")
+ self._num_timesteps = len(timesteps)
def step(
self,
model_output: torch.Tensor,
- timestep: int,
sample: torch.Tensor,
- eta: float = 0.0,
- use_clipped_model_output: bool = False,
- generator=None,
- variance_noise: Optional[torch.Tensor] = None,
+ timestep: int,
return_dict: bool = True,
- ) -> Union[DDIMSchedulerOutput, Tuple]:
+ ) -> Union[CogView4DDIMSchedulerOutput, Tuple]:
"""
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
- process from the learned model outputs (most often the predicted noise).
+ Predict the sample from the previous timestep by applying the flow matching update.
+
+ This method implements the flow matching step for the CogView4 DDIM scheduler. It takes the model output
+ (predicted noise) and current sample, and computes the previous sample by following the flow matching
+ update rule.
Args:
model_output (`torch.Tensor`):
- The direct output from learned diffusion model.
- timestep (`float`):
- The current discrete timestep in the diffusion chain.
+ The output from the diffusion model, typically the predicted noise or velocity.
sample (`torch.Tensor`):
- A current instance of a sample created by the diffusion process.
- eta (`float`):
- The weight of noise for added noise in diffusion step.
- use_clipped_model_output (`bool`, defaults to `False`):
- If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
- because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
- clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
- `use_clipped_model_output` has no effect.
- generator (`torch.Generator`, *optional*):
- A random number generator.
- variance_noise (`torch.Tensor`):
- Alternative to generating noise with `generator` by directly providing the noise for the variance
- itself. Useful for methods such as [`CycleDiffusion`].
+ The current sample at the current timestep.
+ timestep (`int`):
+ The current timestep in the diffusion process.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
+ Whether to return a `CogView4DDIMSchedulerOutput` or a tuple.
Returns:
- [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
- tuple is returned where the first element is the sample tensor.
-
+ `CogView4DDIMSchedulerOutput` or `tuple`:
+ If `return_dict` is True, returns a `CogView4DDIMSchedulerOutput` containing the predicted
+ sample at the previous timestep. Otherwise, returns a tuple with the predicted sample.
"""
+
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
+ idx = self.timestep2idx[timestep.item()]
+ sigma = self.sigmas[idx]
+ sigma_next = self.sigmas[idx + 1]
+ dt = sigma_next - sigma
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
- # Ideally, read DDIM paper in-detail understanding
-
- # Notation ( ->
- # - pred_noise_t -> e_theta(x_t, t)
- # - pred_original_sample -> f_theta(x_t, t) or x_0
- # - std_dev_t -> sigma_t
- # - eta -> η
- # - pred_sample_direction -> "direction pointing to x_t"
- # - pred_prev_sample -> "x_t-1"
-
- # 1. get previous step value (=t-1)
- prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
-
- # 2. compute alphas, betas
- alpha_prod_t = self.alphas_cumprod[timestep]
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
-
- beta_prod_t = 1 - alpha_prod_t
-
- # 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
- if self.config.prediction_type == "epsilon":
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
- pred_epsilon = model_output
- elif self.config.prediction_type == "sample":
- pred_original_sample = model_output
- pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
- elif self.config.prediction_type == "v_prediction":
- pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
- pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
- else:
- raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction`"
- )
-
- # 4. Clip or threshold "predicted x_0"
- if self.config.thresholding:
- pred_original_sample = self._threshold_sample(pred_original_sample)
- elif self.config.clip_sample:
- pred_original_sample = pred_original_sample.clamp(
- -self.config.clip_sample_range, self.config.clip_sample_range
- )
-
- # 5. compute variance: "sigma_t(η)" -> see formula (16)
- # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
- variance = self._get_variance(timestep, prev_timestep)
- std_dev_t = eta * variance ** (0.5)
-
- if use_clipped_model_output:
- # the pred_epsilon is always re-derived from the clipped x_0 in Glide
- pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
-
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
-
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
- prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
-
- if eta > 0:
- if variance_noise is not None and generator is not None:
- raise ValueError(
- "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
- " `variance_noise` stays `None`."
- )
-
- if variance_noise is None:
- variance_noise = randn_tensor(
- model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
- )
- variance = std_dev_t * variance_noise
-
- prev_sample = prev_sample + variance
+ prev_sample = sample + dt * model_output
if not return_dict:
- return (
- prev_sample,
- pred_original_sample,
- )
+ return (prev_sample,)
- return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+ return CogView4DDIMSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
From 63982d65e0cacb9235247fd25d5d7968bbfcdd09 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Fri, 7 Feb 2025 01:09:41 +0800
Subject: [PATCH 49/68] now It work
---
.../transformers/transformer_cogview4.py | 32 ++++++-------------
.../pipelines/cogview4/pipeline_cogview4.py | 1 -
2 files changed, 9 insertions(+), 24 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 3afdd15ec9c7..da16ad5e5c0a 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Any, Dict, Union
+from typing import Dict, Union
import torch
import torch.nn as nn
@@ -281,12 +281,6 @@ def __init__(
conditioning_embedding_dim=time_embed_dim,
elementwise_affine=False,
)
- self.adaln_final = self.norm_out.linear
- # with torch.no_grad():
- # w = self.norm_out.linear.weight.data.clone()
- # w_swapped = swap_scale_shift(w, dim=0)
- # self.adaln_final.weight.data.copy_(w_swapped)
-
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@@ -485,22 +479,14 @@ def forward(
image_rotary_emb=image_rotary_emb,
)
- hidden_states_cond = self.layernorm(hidden_states_cond)
- hidden_states_uncond = self.layernorm(hidden_states_uncond)
- encoder_hidden_states_cond = self.layernorm(encoder_hidden_states_cond)
- encoder_hidden_states_uncond = self.layernorm(encoder_hidden_states_uncond)
-
- #################################################
- # reload weight&bias for debug
- self.adaln_final.weight = torch.load("/home/lhy/code/cogview/adaln_final_weight.pt")
- self.adaln_final.bias = torch.load("/home/lhy/code/cogview/adaln_final_bias.pt")
- #################################################
-
- shift_cond, scale_cond = self.adaln_final(temb_cond).chunk(2, dim=-1)
- shift_uncond, scale_uncond = self.adaln_final(temb_uncond).chunk(2, dim=-1)
-
- hidden_states_cond = hidden_states_cond * (1 + scale_cond) + shift_cond
- hidden_states_uncond = hidden_states_uncond * (1 + scale_uncond) + shift_uncond
+ hidden_states_cond, encoder_hidden_states_cond = (
+ self.norm_out(hidden_states_cond, temb_cond),
+ self.norm_out(encoder_hidden_states_cond, temb_cond),
+ )
+ hidden_states_uncond, encoder_hidden_states_uncond = (
+ self.norm_out(hidden_states_uncond, temb_uncond),
+ self.norm_out(encoder_hidden_states_uncond, temb_uncond),
+ )
hidden_states_cond = self.proj_out(hidden_states_cond)
hidden_states_uncond = self.proj_out(hidden_states_uncond)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 57c45ebdfeca..c13e9d0241c7 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -14,7 +14,6 @@
# limitations under the License.
import inspect
-import math
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
From d4748e049a7dd81cdd9e021be33842eb654d978d Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Fri, 7 Feb 2025 14:03:48 +0800
Subject: [PATCH 50/68] add timestep
---
src/diffusers/pipelines/cogview4/pipeline_cogview4.py | 7 +------
1 file changed, 1 insertion(+), 6 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index c13e9d0241c7..64a029d7e2ff 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -147,12 +147,6 @@ def _get_glm_embeds(
text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
).hidden_states[-2]
- # TODO: This is for Older GLM-4 as https://huggingface.co/THUDM/glm-4-9b, will use https://huggingface.co/THUDM/glm-4-9b-hf for new transformers version format.
- # TODO: Remove it later
- # prompt_embeds = self.text_encoder(
- # text_input_ids.to(self.text_encoder.transformer.device), output_hidden_states=True
- # ).hidden_states[-2]
-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -569,6 +563,7 @@ def __call__(
if self.interrupt:
continue
timestep = t.reshape((1,))
+ timestep = torch.cat([timestep] * num_images_per_prompt)
timestep = torch.cat([timestep] * 2) if do_classifier_free_guidance else t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
From 95f851d9e5d542e6acc37360f117c220f3f3cf23 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Fri, 7 Feb 2025 17:33:12 +0800
Subject: [PATCH 51/68] batch
---
.../transformers/transformer_cogview4.py | 26 ++++++++-----------
1 file changed, 11 insertions(+), 15 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index da16ad5e5c0a..2c8e37cadc2c 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Dict, Union
+from typing import Any, Dict, Union
import torch
import torch.nn as nn
@@ -80,13 +80,17 @@ def __init__(
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def multi_modulate(self, hidden_states, encoder_hidden_states, factors) -> torch.Tensor:
- n_sample, n_type, h = factors[0].shape
+ _, _, h = factors[0].shape
shift_factor, scale_factor = factors[0].view(-1, h), factors[1].view(-1, h)
shift_factor_hidden_states, shift_factor_encoder_hidden_states = shift_factor.chunk(2, dim=0)
scale_factor_hidden_states, scale_factor_encoder_hidden_states = scale_factor.chunk(2, dim=0)
-
+ shift_factor_hidden_states = shift_factor_hidden_states.unsqueeze(1)
+ scale_factor_hidden_states = scale_factor_hidden_states.unsqueeze(1)
hidden_states = torch.addcmul(shift_factor_hidden_states, hidden_states, (1 + scale_factor_hidden_states))
+
+ shift_factor_encoder_hidden_states = shift_factor_encoder_hidden_states.unsqueeze(1)
+ scale_factor_encoder_hidden_states = scale_factor_encoder_hidden_states.unsqueeze(1)
encoder_hidden_states = torch.addcmul(
shift_factor_encoder_hidden_states, encoder_hidden_states, (1 + scale_factor_encoder_hidden_states)
)
@@ -94,11 +98,14 @@ def multi_modulate(self, hidden_states, encoder_hidden_states, factors) -> torch
return hidden_states, encoder_hidden_states
def multi_gate(self, hidden_states, encoder_hidden_states, factor):
- batch_size, seq_len, hidden_dim = hidden_states.shape
+ _, _, hidden_dim = hidden_states.shape
gate_factor = factor.view(-1, hidden_dim)
gate_factor_hidden_states, gate_factor_encoder_hidden_states = gate_factor.chunk(2, dim=0)
+ gate_factor_hidden_states = gate_factor_hidden_states.unsqueeze(1)
+ gate_factor_encoder_hidden_states = gate_factor_encoder_hidden_states.unsqueeze(1)
hidden_states = gate_factor_hidden_states * hidden_states
encoder_hidden_states = gate_factor_encoder_hidden_states * encoder_hidden_states
+
return hidden_states, encoder_hidden_states
def forward(
@@ -111,12 +118,7 @@ def forward(
) -> torch.Tensor:
batch_size, encoder_hidden_states_len, hidden_dim = encoder_hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
residual = hidden_states
-
- # time_embedding embedding, [n_sample, h]
- assert time_embedding is not None
-
layernorm_factor = (
self.adaln(time_embedding)
.view(
@@ -128,9 +130,6 @@ def forward(
.permute(1, 2, 0, 3)
.contiguous()
)
-
- ##############################################################
- # Optional Input Layer norm
hidden_states = self.layernorm(hidden_states)
hidden_states, encoder_hidden_states = self.multi_modulate(
hidden_states=hidden_states[:, encoder_hidden_states_len:],
@@ -151,7 +150,6 @@ def forward(
hidden_states += residual
residual = hidden_states
- ##############################################################
hidden_states = self.layernorm(hidden_states)
hidden_states, encoder_hidden_states = self.multi_modulate(
hidden_states=hidden_states[:, encoder_hidden_states_len:],
@@ -167,8 +165,6 @@ def forward(
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states += residual
-
- ##############################################################
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states_len:],
hidden_states[:, :encoder_hidden_states_len],
From cb5628272f14f22ba2a181c47c9044d56d76dd60 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Fri, 7 Feb 2025 17:43:50 +0800
Subject: [PATCH 52/68] change convert scipt
---
scripts/convert_cogview4_to_diffusers.py | 16 ++--------------
.../convert_cogview4_to_diffusers_megatron.py | 8 ++------
2 files changed, 4 insertions(+), 20 deletions(-)
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index 4405a40fb761..36a0f61cb0a7 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -1,5 +1,6 @@
"""
-Convert a CogView4 checkpoint to the Diffusers format.
+Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format.
+(deprecated Since 2025-02-07 and will remove it in later CogView4 version)
This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
with the Diffusers library.
@@ -217,19 +218,6 @@ def main(args):
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
)
- # TODO: This is for Older GLM-4 as https://huggingface.co/THUDM/glm-4-9b, will use https://huggingface.co/THUDM/glm-4-9b-hf for new transformers version format.
- # TODO: Remove it later
-
- # from transformers import AutoTokenizer,AutoModel
- # text_encoder_id = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/glm-4-9b"
- # tokenizer = AutoTokenizer.from_pretrained(text_encoder_id,trust_remote_code=True)
- # text_encoder = AutoModel.from_pretrained(
- # text_encoder_id,
- # cache_dir=args.text_encoder_cache_dir,
- # torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
- # trust_remote_code = True
- # )
-
for param in text_encoder.parameters():
param.data = param.data.contiguous()
diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py
index b9e08bb4ae4f..4043f8ef4105 100644
--- a/scripts/convert_cogview4_to_diffusers_megatron.py
+++ b/scripts/convert_cogview4_to_diffusers_megatron.py
@@ -1,12 +1,9 @@
"""
-Convert a CogView4 checkpoint to the Diffusers format.
-
-This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
-with the Diffusers library.
+Convert a CogView4 checkpoint from Megatron to the Diffusers format.
Example usage:
python scripts/convert_cogview4_to_diffusers.py \
- --transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
+ --transformer_checkpoint_path 'your path/cogview4_6b/mp_rank_00/model_optim_rng.pt' \
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
--output_path "THUDM/CogView4-6B" \
--dtype "bf16"
@@ -25,7 +22,6 @@
"""
import argparse
-from contextlib import nullcontext
import torch
from transformers import PreTrainedTokenizerFast, GlmForCausalLM
from tqdm import tqdm
From fedf3255a5744153fbacbc181f6b36d73517e9ec Mon Sep 17 00:00:00 2001
From: Aryan
Date: Mon, 10 Feb 2025 06:23:43 +0100
Subject: [PATCH 53/68] refactor pt. 1; make style
---
scripts/convert_cogview4_to_diffusers.py | 3 +-
.../convert_cogview4_to_diffusers_megatron.py | 4 +-
src/diffusers/__init__.py | 4 +-
src/diffusers/models/attention_processor.py | 75 ----
src/diffusers/models/embeddings.py | 71 ----
.../transformers/transformer_cogview4.py | 394 +++++++++---------
src/diffusers/pipelines/auto_pipeline.py | 1 -
.../pipelines/cogview4/pipeline_cogview4.py | 64 +--
.../schedulers/scheduling_cogview.py | 6 +-
.../schedulers/scheduling_ddim_cogview4.py | 19 +-
src/diffusers/utils/dummy_pt_objects.py | 2 +
11 files changed, 241 insertions(+), 402 deletions(-)
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index 36a0f61cb0a7..5d09d1206e7c 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -30,12 +30,13 @@
import torch
from accelerate import init_empty_weights
-from transformers import PreTrainedTokenizerFast, GlmForCausalLM
+from transformers import GlmForCausalLM, PreTrainedTokenizerFast
from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
+
CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py
index 4043f8ef4105..61a9fb3eb2fd 100644
--- a/scripts/convert_cogview4_to_diffusers_megatron.py
+++ b/scripts/convert_cogview4_to_diffusers_megatron.py
@@ -22,9 +22,10 @@
"""
import argparse
+
import torch
-from transformers import PreTrainedTokenizerFast, GlmForCausalLM
from tqdm import tqdm
+from transformers import GlmForCausalLM, PreTrainedTokenizerFast
from diffusers import (
AutoencoderKL,
@@ -34,6 +35,7 @@
)
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_checkpoint_path",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index e685ea1a19e5..f8448e030a72 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -287,8 +287,8 @@
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
- "ConsisIDPipeline",
"CogView4Pipeline",
+ "ConsisIDPipeline",
"CycleDiffusionPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
@@ -783,8 +783,8 @@
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
- ConsisIDPipeline,
CogView4Pipeline,
+ ConsisIDPipeline,
CycleDiffusionPipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index fbe96869f7a9..8bba5a82bc2f 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2803,80 +2803,6 @@ def __call__(
return hidden_states
-class CogView4AttnProcessor:
- """
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogView4AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = hidden_states.shape
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim)
- key = key.view(batch_size, -1, attn.heads, head_dim)
- value = value.view(batch_size, -1, attn.heads, head_dim)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb_megatron
-
- query[:, :, text_seq_length:, :] = apply_rotary_emb_megatron(
- query[:, :, text_seq_length:, :], image_rotary_emb
- )
- key[:, :, text_seq_length:, :] = apply_rotary_emb_megatron(
- key[:, :, text_seq_length:, :], image_rotary_emb
- )
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -6247,7 +6173,6 @@ def __call__(
FusedFluxAttnProcessor2_0,
FusedFluxAttnProcessor2_0_NPU,
CogVideoXAttnProcessor2_0,
- CogView4AttnProcessor,
FusedCogVideoXAttnProcessor2_0,
XFormersAttnAddedKVProcessor,
XFormersAttnProcessor,
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 0c457dd0f2d6..39e1833cac37 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -813,55 +813,6 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
return (hidden_states + pos_embed).to(hidden_states.dtype)
-class CogView4PatchEmbed(nn.Module):
- def __init__(
- self,
- in_channels: int = 16,
- hidden_size: int = 2560,
- patch_size: int = 2,
- text_hidden_size: int = 4096,
- pos_embed_max_size: int = 128,
- ):
- super().__init__()
- self.in_channels = in_channels
- self.hidden_size = hidden_size
- self.patch_size = patch_size
- self.text_hidden_size = text_hidden_size
- self.pos_embed_max_size = pos_embed_max_size
- # Linear projection for image patches
- self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
-
- # Linear projection for text embeddings
- self.text_proj = nn.Linear(text_hidden_size, hidden_size)
-
- def forward(
- self, hidden_states: torch.Tensor, prompt_embeds: torch.Tensor, negative_prompt_embeds: torch.Tensor | None
- ) -> torch.Tensor:
- batch_size, channel, height, width = hidden_states.shape
-
- if height % self.patch_size != 0 or width % self.patch_size != 0:
- raise ValueError("Height and width must be divisible by patch size")
-
- patch_height = height // self.patch_size
- patch_width = width // self.patch_size
-
- # b, c, h, w -> b, c, patch_height, patch_size, patch_width, patch_size
- # -> b, patch_height, patch_width, c, patch_size, patch_size
- # -> b, patch_height * patch_width, c * patch_size * patch_size
- hidden_states = (
- hidden_states.reshape(batch_size, channel, patch_height, self.patch_size, patch_width, self.patch_size)
- .permute(0, 2, 4, 1, 3, 5)
- .reshape(batch_size, patch_height * patch_width, channel * self.patch_size * self.patch_size)
- )
-
- # project
- hidden_states = self.proj(hidden_states) # embed_dim: 64 -> 4096
- prompt_embeds = self.text_proj(prompt_embeds) # embed_dim: 4096 -> 4096
- if negative_prompt_embeds is not None:
- negative_prompt_embeds = self.text_proj(negative_prompt_embeds) # embed_dim: 4096 -> 4096
- return hidden_states, prompt_embeds, negative_prompt_embeds
-
-
def get_3d_rotary_pos_embed(
embed_dim,
crops_coords,
@@ -1283,28 +1234,6 @@ def apply_1d_rope(tokens, pos, cos, sin):
x = torch.cat([t, h, w], dim=-1)
return x
-def apply_rotary_emb_megatron(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
- """Apply rotary position embeddings to input tensor.
-
- Args:
- x: Input tensor of shape [batch_size, n_heads, seq_len, head_dim]
- freqs: Frequency tensor of shape [seq_len, head_dim]
-
- Returns:
- Tensor with rotary position embeddings applied
- """
- batch_size, n_heads, seq_len, rot_dim = x.shape
- assert rot_dim % 2 == 0 and rot_dim == freqs.shape[-1]
-
- x_dim_first_half, x_dim_second_half = x.chunk(2, dim=-1)
- x_rot_shifted = torch.cat([-x_dim_second_half, x_dim_first_half], dim=-1)
-
- cos, sin = torch.cos(freqs), torch.sin(freqs)
-
- x_out = cos * x + sin * x_rot_shifted
-
- return x_out
-
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 2c8e37cadc2c..2b96110cf3ce 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from typing import Any, Dict, Union
+from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -21,15 +20,11 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
-from ...models.attention_processor import (
- Attention,
- AttentionProcessor,
- CogView4AttnProcessor,
-)
+from ...models.attention_processor import Attention
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
-from ...utils import is_torch_version, logging
-from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView4PatchEmbed
+from ...utils import logging
+from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
@@ -37,28 +32,125 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class CogView4TransformerBlock(nn.Module):
- r"""
- Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
+class CogView4PatchEmbed(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 16,
+ hidden_size: int = 2560,
+ patch_size: int = 2,
+ text_hidden_size: int = 4096,
+ pos_embed_max_size: int = 128,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_size = hidden_size
+ self.patch_size = patch_size
+ self.text_hidden_size = text_hidden_size
+ self.pos_embed_max_size = pos_embed_max_size
+ # Linear projection for image patches
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
- Args:
- dim (`int`):
- The number of channels in the input and output.
- num_attention_heads (`int`):
- The number of heads to use for multi-head attention.
- attention_head_dim (`int`):
- The number of channels in each head.
- time_embed_dim (`int`):
- The number of channels in timestep embedding.
+ # Linear projection for text embeddings
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
+
+ def forward(
+ self, hidden_states: torch.Tensor, prompt_embeds: torch.Tensor, negative_prompt_embeds: torch.Tensor | None
+ ) -> torch.Tensor:
+ batch_size, channel, height, width = hidden_states.shape
+
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
+ raise ValueError("Height and width must be divisible by patch size")
+
+ patch_height = height // self.patch_size
+ patch_width = width // self.patch_size
+
+ # b, c, h, w -> b, c, patch_height, patch_size, patch_width, patch_size
+ # -> b, patch_height, patch_width, c, patch_size, patch_size
+ # -> b, patch_height * patch_width, c * patch_size * patch_size
+ hidden_states = (
+ hidden_states.reshape(batch_size, channel, patch_height, self.patch_size, patch_width, self.patch_size)
+ .permute(0, 2, 4, 1, 3, 5)
+ .reshape(batch_size, patch_height * patch_width, channel * self.patch_size * self.patch_size)
+ )
+
+ # project
+ hidden_states = self.proj(hidden_states) # embed_dim: 64 -> 4096
+ prompt_embeds = self.text_proj(prompt_embeds) # embed_dim: 4096 -> 4096
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = self.text_proj(negative_prompt_embeds) # embed_dim: 4096 -> 4096
+ return hidden_states, prompt_embeds, negative_prompt_embeds
+
+
+class CogView4AttnProcessor:
+ """
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
"""
- def __init__(
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogView4AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
self,
- dim: int = 2560,
- num_attention_heads: int = 64,
- attention_head_dim: int = 40,
- time_embed_dim: int = 512,
- ):
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # 1. QKV projections
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ # 2. QK normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # 3. Rotational positional embeddings applied to latent stream
+ if image_rotary_emb is not None:
+
+ def apply_rotary_emb(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
+ cos, sin = freqs
+ x_real, x_imag = x.chunk(2, dim=-1)
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
+ x_out = cos * x.float() + sin * x_rotated.float()
+ return x_out
+
+ query[:, :, text_seq_length:, :] = apply_rotary_emb(query[:, :, text_seq_length:, :], image_rotary_emb)
+ key[:, :, text_seq_length:, :] = apply_rotary_emb(key[:, :, text_seq_length:, :], image_rotary_emb)
+
+ # 4. Attention
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ # 5. Output projection
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class CogView4TransformerBlock(nn.Module):
+ def __init__(
+ self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
+ ) -> None:
super().__init__()
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
@@ -112,17 +204,16 @@ def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
- time_embedding: torch.Tensor = None,
- image_rotary_emb: torch.Tensor = None,
- **kwargs,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, encoder_hidden_states_len, hidden_dim = encoder_hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
layernorm_factor = (
- self.adaln(time_embedding)
+ self.adaln(temb)
.view(
- time_embedding.shape[0],
+ temb.shape[0],
6,
2,
hidden_states.shape[-1],
@@ -172,10 +263,48 @@ def forward(
return hidden_states, encoder_hidden_states
-def swap_scale_shift(weight, dim):
- shift, scale = weight.chunk(2, dim=0)
- new_weight = torch.cat([scale, shift], dim=0)
- return new_weight
+class CogView4RotaryPosEmbed(nn.Module):
+ def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.rope_axes_dim = rope_axes_dim
+
+ dim_h, dim_w = dim // 2, dim // 2
+ h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
+ w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
+ h_seq = torch.arange(self.rope_axes_dim[0])
+ w_seq = torch.arange(self.rope_axes_dim[1])
+ self.freqs_h = torch.outer(h_seq, h_inv_freq)
+ self.freqs_w = torch.outer(w_seq, w_inv_freq)
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size, num_channels, height, width = hidden_states.shape
+ height, width = height // self.patch_size, width // self.patch_size
+
+ h_idx = torch.arange(height)
+ w_idx = torch.arange(width)
+ inner_h_idx = h_idx * self.rope_axes_dim[0] // height
+ inner_w_idx = w_idx * self.rope_axes_dim[1] // width
+
+ self.freqs_h = self.freqs_h.to(hidden_states.device)
+ self.freqs_w = self.freqs_w.to(hidden_states.device)
+ freqs_h = self.freqs_h[inner_h_idx]
+ freqs_w = self.freqs_w[inner_w_idx]
+
+ # Create position matrices for height and width
+ # [height, 1, dim//4] and [1, width, dim//4]
+ freqs_h = freqs_h.unsqueeze(1)
+ freqs_w = freqs_w.unsqueeze(0)
+ # Broadcast freqs_h and freqs_w to [height, width, dim//4]
+ freqs_h = freqs_h.expand(height, width, -1)
+ freqs_w = freqs_w.expand(height, width, -1)
+
+ # Concatenate along last dimension to get [height, width, dim//2]
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1)
+ freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
+ freqs = freqs.reshape(height * width, -1)
+ return (freqs.cos(), freqs.sin())
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
@@ -212,6 +341,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["CogView4TransformerBlock", "CogView4PatchEmbed", "CogView4PatchEmbed"]
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
@register_to_config
def __init__(
@@ -227,26 +357,23 @@ def __init__(
condition_dim: int = 256,
pos_embed_max_size: int = 128,
sample_size: int = 128,
+ rope_axes_dim: Tuple[int, int] = (256, 256),
):
super().__init__()
- self.out_channels = out_channels
- self.inner_dim = num_attention_heads * attention_head_dim
# CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
# Each of these are sincos embeddings of shape 2 * condition_dim
- self.pooled_projection_dim = 3 * 2 * condition_dim
+ pooled_projection_dim = 3 * 2 * condition_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels
- self.max_h = 256
- self.max_w = 256
- self.rope = self.prepare_rope(
- embed_dim=self.config.attention_head_dim, max_h=self.max_h, max_w=self.max_w, rotary_base=10000
- )
-
- self.layernorm = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-5)
+ # 1. RoPE
+ self.rope = CogView4RotaryPosEmbed(attention_head_dim, patch_size, rope_axes_dim, theta=10000.0)
+ # 2. Patch & Text-timestep embedding
self.patch_embed = CogView4PatchEmbed(
in_channels=in_channels,
- hidden_size=self.inner_dim,
+ hidden_size=inner_dim,
patch_size=patch_size,
text_hidden_size=text_embed_dim,
pos_embed_max_size=pos_embed_max_size,
@@ -255,146 +382,29 @@ def __init__(
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
- pooled_projection_dim=self.pooled_projection_dim,
- timesteps_dim=self.inner_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ timesteps_dim=inner_dim,
)
+ # 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
- CogView4TransformerBlock(
- dim=self.inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- time_embed_dim=time_embed_dim,
- )
+ CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
for _ in range(num_layers)
]
)
- ######################################
- self.norm_out = AdaLayerNormContinuous(
- embedding_dim=self.inner_dim,
- conditioning_embedding_dim=time_embed_dim,
- elementwise_affine=False,
- )
- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+ # 4. Output projection
+ self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
- @staticmethod
- def prepare_rope(embed_dim, max_h, max_w, rotary_base):
- dim_h = embed_dim // 2
- dim_w = embed_dim // 2
- h_inv_freq = 1.0 / (
- rotary_base ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
- )
- w_inv_freq = 1.0 / (
- rotary_base ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
- )
- h_seq = torch.arange(max_h, dtype=h_inv_freq.dtype)
- w_seq = torch.arange(max_w, dtype=w_inv_freq.dtype)
- freqs_h = torch.outer(h_seq, h_inv_freq)
- freqs_w = torch.outer(w_seq, w_inv_freq)
- return (freqs_h, freqs_w)
-
- def get_rope_embedding(self, height, width, target_h, target_w, device):
- # Get pre-computed frequencies
- freqs_h, freqs_w = self.rope
-
- h_idx = torch.arange(height)
- w_idx = torch.arange(width)
- inner_h_idx = (h_idx * self.max_h) // target_h
- inner_w_idx = (w_idx * self.max_w) // target_w
-
- freqs_h = freqs_h[inner_h_idx].to(device)
- freqs_w = freqs_w[inner_w_idx].to(device)
-
- # Create position matrices for height and width
- # [height, 1, dim//4] and [1, width, dim//4]
- freqs_h = freqs_h.unsqueeze(1)
- freqs_w = freqs_w.unsqueeze(0)
- # Broadcast freqs_h and freqs_w to [height, width, dim//4]
- freqs_h = freqs_h.expand(height, width, -1)
- freqs_w = freqs_w.expand(height, width, -1)
-
- # Concatenate along last dimension to get [height, width, dim//2]
- freqs = torch.cat([freqs_h, freqs_w], dim=-1)
-
- freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
- freqs = freqs.reshape(height * width, -1)
-
- return freqs
- # return freqs.cos(), freqs.sin()
-
def forward(
self,
hidden_states: torch.Tensor,
prompt_embeds: torch.Tensor,
- negative_prompt_embeds: torch.Tensor | None,
+ negative_prompt_embeds: Optional[torch.Tensor],
timestep: torch.LongTensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
@@ -429,21 +439,18 @@ def forward(
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
The denoised latents using provided inputs as conditioning.
"""
- batch_size, channel, height, width = hidden_states.shape
- patch_height, patch_width = height // self.config.patch_size, width // self.config.patch_size
+ batch_size, num_channels, height, width = hidden_states.shape
do_cfg = negative_prompt_embeds is not None
if do_cfg:
- assert batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0], (
- "batch size mismatch in CFG mode"
- )
+ assert (
+ batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0]
+ ), "batch size mismatch in CFG mode"
else:
assert batch_size == prompt_embeds.shape[0], "batch size mismatch in non-CFG mode"
# 1. RoPE
- image_rotary_emb = self.get_rope_embedding(
- patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
- )
+ image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
@@ -459,20 +466,18 @@ def forward(
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- # TODO 微调使用
- ...
+ hidden_states_cond, encoder_hidden_states_cond = self._gradient_checkpointing_func(
+ block, hidden_states_cond, encoder_hidden_states_cond, temb_cond, image_rotary_emb
+ )
+ hidden_states_uncond, encoder_hidden_states_uncond = self._gradient_checkpointing_func(
+ block, hidden_states_uncond, encoder_hidden_states_uncond, temb_uncond, image_rotary_emb
+ )
else:
hidden_states_cond, encoder_hidden_states_cond = block(
- hidden_states=hidden_states_cond,
- encoder_hidden_states=encoder_hidden_states_cond,
- time_embedding=temb_cond,
- image_rotary_emb=image_rotary_emb,
+ hidden_states_cond, encoder_hidden_states_cond, temb_cond, image_rotary_emb
)
hidden_states_uncond, encoder_hidden_states_uncond = block(
- hidden_states=hidden_states_uncond,
- encoder_hidden_states=encoder_hidden_states_uncond,
- time_embedding=temb_uncond,
- image_rotary_emb=image_rotary_emb,
+ hidden_states_uncond, encoder_hidden_states_uncond, temb_uncond, image_rotary_emb
)
hidden_states_cond, encoder_hidden_states_cond = (
@@ -493,20 +498,21 @@ def forward(
width = width // patch_size
hidden_states_cond = hidden_states_cond.reshape(
- shape=(hidden_states_cond.shape[0], height, width, self.out_channels, patch_size, patch_size)
+ shape=(hidden_states_cond.shape[0], height, width, -1, patch_size, patch_size)
)
hidden_states_cond = torch.einsum("nhwcpq->nchpwq", hidden_states_cond)
output_cond = hidden_states_cond.reshape(
- shape=(hidden_states_cond.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ shape=(hidden_states_cond.shape[0], -1, height * patch_size, width * patch_size)
)
hidden_states_uncond = hidden_states_uncond.reshape(
- shape=(hidden_states_uncond.shape[0], height, width, self.out_channels, patch_size, patch_size)
+ hidden_states_uncond.shape[0], height, width, -1, patch_size, patch_size
)
hidden_states_uncond = torch.einsum("nhwcpq->nchpwq", hidden_states_uncond)
output_uncond = hidden_states_uncond.reshape(
- shape=(hidden_states_uncond.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ hidden_states_uncond.shape[0], -1, height * patch_size, width * patch_size
)
+
if not return_dict:
return (output_cond, output_uncond)
return Transformer2DModelOutput(sample=output_cond), Transformer2DModelOutput(sample=output_uncond)
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index 353be8635649..d9adb3d3fcfe 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -23,7 +23,6 @@
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4Pipeline
-
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 64a029d7e2ff..1f0ffe7d2154 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -100,7 +99,6 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
- self.image_factor = 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _get_glm_embeds(
@@ -266,24 +264,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
- def prepare_extra_step_kwargs(self, generator, eta):
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
- # and should be between [0, 1]
-
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
- extra_step_kwargs = {}
- if accepts_eta:
- extra_step_kwargs["eta"] = eta
-
- # check if the scheduler accepts generator
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
- if accepts_generator:
- extra_step_kwargs["generator"] = generator
- return extra_step_kwargs
-
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
@@ -295,15 +275,8 @@ def check_inputs(
prompt_embeds=None,
negative_prompt_embeds=None,
):
- if height % self.image_factor != 0 or width % self.image_factor != 0:
- raise ValueError(
- f"`height` and `width` have to be divisible by {self.image_factor} but are {height} and {width}."
- )
-
- if height < 512 or height > 2048 or width < 512 or width > 2048:
- raise ValueError(
- f"`height` and `width` must be between 512 and 2048, but got height={height} and width={width}."
- )
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -508,6 +481,7 @@ def __call__(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = self.do_classifier_free_guidance
+
# Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
@@ -519,14 +493,15 @@ def __call__(
max_sequence_length=max_sequence_length,
device=device,
)
- # Prepare latents.
+
+ # Prepare latents
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
- prompt_embeds.dtype,
+ torch.float32,
device,
generator,
latents,
@@ -546,9 +521,6 @@ def __call__(
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
- # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
-
# Prepare timesteps
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
@@ -557,33 +529,38 @@ def __call__(
timesteps = self.scheduler.timesteps
# Denoising loop
+ transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
- timestep = t.reshape((1,))
- timestep = torch.cat([timestep] * num_images_per_prompt)
- timestep = torch.cat([timestep] * 2) if do_classifier_free_guidance else t
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = latent_model_input.to(transformer_dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+
noise_pred = self.transformer(
hidden_states=latent_model_input,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
- timestep=timestep, # Pass sigma as timestep for noise prediction
+ timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)
+
# perform guidance
if do_classifier_free_guidance:
noise_pred_cond, noise_pred_uncond = noise_pred
- noise_pred_guided = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
- latents = self.scheduler.step(noise_pred_guided, latents, t).prev_sample
- latents = latents.to(prompt_embeds.dtype)
+ latents = self.scheduler.step(noise_pred, latents, t).prev_sample
# call the callback, if provided
if callback_on_step_end is not None:
@@ -602,9 +579,8 @@ def __call__(
xm.mark_step()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
- 0
- ]
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
else:
image = latents
diff --git a/src/diffusers/schedulers/scheduling_cogview.py b/src/diffusers/schedulers/scheduling_cogview.py
index 103706360a8f..9a0740f56d63 100644
--- a/src/diffusers/schedulers/scheduling_cogview.py
+++ b/src/diffusers/schedulers/scheduling_cogview.py
@@ -311,9 +311,9 @@ def add_noise(
sqrt_alpha_prod = self.sqrt_alphas_cumprod[timesteps]
sigmas = self.sigmas[timesteps]
assert sqrt_alpha_prod.dim() == 1, f"sqrt_alpha_prod must be a 1D tensor, got {sqrt_alpha_prod.dim()}D"
- assert sqrt_alpha_prod.shape == sigmas.shape, (
- f"sigmas and sqrt_alpha_prod must have the same shape, got {sigmas.shape} and {sqrt_alpha_prod.shape}"
- )
+ assert (
+ sqrt_alpha_prod.shape == sigmas.shape
+ ), f"sigmas and sqrt_alpha_prod must have the same shape, got {sigmas.shape} and {sqrt_alpha_prod.shape}"
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sigmas = sigmas.unsqueeze(-1)
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogview4.py b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
index 6924c30ec286..d596c14caf5e 100644
--- a/src/diffusers/schedulers/scheduling_ddim_cogview4.py
+++ b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
@@ -18,14 +18,13 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
-from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@@ -49,11 +48,11 @@ class CogView4DDIMScheduler(SchedulerMixin, ConfigMixin):
CogView4 DDIM Scheduler.
This scheduler is a modified version of the DDIM scheduler specifically designed for use with the CogView4 model.
- It implements the denoising process using a deterministic approach based on the DDIM (Denoising Diffusion Implicit Models)
- framework.
+ It implements the denoising process using a deterministic approach based on the DDIM (Denoising Diffusion Implicit
+ Models) framework.
- The scheduler maintains the core DDIM functionality while being optimized for the CogView4 architecture and its specific
- requirements for image generation tasks.
+ The scheduler maintains the core DDIM functionality while being optimized for the CogView4 architecture and its
+ specific requirements for image generation tasks.
Args:
num_train_timesteps (int, optional): The number of diffusion steps to train the model. Defaults to 1000.
@@ -172,8 +171,8 @@ def step(
Predict the sample from the previous timestep by applying the flow matching update.
This method implements the flow matching step for the CogView4 DDIM scheduler. It takes the model output
- (predicted noise) and current sample, and computes the previous sample by following the flow matching
- update rule.
+ (predicted noise) and current sample, and computes the previous sample by following the flow matching update
+ rule.
Args:
model_output (`torch.Tensor`):
@@ -187,8 +186,8 @@ def step(
Returns:
`CogView4DDIMSchedulerOutput` or `tuple`:
- If `return_dict` is True, returns a `CogView4DDIMSchedulerOutput` containing the predicted
- sample at the previous timestep. Otherwise, returns a tuple with the predicted sample.
+ If `return_dict` is True, returns a `CogView4DDIMSchedulerOutput` containing the predicted sample at
+ the previous timestep. Otherwise, returns a tuple with the predicted sample.
"""
if self.num_inference_steps is None:
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index e465f75f7129..005365bb8841 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1338,6 +1338,7 @@ def from_config(cls, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+
class CogView4DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]
@@ -1352,6 +1353,7 @@ def from_config(cls, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+
class DDIMInverseScheduler(metaclass=DummyObject):
_backends = ["torch"]
From 4c01c9d7ffb243dfaa01cc195dab216d0d95303a Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 12 Feb 2025 09:59:06 +0100
Subject: [PATCH 54/68] refactor pt. 2
---
.../transformers/transformer_cogview4.py | 171 +++++++++---------
.../pipelines/cogview4/pipeline_cogview4.py | 38 +---
2 files changed, 92 insertions(+), 117 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 2b96110cf3ce..034db2701e6f 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -26,7 +26,6 @@
from ...utils import logging
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
-from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -81,6 +80,53 @@ def forward(
return hidden_states, prompt_embeds, negative_prompt_embeds
+class CogView4AdaLayerNormZero(nn.Module):
+ def __init__(self, embedding_dim: int, dim: int) -> None:
+ super().__init__()
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
+
+ def forward(
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ norm_hidden_states = self.norm(hidden_states)
+ norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
+
+ emb = self.linear(temb)
+ (
+ shift_msa,
+ c_shift_msa,
+ scale_msa,
+ c_scale_msa,
+ gate_msa,
+ c_gate_msa,
+ shift_mlp,
+ c_shift_mlp,
+ scale_mlp,
+ c_scale_mlp,
+ gate_mlp,
+ c_gate_mlp,
+ ) = emb.chunk(12, dim=1)
+
+ hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+ encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
+
+ return (
+ hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ encoder_hidden_states,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ )
+
+
class CogView4AttnProcessor:
"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -89,7 +135,7 @@ class CogView4AttnProcessor:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogView4AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
@@ -153,10 +199,8 @@ def __init__(
) -> None:
super().__init__()
- self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
- self.adaln = self.norm1.linear
- self.layernorm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
-
+ # 1. Attention
+ self.norm1 = CogView4AdaLayerNormZero(time_embed_dim, dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
@@ -169,37 +213,11 @@ def __init__(
processor=CogView4AttnProcessor(),
)
+ # 2. Feedforward
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
- def multi_modulate(self, hidden_states, encoder_hidden_states, factors) -> torch.Tensor:
- _, _, h = factors[0].shape
- shift_factor, scale_factor = factors[0].view(-1, h), factors[1].view(-1, h)
-
- shift_factor_hidden_states, shift_factor_encoder_hidden_states = shift_factor.chunk(2, dim=0)
- scale_factor_hidden_states, scale_factor_encoder_hidden_states = scale_factor.chunk(2, dim=0)
- shift_factor_hidden_states = shift_factor_hidden_states.unsqueeze(1)
- scale_factor_hidden_states = scale_factor_hidden_states.unsqueeze(1)
- hidden_states = torch.addcmul(shift_factor_hidden_states, hidden_states, (1 + scale_factor_hidden_states))
-
- shift_factor_encoder_hidden_states = shift_factor_encoder_hidden_states.unsqueeze(1)
- scale_factor_encoder_hidden_states = scale_factor_encoder_hidden_states.unsqueeze(1)
- encoder_hidden_states = torch.addcmul(
- shift_factor_encoder_hidden_states, encoder_hidden_states, (1 + scale_factor_encoder_hidden_states)
- )
-
- return hidden_states, encoder_hidden_states
-
- def multi_gate(self, hidden_states, encoder_hidden_states, factor):
- _, _, hidden_dim = hidden_states.shape
- gate_factor = factor.view(-1, hidden_dim)
- gate_factor_hidden_states, gate_factor_encoder_hidden_states = gate_factor.chunk(2, dim=0)
- gate_factor_hidden_states = gate_factor_hidden_states.unsqueeze(1)
- gate_factor_encoder_hidden_states = gate_factor_encoder_hidden_states.unsqueeze(1)
- hidden_states = gate_factor_hidden_states * hidden_states
- encoder_hidden_states = gate_factor_encoder_hidden_states * encoder_hidden_states
-
- return hidden_states, encoder_hidden_states
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -207,59 +225,40 @@ def forward(
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- batch_size, encoder_hidden_states_len, hidden_dim = encoder_hidden_states.shape
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- residual = hidden_states
- layernorm_factor = (
- self.adaln(temb)
- .view(
- temb.shape[0],
- 6,
- 2,
- hidden_states.shape[-1],
- )
- .permute(1, 2, 0, 3)
- .contiguous()
- )
- hidden_states = self.layernorm(hidden_states)
- hidden_states, encoder_hidden_states = self.multi_modulate(
- hidden_states=hidden_states[:, encoder_hidden_states_len:],
- encoder_hidden_states=hidden_states[:, :encoder_hidden_states_len],
- factors=(layernorm_factor[0], layernorm_factor[1]),
- )
- hidden_states, encoder_hidden_states = self.attn1(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
+ # 1. Timestep conditioning
+ (
+ norm_hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ norm_encoder_hidden_states,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ ) = self.norm1(hidden_states, encoder_hidden_states, temb)
+
+ # 2. Attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
- hidden_states, encoder_hidden_states = self.multi_gate(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- factor=layernorm_factor[2],
- )
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- hidden_states += residual
-
- residual = hidden_states
- hidden_states = self.layernorm(hidden_states)
- hidden_states, encoder_hidden_states = self.multi_modulate(
- hidden_states=hidden_states[:, encoder_hidden_states_len:],
- encoder_hidden_states=hidden_states[:, :encoder_hidden_states_len],
- factors=(layernorm_factor[3], layernorm_factor[4]),
- )
- hidden_states = self.ff(hidden_states)
- encoder_hidden_states = self.ff(encoder_hidden_states)
- hidden_states, encoder_hidden_states = self.multi_gate(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- factor=layernorm_factor[5],
- )
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- hidden_states += residual
- hidden_states, encoder_hidden_states = (
- hidden_states[:, encoder_hidden_states_len:],
- hidden_states[:, :encoder_hidden_states_len],
- )
+ hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
+
+ # 3. Feedforward
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
+ 1 + c_scale_mlp.unsqueeze(1)
+ ) + c_shift_mlp.unsqueeze(1)
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output_context = self.ff(norm_encoder_hidden_states)
+ hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
+
return hidden_states, encoder_hidden_states
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 1f0ffe7d2154..c72c2f2dba5f 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -78,12 +78,7 @@ class CogView4Pipeline(DiffusionPipeline):
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
-
- _callback_tensor_inputs = [
- "latents",
- "prompt_embeds",
- "negative_prompt_embeds",
- ]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -159,9 +154,9 @@ def encode_prompt(
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
- max_sequence_length: int = 1024,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 1024,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -184,12 +179,12 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- max_sequence_length (`int`, defaults to `1024`):
- Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
+ max_sequence_length (`int`, defaults to `1024`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
"""
device = device or self._execution_device
@@ -200,24 +195,10 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds = self._get_glm_embeds(
- prompt=prompt,
- num_images_per_prompt=num_images_per_prompt,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
-
- if do_classifier_free_guidance and negative_prompt is None:
- negative_prompt_embeds = self._get_glm_embeds(
- prompt="",
- num_images_per_prompt=num_images_per_prompt,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
+ prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype)
if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
@@ -233,11 +214,7 @@ def encode_prompt(
)
negative_prompt_embeds = self._get_glm_embeds(
- prompt=negative_prompt,
- num_images_per_prompt=num_images_per_prompt,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
+ negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype
)
return prompt_embeds, negative_prompt_embeds
@@ -347,7 +324,6 @@ def __call__(
timesteps: Optional[List[int]] = None,
guidance_scale: float = 5.0,
num_images_per_prompt: int = 1,
- eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
From c1b80045b23237918021de581add4ad85a2dc7fd Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 12 Feb 2025 10:39:49 +0100
Subject: [PATCH 55/68] refactor pt. 3
---
.../transformers/transformer_cogview4.py | 149 ++++--------------
.../pipelines/cogview4/pipeline_cogview4.py | 49 +++---
2 files changed, 56 insertions(+), 142 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 034db2701e6f..48fd43d8fb6f 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -46,38 +46,23 @@ def __init__(
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.pos_embed_max_size = pos_embed_max_size
- # Linear projection for image patches
- self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
- # Linear projection for text embeddings
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
- def forward(
- self, hidden_states: torch.Tensor, prompt_embeds: torch.Tensor, negative_prompt_embeds: torch.Tensor | None
- ) -> torch.Tensor:
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
+ post_patch_height = height // self.patch_size
+ post_patch_width = width // self.patch_size
- if height % self.patch_size != 0 or width % self.patch_size != 0:
- raise ValueError("Height and width must be divisible by patch size")
-
- patch_height = height // self.patch_size
- patch_width = width // self.patch_size
-
- # b, c, h, w -> b, c, patch_height, patch_size, patch_width, patch_size
- # -> b, patch_height, patch_width, c, patch_size, patch_size
- # -> b, patch_height * patch_width, c * patch_size * patch_size
- hidden_states = (
- hidden_states.reshape(batch_size, channel, patch_height, self.patch_size, patch_width, self.patch_size)
- .permute(0, 2, 4, 1, 3, 5)
- .reshape(batch_size, patch_height * patch_width, channel * self.patch_size * self.patch_size)
+ hidden_states = hidden_states.reshape(
+ batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
)
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
+ hidden_states = self.proj(hidden_states)
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
- # project
- hidden_states = self.proj(hidden_states) # embed_dim: 64 -> 4096
- prompt_embeds = self.text_proj(prompt_embeds) # embed_dim: 4096 -> 4096
- if negative_prompt_embeds is not None:
- negative_prompt_embeds = self.text_proj(negative_prompt_embeds) # embed_dim: 4096 -> 4096
- return hidden_states, prompt_embeds, negative_prompt_embeds
+ return hidden_states, encoder_hidden_states
class CogView4AdaLayerNormZero(nn.Module):
@@ -347,10 +332,10 @@ def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
+ out_channels: int = 16,
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
- out_channels: int = 16,
text_embed_dim: int = 4096,
time_embed_dim: int = 512,
condition_dim: int = 256,
@@ -402,116 +387,46 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
- prompt_embeds: torch.Tensor,
- negative_prompt_embeds: Optional[torch.Tensor],
+ encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
- """
- The [`CogView3PlusTransformer2DModel`] forward method.
-
- Args:
- hidden_states (`torch.Tensor`):
- Input `hidden_states` of shape `(batch size, channel, height, width)`.
- encoder_hidden_states (`torch.Tensor`):
- Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
- `(batch_size, sequence_len, text_embed_dim)`
- timestep (`torch.LongTensor`):
- Used to indicate denoising step.
- original_size (`torch.Tensor`):
- CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
- target_size (`torch.Tensor`):
- CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
- crop_coords (`torch.Tensor`):
- CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
- tuple.
-
- Returns:
- `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
- The denoised latents using provided inputs as conditioning.
- """
batch_size, num_channels, height, width = hidden_states.shape
- do_cfg = negative_prompt_embeds is not None
-
- if do_cfg:
- assert (
- batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0]
- ), "batch size mismatch in CFG mode"
- else:
- assert batch_size == prompt_embeds.shape[0], "batch size mismatch in non-CFG mode"
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
- # 2. Conditional embeddings
+ # 2. Patch & Timestep embeddings
+ p = self.config.patch_size
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states)
+
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
temb = F.silu(temb)
- temb_cond, temb_uncond = temb.chunk(2)
- hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
- hidden_states, prompt_embeds, negative_prompt_embeds
- )
- hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
- encoder_hidden_states_cond = prompt_embeds
- encoder_hidden_states_uncond = negative_prompt_embeds
-
- for index_block, block in enumerate(self.transformer_blocks):
+ # 3. Transformer blocks
+ for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states_cond, encoder_hidden_states_cond = self._gradient_checkpointing_func(
- block, hidden_states_cond, encoder_hidden_states_cond, temb_cond, image_rotary_emb
- )
- hidden_states_uncond, encoder_hidden_states_uncond = self._gradient_checkpointing_func(
- block, hidden_states_uncond, encoder_hidden_states_uncond, temb_uncond, image_rotary_emb
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
else:
- hidden_states_cond, encoder_hidden_states_cond = block(
- hidden_states_cond, encoder_hidden_states_cond, temb_cond, image_rotary_emb
- )
- hidden_states_uncond, encoder_hidden_states_uncond = block(
- hidden_states_uncond, encoder_hidden_states_uncond, temb_uncond, image_rotary_emb
+ hidden_states, encoder_hidden_states = block(
+ hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
- hidden_states_cond, encoder_hidden_states_cond = (
- self.norm_out(hidden_states_cond, temb_cond),
- self.norm_out(encoder_hidden_states_cond, temb_cond),
- )
- hidden_states_uncond, encoder_hidden_states_uncond = (
- self.norm_out(hidden_states_uncond, temb_uncond),
- self.norm_out(encoder_hidden_states_uncond, temb_uncond),
- )
-
- hidden_states_cond = self.proj_out(hidden_states_cond)
- hidden_states_uncond = self.proj_out(hidden_states_uncond)
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
- # unpatchify
- patch_size = self.config.patch_size
- height = height // patch_size
- width = width // patch_size
-
- hidden_states_cond = hidden_states_cond.reshape(
- shape=(hidden_states_cond.shape[0], height, width, -1, patch_size, patch_size)
- )
- hidden_states_cond = torch.einsum("nhwcpq->nchpwq", hidden_states_cond)
- output_cond = hidden_states_cond.reshape(
- shape=(hidden_states_cond.shape[0], -1, height * patch_size, width * patch_size)
- )
-
- hidden_states_uncond = hidden_states_uncond.reshape(
- hidden_states_uncond.shape[0], height, width, -1, patch_size, patch_size
- )
- hidden_states_uncond = torch.einsum("nhwcpq->nchpwq", hidden_states_uncond)
- output_uncond = hidden_states_uncond.reshape(
- hidden_states_uncond.shape[0], -1, height * patch_size, width * patch_size
- )
+ # 4. Unpatchify
+ hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
+ output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
if not return_dict:
- return (output_cond, output_uncond)
- return Transformer2DModelOutput(sample=output_cond), Transformer2DModelOutput(sample=output_uncond)
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index c72c2f2dba5f..843c1a253fff 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -453,16 +453,11 @@ def __call__(
device = self._execution_device
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = self.do_classifier_free_guidance
-
# Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
- do_classifier_free_guidance,
+ self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@@ -484,18 +479,13 @@ def __call__(
)
# Prepare additional timestep conditions
- original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
- target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
- crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
-
- if do_classifier_free_guidance:
- original_size = torch.cat([original_size, original_size])
- target_size = torch.cat([target_size, target_size])
- crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
- original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
- target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
- crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
+ target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
+ crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
# Prepare timesteps
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
@@ -513,28 +503,37 @@ def __call__(
if self.interrupt:
continue
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
latent_model_input = latent_model_input.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
+ noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
- )
+ )[0]
# perform guidance
- if do_classifier_free_guidance:
- noise_pred_cond, noise_pred_uncond = noise_pred
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=negative_prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ return_dict=False,
+ )[0]
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = noise_pred_cond
latents = self.scheduler.step(noise_pred, latents, t).prev_sample
From 9d55d0a3e2817eae50744b2a974631aca24cdecf Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 12 Feb 2025 12:01:26 +0100
Subject: [PATCH 56/68] add tests
---
src/diffusers/pipelines/__init__.py | 1 +
.../pipelines/cogview4/pipeline_cogview4.py | 5 +-
.../test_models_transformer_cogview4.py | 83 +++++++
tests/pipelines/cogview4/__init__.py | 0
tests/pipelines/cogview4/test_cogview4.py | 228 ++++++++++++++++++
5 files changed, 315 insertions(+), 2 deletions(-)
create mode 100644 tests/models/transformers/test_models_transformer_cogview4.py
create mode 100644 tests/pipelines/cogview4/__init__.py
create mode 100644 tests/pipelines/cogview4/test_cogview4.py
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 048dfabb0923..19c054b9ec40 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -497,6 +497,7 @@
CogVideoXVideoToVideoPipeline,
)
from .cogview3 import CogView3PlusPipeline
+ from .cogview4 import CogView4Pipeline
from .consisid import ConsisIDPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 843c1a253fff..e225cd7b76fa 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -16,7 +16,7 @@
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
-from transformers import GlmModel
+from transformers import AutoTokenizer, GlmModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
@@ -82,7 +82,7 @@ class CogView4Pipeline(DiffusionPipeline):
def __init__(
self,
- tokenizer: GlmModel,
+ tokenizer: AutoTokenizer,
text_encoder: GlmModel,
vae: AutoencoderKL,
transformer: CogView4Transformer2DModel,
@@ -493,6 +493,7 @@ def __call__(
)
self.scheduler.set_timesteps(num_inference_steps, image_seq_len, device)
timesteps = self.scheduler.timesteps
+ self._num_timesteps = len(timesteps)
# Denoising loop
transformer_dtype = self.transformer.dtype
diff --git a/tests/models/transformers/test_models_transformer_cogview4.py b/tests/models/transformers/test_models_transformer_cogview4.py
new file mode 100644
index 000000000000..e311ce77ea50
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_cogview4.py
@@ -0,0 +1,83 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import CogView4Transformer2DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = CogView4Transformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ height = 8
+ width = 8
+ embedding_dim = 8
+ sequence_length = 8
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
+ target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
+ crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "original_size": original_size,
+ "target_size": target_size,
+ "crop_coords": crop_coords,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (4, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 2,
+ "in_channels": 4,
+ "num_layers": 2,
+ "attention_head_dim": 4,
+ "num_attention_heads": 4,
+ "out_channels": 4,
+ "text_embed_dim": 8,
+ "time_embed_dim": 8,
+ "condition_dim": 4,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"CogView4Transformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/pipelines/cogview4/__init__.py b/tests/pipelines/cogview4/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/cogview4/test_cogview4.py b/tests/pipelines/cogview4/test_cogview4.py
new file mode 100644
index 000000000000..56bdb138b108
--- /dev/null
+++ b/tests/pipelines/cogview4/test_cogview4.py
@@ -0,0 +1,228 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, GlmConfig, GlmForCausalLM
+
+from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class CogView4PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = CogView4Pipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = CogView4Transformer2DModel(
+ patch_size=2,
+ in_channels=4,
+ num_layers=2,
+ attention_head_dim=4,
+ num_attention_heads=4,
+ out_channels=4,
+ text_embed_dim=32,
+ time_embed_dim=8,
+ condition_dim=4,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ )
+
+ torch.manual_seed(0)
+ scheduler = CogView4DDIMScheduler()
+
+ torch.manual_seed(0)
+ text_encoder_config = GlmConfig(
+ hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
+ )
+ text_encoder = GlmForCausalLM(text_encoder_config)
+ # TODO(aryan): change this to THUDM/CogView4 once released
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+ expected_image = torch.randn(3, 16, 16)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
From 5e6de42509685d192556c51118726dfd1c11c467 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 12 Feb 2025 12:06:06 +0100
Subject: [PATCH 57/68] make fix-copies
---
src/diffusers/pipelines/__init__.py | 1 +
.../pipelines/cogview4/pipeline_cogview4.py | 2 +-
.../schedulers/scheduling_ddim_cogview4.py | 1 -
src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++
.../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++
5 files changed, 32 insertions(+), 2 deletions(-)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 19c054b9ec40..3d6cdfba3b29 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -155,6 +155,7 @@
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline"]
+ _import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index e225cd7b76fa..f3936b2ff856 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -237,11 +237,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
+
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
- # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogview4.py b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
index d596c14caf5e..9c14be220ab7 100644
--- a/src/diffusers/schedulers/scheduling_ddim_cogview4.py
+++ b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
@@ -28,7 +28,6 @@
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
-# Copied from diffusers.schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput
@dataclass
class CogView4DDIMSchedulerOutput(BaseOutput):
"""
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 005365bb8841..49f165936a3c 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -276,6 +276,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class CogView4Transformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ConsisIDTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index b899915c3046..d8785392476d 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class CogView4Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class ConsisIDPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
From 2046cf205a262f378e878f6379e87911cf72c574 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 12 Feb 2025 12:18:38 +0100
Subject: [PATCH 58/68] update toctree.yml
---
docs/source/en/_toctree.yml | 2 ++
1 file changed, 2 insertions(+)
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 4666c8e0311b..7a1088f63521 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -384,6 +384,8 @@
title: CogVideoX
- local: api/pipelines/cogview3
title: CogView3
+ - local: api/pipelines/cogview4
+ title: CogView4
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/consistency_models
From 39e1198029b8df98cdda202066073734f00d7d6d Mon Sep 17 00:00:00 2001
From: Aryan
Date: Thu, 13 Feb 2025 13:21:15 +0100
Subject: [PATCH 59/68] use flow match scheduler instead of custom
---
scripts/convert_cogview4_to_diffusers.py | 17 +-
.../convert_cogview4_to_diffusers_megatron.py | 22 +-
src/diffusers/__init__.py | 2 -
.../pipelines/cogview4/pipeline_cogview4.py | 124 +++++++--
src/diffusers/schedulers/__init__.py | 2 -
.../schedulers/scheduling_ddim_cogview4.py | 256 ------------------
.../scheduling_flow_match_euler_discrete.py | 17 +-
src/diffusers/utils/dummy_pt_objects.py | 15 -
tests/pipelines/cogview4/test_cogview4.py | 10 +-
9 files changed, 136 insertions(+), 329 deletions(-)
delete mode 100644 src/diffusers/schedulers/scheduling_ddim_cogview4.py
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
index 5d09d1206e7c..484c817dd938 100644
--- a/scripts/convert_cogview4_to_diffusers.py
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -32,7 +32,7 @@
from accelerate import init_empty_weights
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
-from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
+from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
@@ -222,19 +222,8 @@ def main(args):
for param in text_encoder.parameters():
param.data = param.data.contiguous()
- scheduler = CogView4DDIMScheduler.from_config(
- {
- "shift_scale": 1.0,
- "beta_end": 0.012,
- "beta_schedule": "scaled_linear",
- "beta_start": 0.00085,
- "clip_sample": False,
- "num_train_timesteps": 1000,
- "prediction_type": "v_prediction",
- "rescale_betas_zero_snr": True,
- "set_alpha_to_one": True,
- "timestep_spacing": "linspace",
- }
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
)
pipe = CogView4Pipeline(
diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py
index 61a9fb3eb2fd..de5354952493 100644
--- a/scripts/convert_cogview4_to_diffusers_megatron.py
+++ b/scripts/convert_cogview4_to_diffusers_megatron.py
@@ -27,12 +27,7 @@
from tqdm import tqdm
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
-from diffusers import (
- AutoencoderKL,
- CogView4DDIMScheduler,
- CogView4Pipeline,
- CogView4Transformer2DModel,
-)
+from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
@@ -345,19 +340,8 @@ def main(args):
param.data = param.data.contiguous()
# Initialize the scheduler
- scheduler = CogView4DDIMScheduler.from_config(
- {
- "shift_scale": 1.0,
- "beta_end": 0.012,
- "beta_schedule": "scaled_linear",
- "beta_start": 0.00085,
- "clip_sample": False,
- "num_train_timesteps": 1000,
- "prediction_type": "v_prediction",
- "rescale_betas_zero_snr": True,
- "set_alpha_to_one": True,
- "timestep_spacing": "linspace",
- }
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
)
# Create the pipeline
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index bf5af6b6c6cb..a9e7c823db41 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -188,7 +188,6 @@
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
- "CogView4DDIMScheduler",
"DDIMInverseScheduler",
"DDIMParallelScheduler",
"DDIMScheduler",
@@ -707,7 +706,6 @@
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
- CogView4DDIMScheduler,
DDIMInverseScheduler,
DDIMParallelScheduler,
DDIMScheduler,
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index f3936b2ff856..3059fca138ec 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
+import numpy as np
import torch
from transformers import AutoTokenizer, GlmModel
@@ -22,7 +24,7 @@
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, CogView4Transformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
-from ...schedulers import CogView4DDIMScheduler
+from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView4PipelineOutput
@@ -53,6 +55,82 @@
"""
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ base_shift: float = 0.25,
+ max_shift: float = 0.75,
+):
+ # m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ # b = base_shift - m * base_seq_len
+ # mu = image_seq_len * m + b
+ # return mu
+
+ m = (image_seq_len / base_seq_len) ** 0.5
+ mu = m * max_shift + base_shift
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
class CogView4Pipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using CogView4.
@@ -86,7 +164,7 @@ def __init__(
text_encoder: GlmModel,
vae: AutoencoderKL,
transformer: CogView4Transformer2DModel,
- scheduler: CogView4DDIMScheduler,
+ scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
@@ -219,8 +297,10 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ if latents is not None:
+ return latents.to(device)
+
shape = (
batch_size,
num_channels_latents,
@@ -232,14 +312,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
-
- if latents is None:
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- else:
- latents = latents.to(device)
-
- # scale the initial noise by the standard deviation required by the scheduler
- latents = latents * self.scheduler.init_noise_sigma
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def check_inputs(
@@ -322,6 +395,7 @@ def __call__(
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -359,6 +433,10 @@ def __call__(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -491,9 +569,22 @@ def __call__(
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
)
- self.scheduler.set_timesteps(num_inference_steps, image_seq_len, device)
- timesteps = self.scheduler.timesteps
- self._num_timesteps = len(timesteps)
+
+ timesteps = (
+ np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
+ if timesteps is None
+ else np.array(timesteps)
+ )
+ timesteps = timesteps.astype(np.int64)
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("base_shift", 0.25),
+ self.scheduler.config.get("max_shift", 0.75),
+ )
+ _, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
+ timesteps = torch.from_numpy(timesteps).to(device)
# Denoising loop
transformer_dtype = self.transformer.dtype
@@ -504,8 +595,7 @@ def __call__(
if self.interrupt:
continue
- latent_model_input = self.scheduler.scale_model_input(latents, t)
- latent_model_input = latent_model_input.to(transformer_dtype)
+ latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
@@ -536,7 +626,7 @@ def __call__(
else:
noise_pred = noise_pred_cond
- latents = self.scheduler.step(noise_pred, latents, t).prev_sample
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 512d28d95c09..bb9088538653 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -44,7 +44,6 @@
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
- _import_structure["scheduling_ddim_cogview4"] = ["CogView4DDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
@@ -145,7 +144,6 @@
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
- from .scheduling_ddim_cogview4 import CogView4DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogview4.py b/src/diffusers/schedulers/scheduling_ddim_cogview4.py
deleted file mode 100644
index 9c14be220ab7..000000000000
--- a/src/diffusers/schedulers/scheduling_ddim_cogview4.py
+++ /dev/null
@@ -1,256 +0,0 @@
-# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
-# and https://github.com/hojonathanho/diffusion
-
-import math
-from dataclasses import dataclass
-from typing import Optional, Tuple, Union
-
-import numpy as np
-import torch
-
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput
-from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
-
-
-@dataclass
-class CogView4DDIMSchedulerOutput(BaseOutput):
- """
- Output class for the scheduler's `step` function output.
-
- Args:
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
- denoising loop.
- """
-
- prev_sample: torch.FloatTensor
-
-
-class CogView4DDIMScheduler(SchedulerMixin, ConfigMixin):
- """
- CogView4 DDIM Scheduler.
-
- This scheduler is a modified version of the DDIM scheduler specifically designed for use with the CogView4 model.
- It implements the denoising process using a deterministic approach based on the DDIM (Denoising Diffusion Implicit
- Models) framework.
-
- The scheduler maintains the core DDIM functionality while being optimized for the CogView4 architecture and its
- specific requirements for image generation tasks.
-
- Args:
- num_train_timesteps (int, optional): The number of diffusion steps to train the model. Defaults to 1000.
- beta_start (float, optional): The starting value of beta for the noise schedule. Defaults to 0.0001.
- beta_end (float, optional): The ending value of beta for the noise schedule. Defaults to 0.02.
- set_alpha_to_one (bool, optional): Whether to set the final alpha cumprod value to 1. Defaults to True.
- """
-
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
- order = 1
-
- @register_to_config
- def __init__(
- self,
- num_train_timesteps: int = 1000,
- beta_start: float = 0.0001,
- beta_end: float = 0.02,
- set_alpha_to_one: bool = True,
- ):
- self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
-
- self.alphas = 1.0 - self.betas
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
-
- self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
-
- # standard deviation of the initial noise distribution
- self.init_noise_sigma = 1.0
-
- # setable values
- self.num_inference_steps = None
-
- def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
- """
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
- current timestep.
-
- Args:
- sample (`torch.Tensor`):
- The input sample.
- timestep (`int`, *optional*):
- The current timestep in the diffusion chain.
-
- Returns:
- `torch.Tensor`:
- A scaled input sample.
- """
- return sample
-
- @staticmethod
- def calculate_shift(
- image_seq_len,
- base_seq_len: int = 256,
- ):
- if isinstance(image_seq_len, int):
- mu = math.sqrt(image_seq_len / base_seq_len)
- elif isinstance(image_seq_len, torch.Tensor):
- mu = torch.sqrt(image_seq_len / base_seq_len)
- else:
- raise ValueError(f"Invalid type for image_seq_len: {type(image_seq_len)}")
-
- mu = mu * 0.75 + 0.25
-
- return mu
-
- @staticmethod
- def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
- return mu / (mu + (1 / sigmas - 1) ** shift_sigma)
-
- def set_timesteps(self, num_inference_steps: int, image_seq_len: int, device: Union[str, torch.device] = None):
- """
- Sets the discrete timesteps used for the diffusion chain. Supporting to be called in every batch.
-
- Args:
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model.
- image_seq_len (`int`):
- The length of the image sequence.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- """
-
- # Check if the requested number of steps is valid
- if num_inference_steps > self.config.num_train_timesteps:
- raise ValueError(
- f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.num_train_timesteps`:"
- f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
- f" maximal {self.config.num_train_timesteps} timesteps."
- )
-
- # Set the current number of inference steps
- self.num_inference_steps = num_inference_steps
-
- timesteps = np.linspace(self.config.num_train_timesteps, 1, num_inference_steps).astype(np.int64)
- self.timestep2idx = {timestep: i for i, timestep in enumerate(timesteps)}
-
- # Convert the numpy array of timesteps into a PyTorch tensor
- timesteps = torch.from_numpy(timesteps).to(device)
-
- mu = self.calculate_shift(image_seq_len)
- sigmas = timesteps / self.config.num_train_timesteps
- sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
-
- self.timesteps = timesteps
- self.sigmas = self.time_shift(mu, 1.0, sigmas).to("cpu")
- self._num_timesteps = len(timesteps)
-
- def step(
- self,
- model_output: torch.Tensor,
- sample: torch.Tensor,
- timestep: int,
- return_dict: bool = True,
- ) -> Union[CogView4DDIMSchedulerOutput, Tuple]:
- """
- Predict the sample from the previous timestep by applying the flow matching update.
-
- This method implements the flow matching step for the CogView4 DDIM scheduler. It takes the model output
- (predicted noise) and current sample, and computes the previous sample by following the flow matching update
- rule.
-
- Args:
- model_output (`torch.Tensor`):
- The output from the diffusion model, typically the predicted noise or velocity.
- sample (`torch.Tensor`):
- The current sample at the current timestep.
- timestep (`int`):
- The current timestep in the diffusion process.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether to return a `CogView4DDIMSchedulerOutput` or a tuple.
-
- Returns:
- `CogView4DDIMSchedulerOutput` or `tuple`:
- If `return_dict` is True, returns a `CogView4DDIMSchedulerOutput` containing the predicted sample at
- the previous timestep. Otherwise, returns a tuple with the predicted sample.
- """
-
- if self.num_inference_steps is None:
- raise ValueError(
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
- )
- idx = self.timestep2idx[timestep.item()]
- sigma = self.sigmas[idx]
- sigma_next = self.sigmas[idx + 1]
- dt = sigma_next - sigma
-
- prev_sample = sample + dt * model_output
-
- if not return_dict:
- return (prev_sample,)
-
- return CogView4DDIMSchedulerOutput(prev_sample=prev_sample)
-
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
- def add_noise(
- self,
- original_samples: torch.Tensor,
- noise: torch.Tensor,
- timesteps: torch.IntTensor,
- ) -> torch.Tensor:
- # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
- # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
- # for the subsequent add_noise calls
- self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
- alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
- timesteps = timesteps.to(original_samples.device)
-
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
-
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
-
- noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
- return noisy_samples
-
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
- def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
- # Make sure alphas_cumprod and timestep have same device and dtype as sample
- self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
- alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
- timesteps = timesteps.to(sample.device)
-
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
- while len(sqrt_alpha_prod.shape) < len(sample.shape):
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
-
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
- while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
-
- velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
- return velocity
-
- def __len__(self):
- return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 185c9fbabb89..b31aa09a0e08 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -88,7 +88,7 @@ def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
- use_dynamic_shifting=False,
+ use_dynamic_shifting: bool = False,
base_shift: Optional[float] = 0.5,
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
@@ -98,6 +98,7 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
+ time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -105,6 +106,9 @@ def __init__(
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
+ if time_shift_type not in {"exponential", "linear"}:
+ raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
@@ -211,7 +215,10 @@ def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+ if self.config.time_shift_type == "exponential":
+ return self._time_shift_exponential(mu, sigma, t)
+ elif self.config.time_shift_type == "linear":
+ return self._time_shift_linear(mu, sigma, t)
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
r"""
@@ -473,5 +480,11 @@ def _convert_to_beta(
)
return sigmas
+ def _time_shift_exponential(self, mu, sigma, t):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ def _time_shift_linear(self, mu, sigma, t):
+ return mu / (mu + (1 / t - 1) ** sigma)
+
def __len__(self):
return self.config.num_train_timesteps
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 0f85e58d7a4f..9dd1e690742f 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1384,21 +1384,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class CogView4DDIMScheduler(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
class DDIMInverseScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/tests/pipelines/cogview4/test_cogview4.py b/tests/pipelines/cogview4/test_cogview4.py
index 56bdb138b108..2a97a0799d76 100644
--- a/tests/pipelines/cogview4/test_cogview4.py
+++ b/tests/pipelines/cogview4/test_cogview4.py
@@ -19,7 +19,7 @@
import torch
from transformers import AutoTokenizer, GlmConfig, GlmForCausalLM
-from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
+from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -76,7 +76,13 @@ def get_dummy_components(self):
)
torch.manual_seed(0)
- scheduler = CogView4DDIMScheduler()
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ base_shift=0.25,
+ max_shift=0.75,
+ base_image_seq_len=256,
+ use_dynamic_shifting=True,
+ time_shift_type="linear",
+ )
torch.manual_seed(0)
text_encoder_config = GlmConfig(
From b4c9fde226c71a0feea970781f0f7247d5d3295e Mon Sep 17 00:00:00 2001
From: Aryan
Date: Thu, 13 Feb 2025 13:22:56 +0100
Subject: [PATCH 60/68] remove scheduling_cogview.py
---
.../schedulers/scheduling_cogview.py | 332 ------------------
1 file changed, 332 deletions(-)
delete mode 100644 src/diffusers/schedulers/scheduling_cogview.py
diff --git a/src/diffusers/schedulers/scheduling_cogview.py b/src/diffusers/schedulers/scheduling_cogview.py
deleted file mode 100644
index 9a0740f56d63..000000000000
--- a/src/diffusers/schedulers/scheduling_cogview.py
+++ /dev/null
@@ -1,332 +0,0 @@
-# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
-
-from typing import List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils.torch_utils import randn_tensor
-from .scheduling_ddim import DDIMSchedulerOutput
-from .scheduling_utils import SchedulerMixin
-
-
-class CogViewScheduler(SchedulerMixin, ConfigMixin):
- """
- `CogViewScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
-
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
- methods the library implements for all schedulers such as loading and saving.
-
- Args:
- num_train_timesteps (`int`, defaults to 1000):
- The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.00085):
- The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.012):
- The final `beta` value.
- prediction_type (`str`, defaults to `v_prediction`):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
- timestep_spacing (`str`, defaults to `leading`):
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- steps_offset (`int`, defaults to 0):
- An offset added to the inference steps, as required by some model families.
- num_inference_steps (`int`, defaults to 50):
- The number of inference steps to use.
- scale_factor (`float`, defaults to 1.0):
- Scaling factor to apply to the model input.
- snr_shift_scale (`float`, defaults to 1.0):
- Scale factor for shifting the signal-to-noise ratio.
- zero_snr (`bool`, defaults to True):
- Whether to adjust the alphas to achieve zero terminal SNR.
- """
-
- @register_to_config
- def __init__(
- self,
- num_train_timesteps: int = 1000,
- beta_start: float = 0.00085,
- beta_end: float = 0.012,
- prediction_type: str = "v_prediction",
- timestep_spacing: str = "leading",
- steps_offset: int = 0,
- num_inference_steps: int = 50,
- scale_factor: float = 1.0,
- snr_shift_scale: float = 1.0,
- zero_snr: bool = True,
- ):
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
-
- self.alphas = 1.0 - self.betas
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
- # SNR shift
- self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
- sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
- if zero_snr:
- sqrt_alphas_cumprod_0 = sqrt_alphas_cumprod[0]
- sqrt_alphas_cumprod_T_1 = sqrt_alphas_cumprod[-1]
- sqrt_alphas_cumprod -= sqrt_alphas_cumprod_T_1
- sqrt_alphas_cumprod *= sqrt_alphas_cumprod_0 / (sqrt_alphas_cumprod_0 - sqrt_alphas_cumprod_T_1)
- self.sqrt_alphas_cumprod = sqrt_alphas_cumprod
- self.sigmas = torch.sqrt(1 - sqrt_alphas_cumprod**2)
-
- def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
- """
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
- current timestep.
-
- Args:
- sample (`torch.Tensor`):
- The input sample.
- timestep (`int`, *optional*):
- The current timestep in the diffusion chain.
-
- Returns:
- `torch.Tensor`:
- A scaled input sample.
- """
- return sample * self.scale_factor
-
- def set_timesteps(
- self,
- num_inference_steps: Optional[int] = None,
- device: Union[str, torch.device] = None,
- timesteps: Optional[List[int]] = None,
- ):
- """
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
-
- Args:
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
- `num_inference_steps` must be `None`.
-
- """
- if num_inference_steps is not None and timesteps is not None:
- raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
-
- if timesteps is not None:
- for i in range(1, len(timesteps)):
- if timesteps[i] >= timesteps[i - 1]:
- raise ValueError("`custom_timesteps` must be in descending order.")
-
- if timesteps[0] >= self.config.num_train_timesteps:
- raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
- )
-
- timesteps = np.array(timesteps, dtype=np.int64)
- self.custom_timesteps = True
- else:
- if num_inference_steps > self.config.num_train_timesteps:
- raise ValueError(
- f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
- f" maximal {self.config.num_train_timesteps} timesteps."
- )
-
- self.num_inference_steps = num_inference_steps
- self.custom_timesteps = False
-
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
- if self.config.timestep_spacing == "linspace":
- timesteps = (
- np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
- .round()[::-1]
- .copy()
- .astype(np.int64)
- )
- elif self.config.timestep_spacing == "leading":
- step_ratio = self.config.num_train_timesteps // self.num_inference_steps
- # creates integer timesteps by multiplying by ratio
- # casting to int to avoid issues when num_inference_step is power of 3
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
- timesteps += self.config.steps_offset
- elif self.config.timestep_spacing == "trailing":
- step_ratio = self.config.num_train_timesteps / self.num_inference_steps
- # creates integer timesteps by multiplying by ratio
- # casting to int to avoid issues when num_inference_step is power of 3
- timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
- timesteps -= 1
- else:
- raise ValueError(
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
- )
-
- self.timesteps = torch.from_numpy(timesteps).to(device)
-
- def step(
- self,
- model_output: torch.Tensor,
- timestep: int,
- sample: torch.Tensor,
- eta: float = 1.0,
- generator=None,
- variance_noise: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ) -> Union[DDIMSchedulerOutput, Tuple]:
- """
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
- process from the learned model outputs (most often the predicted noise).
-
- Args:
- model_output (`torch.Tensor`):
- The direct output from learned diffusion model.
- timestep (`float`):
- The current discrete timestep in the diffusion chain.
- sample (`torch.Tensor`):
- A current instance of a sample created by the diffusion process.
- eta (`float`):
- The weight of noise for added noise in diffusion step.
- use_clipped_model_output (`bool`, defaults to `False`):
- If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
- because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
- clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
- `use_clipped_model_output` has no effect.
- generator (`torch.Generator`, *optional*):
- A random number generator.
- variance_noise (`torch.Tensor`):
- Alternative to generating noise with `generator` by directly providing the noise for the variance
- itself. Useful for methods such as [`CycleDiffusion`].
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
-
- Returns:
- [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
- tuple is returned where the first element is the sample tensor.
-
- """
- if self.num_inference_steps is None:
- raise ValueError(
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
- )
-
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
- # Ideally, read DDIM paper in-detail understanding
-
- # Notation ( ->
- # - pred_noise_t -> e_theta(x_t, t)
- # - pred_original_sample -> f_theta(x_t, t) or x_0
- # - std_dev_t -> sigma_t
- # - eta -> η
- # - pred_sample_direction -> "direction pointing to x_t"
- # - pred_prev_sample -> "x_t-1"
-
- # 1. get previous step value (=t-1)
- prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
-
- # 2. compute alphas, betas
- alpha_prod_t = self.alphas_cumprod[timestep]
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else 1.0
- sigma_t = eta * torch.sqrt(
- (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
- )
-
- beta_prod_t = 1 - alpha_prod_t
-
- # 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
- if self.config.prediction_type == "epsilon":
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
- pred_epsilon = model_output
- elif self.config.prediction_type == "sample":
- pred_original_sample = model_output
- pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
- elif self.config.prediction_type == "v_prediction":
- pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
- pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
- else:
- raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction`"
- )
-
- # 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
- pred_sample_direction = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5) * pred_epsilon
-
- # 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
- prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
-
- if eta > 0:
- if variance_noise is not None and generator is not None:
- raise ValueError(
- "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
- " `variance_noise` stays `None`."
- )
-
- if variance_noise is None:
- variance_noise = randn_tensor(
- model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
- )
- variance = sigma_t * variance_noise
-
- prev_sample = prev_sample + variance
-
- if not return_dict:
- return (
- prev_sample,
- pred_original_sample,
- )
-
- return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
-
- def add_noise(
- self,
- original_samples: torch.Tensor,
- noise: torch.Tensor,
- timesteps: torch.IntTensor,
- apply_scale: bool = True,
- ) -> torch.Tensor:
- # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
- # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
- # for the subsequent add_noise calls
- self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
- self.sigmas = self.sigmas.to(dtype=original_samples.dtype)
- timesteps = timesteps.to(original_samples.device)
-
- sqrt_alpha_prod = self.sqrt_alphas_cumprod[timesteps]
- sigmas = self.sigmas[timesteps]
- assert sqrt_alpha_prod.dim() == 1, f"sqrt_alpha_prod must be a 1D tensor, got {sqrt_alpha_prod.dim()}D"
- assert (
- sqrt_alpha_prod.shape == sigmas.shape
- ), f"sigmas and sqrt_alpha_prod must have the same shape, got {sigmas.shape} and {sqrt_alpha_prod.shape}"
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
- sigmas = sigmas.unsqueeze(-1)
-
- if apply_scale:
- original_samples = original_samples * self.scale_factor
-
- # scale noise and original samples
- noise = noise * sigmas
- original_samples = original_samples * sqrt_alpha_prod
-
- noisy_samples = noise + original_samples
- return noisy_samples
-
- def __len__(self):
- return self.config.num_train_timesteps
From a137e1736fc24a9aeddc5a50d6f2c6408b1396fa Mon Sep 17 00:00:00 2001
From: Aryan
Date: Thu, 13 Feb 2025 14:19:24 +0100
Subject: [PATCH 61/68] add tiktoken to test dependencies
---
setup.py | 2 ++
src/diffusers/dependency_versions_table.py | 1 +
2 files changed, 3 insertions(+)
diff --git a/setup.py b/setup.py
index 0acdcbbb9c52..1da12e158b36 100644
--- a/setup.py
+++ b/setup.py
@@ -130,6 +130,7 @@
"regex!=2019.12.17",
"requests",
"tensorboard",
+ "tiktoken>=0.7.0",
"torch>=1.4",
"torchvision",
"transformers>=4.41.2",
@@ -226,6 +227,7 @@ def run(self):
"safetensors",
"sentencepiece",
"scipy",
+ "tiktoken",
"torchvision",
"transformers",
"phonemizer",
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 7999368f1417..17d5da60347d 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -38,6 +38,7 @@
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
+ "tiktoken": "tiktoken>=0.7.0",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.41.2",
From da420fba175f7b30c79664a35c631c019b06f3e6 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Fri, 14 Feb 2025 01:43:50 +0530
Subject: [PATCH 62/68] Update src/diffusers/models/embeddings.py
Co-authored-by: YiYi Xu
---
src/diffusers/models/embeddings.py | 75 ------------------------------
1 file changed, 75 deletions(-)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 2a77239ada12..c42fbbc9f0a3 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -2611,78 +2611,3 @@ def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds.append(image_embed)
return projected_image_embeds
-
-
-class CogViewRotary2DEmbedding(nn.Module):
- def __init__(
- self,
- kv_channels: int,
- rotary_percent: float,
- max_h: int = 128,
- max_w: int = 128,
- rotary_interleaved: bool = False,
- seq_len_interpolation_factor: float = None,
- inner_interp: bool = False,
- rotary_base: int = 10000,
- ) -> None:
- super().__init__()
-
- dim = kv_channels
- if rotary_percent < 1.0:
- dim = int(dim * rotary_percent)
- self.rotary_interleaved = rotary_interleaved
-
- self.seq_len_interpolation_factor = seq_len_interpolation_factor
- self.inner_interp = inner_interp
-
- dim_h = kv_channels // 2
- dim_w = kv_channels // 2
-
- device = torch.cuda.current_device()
- h_inv_freq = 1.0 / (
- rotary_base
- ** (torch.arange(0, dim_h, 2, dtype=torch.float32, device=device)[: (dim_h // 2)].float() / dim_h)
- )
- w_inv_freq = 1.0 / (
- rotary_base
- ** (torch.arange(0, dim_w, 2, dtype=torch.float32, device=device)[: (dim_w // 2)].float() / dim_w)
- )
-
- h_seq = torch.arange(max_h, device=device, dtype=h_inv_freq.dtype)
- w_seq = torch.arange(max_w, device=device, dtype=w_inv_freq.dtype)
-
- self.freqs_h = torch.outer(h_seq, h_inv_freq)
- self.freqs_w = torch.outer(w_seq, w_inv_freq)
- self.max_h = max_h
- self.max_w = max_w
-
- def forward(
- self,
- h_idx: torch.Tensor,
- w_idx: torch.Tensor,
- target_h: torch.Tensor = None,
- target_w: torch.Tensor = None,
- mask: torch.Tensor = None,
- ) -> torch.Tensor:
- if self.inner_interp:
- inner_h_idx = (h_idx * self.max_h) // target_h
- inner_w_idx = (w_idx * self.max_w) // target_w
-
- h_emb = self.freqs_h[inner_h_idx]
- w_emb = self.freqs_w[inner_w_idx]
-
- else:
- h_emb = self.freqs_h[h_idx]
- w_emb = self.freqs_w[w_idx]
-
- mask = (mask == 1).unsqueeze(-1)
-
- emb = torch.cat([h_emb, w_emb], dim=-1) * mask
-
- assert emb.ndim == 2, f"expected emb to have 2 dimensions, got {emb.ndim}"
- if not self.rotary_interleaved:
- emb = torch.repeat_interleave(emb, 2, dim=0)
- else:
- emb = torch.repeat_interleave(emb, 2, dim=1)
-
- return emb
From 4003b9c506c54119cf50bb9b3e1fc9a7a5975156 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Thu, 13 Feb 2025 21:30:50 +0100
Subject: [PATCH 63/68] apply suggestions from review
---
.../pipelines/cogview4/pipeline_cogview4.py | 22 ++++++++++---------
1 file changed, 12 insertions(+), 10 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 3059fca138ec..0571af20f12c 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -177,7 +177,6 @@ def __init__(
def _get_glm_embeds(
self,
prompt: Union[str, List[str]] = None,
- num_images_per_prompt: int = 1,
max_sequence_length: int = 1024,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
@@ -186,7 +185,6 @@ def _get_glm_embeds(
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
@@ -219,9 +217,6 @@ def _get_glm_embeds(
).hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- _, seq_len, _ = prompt_embeds.shape
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
@@ -273,7 +268,11 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype)
+ prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
+
+ seq_len = prompt_embeds.size(1)
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
@@ -291,9 +290,11 @@ def encode_prompt(
" the batch size of `prompt`."
)
- negative_prompt_embeds = self._get_glm_embeds(
- negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype
- )
+ negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
+
+ seq_len = negative_prompt_embeds.size(1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
@@ -575,7 +576,7 @@ def __call__(
if timesteps is None
else np.array(timesteps)
)
- timesteps = timesteps.astype(np.int64)
+ timesteps = timesteps.astype(np.float32)
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
mu = calculate_shift(
image_seq_len,
@@ -585,6 +586,7 @@ def __call__(
)
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
timesteps = torch.from_numpy(timesteps).to(device)
+ self._num_timesteps = len(timesteps)
# Denoising loop
transformer_dtype = self.transformer.dtype
From 35c0ec6ea7625670d91d684a824ec11cee3c182d Mon Sep 17 00:00:00 2001
From: Aryan
Date: Thu, 13 Feb 2025 22:28:33 +0100
Subject: [PATCH 64/68] use diffusers apply_rotary_emb
---
src/diffusers/models/embeddings.py | 2 +-
.../transformers/transformer_cogview4.py | 34 ++++++-------------
.../pipelines/cogview4/pipeline_cogview4.py | 3 +-
3 files changed, 13 insertions(+), 26 deletions(-)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index c42fbbc9f0a3..390b752abe15 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1199,7 +1199,7 @@ def apply_rotary_emb(
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
- # Used for Stable Audio and OmniGen
+ # Used for Stable Audio, OmniGen and CogView4
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 48fd43d8fb6f..f622791b572f 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -38,14 +38,9 @@ def __init__(
hidden_size: int = 2560,
patch_size: int = 2,
text_hidden_size: int = 4096,
- pos_embed_max_size: int = 128,
):
super().__init__()
- self.in_channels = in_channels
- self.hidden_size = hidden_size
self.patch_size = patch_size
- self.text_hidden_size = text_hidden_size
- self.pos_embed_max_size = pos_embed_max_size
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
@@ -150,16 +145,14 @@ def __call__(
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
- def apply_rotary_emb(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
- cos, sin = freqs
- x_real, x_imag = x.chunk(2, dim=-1)
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
- x_out = cos * x.float() + sin * x_rotated.float()
- return x_out
-
- query[:, :, text_seq_length:, :] = apply_rotary_emb(query[:, :, text_seq_length:, :], image_rotary_emb)
- key[:, :, text_seq_length:, :] = apply_rotary_emb(key[:, :, text_seq_length:, :], image_rotary_emb)
+ query[:, :, text_seq_length:, :] = apply_rotary_emb(
+ query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
+ )
+ key[:, :, text_seq_length:, :] = apply_rotary_emb(
+ key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
+ )
# 4. Attention
hidden_states = F.scaled_dot_product_attention(
@@ -345,7 +338,7 @@ def __init__(
):
super().__init__()
- # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
+ # CogView4 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
# Each of these are sincos embeddings of shape 2 * condition_dim
pooled_projection_dim = 3 * 2 * condition_dim
inner_dim = num_attention_heads * attention_head_dim
@@ -355,13 +348,7 @@ def __init__(
self.rope = CogView4RotaryPosEmbed(attention_head_dim, patch_size, rope_axes_dim, theta=10000.0)
# 2. Patch & Text-timestep embedding
- self.patch_embed = CogView4PatchEmbed(
- in_channels=in_channels,
- hidden_size=inner_dim,
- patch_size=patch_size,
- text_hidden_size=text_embed_dim,
- pos_embed_max_size=pos_embed_max_size,
- )
+ self.patch_embed = CogView4PatchEmbed(in_channels, inner_dim, patch_size, text_embed_dim)
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
@@ -420,10 +407,11 @@ def forward(
hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
+ # 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
- # 4. Unpatchify
+ # 5. Unpatchify
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 0571af20f12c..02dbf128a122 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -570,13 +570,12 @@ def __call__(
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
)
-
timesteps = (
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
if timesteps is None
else np.array(timesteps)
)
- timesteps = timesteps.astype(np.float32)
+ timesteps = timesteps.astype(np.int64).astype(np.float32)
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
mu = calculate_shift(
image_seq_len,
From d328c5ed9f0df487efb531d31adeb817f8cd9528 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Fri, 14 Feb 2025 16:39:27 +0100
Subject: [PATCH 65/68] update flow match scheduler to accept timesteps
---
.../pipelines/cogview4/pipeline_cogview4.py | 33 +++++++------
.../scheduling_flow_match_euler_discrete.py | 46 +++++++++++++++----
2 files changed, 54 insertions(+), 25 deletions(-)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 02dbf128a122..097d1b6aed41 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -60,18 +60,12 @@ def calculate_shift(
base_seq_len: int = 256,
base_shift: float = 0.25,
max_shift: float = 0.75,
-):
- # m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
- # b = base_shift - m * base_seq_len
- # mu = image_seq_len * m + b
- # return mu
-
+) -> float:
m = (image_seq_len / base_seq_len) ** 0.5
mu = m * max_shift + base_shift
return mu
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -103,10 +97,19 @@ def retrieve_timesteps(
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+
if timesteps is not None and sigmas is not None:
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps and not accepts_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif timesteps is not None and sigmas is None:
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -115,9 +118,8 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
- elif sigmas is not None:
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accept_sigmas:
+ elif timesteps is None and sigmas is not None:
+ if not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
@@ -583,8 +585,9 @@ def __call__(
self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
- _, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
- timesteps = torch.from_numpy(timesteps).to(device)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
+ )
self._num_timesteps = len(timesteps)
# Denoising loop
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index b31aa09a0e08..9b26178830d4 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -78,6 +78,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
use_beta_sigmas (`bool`, defaults to False):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
+ time_shift_type (`str`, defaults to "exponential"):
+ The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
"""
_compatibles = []
@@ -247,6 +249,7 @@ def set_timesteps(
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
+ timesteps: Optional[List[float]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -256,41 +259,64 @@ def set_timesteps(
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ sigmas (`List[float]`, *optional*):
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
+ automatically.
+ mu (`float`, *optional*):
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
+ shifting.
+ timesteps (`List[float]`, *optional*):
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
+ automatically.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
- if sigmas is None:
- timesteps = np.linspace(
- self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
- )
+ self.num_inference_steps = num_inference_steps
+ # 1. Prepare default sigmas
+ is_timesteps_provided = timesteps is not None
+ if sigmas is None:
+ if timesteps is None:
+ timesteps = np.linspace(
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
+ )
+ else:
+ timesteps = np.array(timesteps).astype(np.float32)
sigmas = timesteps / self.config.num_train_timesteps
else:
sigmas = np.array(sigmas).astype(np.float32)
- num_inference_steps = len(sigmas)
- self.num_inference_steps = num_inference_steps
+ # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
+ # "exponential" or "linear" type is applied
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
+ # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
+ # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
-
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
-
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ # 5. Convert sigmas and timesteps to tensors and move to specified device
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
- timesteps = sigmas * self.config.num_train_timesteps
+ if not is_timesteps_provided:
+ timesteps = sigmas * self.config.num_train_timesteps
+ else:
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
+ # 6. Append the terminal sigma value.
+ # If a model requires inverted sigma schedule for denoising but
+ # timesteps without inversion, the `invert_sigmas` flag can be set to `True`. This case is only
+ # required in Mochi
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
@@ -298,7 +324,7 @@ def set_timesteps(
else:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
- self.timesteps = timesteps.to(device=device)
+ self.timesteps = timesteps
self.sigmas = sigmas
self._step_index = None
self._begin_index = None
From 4c37ef025c81daeb79acfb10bf1870fe53face37 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Fri, 14 Feb 2025 16:44:44 +0100
Subject: [PATCH 66/68] fix comment
---
.../schedulers/scheduling_flow_match_euler_discrete.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 359aa0ab5c15..7c070074e0e1 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -314,9 +314,8 @@ def set_timesteps(
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
# 6. Append the terminal sigma value.
- # If a model requires inverted sigma schedule for denoising but
- # timesteps without inversion, the `invert_sigmas` flag can be set to `True`. This case is only
- # required in Mochi
+ # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
+ # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
From 90c240b6a9e40e0effc67dbafc25074925813f2e Mon Sep 17 00:00:00 2001
From: Aryan
Date: Fri, 14 Feb 2025 21:42:27 +0100
Subject: [PATCH 67/68] apply review sugestions
---
.../scheduling_flow_match_euler_discrete.py | 28 +++++++++++++++----
1 file changed, 23 insertions(+), 5 deletions(-)
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 7c070074e0e1..43926f1fdcb5 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -245,7 +245,7 @@ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
def set_timesteps(
self,
- num_inference_steps: int = None,
+ num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
@@ -255,7 +255,7 @@ def set_timesteps(
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
+ num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -270,22 +270,40 @@ def set_timesteps(
automatically.
"""
if self.config.use_dynamic_shifting and mu is None:
- raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
+ raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
+
+ if sigmas is not None and timesteps is not None:
+ if len(sigmas) != len(timesteps):
+ raise ValueError("`sigmas` and `timesteps` should have the same length")
+
+ if num_inference_steps is not None:
+ if (sigmas is not None and len(sigmas) != num_inference_steps) or (
+ timesteps is not None and len(timesteps) != num_inference_steps
+ ):
+ raise ValueError(
+ "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
+ )
+ else:
+ num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
self.num_inference_steps = num_inference_steps
# 1. Prepare default sigmas
is_timesteps_provided = timesteps is not None
+
+ if is_timesteps_provided:
+ timesteps = np.array(timesteps).astype(np.float32)
+
if sigmas is None:
if timesteps is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
- else:
- timesteps = np.array(timesteps).astype(np.float32)
sigmas = timesteps / self.config.num_train_timesteps
+ num_inference_steps = len(sigmas)
else:
sigmas = np.array(sigmas).astype(np.float32)
+ num_inference_steps = len(sigmas)
# 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
# "exponential" or "linear" type is applied
From 2f12b7a196d52762b479ac5e851785d86eb3efa3 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Sat, 15 Feb 2025 03:12:43 +0530
Subject: [PATCH 68/68] Update
src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: YiYi Xu
---
src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 43926f1fdcb5..e3bff7582cd9 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -300,7 +300,6 @@ def set_timesteps(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps
- num_inference_steps = len(sigmas)
else:
sigmas = np.array(sigmas).astype(np.float32)
num_inference_steps = len(sigmas)