-
-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote AD of symbol-indexing into ODESolution #746
Comments
@YingboMa @DhairyaLGandhi , I think we will need a pretty specific overload here? |
|
No, https://github.com/SciML/SciMLBase.jl/blob/master/src/solutions/solution_interface.jl#L37-L53 is what it is. If it's a known symbol, then it's essentially If it's not a known symbol, there's a function that is called that essentially fakes indexing. I assume that part would have to be differentiated? |
Okay, if it's calling into a function, then I understand needing to differentiate it. Thanks for the link! |
I think we only need to differentiate it half of the time though? It might be fixed by a few more |
Right, so we would basically want to teach zygote which symbols it needs to ignore and only differentiate when needed. |
Now that ModelingToolkit#151 landed, the traceback is different but the problem persists. Instead of going into the julia> gr = Zygote.gradient(f1, popt0)
ERROR: ArgumentError: invalid index: x(t) of type Num
Stacktrace:
[1] to_index(i::Num)
@ Base ./indices.jl:300
[2] to_index(A::Matrix{Float64}, i::Num)
@ Base ./indices.jl:277
[3] to_indices
@ ./indices.jl:333 [inlined]
[4] to_indices
@ ./indices.jl:325 [inlined]
[5] view
@ ./subarray.jl:176 [inlined]
[6] (::Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:jac,), Tuple{Bool}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}})(dy::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/array.jl:43
[7] (::Zygote.var"#2248#back#404"{Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:jac,), Tuple{Bool}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}}})(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[8] Pullback
@ ./REPL[38]:1 [inlined]
[9] (::typeof(∂(f1)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
[10] (::Zygote.var"#46#47"{typeof(∂(f1))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:41
[11] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:59
[12] top-level scope
@ REPL[41]:1 |
These may be relevant to the proposed fix: JuliaDiff/ChainRulesCore.jl#239 and FluxML/Zygote.jl#811 |
Yup. I also figured out the issue with @bgctw 's prototype: you need to explicitly use function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) |
Looks like handling the observed variables is tricky b/c they implicitly depend on the equations that relate them to the solution array. |
Here's a partial solution for the states, but not for the observed variables. The array construction is a bit of a kludge. using SciMLBase #B/c modelingtoolkit doesn't reexport issymbollike
ZygoteRules.@adjoint function Base.getindex(VA::ODESolution, sym::Num)
function ODESolution_getindex_pullback(Δ)
# convert symbol to index
i = SciMLBase.issymbollike(sym) ? SciMLBase.sym_to_index(sym, VA) : sym
# similar to VectorOfArray: return zero for non-matching indices
if i == nothing
throw("Error: gradient of observed symbol is not defined yet")
Zygote.pullback(observed,VA,sym,:)
else
Δ′ = [ [i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] for (x,j) in zip(VA.u, 1:length(VA))]
(Δ′,nothing)
end
end
VA[sym], ODESolution_getindex_pullback
end |
Yeah I think it needs to pass |
Ahh, that's probably right. I guess I also need a dummy |
yeah I think so. |
Here's a functioning prototype for the simpler case where a single timeslice is chosen as well: ZygoteRules.@adjoint function Base.getindex(VA::ODESolution, sym::Num,j::Int)
function ODESolution_getindex_pullback(Δ)
# convert symbol to index
i = SciMLBase.issymbollike(sym) ? SciMLBase.sym_to_index(sym, VA) : sym
# similar to VectorOfArray: return zero for non-matching indices
if i === nothing
getter = SciMLBase.getobserved(VA)
grz = Zygote.pullback(getter,sym,VA.u[j],VA.prob.p,VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
dprob = remake(VA.prob,p=dp)
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T,N,typeof(du),Nothing,Nothing,typeof(VA.t),typeof(VA.k), typeof(dprob),typeof(VA.alg),typeof(VA.interp),typeof(VA.destats)}(du,nothing,nothing,VA.t,VA.k,dprob,VA.alg,VA.interp,VA.dense,0,VA.destats,VA.retcode)
(Δ′,nothing,nothing)
else
Δ′ = [ m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)]
(Δ′,nothing,nothing)
end
end
VA[sym,j], ODESolution_getindex_pullback
end I'm stuck on how to treat the derivatives for the parameters in the general case -- they pick up an extra dimension, due to the input Δ being an array over time. |
Let's break this problem down into steps. Could you PR what you have and throw an error on the not handled case? And then it can continue to improve. |
ModelingToolkit allows to index into ODESolution via a symbol. However, currently, this causes problems with Optimization using Zygote gradients.
I tried working on the issue but need to learn more and need guidance with DA and DifferentialEquations.
There is a related discourse topic and an issue at ModelingToolkit.
The following example demonstrates the issue.
The text was updated successfully, but these errors were encountered: