From a170fd982eb063fe7ed6ed71ccb3ba93f82c75bd Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 29 May 2020 01:46:39 -0400 Subject: [PATCH] Setup solve for adjoints to deprecate concrete_solve Fixes https://github.com/SciML/DifferentialEquations.jl/issues/610 Is non-breaking --- src/reversediff.jl | 20 +++----- src/solve.jl | 119 +++++++++++++++++++++++++++------------------ src/tracker.jl | 22 ++++----- src/zygote.jl | 2 + 4 files changed, 92 insertions(+), 71 deletions(-) diff --git a/src/reversediff.jl b/src/reversediff.jl index ea8f8ce45..74976da1d 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -1,20 +1,16 @@ -function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...; - sensealg=nothing,kwargs...) - ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...) +function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;kwargs...) + ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...) end -function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...; - sensealg=nothing,kwargs...) - ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...) +function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;kwargs...) + ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...) end -function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...; - sensealg=nothing,kwargs...) - ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...) +function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;kwargs...) + ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...) end -ReverseDiff.@grad function concrete_solve(prob,alg,u0,p,args...; - sensealg=nothing,kwargs...) - out = _concrete_solve_adjoint(prob,alg,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...) +ReverseDiff.@grad function solve_up(prob,sensealg,u0,p,args...;kwargs...) + out = _solve_adjoint(prob,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...) Array(out[1]),out[2] end diff --git a/src/solve.jl b/src/solve.jl index 4860738ff..7c334e3b3 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -59,11 +59,17 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...) else __solve(_prob,args...; kwargs...)#::T end +end +function solve(prob::DEProblem,args...;sensealg=nothing, + u0 = nothing, p = nothing,kwargs...) + u0 = u0 !== nothing ? u0 : prob.u0 + p = p !== nothing ? p : prob.p + solve_up(prob,sensealg,u0,p,args...;kwargs...) end -function solve(prob::DEProblem,args...;kwargs...) - _prob = get_concrete_problem(prob,kwargs) +function solve_up(prob::DEProblem,sensealg,u0,p,args...;kwargs...) + _prob = get_concrete_problem(prob;u0=u0,p=p,kwargs...) if haskey(kwargs,:alg) && (isempty(args) || args[1] === nothing) alg = kwargs[:alg] isadaptive(alg) && @@ -93,21 +99,21 @@ function solve(prob::EnsembleProblem,args...;kwargs...) end end -function solve(prob::AbstractNoiseProblem,args...;kwargs...) +function solve(prob::AbstractNoiseProblem,args...; kwargs...) __solve(prob,args...;kwargs...) end -function get_concrete_problem(prob::AbstractJumpProblem,kwargs) +function get_concrete_problem(prob::AbstractJumpProblem; kwargs...) prob end -function get_concrete_problem(prob::AbstractSteadyStateProblem, kwargs) +function get_concrete_problem(prob::AbstractSteadyStateProblem; kwargs...) u0 = get_concrete_u0(prob, Inf, kwargs) u0 = promote_u0(u0, prob.p, nothing) remake(prob; u0 = u0) end -function get_concrete_problem(prob::AbstractEnsembleProblem, kwargs) +function get_concrete_problem(prob::AbstractEnsembleProblem; kwargs...) prob end @@ -118,45 +124,45 @@ end function discretize end -function get_concrete_problem(prob, kwargs) - tspan = get_concrete_tspan(prob, kwargs) +function get_concrete_problem(prob; kwargs...) + p = get_concrete_p(prob, kwargs) + tspan = get_concrete_tspan(prob, kwargs, p) u0 = get_concrete_u0(prob, tspan[1], kwargs) - u0_promote = promote_u0(u0, prob.p, tspan[1]) - tspan_promote = promote_tspan(u0, prob.p, tspan, prob, kwargs) + u0_promote = promote_u0(u0, p, tspan[1]) + tspan_promote = promote_tspan(u0, p, tspan, prob, kwargs) if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(u0) && prob.tspan == tspan && typeof(tspan) === typeof(tspan_promote) return prob else - return remake(prob; u0 = u0_promote, tspan = tspan_promote) + return remake(prob; u0 = u0_promote, p=p, tspan = tspan_promote) end end -function get_concrete_problem(prob::DDEProblem, kwargs) - tspan = get_concrete_tspan(prob, kwargs) +function get_concrete_problem(prob::DDEProblem; kwargs...) + p = get_concrete_p(prob, kwargs) + tspan = get_concrete_tspan(prob, kwargs, p) u0 = get_concrete_u0(prob, tspan[1], kwargs) if prob.constant_lags isa Function - constant_lags = prob.constant_lags(prob.p) + constant_lags = prob.constant_lags(p) else constant_lags = prob.constant_lags end - u0 = promote_u0(u0, prob.p, tspan[1]) - tspan = promote_tspan(u0, prob.p, tspan, prob, kwargs) + u0 = promote_u0(u0, p, tspan[1]) + tspan = promote_tspan(u0, p, tspan, prob, kwargs) - remake(prob; u0 = u0, tspan = tspan, constant_lags = constant_lags) + remake(prob; u0 = u0, tspan = tspan, p=p, constant_lags = constant_lags) end -function get_concrete_tspan(prob, kwargs) +function get_concrete_tspan(prob, kwargs, p) if prob.tspan isa Function - tspan = prob.tspan(prob.p) - elseif prob.tspan === (nothing, nothing) - if haskey(kwargs, :tspan) + tspan = prob.tspan(p) + elseif haskey(kwargs, :tspan) tspan = kwargs[:tspan] - else - error("No tspan is set in the problem or chosen in the init/solve call") - end + elseif prob.tspan === (nothing, nothing) + error("No tspan is set in the problem or chosen in the init/solve call") else tspan = prob.tspan end @@ -171,7 +177,7 @@ end function get_concrete_u0(prob, t0, kwargs) if eval_u0(prob.u0) u0 = prob.u0(prob.p, t0) - elseif prob.u0 === nothing + elseif haskey(kwargs,:u0) u0 = kwargs[:u0] else u0 = prob.u0 @@ -180,6 +186,14 @@ function get_concrete_u0(prob, t0, kwargs) handle_distribution_u0(u0) end +function get_concrete_p(prob, kwargs) + if haskey(kwargs,:p) + p = kwargs[:p] + else + p = prob.p + end +end + handle_distribution_u0(_u0) = _u0 eval_u0(u0::Function) = true eval_u0(u0) = false @@ -218,38 +232,49 @@ end ################### Concrete Solve -function _concrete_solve end +@deprecate concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing}, + u0=prob.u0,p=prob.p,args...;kwargs...) solve(prob,alg,args...;u0=u0,p=p,kwargs...) -function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing}, - u0=prob.u0,p=prob.p,args...;kwargs...) - _concrete_solve(prob,alg,u0,p,args...;kwargs...) -end +struct SensitivityADPassThrough <: DiffEqBase.DEAlgorithm end -function _concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing}, - u0=prob.u0,p=prob.p,args...;kwargs...) - sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...) - RecursiveArrayTools.DiffEqArray(sol.u,sol.t) +ZygoteRules.@adjoint function solve_up(prob,sensealg::Union{Nothing,AbstractSensitivityAlgorithm}, + u0,p,args...; + kwargs...) + _solve_adjoint(prob,sensealg,u0,p,args...;kwargs...) end -function _concrete_solve(prob::DiffEqBase.SteadyStateProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing}, - u0=prob.u0,p=prob.p,args...;kwargs...) - sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...) - RecursiveArrayTools.VectorOfArray(sol.u) +function ChainRulesCore.frule(::typeof(solve_up),prob, + sensealg::Union{Nothing,AbstractSensitivityAlgorithm}, + u0,p,args...; + kwargs...) + _solve_forward(prob,sensealg,u0,p,args...;kwargs...) end -function ChainRulesCore.frule(::typeof(concrete_solve),prob,alg,u0,p,args...; - sensealg=nothing,kwargs...) - _concrete_solve_forward(prob,alg,sensealg,u0,p,args...;kwargs...) +function ChainRulesCore.rrule(::typeof(solve_up),prob, + sensealg::Union{Nothing,AbstractSensitivityAlgorithm}, + u0,p,args...; + kwargs...) + _solve_adjoint(prob,sensealg,u0,p,args...;kwargs...) end -function ChainRulesCore.rrule(::typeof(concrete_solve),prob,alg,u0,p,args...; - sensealg=nothing,kwargs...) - _concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...) +### +### Legacy Dispatches to be Non-Breaking +### + +function _solve_adjoint(prob,sensealg,u0,p,args...;kwargs...) + if isempty(args) + _concrete_solve_adjoint(prob,nothing,sensealg,u0,p;kwargs...) + else + _concrete_solve_adjoint(prob,args[1],sensealg,u0,p,Base.tail(args)...;kwargs...) + end end -ZygoteRules.@adjoint function concrete_solve(prob,alg,u0,p,args...; - sensealg=nothing,kwargs...) - _concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...) +function _solve_forward(prob,sensealg,u0,p,args...;kwargs...) + if isempty(args) + _concrete_solve_forward(prob,nothing,sensealg,u0,p;kwargs...) + else + _concrete_solve_forward(prob,args[1],sensealg,u0,p,Base.tail(args)...;kwargs...) + end end function _concrete_solve_adjoint(args...;kwargs...) diff --git a/src/tracker.jl b/src/tracker.jl index 18d4c71f9..a6345b5c3 100644 --- a/src/tracker.jl +++ b/src/tracker.jl @@ -30,22 +30,20 @@ end end @inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u) -function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...; - sensealg=nothing,kwargs...) - Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...) +function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;kwargs...) + Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...) end -function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::Tracker.TrackedArray,args...; - sensealg=nothing,kwargs...) - Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...) +function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::Tracker.TrackedArray,p,args...;kwargs...) + Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...) end -function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p,args...; - sensealg=nothing,kwargs...) - Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...) +function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0,p::Tracker.TrackedArray,args...;kwargs...) + Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...) end -Tracker.@grad function concrete_solve(prob,alg,u0,p,args...; - sensealg=nothing,kwargs...) - _concrete_solve_adjoint(prob,alg,sensealg,Tracker.data(u0),Tracker.data(p),args...;kwargs...) +Tracker.@grad function solve_up(prob,sensealg::Union{Nothing,AbstractSensitivityAlgorithm}, + u0,p,args...; + kwargs...) + _solve_adjoint(prob,sensealg,Tracker.data(u0),Tracker.data(p),args...;kwargs...) end diff --git a/src/zygote.jl b/src/zygote.jl index 7162a56b6..ed23501f6 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -1,3 +1,4 @@ +#= ZygoteRules.@adjoint function ODESolution(u,args...) function ODESolutionAdjoint(ȳ) (ȳ,ntuple(_->nothing, length(args))...) @@ -32,6 +33,7 @@ ZygoteRules.@adjoint function getindex(sol::DESolution, i, j...) end sol[i,j...],DESolution_getindex_adjoint end +=# ZygoteRules.@adjoint function (f::ODEFunction)(u,p,t) if f.vjp === nothing