Skip to content

Commit

Permalink
[Sd3 Dreambooth LoRA] Add text encoder training for the clip encoders (
Browse files Browse the repository at this point in the history
…#8630)

* add clip text-encoder training

* no dora

* text encoder traing fixes

* text encoder traing fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* add text_encoder layers to save_lora

* style

* fix imports

* style

* fix text encoder

* review changes

* review changes

* review changes

* minor change

* add lora tag

* style

* add readme notes

* add tests for clip encoders

* style

* typo

* fixes

* style

* Update tests/lora/test_lora_layers_sd3.py

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/dreambooth/README_sd3.md

Co-authored-by: Sayak Paul <[email protected]>

* minor readme change

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Jun 25, 2024
1 parent 4ad7a1f commit c6e08ec
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 54 deletions.
34 changes: 34 additions & 0 deletions examples/dreambooth/README_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
--push_to_hub
```

### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:

> [!NOTE]
> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.

To perform DreamBooth LoRA with text-encoder training, run:
```bash
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
export OUTPUT_DIR="trained-sd3-lora"

accelerate launch train_dreambooth_lora_sd3.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--dataset_name="Norod78/Yarn-art-style" \
--instance_prompt="a photo of TOK yarn art dog" \
--resolution=1024 \
--train_batch_size=1 \
--train_text_encoder\
--gradient_accumulation_steps=1 \
--optimizer="prodigy"\
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1500 \
--rank=32 \
--seed="0" \
--push_to_hub
```

## Other notes

We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
Loading

2 comments on commit c6e08ec

@SushantGautam
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[rank6]: File "...train_dreambooth_lora_sd3.py", line 1593, in main
[rank6]: prompts, text_encoders, tokenizers
[rank6]: UnboundLocalError: local variable 'text_encoders' referenced before assignment

@neuron-party
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, same issue as above. Seems like there are some bugs with --train-text-encoder set to true.

Please sign in to comment.