Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote AD of symbol-indexing into ODESolution #746

Open
bgctw opened this issue Apr 21, 2021 · 17 comments
Open

Zygote AD of symbol-indexing into ODESolution #746

bgctw opened this issue Apr 21, 2021 · 17 comments

Comments

@bgctw
Copy link

bgctw commented Apr 21, 2021

ModelingToolkit allows to index into ODESolution via a symbol. However, currently, this causes problems with Optimization using Zygote gradients.

I tried working on the issue but need to learn more and need guidance with DA and DifferentialEquations.
There is a related discourse topic and an issue at ModelingToolkit.

The following example demonstrates the issue.

using ModelingToolkit, OrdinaryDiffEq
using DiffEqBase

@parameters α β δ γ
@variables t x(t) y(t) dx(t)
D = Differential(t)
eqs = [
  dx ~ α*x - β*x*y,  # testing observed variables
  D(x) ~ dx,
  D(y) ~ -δ*y + γ*x*y
]
@named lv = ODESystem(eqs)
syss = structural_simplify(lv) 
parms = [α => 1.5, β => 1.0, δ => 3.0, γ => 1.0]
x0 = [x => 1.0, y => 1.0]
tsteps = 0.0:0.1:10.0
prob = ODEProblem(syss, x0, extrema(tsteps), parms, jac = true)
soltrue = solve(prob,  Tsit5(), saveat = tsteps);
popt0 = [1.1]

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) 
  function ODESolution_getindex_pullback(Δ)
    @show Δ
    @show length(VA)
    @show VA
    @show VA.u
    # convert symbol to index
    i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
    @show i
    # similar to VectorOfArray: return zero for non-matching indices
    Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
    (NO_FIELDS, Δ′)
    # TODO: care for observed
  end  
  VA[sym], ODESolution_getindex_pullback(Δ)
end

f1(p) = soltrue[x][1] * p[1] # note the indexing by [x]
f1(popt0)
#using Zygote
gr = Zygote.gradient(f1, popt0) # calls the failing rule for VectorOfArrays instead of above rule
@ChrisRackauckas
Copy link
Member

@YingboMa @DhairyaLGandhi , I think we will need a pretty specific overload here?

@ChrisRackauckas ChrisRackauckas transferred this issue from SciML/OrdinaryDiffEq.jl Apr 22, 2021
@DhairyaLGandhi
Copy link
Member

literal_getindex with symbol possibly?

@ChrisRackauckas
Copy link
Member

No, https://github.com/SciML/SciMLBase.jl/blob/master/src/solutions/solution_interface.jl#L37-L53 is what it is.

If it's a known symbol, then it's essentially @nograd just a translation to an index.

If it's not a known symbol, there's a function that is called that essentially fakes indexing. I assume that part would have to be differentiated?

@DhairyaLGandhi
Copy link
Member

Okay, if it's calling into a function, then I understand needing to differentiate it. Thanks for the link!

@ChrisRackauckas
Copy link
Member

I think we only need to differentiate it half of the time though? It might be fixed by a few more @nograds on symbol handling stuff.

@DhairyaLGandhi
Copy link
Member

Right, so we would basically want to teach zygote which symbols it needs to ignore and only differentiate when needed.

@lamorton
Copy link

lamorton commented Jul 9, 2021

Now that ModelingToolkit#151 landed, the traceback is different but the problem persists. Instead of going into the ZygoteRule from RecursiveArrayTools, the indexing gets handled (incorrectly) by Zygote itself.

