Skip to content

Commit

Permalink
Redesign default ODE solver to be fully type-grounded
Browse files Browse the repository at this point in the history
This accomplishes a few things:

* Faster precompile times by precompiling less
* Full inference of results when using the automatic algorithm
* Hopefully faster load times by also precompiling less

This is done the same way as

* linearsolve SciML/LinearSolve.jl#307
* nonlinearsolve SciML/NonlinearSolve.jl#238

and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice.

It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way.

There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that.

fix typo / test

fix precompilation choices

Update src/composite_algs.jl

Co-authored-by: Nathanael Bosch <[email protected]>

Update src/composite_algs.jl

switch CompositeCache away from tuple so it can start undef

Default Cache

fix precompile

remove fallbacks

remove fallbacks
  • Loading branch information
ChrisRackauckas authored and oscardssmith committed May 14, 2024
1 parent b8bdf0f commit aa1c546
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 158 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Expand Down Expand Up @@ -75,7 +76,6 @@ RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SciMLBase = "2.27.1"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleNonlinearSolve = "1"
SimpleUnPack = "1"
SparseArrays = "1.9"
Expand Down
21 changes: 16 additions & 5 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ using LinearSolve, SimpleNonlinearSolve

using LineSearches

import EnumX

import FillArrays: Trues

# Interfaces
Expand Down Expand Up @@ -141,6 +143,7 @@ include("nlsolve/functional.jl")
include("nlsolve/newton.jl")

include("generic_rosenbrock.jl")
include("composite_algs.jl")

include("caches/basic_caches.jl")
include("caches/low_order_rk_caches.jl")
Expand Down Expand Up @@ -234,7 +237,6 @@ include("constants.jl")
include("solve.jl")
include("initdt.jl")
include("interp_func.jl")
include("composite_algs.jl")

import PrecompileTools

Expand All @@ -253,9 +255,14 @@ PrecompileTools.@compile_workload begin
Tsit5(), Vern7()
]

stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false),
Rodas5P(), Rodas5P(autodiff = false),
FBDF(), FBDF(autodiff = false)
stiff = [Rosenbrock23(),
Rodas5P(),
FBDF()
]

default_ode = [
DefaultODEAlgorithm(autodiff=false),
DefaultODEAlgorithm()
]

autoswitch = [
Expand Down Expand Up @@ -284,7 +291,11 @@ PrecompileTools.@compile_workload begin
append!(solver_list, stiff)
end

if Preferences.@load_preference("PrecompileAutoSwitch", true)
if Preferences.@load_preference("PrecompileDefault", true)
append!(solver_list, default_ode)
end

if Preferences.@load_preference("PrecompileAutoSwitch", false)
append!(solver_list, autoswitch)
end

Expand Down
29 changes: 21 additions & 8 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs))

isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true
isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs))
isdtchangeable(alg::DefaultSolverAlgorithm) = true

function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4,
CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2})
false
Expand All @@ -184,12 +186,14 @@ ismultistep(alg::ETD2) = true
isadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
isadaptive(alg::OrdinaryDiffEqAdaptiveAlgorithm) = true
isadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = all(isadaptive.(alg.algs))
isadaptive(alg::DefaultSolverAlgorithm) = true
isadaptive(alg::DImplicitEuler) = true
isadaptive(alg::DABDF2) = true
isadaptive(alg::DFBDF) = true

anyadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = isadaptive(alg)
anyadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = any(isadaptive, alg.algs)
anyadaptive(alg::DefaultSolverAlgorithm) = true

isautoswitch(alg) = false
isautoswitch(alg::CompositeAlgorithm) = alg.choice_function isa AutoSwitch
Expand All @@ -199,9 +203,11 @@ function qmin_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm})
end
qmin_default(alg::CompositeAlgorithm) = maximum(qmin_default.(alg.algs))
qmin_default(alg::DP8) = 1 // 3
qmin_default(alg::DefaultSolverAlgorithm) = 1 // 5

qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10
qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs))
qmax_default(alg::DefaultSolverAlgorithm) = 10
qmax_default(alg::DP8) = 6
qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8

Expand Down Expand Up @@ -287,7 +293,7 @@ end

function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
algs = map(alg -> DiffEqBase.prepare_alg(alg, u0, p, prob), alg.algs)
CompositeAlgorithm(algs, alg.choice_function)
CompositeAlgorithm(algs, alg.choice_function,)
end

has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false
Expand Down Expand Up @@ -370,6 +376,7 @@ end

alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs))
alg_extrapolates(alg::DefaultSolverAlgorithm) = false
alg_extrapolates(alg::ImplicitEuler) = true
alg_extrapolates(alg::DImplicitEuler) = true
alg_extrapolates(alg::DABDF2) = true
Expand Down Expand Up @@ -733,6 +740,7 @@ alg_order(alg::QPRK98) = 9

alg_maximum_order(alg) = alg_order(alg)
alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs)
alg_maximum_order(alg::DefaultSolverAlgorithm) = 7
alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1)
alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1)
alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1)
Expand Down Expand Up @@ -869,6 +877,7 @@ function gamma_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm})
isadaptive(alg) ? 9 // 10 : 0
end
gamma_default(alg::CompositeAlgorithm) = maximum(gamma_default, alg.algs)
gamma_default(alg::DefaultSolverAlgorithm) = 9 // 10
gamma_default(alg::RKC) = 8 // 10
gamma_default(alg::IRKC) = 8 // 10
function gamma_default(alg::ExtrapolationMidpointDeuflhard)
Expand Down Expand Up @@ -989,14 +998,18 @@ function unwrap_alg(integrator, is_stiff)
if !iscomp
return alg
elseif alg.choice_function isa AutoSwitchCache
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
if alg.choice_function.algtrait isa DefaultODESolver
alg.algs[alg.choice_function.current]
else
return alg.algs[2]
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
else
return alg.algs[2]
end
end
else
return _eval_index(identity, alg.algs, integrator.cache.current)
Expand Down
37 changes: 37 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3242,6 +3242,9 @@ end
struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm
algs::T
choice_function::F
function CompositeAlgorithm(algs, choice_function)
new{typeof(algs), typeof(choice_function)}(algs, choice_function)
end
end

TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1
Expand All @@ -3250,6 +3253,40 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeAlgorithm)
end

mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T}
algtrait::Trait
count::Int
successive_switches::Int
nonstiffalg::nAlg
stiffalg::sAlg
is_stiffalg::Bool
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
current::Int
end

struct AutoSwitch{Trait, nAlg, sAlg, tolType, T}
algtrait::Trait
nonstiffalg::nAlg
stiffalg::sAlg
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
end

struct DefaultODESolver end
const DefaultSolverAlgorithm = Union{CompositeAlgorithm{<:Tuple, <:AutoSwitch{DefaultODESolver}},
CompositeAlgorithm{<:Tuple, <:AutoSwitchCache{DefaultODESolver}}}

################################################################################
"""
MEBDF2: Multistep Method
Expand Down
55 changes: 38 additions & 17 deletions src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,26 @@ end

TruncatedStacktraces.@truncate_stacktrace CompositeCache 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache
args::A
choice_function::F
current::Int
cache1::T1
cache2::T2
cache3::T3
cache4::T4
cache5::T5
cache6::T6
function DefaultCache{T1, T2, T3, T4, T5, T6, F}(args, choice_function, current) where {T1, T2, T3, T4, T5, T6, F}
new{T1, T2, T3, T4, T5, T6, typeof(args), F}(args, choice_function, current)
end
end

function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype,
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev,
uprev2, f, t, dt, reltol, p, calck,
::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits}
caches = (
alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)),
alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)))
CompositeCache(caches, alg.choice_function, 1)
TruncatedStacktraces.@truncate_stacktrace DefaultCache 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
Base.Experimental.silence!(DefaultCache)
end

function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -38,7 +40,26 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU
::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V))
CompositeCache(caches, alg.choice_function, 1)
CompositeCache{typeof(caches), typeof(alg.choice_function)}(
caches, alg.choice_function, 1)
end

function alg_cache(alg::CompositeAlgorithm{Tuple{A1, A2, A3, A4, A5, A6}}, u,
rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits},
uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6}

args = (u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt,
reltol, p, calck, Val(V))
argT = map(typeof, args)
T1 = Base.promote_op(alg_cache, A1, argT...)
T2 = Base.promote_op(alg_cache, A2, argT...)
T3 = Base.promote_op(alg_cache, A3, argT...)
T4 = Base.promote_op(alg_cache, A4, argT...)
T5 = Base.promote_op(alg_cache, A5, argT...)
T6 = Base.promote_op(alg_cache, A6, argT...)
DefaultCache{T1, T2, T3, T4, T5, T6, typeof(alg.choice_function)}(args, alg.choice_function, 1)
end

# map + closure approach doesn't infer
Expand Down
Loading

0 comments on commit aa1c546

Please sign in to comment.