From 817d2507a19892b2ed192d1e5a980db2e34a8969 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 1 Jan 2024 15:02:34 -0500 Subject: [PATCH] Redesign default ODE solver to be fully type-grounded This accomplishes a few things: * Faster precompile times by precompiling less * Full inference of results when using the automatic algorithm * Hopefully faster load times by also precompiling less This is done the same way as * linearsolve https://github.com/SciML/LinearSolve.jl/pull/307 * nonlinearsolve https://github.com/SciML/NonlinearSolve.jl/pull/238 and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice. It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way. There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that. fix typo / test fix precompilation choices Update src/composite_algs.jl Co-authored-by: Nathanael Bosch Update src/composite_algs.jl switch CompositeCache away from tuple so it can start undef Default Cache fix precompile remove fallbacks remove fallbacks --- Project.toml | 223 +++++++++++---------- src/OrdinaryDiffEq.jl | 21 +- src/alg_utils.jl | 29 ++- src/algorithms.jl | 37 ++++ src/caches/basic_caches.jl | 55 +++-- src/caches/verner_caches.jl | 30 ++- src/composite_algs.jl | 160 ++++++++++++--- src/perform_step/composite_perform_step.jl | 149 ++++++++------ src/perform_step/verner_rk_perform_step.jl | 42 ++-- src/solve.jl | 22 +- 10 files changed, 500 insertions(+), 268 deletions(-) diff --git a/Project.toml b/Project.toml index bfe917a7cb..163e27d399 100644 --- a/Project.toml +++ b/Project.toml @@ -1,111 +1,112 @@ -name = "OrdinaryDiffEq" -uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -authors = ["Chris Rackauckas ", "Yingbo Ma "] -version = "6.75.0" - -[deps] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" -DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18" -FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" -FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" -IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" -MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" -Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" -PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -Preferences = "21216c6a-2e73-6563-6e65-726566657250" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" -SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" - -[compat] -ADTypes = "0.2, 1" -Adapt = "3.0, 4" -ArrayInterface = "7" -DataStructures = "0.18" -DiffEqBase = "6.147" -DocStringExtensions = "0.9" -ExponentialUtilities = "1.22" -FastBroadcast = "0.2" -FastClosures = "0.3" -FillArrays = "1.9" -FiniteDiff = "2" -ForwardDiff = "0.10.3" -FunctionWrappersWrappers = "0.1" -IfElse = "0.1" -InteractiveUtils = "1.9" -LineSearches = "7" -LinearAlgebra = "1.9" -LinearSolve = "2.1.10" -Logging = "1.9" -MacroTools = "0.5" -MuladdMacro = "0.2.1" -NLsolve = "4" -NonlinearSolve = "3.7.3" -Polyester = "0.7" -PreallocationTools = "0.4.15" -PrecompileTools = "1" -Preferences = "1.3" -RecursiveArrayTools = "2.36, 3" -Reexport = "1.0" -SciMLBase = "2.27.1" -SciMLOperators = "0.3" -SimpleNonlinearSolve = "1" -SimpleUnPack = "1" -SparseArrays = "1.9" -SparseDiffTools = "2.3" -StaticArrayInterface = "1.2" -StaticArrays = "1.0" -TruncatedStacktraces = "1.2" -julia = "1.10" - -[extras] -AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c" -Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" -DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" -ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" -IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" -NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" -ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[targets] -test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve"] +name = "OrdinaryDiffEq" +uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +authors = ["Chris Rackauckas ", "Yingbo Ma "] +version = "6.75.0" + +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" +IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" + +[compat] +ADTypes = "0.2, 1" +Adapt = "3.0, 4" +ArrayInterface = "7" +DataStructures = "0.18" +DiffEqBase = "6.147" +DocStringExtensions = "0.9" +ExponentialUtilities = "1.22" +FastBroadcast = "0.2" +FastClosures = "0.3" +FillArrays = "1.9" +FiniteDiff = "2" +ForwardDiff = "0.10.3" +FunctionWrappersWrappers = "0.1" +IfElse = "0.1" +InteractiveUtils = "1.9" +LineSearches = "7" +LinearAlgebra = "1.9" +LinearSolve = "2.1.10" +Logging = "1.9" +MacroTools = "0.5" +MuladdMacro = "0.2.1" +NLsolve = "4" +NonlinearSolve = "3.7.3" +Polyester = "0.7" +PreallocationTools = "0.4.15" +PrecompileTools = "1" +Preferences = "1.3" +RecursiveArrayTools = "2.36, 3" +Reexport = "1.0" +SciMLBase = "2.27.1" +SciMLOperators = "0.3" +SimpleNonlinearSolve = "1" +SimpleUnPack = "1" +SparseArrays = "1.9" +SparseDiffTools = "2.3" +StaticArrayInterface = "1.2" +StaticArrays = "1.0" +TruncatedStacktraces = "1.2" +julia = "1.10" + +[extras] +AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c" +Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" +DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" +ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" +IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[targets] +test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve"] diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index f6502e861b..baea651d7d 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -26,6 +26,8 @@ using LinearSolve, SimpleNonlinearSolve using LineSearches +import EnumX + import FillArrays: Trues # Interfaces @@ -139,6 +141,7 @@ include("nlsolve/functional.jl") include("nlsolve/newton.jl") include("generic_rosenbrock.jl") +include("composite_algs.jl") include("caches/basic_caches.jl") include("caches/low_order_rk_caches.jl") @@ -232,7 +235,6 @@ include("constants.jl") include("solve.jl") include("initdt.jl") include("interp_func.jl") -include("composite_algs.jl") import PrecompileTools @@ -251,9 +253,14 @@ PrecompileTools.@compile_workload begin Tsit5(), Vern7() ] - stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false), - Rodas5P(), Rodas5P(autodiff = false), - FBDF(), FBDF(autodiff = false) + stiff = [Rosenbrock23(), + Rodas5P(), + FBDF() + ] + + default_ode = [ + DefaultODEAlgorithm(autodiff=false), + DefaultODEAlgorithm() ] autoswitch = [ @@ -282,7 +289,11 @@ PrecompileTools.@compile_workload begin append!(solver_list, stiff) end - if Preferences.@load_preference("PrecompileAutoSwitch", true) + if Preferences.@load_preference("PrecompileDefault", true) + append!(solver_list, default_ode) + end + + if Preferences.@load_preference("PrecompileAutoSwitch", false) append!(solver_list, autoswitch) end diff --git a/src/alg_utils.jl b/src/alg_utils.jl index b4acdaa7ec..2c78247d1d 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -168,6 +168,8 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs)) isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs)) +isdtchangeable(alg::DefaultSolverAlgorithm) = true + function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4, CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2}) false @@ -180,12 +182,14 @@ ismultistep(alg::ETD2) = true isadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false isadaptive(alg::OrdinaryDiffEqAdaptiveAlgorithm) = true isadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = all(isadaptive.(alg.algs)) +isadaptive(alg::DefaultSolverAlgorithm) = true isadaptive(alg::DImplicitEuler) = true isadaptive(alg::DABDF2) = true isadaptive(alg::DFBDF) = true anyadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = isadaptive(alg) anyadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = any(isadaptive, alg.algs) +anyadaptive(alg::DefaultSolverAlgorithm) = true isautoswitch(alg) = false isautoswitch(alg::CompositeAlgorithm) = alg.choice_function isa AutoSwitch @@ -195,9 +199,11 @@ function qmin_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) end qmin_default(alg::CompositeAlgorithm) = maximum(qmin_default.(alg.algs)) qmin_default(alg::DP8) = 1 // 3 +qmin_default(alg::DefaultSolverAlgorithm) = 1 // 5 qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10 qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs)) +qmax_default(alg::DefaultSolverAlgorithm) = 10 qmax_default(alg::DP8) = 6 qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8 @@ -283,7 +289,7 @@ end function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob) algs = map(alg -> DiffEqBase.prepare_alg(alg, u0, p, prob), alg.algs) - CompositeAlgorithm(algs, alg.choice_function) + CompositeAlgorithm(algs, alg.choice_function,) end has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false @@ -366,6 +372,7 @@ end alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs)) +alg_extrapolates(alg::DefaultSolverAlgorithm) = false alg_extrapolates(alg::ImplicitEuler) = true alg_extrapolates(alg::DImplicitEuler) = true alg_extrapolates(alg::DABDF2) = true @@ -726,6 +733,7 @@ alg_order(alg::QPRK98) = 9 alg_maximum_order(alg) = alg_order(alg) alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs) +alg_maximum_order(alg::DefaultSolverAlgorithm) = 7 alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1) alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1) alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1) @@ -862,6 +870,7 @@ function gamma_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) isadaptive(alg) ? 9 // 10 : 0 end gamma_default(alg::CompositeAlgorithm) = maximum(gamma_default, alg.algs) +gamma_default(alg::DefaultSolverAlgorithm) = 9 // 10 gamma_default(alg::RKC) = 8 // 10 gamma_default(alg::IRKC) = 8 // 10 function gamma_default(alg::ExtrapolationMidpointDeuflhard) @@ -982,14 +991,18 @@ function unwrap_alg(integrator, is_stiff) if !iscomp return alg elseif alg.choice_function isa AutoSwitchCache - if is_stiff === nothing - throwautoswitch(alg) - end - num = is_stiff ? 2 : 1 - if num == 1 - return alg.algs[1] + if alg.choice_function.algtrait isa DefaultODESolver + alg.algs[alg.choice_function.current] else - return alg.algs[2] + if is_stiff === nothing + throwautoswitch(alg) + end + num = is_stiff ? 2 : 1 + if num == 1 + return alg.algs[1] + else + return alg.algs[2] + end end else return _eval_index(identity, alg.algs, integrator.cache.current) diff --git a/src/algorithms.jl b/src/algorithms.jl index 0463c89c64..723f59facd 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3186,6 +3186,9 @@ end struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F + function CompositeAlgorithm(algs, choice_function) + new{typeof(algs), typeof(choice_function)}(algs, choice_function) + end end TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1 @@ -3194,6 +3197,40 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) Base.Experimental.silence!(CompositeAlgorithm) end +mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T} + algtrait::Trait + count::Int + successive_switches::Int + nonstiffalg::nAlg + stiffalg::sAlg + is_stiffalg::Bool + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int + current::Int +end + +struct AutoSwitch{Trait, nAlg, sAlg, tolType, T} + algtrait::Trait + nonstiffalg::nAlg + stiffalg::sAlg + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int +end + +struct DefaultODESolver end +const DefaultSolverAlgorithm = Union{CompositeAlgorithm{<:Tuple, <:AutoSwitch{DefaultODESolver}}, +CompositeAlgorithm{<:Tuple, <:AutoSwitchCache{DefaultODESolver}}} + ################################################################################ """ MEBDF2: Multistep Method diff --git a/src/caches/basic_caches.jl b/src/caches/basic_caches.jl index fa492dbe7a..5853f6449f 100644 --- a/src/caches/basic_caches.jl +++ b/src/caches/basic_caches.jl @@ -12,24 +12,26 @@ end TruncatedStacktraces.@truncate_stacktrace CompositeCache 1 -if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) - Base.Experimental.silence!(CompositeCache) +mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache + args::A + choice_function::F + current::Int + cache1::T1 + cache2::T2 + cache3::T3 + cache4::T4 + cache5::T5 + cache6::T6 + function DefaultCache{T1, T2, T3, T4, T5, T6, F}(args, choice_function, current) where {T1, T2, T3, T4, T5, T6, F} + new{T1, T2, T3, T4, T5, T6, typeof(args), F}(args, choice_function, current) + end end -function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype, - ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, - uprev2, f, t, dt, reltol, p, calck, - ::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits, - tTypeNoUnits} - caches = ( - alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, - tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)), - alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, - tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V))) - CompositeCache(caches, alg.choice_function, 1) +TruncatedStacktraces.@truncate_stacktrace DefaultCache 1 + +if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) + Base.Experimental.silence!(CompositeCache) + Base.Experimental.silence!(DefaultCache) end function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -38,7 +40,26 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)) - CompositeCache(caches, alg.choice_function, 1) + CompositeCache{typeof(caches), typeof(alg.choice_function)}( + caches, alg.choice_function, 1) +end + +function alg_cache(alg::CompositeAlgorithm{Tuple{A1, A2, A3, A4, A5, A6}}, u, + rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6} + + args = (u, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, + reltol, p, calck, Val(V)) + argT = map(typeof, args) + T1 = Base.promote_op(alg_cache, A1, argT...) + T2 = Base.promote_op(alg_cache, A2, argT...) + T3 = Base.promote_op(alg_cache, A3, argT...) + T4 = Base.promote_op(alg_cache, A4, argT...) + T5 = Base.promote_op(alg_cache, A5, argT...) + T6 = Base.promote_op(alg_cache, A6, argT...) + DefaultCache{T1, T2, T3, T4, T5, T6, typeof(alg.choice_function)}(args, alg.choice_function, 1) end # map + closure approach doesn't infer diff --git a/src/caches/verner_caches.jl b/src/caches/verner_caches.jl index 08b1de0919..6de34d0559 100644 --- a/src/caches/verner_caches.jl +++ b/src/caches/verner_caches.jl @@ -20,6 +20,7 @@ stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern6Cache 1 @@ -44,11 +45,12 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern6Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, utilde, tmp, rtmp, atmp, tab, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern6ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -56,7 +58,7 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern6ConstantCache(tab) + Vern6ConstantCache(tab, alg.lazy) end @cache struct Vern7Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -81,6 +83,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern7Cache 1 @@ -105,16 +108,18 @@ function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern7Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end -struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern7ConstantCache() + Vern7ConstantCache(alg.lazy) end @cache struct Vern8Cache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter, @@ -143,6 +148,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern8Cache 1 @@ -171,11 +177,12 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern8Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, utilde, - tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread) + tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern8ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -183,7 +190,7 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern8ConstantCache(tab) + Vern8ConstantCache(tab, alg.lazy) end @cache struct Vern9Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -214,6 +221,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern9Cache 1 @@ -245,14 +253,16 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern9Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, utilde, tmp, rtmp, atmp, alg.stage_limiter!, alg.step_limiter!, - alg.thread) + alg.thread, alg.lazy) end -struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern9ConstantCache() + Vern9ConstantCache(alg.lazy) end diff --git a/src/composite_algs.jl b/src/composite_algs.jl index 9f1d4dcf6b..1f76765b68 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -1,30 +1,8 @@ -mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T} - count::Int - successive_switches::Int - nonstiffalg::nAlg - stiffalg::sAlg - is_stiffalg::Bool - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end +### AutoSwitch +### Designed to switch between two solvers, stiff and non-stiff -struct AutoSwitch{nAlg, sAlg, tolType, T} - nonstiffalg::nAlg - stiffalg::sAlg - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end -function AutoSwitch(nonstiffalg, stiffalg; maxstiffstep = 10, maxnonstiffstep = 3, +function AutoSwitch(nonstiffalg, stiffalg, algtrait = nothing; + maxstiffstep = 10, maxnonstiffstep = 3, nonstifftol = 9 // 10, stifftol = 9 // 10, dtfac = 2, stiffalgfirst = false, switch_max = 5) @@ -41,7 +19,6 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) if !bool integrator.alg.choice_function.successive_switches += 1 - integrator.do_error_check = false else integrator.alg.choice_function.successive_switches = 0 end @@ -53,7 +30,11 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) end function (AS::AutoSwitchCache)(integrator) - integrator.iter == 0 && return Int(AS.stiffalgfirst) + 1 + if AS.current == 0 + AS.current = Int(AS.stiffalgfirst) + 1 + return AS.current + end + dt = integrator.dt # Successive stiffness test positives are counted by a positive integer, # and successive stiffness test negatives are counted by a negative integer @@ -68,17 +49,134 @@ function (AS::AutoSwitchCache)(integrator) integrator.dt = dt / AS.dtfac AS.is_stiffalg = false end - return Int(AS.is_stiffalg) + 1 + AS.current = Int(AS.is_stiffalg) + 1 + return AS.current end -function AutoAlgSwitch(nonstiffalg, stiffalg; kwargs...) - AS = AutoSwitch(nonstiffalg, stiffalg; kwargs...) +function AutoAlgSwitch(nonstiffalg::OrdinaryDiffEqAlgorithm, stiffalg::OrdinaryDiffEqAlgorithm, algtrait = nothing; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) CompositeAlgorithm((nonstiffalg, stiffalg), AS) end +function AutoAlgSwitch(nonstiffalg::Tuple, stiffalg::Tuple, algtrait; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) + CompositeAlgorithm((nonstiffalg..., stiffalg...), AS) +end + AutoTsit5(alg; kwargs...) = AutoAlgSwitch(Tsit5(), alg; kwargs...) AutoDP5(alg; kwargs...) = AutoAlgSwitch(DP5(), alg; kwargs...) AutoVern6(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern6(lazy = lazy), alg; kwargs...) AutoVern7(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern7(lazy = lazy), alg; kwargs...) AutoVern8(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern8(lazy = lazy), alg; kwargs...) AutoVern9(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern9(lazy = lazy), alg; kwargs...) + +### Default ODE Solver + +EnumX.@enumx DefaultSolverChoice begin + Tsit5 = 1 + Vern7 = 2 + Rosenbrock23 = 3 + Rodas5P = 4 + FBDF = 5 + KrylovFBDF = 6 +end + +const NUM_NONSTIFF = 2 +const NUM_STIFF = 4 +const LOW_TOL = 1e-6 +const MED_TOL = 1e-2 +const EXTREME_TOL = 1e-9 +const SMALLSIZE = 50 +const MEDIUMSIZE = 500 +const STABILITY_SIZES = (alg_stability_size(Tsit5()), alg_stability_size(Vern7())) +const DEFAULTBETA2S = (beta2_default(Tsit5()), beta2_default(Vern7()), beta2_default(Rosenbrock23()), beta2_default(Rodas5P()), beta2_default(FBDF()), beta2_default(FBDF())) +const DEFAULTBETA1S = (beta1_default(Tsit5(),DEFAULTBETA2S[1]), beta1_default(Vern7(),DEFAULTBETA2S[2]), + beta1_default(Rosenbrock23(), DEFAULTBETA2S[3]), beta1_default(Rodas5P(), DEFAULTBETA2S[4]), + beta1_default(FBDF(), DEFAULTBETA2S[5]), beta1_default(FBDF(), DEFAULTBETA2S[6])) + +callbacks_exists(integrator) = !isempty(integrator.opts.callbacks) +current_nonstiff(current) = ifelse(current <= NUM_NONSTIFF,current,current-NUM_STIFF) + +function DefaultODEAlgorithm(; lazy = true, stiffalgfirst = false, kwargs...) + nonstiff = (Tsit5(), Vern7(lazy = lazy)) + stiff = (Rosenbrock23(;kwargs...), Rodas5P(;kwargs...), FBDF(;kwargs...), FBDF(;linsolve = LinearSolve.KrylovJL_GMRES())) + AutoAlgSwitch(nonstiff, stiff, DefaultODESolver(); stiffalgfirst) +end + +function is_stiff(integrator, alg, ntol, stol, is_stiffalg, current) + eigen_est, dt = integrator.eigen_est, integrator.dt + stiffness = abs(eigen_est * dt / STABILITY_SIZES[nonstiffchoice(integrator.opts.reltol)]) # `abs` here is just for safety + tol = is_stiffalg ? stol : ntol + os = oneunit(stiffness) + bool = stiffness > os * tol + + if !bool + integrator.alg.choice_function.successive_switches += 1 + else + integrator.alg.choice_function.successive_switches = 0 + end + + integrator.do_error_check = (integrator.alg.choice_function.successive_switches > + integrator.alg.choice_function.switch_max || !bool) || + is_stiffalg + bool +end + +function nonstiffchoice(reltol) + x = if reltol < LOW_TOL + DefaultSolverChoice.Vern7 + else + DefaultSolverChoice.Tsit5 + end + Int(x) +end + +function stiffchoice(reltol, len) + x = if len > MEDIUMSIZE + DefaultSolverChoice.KrylovFBDF + elseif len > SMALLSIZE + DefaultSolverChoice.FBDF + else + if reltol < LOW_TOL + DefaultSolverChoice.Rodas5P + else + DefaultSolverChoice.Rosenbrock23 + end + end + Int(x) +end + +function (AS::AutoSwitchCache{DefaultODESolver})(integrator) + + len = length(integrator.u) + reltol = integrator.opts.reltol + + # Chooose the starting method + if AS.current == 0 + choice = if AS.stiffalgfirst || integrator.f.mass_matrix != I + stiffchoice(reltol, len) + else + nonstiffchoice(reltol) + end + AS.current = choice + return AS.current + end + + dt = integrator.dt + # Successive stiffness test positives are counted by a positive integer, + # and successive stiffness test negatives are counted by a negative integer + AS.count = is_stiff(integrator, AS.nonstiffalg, AS.nonstifftol, AS.stifftol, + AS.is_stiffalg, AS.current) ? + AS.count < 0 ? 1 : AS.count + 1 : + AS.count > 0 ? -1 : AS.count - 1 + if (!AS.is_stiffalg && AS.count > AS.maxstiffstep) + integrator.dt = dt * AS.dtfac + AS.is_stiffalg = true + AS.current = stiffchoice(reltol, len) + elseif (AS.is_stiffalg && AS.count < -AS.maxnonstiffstep) + integrator.dt = dt / AS.dtfac + AS.is_stiffalg = false + AS.current = nonstiffchoice(reltol) + end + return AS.current +end diff --git a/src/perform_step/composite_perform_step.jl b/src/perform_step/composite_perform_step.jl index 75057a3adc..36cf00230d 100644 --- a/src/perform_step/composite_perform_step.jl +++ b/src/perform_step/composite_perform_step.jl @@ -1,38 +1,50 @@ -#= - -Maybe do generated functions to reduce dispatch times? - -f(x) = x -g(x,i) = f(x[i]) -g{i}(x,::Type{Val{i}}) = f(x[i]) -@generated function gg(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(f(tup[i])) i->error("unreachable")) - end -h(i) = g((1,1.0,"foo"), i) -h2{i}(::Type{Val{i}}) = g((1,1.0,"foo"), Val{i}) -h3(i) = gg((1,1.0,"foo"), i) -@benchmark h(1) -mean time: 31.822 ns (0.00% GC) -@benchmark h2(Val{1}) -mean time: 1.585 ns (0.00% GC) -@benchmark h3(1) -mean time: 6.423 ns (0.00% GC) - -@generated function foo(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) -end - -@code_typed foo((1,1.0), 1) - -@generated function perform_step!(integrator, cache::CompositeCache, repeat_step=false) - N = length(cache.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) +function initialize!(integrator, cache::DefaultCache) + cache.current = cache.choice_function(integrator) + algs = integrator.alg.algs + if cache.current == 1 + if !isdefined(cache, :cache1) + cache.cache1 = alg_cache(algs[1], cache.args...) + end + initialize!(integrator, cache.cache1) + elseif cache.current == 2 + if !isdefined(cache, :cache2) + cache.cache2 = alg_cache(algs[2], cache.args...) + end + initialize!(integrator, cache.cache2) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2]) + elseif cache.current == 3 + if !isdefined(cache, :cache3) + cache.cache3 = alg_cache(algs[3], cache.args...) + end + initialize!(integrator, cache.cache3) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3]) + elseif cache.current == 4 + if !isdefined(cache, :cache4) + cache.cache4 = alg_cache(algs[4], cache.args...) + end + initialize!(integrator, cache.cache4) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4]) + elseif cache.current == 5 + if !isdefined(cache, :cache5) + cache.cache5 = alg_cache(algs[5], cache.args...) + end + initialize!(integrator, cache.cache5) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5]) + elseif cache.current == 6 + if !isdefined(cache, :cache6) + cache.cache6 = alg_cache(algs[6], cache.args...) + end + initialize!(integrator, cache.cache6) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[6]) + end + resize!(integrator.k, integrator.kshortsize) end -=# - function initialize!(integrator, cache::CompositeCache) cache.current = cache.choice_function(integrator) if cache.current == 1 @@ -69,28 +81,35 @@ the behaviour is consistent. In particular, prevents dt ⟶ 0 if starting with non-adaptive alg and opts.adaptive=true, and dt=cst if starting with adaptive alg and opts.adaptive=false. """ -function ensure_behaving_adaptivity!(integrator, cache::CompositeCache) +function ensure_behaving_adaptivity!(integrator, cache::Union{DefaultCache, CompositeCache}) if anyadaptive(integrator.alg) && !isadaptive(integrator.alg) integrator.opts.adaptive = isadaptive(integrator.alg.algs[cache.current]) end end -function perform_step!(integrator, cache::CompositeCache, repeat_step = false) +function perform_step!(integrator, cache::DefaultCache, repeat_step = false) if cache.current == 1 - perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) + perform_step!(integrator, @inbounds(cache.cache1), repeat_step) elseif cache.current == 2 - perform_step!(integrator, @inbounds(cache.caches[2]), repeat_step) - else - perform_step!(integrator, @inbounds(cache.caches[cache.current]), repeat_step) + perform_step!(integrator, @inbounds(cache.cache2), repeat_step) + elseif cache.current == 3 + perform_step!(integrator, @inbounds(cache.cache3), repeat_step) + elseif cache.current == 4 + perform_step!(integrator, @inbounds(cache.cache4), repeat_step) + elseif cache.current == 5 + perform_step!(integrator, @inbounds(cache.cache5), repeat_step) + elseif cache.current == 6 + perform_step!(integrator, @inbounds(cache.cache6), repeat_step) end end -function perform_step!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}, - repeat_step = false) where {T1, T2, F} +function perform_step!(integrator, cache::CompositeCache, repeat_step = false) if cache.current == 1 perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) elseif cache.current == 2 perform_step!(integrator, @inbounds(cache.caches[2]), repeat_step) + else + perform_step!(integrator, @inbounds(cache.caches[cache.current]), repeat_step) end end @@ -122,6 +141,24 @@ function choose_algorithm!(integrator, end function choose_algorithm!(integrator, cache::CompositeCache) + new_current = cache.choice_function(integrator) + old_current = cache.current + @inbounds if new_current != old_current + cache.current = new_current + initialize!(integrator, @inbounds(cache.caches[new_current])) + + controller.beta2 = beta2_default(alg2) + controller.beta1 = beta2_default(alg2) + DEFAULTBETA2S + + reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], + integrator.alg.algs[new_current]) + transfer_cache!(integrator, integrator.cache.caches[old_current], + integrator.cache.caches[new_current]) + end +end + +function choose_algorithm!(integrator, cache::CompositeCache{<:Any, <:AutoSwitchCache{DefaultODESolver}}) new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current @@ -130,26 +167,23 @@ function choose_algorithm!(integrator, cache::CompositeCache) initialize!(integrator, @inbounds(cache.caches[1])) elseif new_current == 2 initialize!(integrator, @inbounds(cache.caches[2])) + elseif new_current == 3 + initialize!(integrator, @inbounds(cache.caches[3])) + elseif new_current == 4 + initialize!(integrator, @inbounds(cache.caches[4])) + elseif new_current == 5 + initialize!(integrator, @inbounds(cache.caches[5])) + elseif new_current == 6 + initialize!(integrator, @inbounds(cache.caches[6])) else initialize!(integrator, @inbounds(cache.caches[new_current])) end - if old_current == 1 && new_current == 2 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[1], - integrator.alg.algs[2]) - transfer_cache!(integrator, integrator.cache.caches[1], - integrator.cache.caches[2]) - elseif old_current == 2 && new_current == 1 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[2], - integrator.alg.algs[1]) - transfer_cache!(integrator, integrator.cache.caches[2], - integrator.cache.caches[1]) - else - reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], - integrator.alg.algs[new_current]) - transfer_cache!(integrator, integrator.cache.caches[old_current], - integrator.cache.caches[new_current]) - end + + # dtchangable, qmin_default, qmax_default, and isadaptive ignored since all same + integrator.opts.controller.beta1 = DEFAULTBETA1S[new_current] + integrator.opts.controller.beta2 = DEFAULTBETA2S[new_current] end + nothing end """ @@ -170,6 +204,7 @@ function reset_alg_dependent_opts!(integrator, alg1, alg2) integrator.opts.qmax == qmax_default(alg2) end reset_alg_dependent_opts!(integrator.opts.controller, alg1, alg2) + nothing end # Write how to transfer the cache variables from one cache to the other diff --git a/src/perform_step/verner_rk_perform_step.jl b/src/perform_step/verner_rk_perform_step.jl index 0a44e217da..7c76e26719 100644 --- a/src/perform_step/verner_rk_perform_step.jl +++ b/src/perform_step/verner_rk_perform_step.jl @@ -2,7 +2,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal integrator.stats.nf += 1 alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -13,7 +13,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) end integrator.k[integrator.kshortsize] = integrator.fsallast - if !alg.lazy + if !cache.lazy @inbounds for i in 10:12 integrator.k[i] = zero(integrator.fsalfirst) end @@ -63,7 +63,7 @@ end integrator.k[9] = k9 alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -94,7 +94,7 @@ end function initialize!(integrator, cache::Vern6Cache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.fsalfirst = cache.k1 integrator.fsallast = cache.k9 @unpack k = integrator @@ -109,7 +109,7 @@ function initialize!(integrator, cache::Vern6Cache) k[8] = cache.k8 k[9] = cache.k9 # Set the pointers - if !alg.lazy + if !cache.lazy k[10] = similar(cache.k1) k[11] = similar(cache.k1) k[12] = similar(cache.k1) @@ -182,7 +182,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -214,7 +214,7 @@ end function initialize!(integrator, cache::Vern7ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -277,7 +277,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern7ExtraStages T T2 @@ -329,7 +329,7 @@ function initialize!(integrator, cache::Vern7Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -342,7 +342,7 @@ function initialize!(integrator, cache::Vern7Cache) k[9] = k9 k[10] = k10 # Setup pointers - if !alg.lazy + if !cache.lazy k[11] = similar(cache.k1) k[12] = similar(cache.k1) k[13] = similar(cache.k1) @@ -433,7 +433,7 @@ end integrator.EEst = integrator.opts.internalnorm(atmp, t) end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache @@ -489,7 +489,7 @@ end function initialize!(integrator, cache::Vern8ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -575,7 +575,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -642,7 +642,7 @@ function initialize!(integrator, cache::Vern8Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -658,7 +658,7 @@ function initialize!(integrator, cache::Vern8Cache) k[12] = k12 k[13] = k13 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 14:21 k[i] = similar(cache.k1) end @@ -764,7 +764,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -851,7 +851,7 @@ end function initialize!(integrator, cache::Vern9ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -945,7 +945,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern9ExtraStages T T2 @@ -1032,7 +1032,7 @@ function initialize!(integrator, cache::Vern9Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) resize!(k, integrator.kshortsize) # k2, k3,k4,k5,k6,k7 are not used in the code (not even in interpolations), we dont need their pointers. # So we mapped k[2] (from integrator) with k8 (from cache), k[3] with k9 and so on. @@ -1047,7 +1047,7 @@ function initialize!(integrator, cache::Vern9Cache) k[9] = k15 k[10] = k16 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 11:20 k[i] = similar(cache.k1) end @@ -1174,7 +1174,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache diff --git a/src/solve.jl b/src/solve.jl index e29136450e..2b5c26efb6 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -136,7 +136,7 @@ function DiffEqBase.__init( if alg isa CompositeAlgorithm && alg.choice_function isa AutoSwitch auto = alg.choice_function _alg = CompositeAlgorithm(alg.algs, - AutoSwitchCache(0, 0, + AutoSwitchCache(auto.algtrait, 0, 0, auto.nonstiffalg, auto.stiffalg, auto.stiffalgfirst, @@ -146,7 +146,7 @@ function DiffEqBase.__init( auto.stifftol, auto.dtfac, auto.stiffalgfirst, - auto.switch_max)) + auto.switch_max, 0)) else _alg = alg end @@ -415,12 +415,18 @@ function DiffEqBase.__init( differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) - id = InterpolationData( - f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false) - sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - alg_choice = alg_choice, - calculate_error = false, stats = stats) + if _alg isa OrdinaryDiffEqCompositeAlgorithm + id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + alg_choice = alg_choice, + calculate_error = false, stats = stats) + else + id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + calculate_error = false, stats = stats) + end if recompile_flag == true FType = typeof(f)