julia> gr = Zygote.gradient(f1, popt0) 
ERROR: ArgumentError: invalid index: x(t) of type Num
Stacktrace:
  [1] to_index(i::Num)
    @ Base ./indices.jl:300
  [2] to_index(A::Matrix{Float64}, i::Num)
    @ Base ./indices.jl:277
  [3] to_indices
    @ ./indices.jl:333 [inlined]
  [4] to_indices
    @ ./indices.jl:325 [inlined]
  [5] view
    @ ./subarray.jl:176 [inlined]
  [6] (::Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:jac,), Tuple{Bool}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}})(dy::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/array.jl:43
  [7] (::Zygote.var"#2248#back#404"{Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:jac,), Tuple{Bool}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}}})(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [8] Pullback
    @ ./REPL[38]:1 [inlined]
  [9] (::typeof((f1)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#46#47"{typeof((f1))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:41
 [11] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:59
 [12] top-level scope
    @ REPL[41]:1

@lamorton
Copy link

lamorton commented Jul 9, 2021

These may be relevant to the proposed fix: JuliaDiff/ChainRulesCore.jl#239 and FluxML/Zygote.jl#811

@lamorton
Copy link

lamorton commented Jul 9, 2021

Yup. I also figured out the issue with @bgctw 's prototype: you need to explicitly use Base.getindex here:

function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) 

@lamorton
Copy link

Looks like handling the observed variables is tricky b/c they implicitly depend on the equations that relate them to the solution array.

@lamorton
Copy link

Here's a partial solution for the states, but not for the observed variables. The array construction is a bit of a kludge.

using SciMLBase #B/c modelingtoolkit doesn't reexport issymbollike
ZygoteRules.@adjoint function Base.getindex(VA::ODESolution, sym::Num) 
  function ODESolution_getindex_pullback(Δ)
    # convert symbol to index
    i = SciMLBase.issymbollike(sym) ? SciMLBase.sym_to_index(sym, VA) : sym
    # similar to VectorOfArray: return zero for non-matching indices
    if i == nothing
      throw("Error: gradient of observed symbol is not defined yet")
      Zygote.pullback(observed,VA,sym,:)
    else
      Δ′ = [ [i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] for (x,j) in zip(VA.u, 1:length(VA))]
      (Δ′,nothing)
    end
  end
  VA[sym], ODESolution_getindex_pullback
end

@ChrisRackauckas
Copy link
Member

Yeah I think it needs to pass Δ into the pullback of the observed function itself?

@lamorton
Copy link

lamorton commented Jul 12, 2021

Ahh, that's probably right. I guess I also need a dummy ODESolution just to hang the derivatives for sol.prob.p on?

@ChrisRackauckas
Copy link
Member

yeah I think so.

@lamorton
Copy link

lamorton commented Jul 18, 2021

Here's a functioning prototype for the simpler case where a single timeslice is chosen as well:

ZygoteRules.@adjoint function Base.getindex(VA::ODESolution, sym::Num,j::Int) 
  function ODESolution_getindex_pullback(Δ)
    # convert symbol to index
    i = SciMLBase.issymbollike(sym) ? SciMLBase.sym_to_index(sym, VA) : sym
    # similar to VectorOfArray: return zero for non-matching indices
    if i === nothing
      getter = SciMLBase.getobserved(VA)
      grz = Zygote.pullback(getter,sym,VA.u[j],VA.prob.p,VA.t[j])[2](Δ)
      du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] 
      dp = grz[3] # pullback for p
      dprob = remake(VA.prob,p=dp)
      T = eltype(eltype(VA.u))
      N = length(VA.prob.p)
      Δ′ = ODESolution{T,N,typeof(du),Nothing,Nothing,typeof(VA.t),typeof(VA.k), typeof(dprob),typeof(VA.alg),typeof(VA.interp),typeof(VA.destats)}(du,nothing,nothing,VA.t,VA.k,dprob,VA.alg,VA.interp,VA.dense,0,VA.destats,VA.retcode)
      (Δ′,nothing,nothing)
    else
      Δ′ = [ m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)]
      (Δ′,nothing,nothing)
    end
  end
  VA[sym,j], ODESolution_getindex_pullback
end

I'm stuck on how to treat the derivatives for the parameters in the general case -- they pick up an extra dimension, due to the input Δ being an array over time.

@ChrisRackauckas
Copy link
Member

Let's break this problem down into steps. Could you PR what you have and throw an error on the not handled case? And then it can continue to improve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants