Skip to content

Commit

Permalink
dont merge sdxl kek
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Dec 18, 2023
1 parent f506831 commit 1b46056
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
15 changes: 8 additions & 7 deletions sd_meh/merge_methods.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools
import math
import operator
import textwrap

import torch
from torch import Tensor
from typing import Tuple
Expand Down Expand Up @@ -245,13 +247,14 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs):
if alpha_is_float:
# cancel reflection. without this, eigenvalues often have a complex component
# and then we can't obtain a valid dtype for the merge
u[:, -1] *= torch.nan_to_num(1 / (torch.det(u) * torch.det(v_t)))
u[:, -1] /= (torch.det(u) * torch.det(v_t))

transform = rotation = u @ v_t
print(f"shape: {a.shape} -> {a_neurons.shape} -> {transform.shape}")
det = torch.det(transform)
if torch.abs(det.abs() - 1) > 1e-6:
print("determinant error:", det)
if not torch.isfinite(u).all():
raise ValueError(textwrap.dedent(f"""determinant error: {torch.det(rotation)}.
This can happen when merging on the CPU with the "rotate" method.
Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks.
See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484"""))

if alpha_is_float:
transform = fractional_matrix_power(transform, alpha)
Expand All @@ -277,6 +280,4 @@ def fractional_matrix_power(matrix: Tensor, power: float):
eigenvalues, eigenvectors = torch.linalg.eig(matrix)
eigenvalues.pow_(power)
result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors)
if ((error := result.imag) > 1e-4).any():
print("image error:", error)
return result.real.to(dtype=matrix.dtype)
8 changes: 3 additions & 5 deletions sd_meh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ def compute_weights(weights, base):
if not weights:
return [base] * NUM_TOTAL_BLOCKS

if "," not in weights:
return weights

w_alpha = list(map(float, weights.split(",")))
if len(w_alpha) == NUM_TOTAL_BLOCKS:
return w_alpha
w_alpha[len(w_alpha):NUM_TOTAL_BLOCKS] = [w_alpha[-1]] * max(0, NUM_TOTAL_BLOCKS - len(w_alpha))
w_alpha[NUM_TOTAL_BLOCKS:] = ()
return w_alpha


def assemble_weights_and_bases(preset, weights, base, greek_letter):
Expand Down

0 comments on commit 1b46056

Please sign in to comment.