-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dense instead of sparse matrix returned during differentiation #1537
Comments
I suspect the issue is with missing rule(s) in ChainRules since Zygote has almost nothing in the way of machinery for diffing sparse arrays. What happens if you set |
No error thrown, same dense array returned. |
Ok, I dug into this a bit more and have a reduced MWE: julia> Zygote.unbroadcast(A, collect(A)) |> summary
"10×10 SparseMatrixCSC{Float32, Int64} with 41 stored entries"
julia> Zygote.unbroadcast(Acu, cu(collect(Acu))) |> summary
"10×10 CuArray{Float32, 2, CUDA.DeviceMemory}"
Zygote.jl/src/lib/broadcast.jl Line 59 in 9b6dd08
The gist is that the actual gradient computed from the broadcast is dense. However, the CPU path knows how to "project" that back into a sparse gradient so the end result is sparse. This is done via the projection machinery in ChainRulesCore.jl. ChainRulesCore has projection machinery for CPU sparse matrices, but not GPU ones. So ultimately, I'd say there are two ways to look at this. One is that returning the sparse matrix isn't saving much and arguably using more compute. Re-sparsifying the dense gradients yourself afterwards might work well enough. The other interpretation is that this is a missing projection rule in ChainRulesCore, in which case this may be worth a feature request. |
Thanks for digging into this! I may request a feature to ChainRulesCore then. Cheers! |
Hi there,
I found an inconsistent behaviour when differentiating a function which takes in an
AbstractCuSparseArray
versus the same function differentiating an AbstractSparseArrayMaybe someone here has an idea of where could this come from?
The text was updated successfully, but these errors were encountered: