Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup solve for adjoints to deprecate concrete_solve #520

Merged
merged 10 commits into from
May 30, 2020
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqBase"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
authors = ["Chris Rackauckas <[email protected]>"]
version = "6.35.2"
version = "6.36.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
26 changes: 26 additions & 0 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,41 @@ struct ConstantInterpolation{T1,T2} <: AbstractDiffEqInterpolation
u::T2
end

"""
$(TYPEDEF)
"""
struct SensitivityInterpolation{T1,T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
end

interp_summary(::AbstractDiffEqInterpolation) = "Unknown"
interp_summary(::HermiteInterpolation) = "3rd order Hermite"
interp_summary(::LinearInterpolation) = "1st order linear"
interp_summary(::ConstantInterpolation) = "Piecewise constant interpolation"
interp_summary(::Nothing) = "No interpolation"
interp_summary(::SensitivityInterpolation) = "Interpolation disabled due to sensitivity analysis"
interp_summary(sol::DESolution) = interp_summary(sol.interp)

const SENSITIVITY_INTERP_MESSAGE =
"""
Standard interpolation is disabled due to sensitivity analysis being
used for the gradients. Only linear and constant interpolations are
compatible with non-AD sensitivity analysis calculations. Either
utilize tooling like saveat to avoid post-solution interpolation, use
the keyword argument dense=false for linear or constant interpolations,
or use the keyword argument sensealg=SensitivityADPassThrough() to revert
to AD-based derivatives.
"""

(id::HermiteInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
(id::HermiteInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
(id::LinearInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
(id::LinearInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
(id::ConstantInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
(id::ConstantInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
(id::SensitivityInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
(id::SensitivityInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)

@inline function interpolation(tvals,id,idxs,deriv,p,continuity::Symbol=:left)
t = id.t; u = id.u
Expand Down Expand Up @@ -72,6 +94,7 @@ interp_summary(sol::DESolution) = interp_summary(sol.interp)
vals[j] = u[i-1][idxs]
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i-1]
Θ = (tval-t[i-1])/dt
idxs_internal = idxs
Expand Down Expand Up @@ -119,6 +142,7 @@ times t (sorted), with values u and derivatives ks
vals[j] = u[i-1][idxs]
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i-1]
Θ = (tval-t[i-1])/dt
idxs_internal = idxs
Expand Down Expand Up @@ -169,6 +193,7 @@ times t (sorted), with values u and derivatives ks
val = u[i-1][idxs]
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i-1]
Θ = (tval-t[i-1])/dt
idxs_internal = idxs
Expand Down Expand Up @@ -211,6 +236,7 @@ times t (sorted), with values u and derivatives ks
copy!(out,u[i-1][idxs])
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i-1]
Θ = (tval-t[i-1])/dt
idxs_internal = idxs
Expand Down
20 changes: 8 additions & 12 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;
sensealg=nothing,kwargs...)
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;kwargs...)
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;
sensealg=nothing,kwargs...)
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;kwargs...)
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;
sensealg=nothing,kwargs...)
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;kwargs...)
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
end

ReverseDiff.@grad function concrete_solve(prob,alg,u0,p,args...;
sensealg=nothing,kwargs...)
out = _concrete_solve_adjoint(prob,alg,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
ReverseDiff.@grad function solve_up(prob,sensealg,u0,p,args...;kwargs...)
out = _solve_adjoint(prob,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
Array(out[1]),out[2]
end
54 changes: 37 additions & 17 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,42 @@ function solution_new_retcode(sol::AbstractODESolution{T,N},retcode) where {T,N}
sol.alg,sol.interp,sol.dense,sol.tslocation,sol.destats,retcode)
end

function solution_new_tslocation(sol::AbstractODESolution{T,N},tslocation) where {T,N}
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
typeof(sol.t),typeof(sol.k),
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
sol.u,sol.u_analytic,sol.errors,sol.t,sol.k,sol.prob,
sol.alg,sol.interp,sol.dense,tslocation,sol.destats,sol.retcode)
function solution_new_tslocation(sol::AbstractODESolution{T,N},tslocation) where {T,N}
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
typeof(sol.t),typeof(sol.k),
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
sol.u,sol.u_analytic,sol.errors,sol.t,sol.k,sol.prob,
sol.alg,sol.interp,sol.dense,tslocation,sol.destats,sol.retcode)
end

function solution_slice(sol::AbstractODESolution{T,N},I) where {T,N}
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
typeof(sol.t),typeof(sol.k),
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
sol.u[I],
sol.u_analytic === nothing ? nothing : sol.u_analytic[I],
sol.errors,sol.t[I],
sol.dense ? sol.k[I] : sol.k,
sol.prob,
sol.alg,sol.interp,false,sol.tslocation,sol.destats,sol.retcode)
end

function sensitivity_solution(sol::AbstractODESolution,u,t)
T = eltype(eltype(u))
N = length((size(sol.prob.u0)..., length(u)))
interp = if typeof(sol.interp) <: LinearInterpolation
LinearInterpolation(t,u)
elseif typeof(sol.interp) <: ConstantInterpolation
ConstantInterpolation(t,u)
else
SensitivityInterpolation(t,u)
end

function solution_slice(sol::AbstractODESolution{T,N},I) where {T,N}
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
typeof(sol.t),typeof(sol.k),
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
sol.u[I],
sol.u_analytic === nothing ? nothing : sol.u_analytic[I],
sol.errors,sol.t[I],
sol.dense ? sol.k[I] : sol.k,
sol.prob,
sol.alg,sol.interp,false,sol.tslocation,sol.destats,sol.retcode)
end
ODESolution{T,N,typeof(u),typeof(sol.u_analytic),typeof(sol.errors),
typeof(t),Nothing,typeof(sol.prob),typeof(sol.alg),
typeof(interp),typeof(sol.destats)}(
u,sol.u_analytic,sol.errors,t,nothing,sol.prob,
sol.alg,interp,
sol.dense,sol.tslocation,
sol.destats,sol.retcode)
end
21 changes: 21 additions & 0 deletions src/solutions/rode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,24 @@ function solution_slice(sol::AbstractRODESolution{T,N},I) where {T,N}
false,sol.tslocation,sol.destats,
sol.retcode,sol.seed)
end

function sensitivity_solution(sol::AbstractRODESolution,u,t)
T = eltype(eltype(u))
N = length((size(sol.prob.u0)..., length(u)))
interp = if typeof(sol.interp) <: LinearInterpolation
LinearInterpolation(t,u)
elseif typeof(sol.interp) <: ConstantInterpolation
ConstantInterpolation(t,u)
else
SensitivityInterpolation(t,u)
end

RODESolution{T,N,typeof(u),typeof(sol.u_analytic),
typeof(sol.errors),typeof(t),
typeof(nothing),typeof(sol.prob),typeof(sol.alg),
typeof(sol.interp),typeof(sol.destats)}(
u,sol.u_analytic,sol.errors,t,nothing,sol.prob,
sol.alg,sol.interp,
sol.dense,sol.tslocation,sol.destats,
sol.retcode,sol.seed)
end
9 changes: 9 additions & 0 deletions src/solutions/steady_state_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ function build_solution(prob::AbstractSteadyStateProblem,

SteadyStateSolution{T,N,typeof(u),typeof(resid),typeof(prob),typeof(alg)}(u,resid,prob,alg,retcode)
end

function sensitivity_solution(sol::AbstractSteadyStateSolution,u)
T = eltype(eltype(u))
N = length((size(sol.prob.u0)...,))

SteadyStateSolution{T,N,typeof(u),typeof(sol.resid),
typeof(sol.prob),typeof(sol.alg)}(
u,sol.resid,sol.prob,sol.alg,sol.retcode)
end
Loading