From 8f07e895f5a50a274a8d8bbc442db33f27b62d5e Mon Sep 17 00:00:00 2001 From: saftle Date: Mon, 15 Jan 2024 23:01:51 -0700 Subject: [PATCH] Basic SDXL Support Most of this code was taken from the SDXL branch upstream. This is to integrate it into the current dev branch. It does not have rebasin or pruning support. --- merge_models.py | 3 +++ sd_meh/merge.py | 38 +++++++++++++++++++++++++++++++++----- sd_meh/rebasin.py | 6 ++---- sd_meh/utils.py | 17 +++++++++++------ 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/merge_models.py b/merge_models.py index 60ae7c4..845f2bb 100644 --- a/merge_models.py +++ b/merge_models.py @@ -113,6 +113,7 @@ type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False), default="INFO", ) +@click.option("-xl", "--sdxl", "sdxl", is_flag=True) def main( model_a, model_b, @@ -140,6 +141,7 @@ def main( presets_alpha_lambda, presets_beta_lambda, logging_level, + sdxl, ): if logging_level: logging.basicConfig(format="%(levelname)s: %(message)s", level=logging_level) @@ -162,6 +164,7 @@ def main( block_weights_preset_beta_b, presets_alpha_lambda, presets_beta_lambda, + sdxl, ) merged = merge_models( diff --git a/sd_meh/merge.py b/sd_meh/merge.py index c1cd05a..92eb5d2 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -28,6 +28,10 @@ NUM_OUTPUT_BLOCKS = 12 NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS +NUM_INPUT_BLOCKS_XL = 9 +NUM_OUTPUT_BLOCKS_XL = 9 +NUM_TOTAL_BLOCKS_XL = NUM_INPUT_BLOCKS_XL + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS_XL + KEY_POSITION_IDS = ".".join( [ "cond_stage_model", @@ -144,6 +148,11 @@ def merge_models( ) -> Dict: thetas = load_thetas(models, prune, device, precision) + sdxl = ( + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" + in thetas["model_a"].keys() + ) + if "model_d" in models: logging.info(f"substract from a") a_sub = simple_merge( @@ -156,6 +165,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) thetas["model_a"] = a_sub @@ -172,6 +182,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) thetas["model_b"] = b_sub @@ -189,6 +200,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) thetas["model_c"] = c_sub @@ -212,6 +224,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) # clip only after the last re-basin iteration if weights_clip: @@ -235,6 +248,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) # clip only after the last re-basin iteration if weights_clip: @@ -257,6 +271,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) else: merged = simple_merge( @@ -269,6 +284,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) if "model_d" in models: @@ -283,6 +299,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + sdxl=sdxl, ) return un_prune_model(merged, thetas, models, device, prune, precision) @@ -335,6 +352,7 @@ def simple_merge( device: str = "cpu", work_device: Optional[str] = None, threads: int = 1, + sdxl: bool = False, ) -> Dict: futures = [] with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress: @@ -352,6 +370,7 @@ def simple_merge( weights_clip, device, work_device, + sdxl, ) futures.append(future) @@ -384,6 +403,7 @@ def rebasin_merge( device="cpu", work_device=None, threads: int = 1, + sdxl: bool = False, ): # WARNING: not sure how this does when 3 models are involved... @@ -413,6 +433,7 @@ def rebasin_merge( device, work_device, threads, + sdxl, ) log_vram("simple merge done") @@ -476,6 +497,7 @@ def merge_key( weights_clip: bool = False, device: str = "cpu", work_device: Optional[str] = None, + sdxl: bool = False, ) -> Optional[Tuple[str, Dict]]: if work_device is None: work_device = device @@ -500,16 +522,22 @@ def merge_key( if "time_embed" in key: weight_index = 0 # before input blocks elif ".out." in key: - weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks + weight_index = ( + NUM_TOTAL_BLOCKS_XL - 1 if sdxl else NUM_TOTAL_BLOCKS - 1 + ) # after output blocks elif m := re_inp.search(key): weight_index = int(m.groups()[0]) elif re_mid.search(key): - weight_index = NUM_INPUT_BLOCKS + weight_index = NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS elif m := re_out.search(key): - weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + int(m.groups()[0]) + weight_index = ( + (NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS) + + NUM_MID_BLOCK + + int(m.groups()[0]) + ) - if weight_index >= NUM_TOTAL_BLOCKS: - raise ValueError(f"illegal block index {key}") + if weight_index >= (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS): + raise ValueError(f"illegal block index {weight_index} for key {key}") if weight_index >= 0: current_bases = {k: w[weight_index] for k, w in weights.items()} diff --git a/sd_meh/rebasin.py b/sd_meh/rebasin.py index 2fbb418..010d67f 100644 --- a/sd_meh/rebasin.py +++ b/sd_meh/rebasin.py @@ -2200,11 +2200,9 @@ def apply_permutation(ps: PermutationSpec, perm, params): def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha): for k in model_a: try: - perm_params = get_permuted_param( - ps, perm, k, model_a - ) + perm_params = get_permuted_param(ps, perm, k, model_a) model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params - except RuntimeError: # dealing with pix2pix and inpainting models + except RuntimeError: # dealing with pix2pix and inpainting models continue return model_a diff --git a/sd_meh/utils.py b/sd_meh/utils.py index f507ae8..ee2f2e4 100644 --- a/sd_meh/utils.py +++ b/sd_meh/utils.py @@ -2,7 +2,7 @@ import logging from sd_meh import merge_methods -from sd_meh.merge import NUM_TOTAL_BLOCKS +from sd_meh.merge import NUM_TOTAL_BLOCKS, NUM_TOTAL_BLOCKS_XL from sd_meh.presets import BLOCK_WEIGHTS_PRESETS MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction)) @@ -13,25 +13,25 @@ ] -def compute_weights(weights, base): +def compute_weights(weights, base, sdxl=False): if not weights: - return [base] * NUM_TOTAL_BLOCKS + return [base] * (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS) if "," not in weights: return weights w_alpha = list(map(float, weights.split(","))) - if len(w_alpha) == NUM_TOTAL_BLOCKS: + if len(w_alpha) == (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS): return w_alpha -def assemble_weights_and_bases(preset, weights, base, greek_letter): +def assemble_weights_and_bases(preset, weights, base, greek_letter, sdxl=False): logging.info(f"Assembling {greek_letter} w&b") if preset: logging.info(f"Using {preset} preset") base, *weights = BLOCK_WEIGHTS_PRESETS[preset] bases = {greek_letter: base} - weights = {greek_letter: compute_weights(weights, base)} + weights = {greek_letter: compute_weights(weights, base, sdxl)} logging.info(f"base_{greek_letter}: {bases[greek_letter]}") logging.info(f"{greek_letter} weights: {weights[greek_letter]}") @@ -70,12 +70,14 @@ def weights_and_bases( block_weights_preset_beta_b, presets_alpha_lambda, presets_beta_lambda, + sdxl: bool = False, ): weights, bases = assemble_weights_and_bases( block_weights_preset_alpha, weights_alpha, base_alpha, "alpha", + sdxl, ) if block_weights_preset_alpha_b: @@ -85,6 +87,7 @@ def weights_and_bases( None, None, "alpha", + sdxl, ) weights, bases = interpolate_presets( weights, @@ -101,6 +104,7 @@ def weights_and_bases( weights_beta, base_beta, "beta", + sdxl, ) if block_weights_preset_beta_b: @@ -110,6 +114,7 @@ def weights_and_bases( None, None, "beta", + sdxl, ) weights, bases = interpolate_presets( weights,