Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: neuron rotation #50

Draft
wants to merge 44 commits into
base: dev
Choose a base branch
from
Draft

feat: neuron rotation #50

wants to merge 44 commits into from

Conversation

ljleb
Copy link
Collaborator

@ljleb ljleb commented Nov 10, 2023

For each key:

  1. extract an orthogonal matrix, model $A$ is the base and model $B$ is the target orientation
  2. apply the orthogonal transform immediately

Note: this is pretty slow. On my RTX 3080 it takes me ~3 minutes to merge two models.

@ljleb ljleb changed the title OFT extract + apply feat: OFT extract + apply Nov 10, 2023
@ljleb ljleb changed the title feat: OFT extract + apply feat: OFT extract + apply, aka rotate Nov 10, 2023
@ljleb
Copy link
Collaborator Author

ljleb commented Nov 10, 2023

A good idea would be to change the rotation rate of the orthogonal matrix $Q$ using $Q^{\alpha}$, i.e.:
$\alpha = 0$ : $I$
$\alpha = 1$ : $Q$
$\alpha = 2$ : $Q^2$
$\alpha = 0.5$ : $Q^{0.5}$

I have not tested this, but I believe we will run into the problem that it will be even slower. According to GPT, we have to compute a matrix log and then a matrix power. I'll run some tests to see how well that works.

@ljleb
Copy link
Collaborator Author

ljleb commented Nov 10, 2023

GPT4 found something called the cayley transform, which seems to do what we want.

2D with an arbitrary matrix $A$ and a rotation power $t$:
a2b1ae15-9df4-49e9-ae40-99b853bf7b94

3D with an arbitrary matrix $A$ and a rotation power $t$:
ce4f3560-5a53-4915-9f35-6f91ac3a4e69

link to discussion with GPT4: https://chat.openai.com/share/96a9b2ae-3a5f-47ce-8b22-bb07e5f6d1a9
reference: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map

@ljleb
Copy link
Collaborator Author

ljleb commented Nov 10, 2023

I have not tested this, but I believe we will run into the problem that it will be even slower. According to GPT, we have to compute a matrix log and then a matrix power. I'll run some tests to see how well that works.

It turns out that it is not that much slower. I'm hitting ~6 minutes when merging using --device cuda. It goes down to ~3 minutes if we exclude the cases where a layer has 1 dimension.

@ljleb
Copy link
Collaborator Author

ljleb commented Nov 10, 2023

I ran more tests and cayley seems to break down on the 1D case. I need to spend more time on alpha to make it work.

@ljleb ljleb changed the title feat: OFT extract + apply, aka rotate feat: neuron rotation Nov 12, 2023
@ljleb
Copy link
Collaborator Author

ljleb commented Nov 12, 2023

I found that you can apply a fractional power to the eigenvalues of a matrix to implement a fractional matrix exponent. This is fairly slow and requires double precision to work, otherwise the output gets an imaginary component because of precision errors.

With alpha not an integer, it takes ~15 minutes to merge with --work_device cuda, ~12 minutes to merge with --device cuda, and still ~5 minutes to merge with an integer alpha.

@ljleb ljleb added the 🔥 New feature or request label Nov 17, 2023
@ljleb
Copy link
Collaborator Author

ljleb commented Nov 17, 2023

Notes on a couple of trade-offs I had to look into:

  • $\alpha$ is used for two different purposes:
    1. rotating two sets of weights about their centroids in an $n$-dimensional space with ${n(n-1)}\over{2}$ rotation planes
    2. rotating the centroid of $A$ towards the centroid of $B$ with respect to the origin of the vector space on an ellipse
    • this has the implication that $\alpha \equiv 0 \mod 4$ will position the weights about the centroid of A. The purpose of this is to have a smooth transition for $0 \leq \alpha \leq 1$. The range $4k + 1 < \alpha < 4k + 4$ doesn't really make much sense as a result of this decision. To fix this situation, this method would need a 3rd parameter $\gamma$ to separately control these two settings
  • $\beta$ is used to morph the neurons of $A$ into the shape of $B$, independently of the rotation of the neurons and of the position of their centroid. 0 = shape of $A$, 1 = shape of $B$
  • the neurons of some of the conv layers have > 20k floats. to solve the procrustes problem for these cases, the algorithm has to compute the SVD of a 20k x 20k matrix, which is really not practical. For this reason, I excluded the conv layers from the keys to be rotated. With the merge I tested this against, full exclusion of the conv layers seems to instead give more aesthetic results than splitting the conv layers neurons in smaller chuncks and doing the SVD on these.
    • As a result of this, merging 1.5 models now only takes ~3 minutes on an RTX 3080

@ljleb ljleb requested a review from s1dlx November 17, 2023 01:53
@ljleb ljleb marked this pull request as draft November 17, 2023 18:23
@ljleb
Copy link
Collaborator Author

ljleb commented Nov 17, 2023

The models I tested against were not completely different, in particular the text encoder was the same. This skewed my small benchmarks for expected merge times. It seems to take 9 minutes to merge 2 v1 models with all different keys. Turning this to draft until we get merge time lower or determine that the merge method is valuable enough to outweight 9 minutes.

if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1:
return new_centroid.reshape_as(a)

svd_driver = "gesvd" if a.is_cuda else None
Copy link

@mariaWitch mariaWitch Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a lot more complex than meets the eye. We should be determining the svd driver based on the size of the matrix. Different drivers perform faster on smaller/bigger matrices. And in some instances the CPU will out perform the GPU. What exactly is our average matrix size when we call svd?

Copy link
Collaborator Author

@ljleb ljleb Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we include all keys, it goes form $320^2$ to ~ $20K^2$. As this upper bound isn't really practical, if we exclude all conv layers (which have the largest neurons), the upper bound is ~ $5K^2$. I can list all sizes in a bit, they all are square matrices.

I've never done this before at all, this is all new to me. Appreciate the help. IIUC, this only matters on cuda devices?

Copy link
Collaborator Author

@ljleb ljleb Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all matrices sizes that currently go through svd are listed below:

  • 320x320: 47 keys
  • 640x640: 48 keys
  • 768x768: 94 keys
  • 960x960: 2 keys
  • 1280x1280: 83 keys
  • 2560x2560: 10 keys
  • 3072x3072: 12 keys
  • 5120x5120: 6 keys

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some benchmarking between jax's svd functions jitted through XLA and pytorch's different drivers on a colab using a v100 (a 3080 is about equal to this in PyTorch Performance), and these were the results.
image
Basically unless you need full accuracy, even with full_matrices set to true, gesvdj is going to be faster. However the speed you gain comes at the cost of some accuracy, and the potential to not always converge without needing to fall back to gesvd.

This comment was marked as outdated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way full_matrices=False doesn't produce a reduced SVD when $m=n$ ($m$ and $n$ being the width and height of the svd input). That's why it didn't seem to affect generation speed. We might want to remove it as it doesn't really change anything, since the input to the svd is always a square covariance matrix here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I did complete a full merge on CUDA, and didn't receive the error. I think it has something to do with trying to move models between the CPU and GPU, interacting with WebUI keeping models loaded in memory. Is there sanity checking when the models are loaded to ensure that they have been moved to CPU if the work_device is set to CPU?

Copy link
Collaborator Author

@ljleb ljleb Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, when assembling the merge args, the weights are sent to the requested device:

meh/sd_meh/merge.py

Lines 465 to 466 in 2780321

"a": thetas["model_a"][key].to(work_device),
"b": thetas["model_b"][key].to(work_device),

note that if work_device is None, it takes the value of device:

meh/sd_meh/merge.py

Lines 371 to 372 in 2780321

if work_device is None:
work_device = device

So IIUC, it shouldn't be a device issue.

Copy link
Collaborator Author

@ljleb ljleb Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found the culprit.

It seems that on CPU there isn't enough precision sometimes, which leads too $U$ or $V^T$ having a determinant of 0. This is not what SVD should output, $U$ and $V^T$ should always be orthogonal transforms, which implies $|det U| = |det V^T| = 1$.

When the determinant of $U$ or $V^T$ is 0, then this line divides by 0:

        u[:, -1] /= torch.det(u) * torch.det(v_t)

So the last column of u sometimes is filled with infinities. Then, when trying to compute the eigenvalues of the matrix, an error is then raised.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted below, while this prevents the entire merge from raising an error, rotations with invalid determinants still result in a broken merge. I went the other direction and raised an error instead.

@Enferlain
Copy link

Enferlain commented Dec 17, 2023

Wanted to try with the bayesian merger extension. Added the changed parts to merge_methods.py

At the end of stage 1 got this error

