Skip to content

Commit

Permalink
Initial commit of partial fix for Zygote AD failure with symbolic ind…
Browse files Browse the repository at this point in the history
…exing of ODESolution. (SciML/DifferentialEquations.jl#746)
  • Loading branch information
lamorton committed Jul 18, 2021
1 parent 7ecbed1 commit 2e9b742
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ArrayInterface = "2.6, 3.0"
Expand Down
2 changes: 2 additions & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using StaticArrays
import Logging, ArrayInterface
import IteratorInterfaceExtensions
import CommonSolve: solve, init, solve!
import ZygoteRules, Zygote

function __solve end
function __init end
Expand Down Expand Up @@ -521,6 +522,7 @@ include("solutions/rode_solutions.jl")
include("solutions/optimization_solutions.jl")
include("solutions/dae_solutions.jl")
include("solutions/solution_interface.jl")
include("solutions/zygote.jl")

include("ensemble/ensemble_solutions.jl")
include("ensemble/ensemble_problems.jl")
Expand Down
34 changes: 34 additions & 0 deletions src/solutions/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
ZygoteRules.@adjoint function Base.getindex(VA::ODESolution, sym,j::Int)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
if i === nothing
getter = getobserved(VA)
grz = Zygote.pullback(getter,sym,VA.u[j],VA.prob.p,VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
dprob = remake(VA.prob,p=dp)
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T,N,typeof(du),Nothing,Nothing,Nothing,Nothing, typeof(dprob),Nothing,Nothing,Nothing}(du,nothing,nothing,nothing,nothing,dprob,nothing,nothing,VA.dense,0,nothing,VA.retcode)
(Δ′,nothing,nothing)
else
Δ′ = [ m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)]
(Δ′,nothing,nothing)
end
end
VA[sym,j], ODESolution_getindex_pullback
end

ZygoteRules.@adjoint function Base.getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
if i === nothing
throw("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")
else
Δ′ = [ [i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] for (x,j) in zip(VA.u, 1:length(VA))]
(Δ′,nothing)
end
end
VA[sym], ODESolution_getindex_pullback
end

0 comments on commit 2e9b742

Please sign in to comment.