-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Something wrong with empty (Named-)Tuples and generators #1294
Comments
Some background: function msolve(prob; ps=prob.p, dt=0.01, salg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(), noisemixing=true))
prob = remake(prob, p=ps)
s = solve(prob, EM(), sensealg=salg, dt=dt)
s[end][end]
end essentially takes an SDE problem |
While you're working on a MWE, can you provide the full message and stacktrace of the latest error along with the code to run it? The gist in the linked issue appears to be out of date. |
I think I nailed it down to the
I believe above methods ( |
The current fixed code is here. To reproduce the error uncomment the |
Haven't run this, but sometimes Zygote is confused by re-using the name |
|
The problem persists with binding to julia> test()
∂(prob2 = remake(prob, p = ps)) = nothing
∂(prob2 = remake(prob, p = ps)) = nothing
∂(prob2 = remake(prob, p = ps)) = nothing
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}}) |
Without solving where these come from, they should both probably be
Or perhaps adding methods to These might be worth doing anyway. (JuliaDiff/ChainRulesCore.jl#565 is something similar.) If _project |
Sorry, I meant to wrap around just the |
I think I distilled it into a MWE using Zygote
using StochasticDiffEq, SciMLSensitivity
import Lux
function mwe()
x0 = rand(1)
p0 = rand(1)
drift(du,u,p,t) = (du .= 1)
noise(du,u,p,t) = (du .= 1)
prob = SDEProblem(drift, noise, x0, 1., p0)
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
Zygote.gradient(p0) do p
sum(Zygote.@showgrad(solve(remake(prob, p=p), EM(), dt=.1, sensealg=sensealg)[end][1]) for i in 1:3)
end
end With @showgrad in the correct position this now returns julia> mwe()
∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [-0.09612757465640165], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [-0.06807801678762193], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [0.16420559144402358], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}}) I was surprised to see that the |
That sounds a bit like piracy, which is bad. Does there seem to be a fix involving sending Also, what Julia version does this MWE work on? (Failed to install everything on nightly.)
|
I tried Zygote._project(x, dx::NamedTuple{()}) = nothing
Zygote._project(x, dx::NamedTuple{(), Tuple{}}) = nothing
Zygote._project(x, dx::Tuple{}) = nothing all without effect Edit: Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}}) = nothing seems to fix it. |
Looking into # Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
end It's a pirate and touching the problematic |
Does adding that definition into your own code without importing Lux also break things? On the thing which seems to fix things
Do you mind tweaking the definition to this and pasting the stacktrace it generates here? function Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}})
display(stacktrace())
println()
end |
The definition without Lux does not brake it, so I guess thats not the problem. Here are the stacktraces you asked for. One should probably start at the end since the problem only occurs after the 3rd iteration.
59-element Vector{Base.StackTraces.StackFrame}:
wrap_chainrules_output at REPL[5]:2 [inlined]
map at tuple.jl:223 [inlined]
wrap_chainrules_output at chainrules.jl:106 [inlined]
ZBack at chainrules.jl:206 [inlined]
Pullback at namedtuple.jl:280 [inlined]
(::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0
Pullback at remake.jl:32 [inlined]
(::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at remake.jl:28 [inlined]
(::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at none:0 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:95 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:62 [inlined]
(::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:48 [inlined]
(::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:44 [inlined]
(::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:162 [inlined]
Pullback at reduce.jl:162 [inlined]
(::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]
⋮
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
(::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45
gradient(f::Function, args::Vector{Float64}) at interface.jl:97
mwe() at mwe1294.jl:17
top-level scope at mwe1294.jl:21
eval at boot.jl:368 [inlined]
include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428
_include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488
include(fname::String) at client.jl:476
top-level scope at REPL[6]:1
top-level scope at initialization.jl:52
eval at boot.jl:368 [inlined]
eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151
repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247
start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232
run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369
run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355
(::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419
#invokelatest#2 at essentials.jl:729 [inlined]
invokelatest at essentials.jl:726 [inlined]
run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404
exec_options(opts::Base.JLOptions) at client.jl:318
_start() at client.jl:522
57-element Vector{Base.StackTraces.StackFrame}: 59-element Vector{Base.StackTraces.StackFrame}: 57-element Vector{Base.StackTraces.StackFrame}: 59-element Vector{Base.StackTraces.StackFrame}: 57-element Vector{Base.StackTraces.StackFrame}: |
Thanks! The last stacktrace includes https://github.com/avik-pal/Lux.jl/blob/11ac3e476161eedea23194b31e48e8d128950e00/src/autodiff.jl#L53-L63, which is very much piracy. Is that the last stacktrace printed before the error? If so, can you see if that rrule overload breaks things? |
After removing the lines in question the test runs through 🎷 |
In a project of mine I want to take derivatives of some Neural SDE solution (computed by the custom wrapper
msolve
) wrt. to the Lux NN parameters:fails with a
After following the suggestion of @ToucheSir in #1290 and replacing the generator with
sum(_ -> msolve(prob, ps=ps), 1:n)
the error changes toI hotfixed this with
and the code runs through.
Searching for occurences of
(:data, :itr)
I could make out onlyZygote.jl/src/lib/base.jl
Line 155 in de078c8
and the resp. function below.
I have no clue how this all works together but thank @mcabbott and @ToucheSir a lot for helping me find the fix.
Feel free to correct the issue title and let me know if I can be of any further help fixing this (regarding the Zygote internals I am quite out of my water though).
The text was updated successfully, but these errors were encountered: