You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! The following code does not differentiate correctly when N >= 8
using StaticArrays
using Enzyme
using ForwardDiff
using Lux
using ComponentArrays
N =8
d = Lux.Dense(N => N)
ps = (;
weight =randn(SMatrix{N, N, Float64}),
bias =randn(SVector{N, Float64}),
)
x =randn(SVector{N, Float64})
fun =let d = d, x = x
ps ->sum(d(x, ps, (;))[1])
end
grad1 = ForwardDiff.gradient(fun, ComponentVector(ps))
grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)
maximum(abs, grad1 .-ComponentVector(grad2))
I have posted a issue for Enzyme (EnzymeAD/Enzyme.jl#1855) with a Lux-free MWE and it seems the problem is linked to using reshape together with a SVector (if I convert x to a SMatrix before feeding it to Dense the problem does not appear). I wonder if a fix similar to LuxDL/LuxLib.jl#141 could be implemented in Lux as well? At first sight it looks to me that just adding one line for specializing make_abstract_matrix on SVectors would do the trick.
The text was updated successfully, but these errors were encountered:
Hi! The following code does not differentiate correctly when N >= 8
I have posted a issue for Enzyme (EnzymeAD/Enzyme.jl#1855) with a Lux-free MWE and it seems the problem is linked to using
reshape
together with aSVector
(if I convertx
to aSMatrix
before feeding it toDense
the problem does not appear). I wonder if a fix similar to LuxDL/LuxLib.jl#141 could be implemented in Lux as well? At first sight it looks to me that just adding one line for specializingmake_abstract_matrix
onSVectors
would do the trick.The text was updated successfully, but these errors were encountered: