-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More efficient projection in svd pullback #755
base: main
Are you sure you want to change the base?
Conversation
The formula (I - U*U')*X can be extremely slow and memory-intensive when U and X are very tall matrices. Replacing it with the equivalent X - U*(U'*X) in two places.
@sethaxen are you able to take a look at this? |
Indeed, I've not inspected this one closely, but the approach of this PR looks right. I'll review in-depth shortly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, just a suggestion on how to remove another dense matmul.
I like the idea of using the if statements to avoid computing one of the I-xx'
projection matrices if possible. Feel free to add that to this PR.
This uses fewer matrix multiplications. The code no longer uses the helper function _mulsubtrans!! so it has been removed.
As a consequence of the smplifications in my last commits to this PR, the I have not done any benchmarking, but I feel the removal of unnecessary matrix multiplication should speed things up in general. I've also changed the function signature to take |
|
In that case, I'll leave it as it is now, with The reason why that package has its own method might be that |
Actually, the method from Neurogenesis is still faster, because it avoids the allocation and computation of the |
One more potential improvement: The matrix currently written as |
This is fine, though I would actually make the signature Can you run JuliaFormatter on the code? |
Yes, I meant to write The latest commit that I pushed improves performance a bit, and has been passed through JuliaFormatter. Since Julia can't infer that |
That's not necessary. Specializing for the ChainRules.jl/src/rulesets/LinearAlgebra/symmetric.jl Lines 153 to 154 in ae37562
|
The general-case method also works in this case, but is slightly slower because it creates a dense matrix M even though only the diagonal entries are nonzero.
The latest commit to this PR adds the method for when Ū and V̄t are AbstractZero. I don't think there's anything else left to do at the moment. |
This change makes the pullback for
svd
faster and more memory efficient when taking the SVD of very tall or very wide matrices (which is a common application).The issue is that the "textbook" formula
(I - U*U')*X
can be extremely slow and memory-intensive whenU
istall. The mathematically equivalent form
X - U*(U'*X)
avoids creating and then multiplying by the large matrixU*U'
.Example:
Without this PR, the above runs in about 2 seconds and allocates 3 GB.
With this PR, it runs in less than 0.01 seconds and allocates 11 MB.
There is further room for improvement: When the input is to
svd
wide, thenI - U*U'
is zero. Conversely, when the input is tall, thenI - V*V'
is zero. This means that it would be possible to avoid some unnecessary computation by adding a couple of if-statements.(This PR also removes the undocumented, unexported utility function
_eyesubx!
that is not used elsewhere in the package.)