Skip to content

Commit

Permalink
sd3 training
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jun 23, 2024
1 parent a518e3c commit d53ea22
Show file tree
Hide file tree
Showing 8 changed files with 1,909 additions and 44 deletions.
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
This repository contains training, generation and utility scripts for Stable Diffusion.

## SD3 training

SD3 training is done with `sd3_train.py`.

`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently.

`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.

t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`.

There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.

```toml
learning_rate = 1e-5 # seems to be too high
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
cache_text_encoder_outputs = true
cache_text_encoder_outputs_to_disk = true
vae_batch_size = 1
cache_latents = true
cache_latents_to_disk = true
```

---

[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。

Expand Down
20 changes: 17 additions & 3 deletions library/sai_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

r"""
Expand Down Expand Up @@ -55,11 +57,14 @@
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3-medium"
ARCH_SD3_UNKNOWN = "stable-diffusion-3"

ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"

IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"

PRED_TYPE_EPSILON = "epsilon"
Expand Down Expand Up @@ -113,7 +118,11 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: str = None,
):
"""
sd3: only supports "m"
"""
# if state_dict is None, hash is not calculated

metadata = {}
Expand All @@ -126,6 +135,11 @@ def build_metadata(

if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
if sd3 == "m":
arch = ARCH_SD3_M
else:
arch = ARCH_SD3_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
Expand All @@ -142,7 +156,7 @@ def build_metadata(
metadata["modelspec.architecture"] = arch

if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion

if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
Expand Down Expand Up @@ -236,7 +250,7 @@ def build_metadata(
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")

return metadata


Expand All @@ -250,7 +264,7 @@ def get_title(metadata: dict) -> Optional[str]:
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}

with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
Expand Down
102 changes: 92 additions & 10 deletions library/sd3_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref
# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref
# the original code is licensed under the MIT License

# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!

from ast import Tuple
from functools import partial
import math
from typing import Dict, Optional
from types import SimpleNamespace
from typing import Dict, List, Optional, Union
import einops
import numpy as np
import torch
Expand Down Expand Up @@ -106,6 +108,8 @@ def __init__(self, t5xxl=True):
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
self.t5xxl = T5XXLTokenizer() if t5xxl else None
# t5xxl has 99999999 max length, clip has 77
self.model_max_length = self.clip_l.max_length # 77

def tokenize_with_weights(self, text: str):
return (
Expand Down Expand Up @@ -870,6 +874,10 @@ def __init__(
self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels)
# self.initialize_weights()

@property
def model_type(self):
return "m" # only support medium

def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
for block in self.joint_blocks:
Expand Down Expand Up @@ -1013,6 +1021,10 @@ def create_mmdit_sd3_medium_configs(attn_mode: str):
# endregion

# region VAE
# TODO support xformers

VAE_SCALE_FACTOR = 1.5305
VAE_SHIFT_FACTOR = 0.0609


def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
Expand Down Expand Up @@ -1222,6 +1234,14 @@ def __init__(self, dtype=torch.float32, device=None):
self.encoder = VAEEncoder(dtype=dtype, device=device)
self.decoder = VAEDecoder(dtype=dtype, device=device)

@property
def device(self):
return next(self.parameters()).device

@property
def dtype(self):
return next(self.parameters()).dtype

@torch.autocast("cuda", dtype=torch.float16)
def decode(self, latent):
return self.decoder(latent)
Expand All @@ -1234,6 +1254,43 @@ def encode(self, image):
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)

@staticmethod
def process_in(latent):
return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR

@staticmethod
def process_out(latent):
return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR


class VAEOutput:
def __init__(self, latent):
self.latent = latent

@property
def latent_dist(self):
return self

def sample(self):
return self.latent


class VAEWrapper:
def __init__(self, vae):
self.vae = vae

@property
def device(self):
return self.vae.device

@property
def dtype(self):
return self.vae.dtype

# latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
def encode(self, image):
return VAEOutput(self.vae.encode(image))


# endregion

Expand Down Expand Up @@ -1370,15 +1427,39 @@ def forward(self, *args, **kwargs):


class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
out, pooled = self([tokens])
if pooled is not None:
first_pooled = pooled[0:1].cpu()
# def encode_token_weights(self, token_weight_pairs):
# tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
# out, pooled = self([tokens])
# if pooled is not None:
# first_pooled = pooled[0:1]
# else:
# first_pooled = pooled
# output = [out[0:1]]
# return torch.cat(output, dim=-2), first_pooled

# fix to support batched inputs
# : Union[List[Tuple[torch.Tensor, torch.Tensor]], List[List[Tuple[torch.Tensor, torch.Tensor]]]]
def encode_token_weights(self, list_of_token_weight_pairs):
has_batch = isinstance(list_of_token_weight_pairs[0][0], list)

if has_batch:
list_of_tokens = []
for pairs in list_of_token_weight_pairs:
tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0]
list_of_tokens.append(tokens)
else:
first_pooled = pooled
output = [out[0:1]]
return torch.cat(output, dim=-2).cpu(), first_pooled
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]

out, pooled = self(list_of_tokens)
if has_batch:
return out, pooled
else:
if pooled is not None:
first_pooled = pooled[0:1]
else:
first_pooled = pooled
output = [out[0:1]]
return torch.cat(output, dim=-2), first_pooled


class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
Expand Down Expand Up @@ -1694,6 +1775,7 @@ def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermed
x = self.embed_tokens(input_ids)
past_bias = None
for i, l in enumerate(self.block):
# uncomment to debug layerwise output: fp16 may cause issues
# print(i, x.mean(), x.std())
x, past_bias = l(x, past_bias)
if i == intermediate_output:
Expand Down
Loading

0 comments on commit d53ea22

Please sign in to comment.