Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/sdxl' into oft
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Dec 18, 2023
2 parents f506831 + f8d8f48 commit ac54d26
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 13 deletions.
3 changes: 3 additions & 0 deletions merge_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False),
default="INFO",
)
@click.option("-xl", "--sdxl", "sdxl", is_flag=True)
def main(
model_a,
model_b,
Expand All @@ -138,6 +139,7 @@ def main(
presets_alpha_lambda,
presets_beta_lambda,
logging_level,
sdxl,
):
if logging_level:
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging_level)
Expand All @@ -158,6 +160,7 @@ def main(
block_weights_preset_beta_b,
presets_alpha_lambda,
presets_beta_lambda,
sdxl,
)

merged = merge_models(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sd-meh"
version = "0.9.4"
version = "0.10.0"
description = "stable diffusion merging execution helper"
authors = ["s1dlx <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion sd_meh/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.4"
__version__ = "0.10.0"
32 changes: 27 additions & 5 deletions sd_meh/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
NUM_OUTPUT_BLOCKS = 12
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS

NUM_INPUT_BLOCKS_XL = 9
NUM_OUTPUT_BLOCKS_XL = 9
NUM_TOTAL_BLOCKS_XL = NUM_INPUT_BLOCKS_XL + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS_XL

KEY_POSITION_IDS = ".".join(
[
"cond_stage_model",
Expand Down Expand Up @@ -144,6 +148,11 @@ def merge_models(
) -> Dict:
thetas = load_thetas(models, prune, device, precision)

sdxl = (
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight"
in thetas["model_a"].keys()
)

logging.info(f"start merging with {merge_mode} method")
if re_basin:
merged = rebasin_merge(
Expand All @@ -157,6 +166,7 @@ def merge_models(
device=device,
work_device=work_device,
threads=threads,
sdxl=sdxl,
)
else:
merged = simple_merge(
Expand All @@ -169,6 +179,7 @@ def merge_models(
device=device,
work_device=work_device,
threads=threads,
sdxl=sdxl,
)

return un_prune_model(merged, thetas, models, device, prune, precision)
Expand Down Expand Up @@ -221,6 +232,7 @@ def simple_merge(
device: str = "cpu",
work_device: Optional[str] = None,
threads: int = 1,
sdxl: bool = False,
) -> Dict:
futures = []
with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress:
Expand All @@ -238,6 +250,7 @@ def simple_merge(
weights_clip,
device,
work_device,
sdxl,
)
futures.append(future)

Expand Down Expand Up @@ -270,6 +283,7 @@ def rebasin_merge(
device="cpu",
work_device=None,
threads: int = 1,
sdxl: bool = False,
):
# WARNING: not sure how this does when 3 models are involved...

Expand Down Expand Up @@ -299,6 +313,7 @@ def rebasin_merge(
device,
work_device,
threads,
sdxl,
)

log_vram("simple merge done")
Expand Down Expand Up @@ -367,6 +382,7 @@ def merge_key(
weights_clip: bool = False,
device: str = "cpu",
work_device: Optional[str] = None,
sdxl: bool = False,
) -> Optional[Tuple[str, Dict]]:
if work_device is None:
work_device = device
Expand All @@ -391,16 +407,22 @@ def merge_key(
if "time_embed" in key:
weight_index = 0 # before input blocks
elif ".out." in key:
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
weight_index = (
NUM_TOTAL_BLOCKS_XL - 1 if sdxl else NUM_TOTAL_BLOCKS - 1
) # after output blocks
elif m := re_inp.search(key):
weight_index = int(m.groups()[0])
elif re_mid.search(key):
weight_index = NUM_INPUT_BLOCKS
weight_index = NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS
elif m := re_out.search(key):
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + int(m.groups()[0])
weight_index = (
(NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS)
+ NUM_MID_BLOCK
+ int(m.groups()[0])
)

if weight_index >= NUM_TOTAL_BLOCKS:
raise ValueError(f"illegal block index {key}")
if weight_index >= (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS):
raise ValueError(f"illegal block index {weight_index} for key {key}")

if weight_index >= 0:
current_bases = {k: w[weight_index] for k, w in weights.items()}
Expand Down
17 changes: 11 additions & 6 deletions sd_meh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from sd_meh import merge_methods
from sd_meh.merge import NUM_TOTAL_BLOCKS
from sd_meh.merge import NUM_TOTAL_BLOCKS, NUM_TOTAL_BLOCKS_XL
from sd_meh.presets import BLOCK_WEIGHTS_PRESETS

MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction))
Expand All @@ -13,25 +13,25 @@
]


def compute_weights(weights, base):
def compute_weights(weights, base, sdxl: bool = False):
if not weights:
return [base] * NUM_TOTAL_BLOCKS
return [base] * (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS)

if "," not in weights:
return weights

w_alpha = list(map(float, weights.split(",")))
if len(w_alpha) == NUM_TOTAL_BLOCKS:
if len(w_alpha) == (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS):
return w_alpha


def assemble_weights_and_bases(preset, weights, base, greek_letter):
def assemble_weights_and_bases(preset, weights, base, greek_letter, sdxl: bool = False):
logging.info(f"Assembling {greek_letter} w&b")
if preset:
logging.info(f"Using {preset} preset")
base, *weights = BLOCK_WEIGHTS_PRESETS[preset]
bases = {greek_letter: base}
weights = {greek_letter: compute_weights(weights, base)}
weights = {greek_letter: compute_weights(weights, base, sdxl)}

logging.info(f"base_{greek_letter}: {bases[greek_letter]}")
logging.info(f"{greek_letter} weights: {weights[greek_letter]}")
Expand Down Expand Up @@ -70,12 +70,14 @@ def weights_and_bases(
block_weights_preset_beta_b,
presets_alpha_lambda,
presets_beta_lambda,
sdxl: bool = False,
):
weights, bases = assemble_weights_and_bases(
block_weights_preset_alpha,
weights_alpha,
base_alpha,
"alpha",
sdxl,
)

if block_weights_preset_alpha_b:
Expand All @@ -85,6 +87,7 @@ def weights_and_bases(
None,
None,
"alpha",
sdxl,
)
weights, bases = interpolate_presets(
weights,
Expand All @@ -101,6 +104,7 @@ def weights_and_bases(
weights_beta,
base_beta,
"beta",
sdxl,
)

if block_weights_preset_beta_b:
Expand All @@ -110,6 +114,7 @@ def weights_and_bases(
None,
None,
"beta",
sdxl,
)
weights, bases = interpolate_presets(
weights,
Expand Down

0 comments on commit ac54d26

Please sign in to comment.