From 1b460568b97bf53f0dc2ef2e5cb3fcbb2b67cedc Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 18 Dec 2023 02:54:22 -0500 Subject: [PATCH] dont merge sdxl kek --- sd_meh/merge_methods.py | 15 ++++++++------- sd_meh/utils.py | 8 +++----- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 6961f55..96f2895 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,6 +1,8 @@ import functools import math import operator +import textwrap + import torch from torch import Tensor from typing import Tuple @@ -245,13 +247,14 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): if alpha_is_float: # cancel reflection. without this, eigenvalues often have a complex component # and then we can't obtain a valid dtype for the merge - u[:, -1] *= torch.nan_to_num(1 / (torch.det(u) * torch.det(v_t))) + u[:, -1] /= (torch.det(u) * torch.det(v_t)) transform = rotation = u @ v_t - print(f"shape: {a.shape} -> {a_neurons.shape} -> {transform.shape}") - det = torch.det(transform) - if torch.abs(det.abs() - 1) > 1e-6: - print("determinant error:", det) + if not torch.isfinite(u).all(): + raise ValueError(textwrap.dedent(f"""determinant error: {torch.det(rotation)}. + This can happen when merging on the CPU with the "rotate" method. + Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks. + See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484""")) if alpha_is_float: transform = fractional_matrix_power(transform, alpha) @@ -277,6 +280,4 @@ def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues, eigenvectors = torch.linalg.eig(matrix) eigenvalues.pow_(power) result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) - if ((error := result.imag) > 1e-4).any(): - print("image error:", error) return result.real.to(dtype=matrix.dtype) diff --git a/sd_meh/utils.py b/sd_meh/utils.py index f507ae8..27e135c 100644 --- a/sd_meh/utils.py +++ b/sd_meh/utils.py @@ -17,12 +17,10 @@ def compute_weights(weights, base): if not weights: return [base] * NUM_TOTAL_BLOCKS - if "," not in weights: - return weights - w_alpha = list(map(float, weights.split(","))) - if len(w_alpha) == NUM_TOTAL_BLOCKS: - return w_alpha + w_alpha[len(w_alpha):NUM_TOTAL_BLOCKS] = [w_alpha[-1]] * max(0, NUM_TOTAL_BLOCKS - len(w_alpha)) + w_alpha[NUM_TOTAL_BLOCKS:] = () + return w_alpha def assemble_weights_and_bases(preset, weights, base, greek_letter):