stage 1:  96%|████████████████████████████████████████████████████████████████▏  | 1608/1680 [1:03:11<02:49,  2.36s/it]
*** API error: POST: http://127.0.0.1:7860/bbwm/merge-models {'error': 'RuntimeError', 'detail': '', 'body': '', 'errors': 'torch.linalg.eig: input tensor should not contain infs or NaNs.'}
    Traceback (most recent call last):
      File "D:\stable-diffusion-webui\venv\lib\site-packages\anyio\streams\memory.py", line 98, in receive
        return self.receive_nowait()
      File "D:\stable-diffusion-webui\venv\lib\site-packages\anyio\streams\memory.py", line 93, in receive_nowait
        raise WouldBlock
    anyio.WouldBlock

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last):
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\base.py", line 78, in call_next
        message = await recv_stream.receive()
      File "D:\stable-diffusion-webui\venv\lib\site-packages\anyio\streams\memory.py", line 118, in receive
        raise EndOfStream
    anyio.EndOfStream

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last):
      File "D:\stable-diffusion-webui\modules\api\api.py", line 186, in exception_handling
        return await call_next(request)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\base.py", line 84, in call_next
        raise app_exc
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\base.py", line 70, in coro
        await self.app(scope, receive_or_disconnect, send_no_error)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\base.py", line 108, in __call__
        response = await self.dispatch_func(request, call_next)
      File "D:\stable-diffusion-webui\modules\api\api.py", line 150, in log_and_time
        res: Response = await call_next(req)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\base.py", line 84, in call_next
        raise app_exc
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\base.py", line 70, in coro
        await self.app(scope, receive_or_disconnect, send_no_error)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\cors.py", line 84, in __call__
        await self.app(scope, receive, send)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\gzip.py", line 24, in __call__
        await responder(scope, receive, send)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\gzip.py", line 44, in __call__
        await self.app(scope, receive, self.send_with_gzip)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\exceptions.py", line 79, in __call__
        raise exc
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\middleware\exceptions.py", line 68, in __call__
        await self.app(scope, receive, sender)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\fastapi\middleware\asyncexitstack.py", line 21, in __call__
        raise e
      File "D:\stable-diffusion-webui\venv\lib\site-packages\fastapi\middleware\asyncexitstack.py", line 18, in __call__
        await self.app(scope, receive, send)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\routing.py", line 718, in __call__
        await route.handle(scope, receive, send)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\routing.py", line 276, in handle
        await self.app(scope, receive, send)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\starlette\routing.py", line 66, in app
        response = await func(request)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\fastapi\routing.py", line 237, in app
        raw_response = await run_endpoint_function(
      File "D:\stable-diffusion-webui\venv\lib\site-packages\fastapi\routing.py", line 163, in run_endpoint_function
        return await dependant.call(**values)
      File "D:\stable-diffusion-webui\extensions\sd-webui-bayesian-merger\scripts\api.py", line 78, in merge_models_api
        merged = merge_models(
      File "D:\stable-diffusion-webui\venv\lib\site-packages\sd_meh\merge.py", line 176, in merge_models
        merged = simple_merge(
      File "D:\stable-diffusion-webui\venv\lib\site-packages\sd_meh\merge.py", line 262, in simple_merge
        res.result()
      File "C:\Users\Imi\AppData\Local\Programs\Python\Python310\lib\concurrent\futures\_base.py", line 451, in result
        return self.__get_result()
      File "C:\Users\Imi\AppData\Local\Programs\Python\Python310\lib\concurrent\futures\_base.py", line 403, in __get_result
        raise self._exception
      File "C:\Users\Imi\AppData\Local\Programs\Python\Python310\lib\concurrent\futures\thread.py", line 58, in run
        result = self.fn(*self.args, **self.kwargs)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\sd_meh\merge.py", line 371, in simple_merge_key
        with merge_key_context(key, thetas, *args, **kwargs) as result:
      File "C:\Users\Imi\AppData\Local\Programs\Python\Python310\lib\contextlib.py", line 135, in __enter__
        return next(self.gen)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\sd_meh\merge.py", line 475, in merge_key_context
        result = merge_key(*args, **kwargs)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\sd_meh\merge.py", line 447, in merge_key
        merged_key = merge_method(**merge_args).to(device)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\sd_meh\merge_methods.py", line 259, in rotate
        transform = fractional_matrix_power(transform, alpha)
      File "D:\stable-diffusion-webui\venv\lib\site-packages\sd_meh\merge_methods.py", line 279, in fractional_matrix_power
        eigenvalues, eigenvectors = torch.linalg.eig(matrix)
    RuntimeError: torch.linalg.eig: input tensor should not contain infs or NaNs.

(fix is using cuda as device instead of cpu)

@ljleb
Copy link
Collaborator Author

ljleb commented Dec 18, 2023

See the discussion here #50 (comment). This can happen when merging on the CPU with fractional alpha.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🔥 New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants