Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with Enzyme AD and SArray parameters #935

Closed
lcontento opened this issue Sep 18, 2024 · 1 comment · Fixed by #936
Closed

Problem with Enzyme AD and SArray parameters #935

lcontento opened this issue Sep 18, 2024 · 1 comment · Fixed by #936

Comments

@lcontento
Copy link

Hi! The following code does not differentiate correctly when N >= 8

using StaticArrays
using Enzyme
using ForwardDiff
using Lux
using ComponentArrays
N = 8
d = Lux.Dense(N => N)
ps = (;
  weight = randn(SMatrix{N, N, Float64}),
  bias = randn(SVector{N, Float64}),
)
x = randn(SVector{N, Float64})
fun = let d = d, x = x
  ps -> sum(d(x, ps, (;))[1])
end
grad1 = ForwardDiff.gradient(fun, ComponentVector(ps))
grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)
maximum(abs, grad1 .- ComponentVector(grad2))

I have posted a issue for Enzyme (EnzymeAD/Enzyme.jl#1855) with a Lux-free MWE and it seems the problem is linked to using reshape together with a SVector (if I convert x to a SMatrix before feeding it to Dense the problem does not appear). I wonder if a fix similar to LuxDL/LuxLib.jl#141 could be implemented in Lux as well? At first sight it looks to me that just adding one line for specializing make_abstract_matrix on SVectors would do the trick.

@avik-pal
Copy link
Member

Yes that would do it. I can patch it today

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants