Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use adjoints from SteadyStateProblems also for NonlinearProblems #684

Merged
merged 4 commits into from
Jul 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ArrayInterfaceCore = "0.1.1"
ArrayInterfaceTracker = "0.1"
Cassette = "0.3.6"
ChainRulesCore = "0.10.7, 1"
DiffEqBase = "6.90"
DiffEqBase = "6.93"
DiffEqCallbacks = "2.17"
DiffEqNoiseProcess = "4.1.4, 5.0"
DiffEqOperators = "4.34"
Expand All @@ -63,7 +63,7 @@ RandomNumbers = "1.5.3"
RecursiveArrayTools = "2.4.2"
Reexport = "0.2, 1.0"
ReverseDiff = "1.9"
SciMLBase = "1.24"
SciMLBase = "1.42.3"
StochasticDiffEq = "6.20"
Tracker = "0.2"
Zygote = "0.6"
Expand All @@ -78,6 +78,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand All @@ -90,4 +91,4 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "SparseArrays"]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays"]
6 changes: 3 additions & 3 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
quad = false,
noiseterm = false, needs_jac = false) where {G, DG1, DG2}
prob = sol.prob
if prob isa DiffEqBase.SteadyStateProblem
if prob isa Union{SteadyStateProblem, NonlinearProblem}
@unpack u0, p = prob
tspan = (nothing, nothing)
#elseif prob isa SDEProblem
Expand Down Expand Up @@ -123,7 +123,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
end
end

if prob isa DiffEqBase.SteadyStateProblem
if prob isa Union{SteadyStateProblem, NonlinearProblem}
y = copy(sol.u)
else
y = copy(sol.u[end])
Expand All @@ -138,7 +138,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
@assert sensealg.autojacvec !== nothing

if sensealg.autojacvec isa ReverseDiffVJP
if prob isa DiffEqBase.SteadyStateProblem
if prob isa Union{SteadyStateProblem, NonlinearProblem}
if DiffEqBase.isinplace(prob)
tape = ReverseDiff.GradientTape((y, _p)) do u, p
du1 = p !== nothing && p !== DiffEqBase.NullParameters() ?
Expand Down
10 changes: 5 additions & 5 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy,
prob = getprob(S)

@unpack J, uf, f_cache, jac_config = S.diffcache
if !(prob isa DiffEqBase.SteadyStateProblem)
if !(prob isa Union{SteadyStateProblem, NonlinearProblem})
if W === nothing
if DiffEqBase.has_jac(f)
f.jac(J, y, p, t) # Calculate the Jacobian into J
Expand Down Expand Up @@ -403,7 +403,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
_p = p
end

if typeof(prob) <: SteadyStateProblem ||
if prob isa Union{SteadyStateProblem, NonlinearProblem} ||
(eltype(λ) <: eltype(prob.u0) && typeof(t) <: eltype(prob.u0) &&
compile_tape(sensealg.autojacvec))
tape = S.diffcache.paramjac_config
Expand Down Expand Up @@ -442,7 +442,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
end
end

if prob isa DiffEqBase.SteadyStateProblem
if prob isa Union{SteadyStateProblem, NonlinearProblem}
tu, tp = ReverseDiff.input_hook(tape)
else
if W === nothing
Expand All @@ -454,13 +454,13 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
output = ReverseDiff.output_hook(tape)
ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls
ReverseDiff.unseed!(tp)
if !(prob isa DiffEqBase.SteadyStateProblem)
if !(prob isa Union{SteadyStateProblem, NonlinearProblem})
ReverseDiff.unseed!(tt)
end
W !== nothing && ReverseDiff.unseed!(tW)
ReverseDiff.value!(tu, y)
typeof(p) <: DiffEqBase.NullParameters || ReverseDiff.value!(tp, p)
if !(prob isa DiffEqBase.SteadyStateProblem)
if !(prob isa Union{SteadyStateProblem, NonlinearProblem})
ReverseDiff.value!(tt, [t])
end
W !== nothing && ReverseDiff.value!(tW, W)
Expand Down
8 changes: 7 additions & 1 deletion src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ function SteadyStateAdjointSensitivityFunction(g,
sol,
dgdu,
dgdp,
f,
colorvec,
needs_jac)
@unpack f, p, u0 = sol.prob
@unpack p, u0 = sol.prob

diffcache, y = adjointdiffcache(g,
sensealg,
Expand Down Expand Up @@ -62,6 +63,10 @@ end
kwargs...) where {DG1, DG2, G}
@unpack f, p, u0 = sol.prob

if sol.prob isa NonlinearProblem
f = convert(ODEFunction, f)
end

dgdu === nothing && dgdp === nothing && g === nothing &&
error("Either `dgdu`, `dgdp`, or `g` must be specified.")

Expand All @@ -78,6 +83,7 @@ end
sol,
dgdu,
dgdp,
f,
f.colorvec,
needs_jac)
@unpack diffcache, y, sol, λ, vjp, linsolve = sense
Expand Down
43 changes: 42 additions & 1 deletion test/steady_state.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Test, LinearAlgebra
using SciMLSensitivity, SteadyStateDiffEq, DiffEqBase, NLsolve
using OrdinaryDiffEq
using NonlinearSolve
using ForwardDiff, Calculus
using Zygote
using Random
Random.seed!(12345)

Expand Down Expand Up @@ -272,7 +274,6 @@ Random.seed!(12345)
end
end

using Zygote
@testset "concrete_solve derivatives steady state solver" begin
function g1(u, p, t)
sum(u)
Expand Down Expand Up @@ -385,3 +386,43 @@ using Zygote
@test res1oop[1]≈dp2oop[1] rtol=1e-10
end
end

@testset "NonlinearProblem" begin
u0 = [0.0]
p = [2.0, 1.0]
prob = NonlinearProblem((du, u, p) -> du[1] = u[1] - p[1] + p[2], u0, p)
prob2 = NonlinearProblem{false}((u, p) -> u .- p[1] .+ p[2], u0, p)

solve1 = solve(remake(prob, p = p), NewtonRaphson())
solve2 = solve(prob2, NewtonRaphson())
@test solve1.u == solve2.u

prob3 = SteadyStateProblem((u, p, t) -> -u .+ p[1] .- p[2], [0.0], p)
solve3 = solve(prob3, DynamicSS(Rodas5()))
@test solve1.u≈solve3.u rtol=1e-6

prob4 = SteadyStateProblem((du, u, p, t) -> du[1] = -u[1] + p[1] - p[2], [0.0], p)
solve4 = solve(prob4, DynamicSS(Rodas5()))
@test solve3.u≈solve4.u rtol=1e-10

function test_loss(p, prob; alg = NewtonRaphson())
_prob = remake(prob, p = p)
sol = sum(solve(_prob, alg,
sensealg = SteadyStateAdjoint(autojacvec = ReverseDiffVJP())))
return sol
end

test_loss(p, prob)
test_loss(p, prob2)
test_loss(p, prob3, alg = DynamicSS(Rodas5()))
test_loss(p, prob4, alg = DynamicSS(Rodas5()))

dp1 = Zygote.gradient(p -> test_loss(p, prob), p)[1]
dp2 = Zygote.gradient(p -> test_loss(p, prob2), p)[1]
dp3 = Zygote.gradient(p -> test_loss(p, prob3, alg = DynamicSS(Rodas5())), p)[1]
dp4 = Zygote.gradient(p -> test_loss(p, prob4, alg = DynamicSS(Rodas5())), p)[1]

@test dp1≈dp2 rtol=1e-10
@test dp1≈dp3 rtol=1e-10
@test dp1≈dp4 rtol=1e-10
end