diff --git a/sd_meh/merge.py b/sd_meh/merge.py index 760391e..38d3989 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -141,6 +141,7 @@ def merge_models( work_device: Optional[str] = None, prune: bool = False, threads: int = 1, + cache: Optional[Dict] = None, ) -> Dict: thetas = load_thetas(models, prune, device, precision) @@ -169,6 +170,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + cache=cache, ) return un_prune_model(merged, thetas, models, device, prune, precision) @@ -221,6 +223,7 @@ def simple_merge( device: str = "cpu", work_device: Optional[str] = None, threads: int = 1, + cache: Optional[Dict] = None, ) -> Dict: futures = [] with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress: @@ -238,6 +241,7 @@ def simple_merge( weights_clip, device, work_device, + cache, ) futures.append(future) @@ -367,6 +371,7 @@ def merge_key( weights_clip: bool = False, device: str = "cpu", work_device: Optional[str] = None, + cache: Optional[Dict] = None, ) -> Optional[Tuple[str, Dict]]: if work_device is None: work_device = device @@ -410,7 +415,7 @@ def merge_key( except AttributeError as e: raise ValueError(f"{merge_mode} not implemented, aborting merge!") from e - merge_args = get_merge_method_args(current_bases, thetas, key, work_device) + merge_args = get_merge_method_args(current_bases, thetas, key, work_device, cache) # dealing wiht pix2pix and inpainting models if (a_size := merge_args["a"].size()) != (b_size := merge_args["b"].size()): @@ -460,11 +465,16 @@ def get_merge_method_args( thetas: Dict, key: str, work_device: str, + cache: Optional[Dict], ) -> Dict: + if cache is not None and key not in cache: + cache[key] = {} + merge_method_args = { "a": thetas["model_a"][key].to(work_device), "b": thetas["model_b"][key].to(work_device), **current_bases, + "cache": cache[key] if cache is not None else None, } if "model_c" in thetas: diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index c10c459..32cab22 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,8 +1,11 @@ +import functools import math -from typing import Tuple +import operator +import textwrap import torch from torch import Tensor +from typing import Tuple __all__ = [ "weighted_sum", @@ -17,6 +20,7 @@ "similarity_add_difference", "distribution_crossover", "ties_add_difference", + "rotate", ] @@ -209,3 +213,93 @@ def filter_top_k(a: Tensor, k: float): k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) top_k_filter = (torch.abs(a) >= k_value).float() return a * top_k_filter + + +def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): + if alpha == 0 and beta == 0: + return a + + is_conv = len(a.shape) == 4 and a.shape[-1] != 1 + if len(a.shape) == 0 or is_conv or torch.allclose(a.half(), b.half()): + return weighted_sum(a, b, beta) + + if len(a.shape) == 4: + shape_2d = (-1, functools.reduce(operator.mul, a.shape[1:])) + else: + shape_2d = (-1, a.shape[-1]) + + a_neurons = a.reshape(*shape_2d).double() + b_neurons = b.reshape(*shape_2d).double() + + a_centroid = a_neurons.mean(0) + b_centroid = b_neurons.mean(0) + new_centroid = weighted_sum(a_centroid, b_centroid, alpha) + if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1: + return new_centroid.reshape_as(a) + + a_neurons -= a_centroid + b_neurons -= b_centroid + + alpha_is_float = alpha != round(alpha) + + if kwargs["cache"] is not None and "rotation" in kwargs["cache"]: + rotation = transform = kwargs["cache"]["rotation"].to(a.device) + else: + svd_driver = "gesvd" if a.is_cuda else None + u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) + + if alpha_is_float: + # cancel reflection. without this, eigenvalues often have a complex component + # and then we can't obtain a valid dtype for the merge + u[:, -1] /= torch.det(u) * torch.det(v_t) + + rotation = transform = u @ v_t + if not torch.isfinite(u).all(): + raise ValueError( + textwrap.dedent( + f"""determinant error: {torch.det(rotation)}. + This can happen when merging on the CPU with the "rotate" method. + Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks. + See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484""" + ) + ) + + if kwargs["cache"] is not None: + kwargs["cache"]["rotation"] = rotation.cpu() + + if alpha_is_float: + transform = fractional_matrix_power(transform, alpha, kwargs["cache"]) + elif alpha == 0: + transform = torch.eye( + len(transform), + dtype=transform.dtype, + device=transform.device, + ) + elif alpha != 1: + transform = torch.linalg.matrix_power(transform, round(alpha)) + + if beta != 0: + # interpolate the relationship between the neurons + a_neurons = weighted_sum(a_neurons, b_neurons @ rotation.T, beta) + + a_neurons @= transform + a_neurons += new_centroid + return a_neurons.reshape_as(a).to(a.dtype) + + +def fractional_matrix_power(matrix: Tensor, power: float, cache: dict): + if cache is not None and "eigenvalues" in cache: + eigenvalues = cache["eigenvalues"].to(matrix.device) + eigenvectors = cache["eigenvectors"].to(matrix.device) + eigenvectors_inv = cache["eigenvectors_inv"].to(matrix.device) + else: + eigenvalues, eigenvectors = torch.linalg.eig(matrix) + eigenvectors_inv = torch.linalg.inv(eigenvectors) + if cache is not None: + cache["eigenvalues"] = eigenvalues.cpu() + cache["eigenvectors"] = eigenvectors.cpu() + cache["eigenvectors_inv"] = eigenvectors_inv.cpu() + + eigenvalues.pow_(power) + result = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors_inv + return result.real.to(dtype=matrix.dtype) diff --git a/sd_meh/rebasin.py b/sd_meh/rebasin.py index 2fbb418..010d67f 100644 --- a/sd_meh/rebasin.py +++ b/sd_meh/rebasin.py @@ -2200,11 +2200,9 @@ def apply_permutation(ps: PermutationSpec, perm, params): def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha): for k in model_a: try: - perm_params = get_permuted_param( - ps, perm, k, model_a - ) + perm_params = get_permuted_param(ps, perm, k, model_a) model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params - except RuntimeError: # dealing with pix2pix and inpainting models + except RuntimeError: # dealing with pix2pix and inpainting models continue return model_a