Skip to content

Commit

Permalink
Merge pull request #366 from SciML/events
Browse files Browse the repository at this point in the history
ReverseDiffAdjoint events
  • Loading branch information
ChrisRackauckas authored Dec 21, 2020
2 parents fc51133 + 523bb64 commit 8275566
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
43 changes: 29 additions & 14 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,20 @@ function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::TrackerAdjoint,

local sol
function tracker_adjoint_forwardpass(_u0,_p)

if (convert_tspan(sensealg) === nothing && (
(haskey(kwargs,:callback) && has_continuous_callback(kwargs[:callback])) ||
(haskey(prob.kwargs,:callback) && has_continuous_callback(prob.kwargs[:callback]))
)) || (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg))
_tspan = convert.(eltype(_p),prob.tspan)
else
_tspan = prob.tspan
end

if DiffEqBase.isinplace(prob)
# use Array{TrackedReal} for mutation to work
# Recurse to all Array{TrackedArray}
_prob = remake(prob,u0=map(identity,_u0),p=_p)
_prob = remake(prob,u0=map(identity,_u0),p=_p,tspan=_tspan)
else
# use TrackedArray for efficiency of the tape
function _f(args...)
Expand All @@ -273,11 +283,12 @@ function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::TrackerAdjoint,
Tracker.collect(out)
end
end
_prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f)(_f,_g),u0=_u0,p=_p)
_prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f)(_f,_g),u0=_u0,p=_p,tspan=_tspan)
else
_prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f)(_f),u0=_u0,p=_p)
_prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f)(_f),u0=_u0,p=_p,tspan=_tspan)
end
end
@show _prob.tspan
sol = solve(_prob,alg,args...;sensealg=DiffEqBase.SensitivityADPassThrough(),kwargs...)

if typeof(sol.u[1]) <: Array
Expand Down Expand Up @@ -313,18 +324,28 @@ function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::ReverseDiffAdjoin
local sol

function reversediff_adjoint_forwardpass(_u0,_p)

if (convert_tspan(sensealg) === nothing && (
(haskey(kwargs,:callback) && has_continuous_callback(kwargs[:callback])) ||
(haskey(prob.kwargs,:callback) && has_continuous_callback(prob.kwargs[:callback]))
)) || (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg))
_tspan = convert.(eltype(_p),prob.tspan)
else
_tspan = prob.tspan
end

if DiffEqBase.isinplace(prob)
# use Array{TrackedReal} for mutation to work
# Recurse to all Array{TrackedArray}
_prob = remake(prob,u0=map(identity,_u0),p=_p)
_prob = remake(prob,u0=map(identity,_u0),p=_p,tspan=_tspan)
else
# use TrackedArray for efficiency of the tape
_f(args...) = reduce(vcat,prob.f(args...))
if prob isa SDEProblem
_g(args...) = reduce(vcat,prob.g(args...))
_prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f)(_f,_g),u0=_u0,p=_p)
_prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f)(_f,_g),u0=_u0,p=_p,tspan=_tspan)
else
_prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f)(_f),u0=_u0,p=_p)
_prob = remake(prob,f=DiffEqBase.parameterless_type(prob.f)(_f),u0=_u0,p=_p,tspan=_tspan)
end
end

Expand All @@ -336,7 +357,7 @@ function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::ReverseDiffAdjoin
u = map(ReverseDiff.value,sol.u)
end

Array(sol)
sol
end

tape = ReverseDiff.GradientTape(reversediff_adjoint_forwardpass,(u0, p))
Expand All @@ -346,13 +367,7 @@ function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::ReverseDiffAdjoin
typeof(p) <: DiffEqBase.NullParameters || ReverseDiff.value!(tp, p)
ReverseDiff.forward_pass!(tape)
function reversediff_adjoint_backpass(ybar)
if prob isa SDEProblem
for i in eachindex(ybar)
@views ReverseDiff.increment_deriv!(output[:,i], ybar[i])
end
else
ReverseDiff.increment_deriv!(output, ybar)
end
ReverseDiff.increment_deriv!(output, ybar)
ReverseDiff.reverse_pass!(tape)
(nothing,nothing,ReverseDiff.deriv(tu),ReverseDiff.deriv(tp),nothing,ntuple(_->nothing, length(args))...)
end
Expand Down
1 change: 1 addition & 0 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ struct ReverseDiffNoise{compile} <: NoiseChoice
end

@inline convert_tspan(::ForwardDiffSensitivity{CS,CTS}) where {CS,CTS} = CTS
@inline convert_tspan(::Any) = nothing
@inline alg_autodiff(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT}) where {CS,AD,FDT} = AD
@inline get_chunksize(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT}) where {CS,AD,FDT} = CS
@inline diff_type(alg::DiffEqBase.AbstractSensitivityAlgorithm{CS,AD,FDT}) where {CS,AD,FDT} = FDT
Expand Down

0 comments on commit 8275566

Please sign in to comment.