Skip to content

Commit

Permalink
Setup solve for adjoints to deprecate concrete_solve
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed May 29, 2020
1 parent 00ad2e3 commit a170fd9
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 71 deletions.
20 changes: 8 additions & 12 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;
sensealg=nothing,kwargs...)
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;kwargs...)
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;
sensealg=nothing,kwargs...)
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;kwargs...)
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;
sensealg=nothing,kwargs...)
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;kwargs...)
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

ReverseDiff.@grad function concrete_solve(prob,alg,u0,p,args...;
sensealg=nothing,kwargs...)
out = _concrete_solve_adjoint(prob,alg,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
ReverseDiff.@grad function solve_up(prob,sensealg,u0,p,args...;kwargs...)
out = _solve_adjoint(prob,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
Array(out[1]),out[2]
end
119 changes: 72 additions & 47 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,17 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
else
__solve(_prob,args...; kwargs...)#::T
end
end

function solve(prob::DEProblem,args...;sensealg=nothing,
u0 = nothing, p = nothing,kwargs...)
u0 = u0 !== nothing ? u0 : prob.u0
p = p !== nothing ? p : prob.p
solve_up(prob,sensealg,u0,p,args...;kwargs...)
end

function solve(prob::DEProblem,args...;kwargs...)
_prob = get_concrete_problem(prob,kwargs)
function solve_up(prob::DEProblem,sensealg,u0,p,args...;kwargs...)
_prob = get_concrete_problem(prob;u0=u0,p=p,kwargs...)
if haskey(kwargs,:alg) && (isempty(args) || args[1] === nothing)
alg = kwargs[:alg]
isadaptive(alg) &&
Expand Down Expand Up @@ -93,21 +99,21 @@ function solve(prob::EnsembleProblem,args...;kwargs...)
end
end

function solve(prob::AbstractNoiseProblem,args...;kwargs...)
function solve(prob::AbstractNoiseProblem,args...; kwargs...)
__solve(prob,args...;kwargs...)
end

function get_concrete_problem(prob::AbstractJumpProblem,kwargs)
function get_concrete_problem(prob::AbstractJumpProblem; kwargs...)
prob
end

function get_concrete_problem(prob::AbstractSteadyStateProblem, kwargs)
function get_concrete_problem(prob::AbstractSteadyStateProblem; kwargs...)
u0 = get_concrete_u0(prob, Inf, kwargs)
u0 = promote_u0(u0, prob.p, nothing)
remake(prob; u0 = u0)
end

function get_concrete_problem(prob::AbstractEnsembleProblem, kwargs)
function get_concrete_problem(prob::AbstractEnsembleProblem; kwargs...)
prob
end

Expand All @@ -118,45 +124,45 @@ end

function discretize end

function get_concrete_problem(prob, kwargs)
tspan = get_concrete_tspan(prob, kwargs)
function get_concrete_problem(prob; kwargs...)
p = get_concrete_p(prob, kwargs)
tspan = get_concrete_tspan(prob, kwargs, p)
u0 = get_concrete_u0(prob, tspan[1], kwargs)
u0_promote = promote_u0(u0, prob.p, tspan[1])
tspan_promote = promote_tspan(u0, prob.p, tspan, prob, kwargs)
u0_promote = promote_u0(u0, p, tspan[1])
tspan_promote = promote_tspan(u0, p, tspan, prob, kwargs)
if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(u0) &&
prob.tspan == tspan && typeof(tspan) === typeof(tspan_promote)
return prob
else
return remake(prob; u0 = u0_promote, tspan = tspan_promote)
return remake(prob; u0 = u0_promote, p=p, tspan = tspan_promote)
end
end

function get_concrete_problem(prob::DDEProblem, kwargs)
tspan = get_concrete_tspan(prob, kwargs)
function get_concrete_problem(prob::DDEProblem; kwargs...)
p = get_concrete_p(prob, kwargs)
tspan = get_concrete_tspan(prob, kwargs, p)

u0 = get_concrete_u0(prob, tspan[1], kwargs)

if prob.constant_lags isa Function
constant_lags = prob.constant_lags(prob.p)
constant_lags = prob.constant_lags(p)
else
constant_lags = prob.constant_lags
end

u0 = promote_u0(u0, prob.p, tspan[1])
tspan = promote_tspan(u0, prob.p, tspan, prob, kwargs)
u0 = promote_u0(u0, p, tspan[1])
tspan = promote_tspan(u0, p, tspan, prob, kwargs)

remake(prob; u0 = u0, tspan = tspan, constant_lags = constant_lags)
remake(prob; u0 = u0, tspan = tspan, p=p, constant_lags = constant_lags)
end

function get_concrete_tspan(prob, kwargs)
function get_concrete_tspan(prob, kwargs, p)
if prob.tspan isa Function
tspan = prob.tspan(prob.p)
elseif prob.tspan === (nothing, nothing)
if haskey(kwargs, :tspan)
tspan = prob.tspan(p)
elseif haskey(kwargs, :tspan)
tspan = kwargs[:tspan]
else
error("No tspan is set in the problem or chosen in the init/solve call")
end
elseif prob.tspan === (nothing, nothing)
error("No tspan is set in the problem or chosen in the init/solve call")
else
tspan = prob.tspan
end
Expand All @@ -171,7 +177,7 @@ end
function get_concrete_u0(prob, t0, kwargs)
if eval_u0(prob.u0)
u0 = prob.u0(prob.p, t0)
elseif prob.u0 === nothing
elseif haskey(kwargs,:u0)
u0 = kwargs[:u0]
else
u0 = prob.u0
Expand All @@ -180,6 +186,14 @@ function get_concrete_u0(prob, t0, kwargs)
handle_distribution_u0(u0)
end

function get_concrete_p(prob, kwargs)
if haskey(kwargs,:p)
p = kwargs[:p]
else
p = prob.p
end
end

handle_distribution_u0(_u0) = _u0
eval_u0(u0::Function) = true
eval_u0(u0) = false
Expand Down Expand Up @@ -218,38 +232,49 @@ end

################### Concrete Solve

function _concrete_solve end
@deprecate concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
u0=prob.u0,p=prob.p,args...;kwargs...) solve(prob,alg,args...;u0=u0,p=p,kwargs...)

function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
u0=prob.u0,p=prob.p,args...;kwargs...)
_concrete_solve(prob,alg,u0,p,args...;kwargs...)
end
struct SensitivityADPassThrough <: DiffEqBase.DEAlgorithm end

function _concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
u0=prob.u0,p=prob.p,args...;kwargs...)
sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...)
RecursiveArrayTools.DiffEqArray(sol.u,sol.t)
ZygoteRules.@adjoint function solve_up(prob,sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
u0,p,args...;
kwargs...)
_solve_adjoint(prob,sensealg,u0,p,args...;kwargs...)
end

function _concrete_solve(prob::DiffEqBase.SteadyStateProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
u0=prob.u0,p=prob.p,args...;kwargs...)
sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...)
RecursiveArrayTools.VectorOfArray(sol.u)
function ChainRulesCore.frule(::typeof(solve_up),prob,
sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
u0,p,args...;
kwargs...)
_solve_forward(prob,sensealg,u0,p,args...;kwargs...)
end

function ChainRulesCore.frule(::typeof(concrete_solve),prob,alg,u0,p,args...;
sensealg=nothing,kwargs...)
_concrete_solve_forward(prob,alg,sensealg,u0,p,args...;kwargs...)
function ChainRulesCore.rrule(::typeof(solve_up),prob,
sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
u0,p,args...;
kwargs...)
_solve_adjoint(prob,sensealg,u0,p,args...;kwargs...)
end

function ChainRulesCore.rrule(::typeof(concrete_solve),prob,alg,u0,p,args...;
sensealg=nothing,kwargs...)
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
###
### Legacy Dispatches to be Non-Breaking
###

function _solve_adjoint(prob,sensealg,u0,p,args...;kwargs...)
if isempty(args)
_concrete_solve_adjoint(prob,nothing,sensealg,u0,p;kwargs...)
else
_concrete_solve_adjoint(prob,args[1],sensealg,u0,p,Base.tail(args)...;kwargs...)
end
end

ZygoteRules.@adjoint function concrete_solve(prob,alg,u0,p,args...;
sensealg=nothing,kwargs...)
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
function _solve_forward(prob,sensealg,u0,p,args...;kwargs...)
if isempty(args)
_concrete_solve_forward(prob,nothing,sensealg,u0,p;kwargs...)
else
_concrete_solve_forward(prob,args[1],sensealg,u0,p,Base.tail(args)...;kwargs...)
end
end

function _concrete_solve_adjoint(args...;kwargs...)
Expand Down
22 changes: 10 additions & 12 deletions src/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,20 @@ end
end
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u)

function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;
sensealg=nothing,kwargs...)
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;kwargs...)
Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::Tracker.TrackedArray,args...;
sensealg=nothing,kwargs...)
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::Tracker.TrackedArray,p,args...;kwargs...)
Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p,args...;
sensealg=nothing,kwargs...)
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0,p::Tracker.TrackedArray,args...;kwargs...)
Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

Tracker.@grad function concrete_solve(prob,alg,u0,p,args...;
sensealg=nothing,kwargs...)
_concrete_solve_adjoint(prob,alg,sensealg,Tracker.data(u0),Tracker.data(p),args...;kwargs...)
Tracker.@grad function solve_up(prob,sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
u0,p,args...;
kwargs...)
_solve_adjoint(prob,sensealg,Tracker.data(u0),Tracker.data(p),args...;kwargs...)
end
2 changes: 2 additions & 0 deletions src/zygote.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#=
ZygoteRules.@adjoint function ODESolution(u,args...)
function ODESolutionAdjoint(ȳ)
(ȳ,ntuple(_->nothing, length(args))...)
Expand Down Expand Up @@ -32,6 +33,7 @@ ZygoteRules.@adjoint function getindex(sol::DESolution, i, j...)
end
sol[i,j...],DESolution_getindex_adjoint
end
=#

ZygoteRules.@adjoint function (f::ODEFunction)(u,p,t)
if f.vjp === nothing
Expand Down

0 comments on commit a170fd9

Please sign in to comment.