Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: neuron rotation #50

Draft
wants to merge 44 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
18bff99
oft
ljleb Nov 10, 2023
47bdac9
device
ljleb Nov 10, 2023
0681458
rename
ljleb Nov 10, 2023
1fe0882
fix black
ljleb Nov 10, 2023
2810d89
cayley interpolation for alpha
ljleb Nov 10, 2023
09bed88
refact
ljleb Nov 10, 2023
f11c054
add method to __all__
ljleb Nov 10, 2023
1f497e9
include 1D 'rotation'
ljleb Nov 10, 2023
36fccaa
ignore alpha for now
ljleb Nov 10, 2023
f18208d
refact
ljleb Nov 10, 2023
1dafe83
implement fractional rotations
ljleb Nov 11, 2023
149ab16
fix transform direction
ljleb Nov 11, 2023
1f71391
fix eye
ljleb Nov 11, 2023
b464fd3
rewrite with out=
ljleb Nov 11, 2023
e1dc59c
it works; opt now
ljleb Nov 12, 2023
cbb6a06
optimize: 45m -> 7m
ljleb Nov 12, 2023
ce62946
rm print
ljleb Nov 12, 2023
8172927
fix precision issues
ljleb Nov 12, 2023
19fcc0a
fix precision issues
ljleb Nov 12, 2023
f954270
black
ljleb Nov 12, 2023
e94e252
dont change
ljleb Nov 12, 2023
1f380c8
imps
ljleb Nov 12, 2023
ea95b66
beta is deformation
ljleb Nov 12, 2023
1751f59
simplify
ljleb Nov 12, 2023
c69bb95
@
ljleb Nov 12, 2023
0d5160b
backup
ljleb Nov 13, 2023
1920496
deal with conv attention shape, rotate centroids
ljleb Nov 13, 2023
5a1c776
black
ljleb Nov 17, 2023
f61d6aa
wip
ljleb Nov 18, 2023
6ac82d6
refact
ljleb Nov 18, 2023
7fff089
backup
ljleb Nov 20, 2023
6ddc503
remove approx
ljleb Nov 21, 2023
a6742b3
dont edit
ljleb Nov 21, 2023
d84b776
fix fp16 and fp32 merges
ljleb Dec 7, 2023
d812ea8
reduced svd
ljleb Dec 8, 2023
38d4db6
black
ljleb Dec 8, 2023
3c90395
dont ellipsis
ljleb Dec 17, 2023
f506831
print more info for debug
ljleb Dec 18, 2023
1b46056
dont merge sdxl kek
ljleb Dec 18, 2023
aeb8c99
black
ljleb Dec 18, 2023
de71102
revert utils.py
ljleb Dec 18, 2023
a01e016
cache impl
ljleb Jan 27, 2024
003017e
cache eigen inv
ljleb Jan 27, 2024
81515bd
Update merge.py
ljleb Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions sd_meh/merge_methods.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import functools
import math
from typing import Tuple

import operator
import torch
from torch import Tensor
from typing import Tuple

__all__ = [
"weighted_sum",
Expand All @@ -17,6 +18,7 @@
"similarity_add_difference",
"distribution_crossover",
"ties_add_difference",
"rotate",
]


Expand Down Expand Up @@ -209,3 +211,85 @@ def filter_top_k(a: Tensor, k: float):
k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k)
top_k_filter = (torch.abs(a) >= k_value).float()
return a * top_k_filter


def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs):
if alpha == 0 and beta == 0:
return a

is_conv = len(a.shape) == 4 and a.shape[-1] != 1
if len(a.shape) == 0 or is_conv or torch.allclose(a.half(), b.half()):
return weighted_sum(a, b, beta)

if len(a.shape) == 4:
shape_2d = (-1, functools.reduce(operator.mul, a.shape[1:]))
else:
shape_2d = (-1, a.shape[-1])

a_neurons = a.reshape(*shape_2d).double()
b_neurons = b.reshape(*shape_2d).double()

a_centroid = a_neurons.mean(0)
b_centroid = b_neurons.mean(0)
new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha)
if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1:
return new_centroid.reshape_as(a)

a_neurons -= a_centroid
b_neurons -= b_centroid

svd_driver = "gesvd" if a.is_cuda else None
Copy link

@mariaWitch mariaWitch Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a lot more complex than meets the eye. We should be determining the svd driver based on the size of the matrix. Different drivers perform faster on smaller/bigger matrices. And in some instances the CPU will out perform the GPU. What exactly is our average matrix size when we call svd?

Copy link
Collaborator Author

@ljleb ljleb Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we include all keys, it goes form $320^2$ to ~ $20K^2$. As this upper bound isn't really practical, if we exclude all conv layers (which have the largest neurons), the upper bound is ~ $5K^2$. I can list all sizes in a bit, they all are square matrices.

I've never done this before at all, this is all new to me. Appreciate the help. IIUC, this only matters on cuda devices?

Copy link
Collaborator Author

@ljleb ljleb Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all matrices sizes that currently go through svd are listed below:

  • 320x320: 47 keys
  • 640x640: 48 keys
  • 768x768: 94 keys
  • 960x960: 2 keys
  • 1280x1280: 83 keys
  • 2560x2560: 10 keys
  • 3072x3072: 12 keys
  • 5120x5120: 6 keys

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some benchmarking between jax's svd functions jitted through XLA and pytorch's different drivers on a colab using a v100 (a 3080 is about equal to this in PyTorch Performance), and these were the results.
image
Basically unless you need full accuracy, even with full_matrices set to true, gesvdj is going to be faster. However the speed you gain comes at the cost of some accuracy, and the potential to not always converge without needing to fall back to gesvd.

This comment was marked as outdated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way full_matrices=False doesn't produce a reduced SVD when $m=n$ ($m$ and $n$ being the width and height of the svd input). That's why it didn't seem to affect generation speed. We might want to remove it as it doesn't really change anything, since the input to the svd is always a square covariance matrix here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I did complete a full merge on CUDA, and didn't receive the error. I think it has something to do with trying to move models between the CPU and GPU, interacting with WebUI keeping models loaded in memory. Is there sanity checking when the models are loaded to ensure that they have been moved to CPU if the work_device is set to CPU?

Copy link
Collaborator Author

@ljleb ljleb Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, when assembling the merge args, the weights are sent to the requested device:

meh/sd_meh/merge.py

Lines 465 to 466 in 2780321

"a": thetas["model_a"][key].to(work_device),
"b": thetas["model_b"][key].to(work_device),

note that if work_device is None, it takes the value of device:

meh/sd_meh/merge.py

Lines 371 to 372 in 2780321

if work_device is None:
work_device = device

So IIUC, it shouldn't be a device issue.

Copy link
Collaborator Author

@ljleb ljleb Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found the culprit.

It seems that on CPU there isn't enough precision sometimes, which leads too $U$ or $V^T$ having a determinant of 0. This is not what SVD should output, $U$ and $V^T$ should always be orthogonal transforms, which implies $|det U| = |det V^T| = 1$.

When the determinant of $U$ or $V^T$ is 0, then this line divides by 0:

        u[:, -1] /= torch.det(u) * torch.det(v_t)

So the last column of u sometimes is filled with infinities. Then, when trying to compute the eigenvalues of the matrix, an error is then raised.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted below, while this prevents the entire merge from raising an error, rotations with invalid determinants still result in a broken merge. I went the other direction and raised an error instead.

u, _, v_t = torch.linalg.svd(
a_neurons.T @ b_neurons, full_matrices=False, driver=svd_driver
)

alpha_is_float = alpha != round(alpha)
if alpha_is_float:
# cancel reflection. without this, eigenvalues often have a complex component
# and then we can't obtain a valid dtype for the merge
u[:, -1] /= torch.det(u) * torch.det(v_t)

transform = rotation = u @ v_t
print("shape:", transform.shape)
det = torch.det(transform)
if torch.abs(det.abs() - 1) > 1e-6:
print("determinant error:", det)

if alpha_is_float:
transform = fractional_matrix_power(transform, alpha)
elif alpha == 0:
transform = torch.eye(
len(transform),
dtype=transform.dtype,
device=transform.device,
)
elif alpha != 1:
transform = torch.linalg.matrix_power(transform, round(alpha))

if beta != 0:
# interpolate the relationship between the neurons
a_neurons = weighted_sum(a_neurons, b_neurons @ rotation.T, beta)

a_neurons @= transform
a_neurons += new_centroid
return a_neurons.reshape_as(a).to(a.dtype)


def fractional_matrix_power(matrix: Tensor, power: float):
eigenvalues, eigenvectors = torch.linalg.eig(matrix)
eigenvalues.pow_(power)
result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors)
if ((error := result.imag) > 1e-4).any():
print("image error:", error)
return result.real.to(dtype=matrix.dtype)


def sample_ellipsis(a, b, t):
return torch.column_stack((a, b)) @ torch.tensor(
[
math.sin(t),
math.cos(t),
],
dtype=a.dtype,
device=a.device,
)
6 changes: 2 additions & 4 deletions sd_meh/rebasin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,11 +2200,9 @@ def apply_permutation(ps: PermutationSpec, perm, params):
def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha):
for k in model_a:
try:
perm_params = get_permuted_param(
ps, perm, k, model_a
)
perm_params = get_permuted_param(ps, perm, k, model_a)
model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params
except RuntimeError: # dealing with pix2pix and inpainting models
except RuntimeError: # dealing with pix2pix and inpainting models
continue
return model_a

Expand Down
Loading