From ed4234d5bf104a47afd80aa24afe1be85573921c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 3 Nov 2023 23:09:23 -0400 Subject: [PATCH 1/6] Add documentation about the solvers --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index b240e092..e158c4c2 100644 --- a/README.md +++ b/README.md @@ -48,9 +48,12 @@ Precompilation can be controlled via `Preferences.jl` - `PrecompileMIRK` -- Precompile the MIRK2 - MIRK6 algorithms (default: `true`). - `PrecompileShooting` -- Precompile the single shooting algorithms (default: `true`). This is triggered when `OrdinaryDiffEq` is loaded. - `PrecompileMultipleShooting` -- Precompile the multiple shooting algorithms (default: `true`). This is triggered when `OrdinaryDiffEq` is loaded. +<<<<<<< HEAD - `PrecompileMIRKNLLS` -- Precompile the MIRK2 - MIRK6 algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). - `PrecompileShootingNLLS` -- Precompile the single shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded. - `PrecompileMultipleShootingNLLS` -- Precompile the multiple shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded. +======= +>>>>>>> 5d53c01 (Add documentation about the solvers) To set these preferences before loading the package, do the following (replacing `PrecompileShooting` with the preference you want to set, or pass in multiple pairs to set them together): From 6526a574f44b712147fe3eb7371231928149da94 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 3 Nov 2023 21:46:44 -0400 Subject: [PATCH 2/6] Support for u0 initial guess function --- Project.toml | 3 +++ src/BoundaryValueDiffEq.jl | 27 +++++++++++++++++++++++++- src/solve/mirk.jl | 14 +++++++++----- src/solve/multiple_shooting.jl | 25 ++++++++++++++---------- src/utils.jl | 35 +++++++++++++++++++++++++++------- 5 files changed, 81 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 4da652ec..83ec55e6 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" @@ -22,6 +23,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" @@ -56,6 +58,7 @@ SciMLBase = "2.5" Setfield = "1" SparseArrays = "1.9" SparseDiffTools = "2.9" +Tricks = "0.1" TruncatedStacktraces = "1" UnPack = "1" julia = "1.9" diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index 0113cc87..5d51a7db 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -5,7 +5,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat @recompile_invalidations begin using ADTypes, Adapt, BandedMatrices, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve, PreallocationTools, Preferences, RecursiveArrayTools, Reexport, - SciMLBase, Setfield, SparseArrays, SparseDiffTools + SciMLBase, Setfield, SparseArrays, SparseDiffTools, Tricks import ADTypes: AbstractADType import ArrayInterface: matrix_colors, @@ -22,6 +22,31 @@ end @reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase +# TODO: Upstream +# For BVPs we want to propagate even a function u0 +function DiffEqBase.get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs) + if haskey(kwargs, :u0) + u0 = kwargs[:u0] + else + u0 = prob.u0 + end + + isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) + + _u0 = DiffEqBase.handle_distribution_u0(u0) + + if isinplace(prob) && (_u0 isa Number || _u0 isa DiffEqBase.SArray) + throw(DiffEqBase.IncompatibleInitialConditionError()) + end + + if _u0 isa Tuple + throw(DiffEqBase.TupleStateError()) + end + + return _u0 +end +# End of Upstream + include("types.jl") include("utils.jl") include("algorithms.jl") diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 45857e73..468484ca 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -35,7 +35,13 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, abstol = 1e-3, adaptive = true, kwargs...) @set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg) iip = isinplace(prob) + _, T, M, n, X = __extract_problem_details(prob; dt, check_positive_dt = true) + # NOTE: Assumes the user provided initial guess is on a uniform mesh + mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1)) + + mesh_dt = diff(mesh) + chunksize = pickchunksize(M * (n + 1)) __alloc = x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg) @@ -43,10 +49,6 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, fᵢ_cache = __alloc(similar(X)) fᵢ₂_cache = vec(similar(X)) - # NOTE: Assumes the user provided initial guess is on a uniform mesh - mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1)) - mesh_dt = diff(mesh) - defect_threshold = T(0.1) # TODO: Allow user to specify these MxNsub = 3000 # TODO: Allow user to specify these @@ -100,7 +102,9 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, vecf, vecbc end - return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob, + prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob + + return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages, resid₁_size, (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...)) diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index 9326c3db..f26f9924 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -49,8 +49,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), ode_cache_loss_fn; kwargs..., verbose, odesolve_kwargs...) else u_at_nodes = __multiple_shooting_initialize!(nodes, u_at_nodes, prob, alg, - cur_nshoot, all_nshoots[i - 1], ig, ode_cache_loss_fn; kwargs..., verbose, - odesolve_kwargs...) + cur_nshoot, all_nshoots[i - 1], ig, ode_cache_loss_fn, u0; kwargs..., + verbose, odesolve_kwargs...) end if prob.problem_type isa TwoPointBVProblem @@ -362,9 +362,13 @@ end resize!(nodes, nshoots + 1) nodes .= range(tspan[1], tspan[2]; length = nshoots + 1) - N = length(first(u0)) - u_at_nodes = similar(first(u0), (nshoots + 1) * N) - recursive_flatten!(u_at_nodes, u0) + # NOTE: We don't check `u0 isa Function` since `u0` in-principle can be a callable + # struct + u0_ = u0 isa AbstractArray ? u0 : [__initial_guess(u0, prob.p, t) for t in nodes] + + N = length(first(u0_)) + u_at_nodes = similar(first(u0_), (nshoots + 1) * N) + recursive_flatten!(u_at_nodes, u0_) return u_at_nodes end @@ -401,7 +405,8 @@ end end else @warn "Initialization using odesolve failed. Initializing using 0s. It is \ - recommended to provide an `initial_guess` in this case." + recommended to provide an initial guess function via \ + `u0 = (p, t)` or `u0 = (t)` in this case." fill!(u_at_nodes, 0) end @@ -410,16 +415,16 @@ end # Grid coarsening @views function __multiple_shooting_initialize!(nodes, u_at_nodes_prev, prob, alg, - nshoots, old_nshoots, ig, odecache_; kwargs...) - @unpack f, u0, tspan, p = prob + nshoots, old_nshoots, ig, odecache_, u0; kwargs...) + @unpack f, tspan, p = prob prev_nodes = copy(nodes) odecache = odecache_ isa Vector ? first(odecache_) : odecache_ resize!(nodes, nshoots + 1) nodes .= range(tspan[1], tspan[2]; length = nshoots + 1) - N = _unwrap_val(ig) ? length(first(u0)) : length(u0) + N = length(u0) - u_at_nodes = similar(_unwrap_val(ig) ? first(u0) : u0, N + nshoots * N) + u_at_nodes = similar(u0, N + nshoots * N) u_at_nodes[1:N] .= u_at_nodes_prev[1:N] u_at_nodes[(end - N + 1):end] .= u_at_nodes_prev[(end - N + 1):end] diff --git a/src/utils.jl b/src/utils.jl index b302267b..c380f15e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -139,17 +139,38 @@ function __extract_problem_details(prob, u0::AbstractArray; dt = 0.0, t₀, t₁ = prob.tspan return Val(false), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), prob.u0 end -function __extract_problem_details(prob, ::F; kwargs...) where {F <: Function} - throw(ArgumentError("passing `u0` as a function is not supported yet. Curently we only \ - support AbstractArray or Vector of AbstractArrays as input! \ - Use the latter format for passing in initial guess!")) +function __extract_problem_details(prob, f::F; dt = 0.0, + check_positive_dt::Bool = false) where {F <: Function} + # Problem passes in a initial guess function + check_positive_dt && dt ≤ 0 && throw(ArgumentError("dt must be positive")) + u0 = __initial_guess(f, prob.p, prob.tspan[1]) + t₀, t₁ = prob.tspan + return Val(true), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), u0 end -__initial_state_from_prob(prob::BVProblem, mesh) = __initial_state_from_prob(prob.u0, mesh) -__initial_state_from_prob(u0::AbstractArray, mesh) = [copy(vec(u0)) for _ in mesh] -function __initial_state_from_prob(u0::AbstractVector{<:AbstractVector}, _) +function __initial_guess(f::F, p::P, t::T) where {F, P, T} + if static_hasmethod(f, Tuple{P, T}) + return f(p, t) + elseif static_hasmethod(f, Tuple{T}) + return f(t) + else + throw(ArgumentError("`initial_guess` must be a function of the form `f(p, t)` or \ + `f(t)`")) + end +end + +function __initial_state_from_prob(prob::BVProblem, mesh) + return __initial_state_from_prob(prob, prob.u0, mesh) +end +function __initial_state_from_prob(::BVProblem, u0::AbstractArray, mesh) + return [copy(vec(u0)) for _ in mesh] +end +function __initial_state_from_prob(::BVProblem, u0::AbstractVector{<:AbstractVector}, _) return [copy(vec(u)) for u in u0] end +function __initial_state_from_prob(prob::BVProblem, f::F, mesh) where {F} + return [__initial_guess(f, prob.p, t) for t in mesh] +end function __get_bcresid_prototype(prob::BVProblem, u) return __get_bcresid_prototype(prob.problem_type, prob, u) From abf3be05828fa12fe3a4f68b956dd462e4326c47 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Nov 2023 10:20:28 -0400 Subject: [PATCH 3/6] Setup for a specialized Multiple Shooting Algorithm --- src/algorithms.jl | 18 ++++++++++++------ src/solve/multiple_shooting.jl | 16 +++++++++++++++- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index c735088c..0fbe46ab 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -62,7 +62,8 @@ end """ MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(), - grid_coarsening = true, jac_alg = BVPJacobianAlgorithm()) + grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(), + static_auto_nodes::Val = Val(false)) Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP. Significantly more stable than Single Shooting. @@ -97,12 +98,15 @@ Significantly more stable than Single Shooting. - `Function`: Takes the current number of shooting points and returns the next number of shooting points. For example, if `nshoots = 10` and `grid_coarsening = n -> n ÷ 2`, then the grid will be coarsened to `[5, 2]`. + - `static_auto_nodes`: Automatically detect the timepoints used in the boundary condition + and use a faster version of the algorithm! This particular keyword argument should be + considered experimental and should be used with care! !!! note For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm` must be provided. """ -@concrete struct MultipleShooting{J <: BVPJacobianAlgorithm} +@concrete struct MultipleShooting{S, J <: BVPJacobianAlgorithm} ode_alg nlsolve jac_alg::J @@ -110,9 +114,9 @@ Significantly more stable than Single Shooting. grid_coarsening end -function concretize_jacobian_algorithm(alg::MultipleShooting, prob) +function concretize_jacobian_algorithm(alg::MultipleShooting{S}, prob) where {S} jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg) - return MultipleShooting(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots, + return MultipleShooting{S}(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots, alg.grid_coarsening) end @@ -122,16 +126,18 @@ function update_nshoots(alg::MultipleShooting, nshoots::Int) end function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(), - grid_coarsening = true, jac_alg = BVPJacobianAlgorithm()) + grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(), + static_auto_nodes::Val{S} = Val(false)) where {S} @assert grid_coarsening isa Bool || grid_coarsening isa Function || grid_coarsening isa AbstractVector{<:Integer} || grid_coarsening isa NTuple{N, <:Integer} where {N} + @assert S isa Bool grid_coarsening isa Tuple && (grid_coarsening = Vector(grid_coarsening...)) if grid_coarsening isa AbstractVector sort!(grid_coarsening; rev = true) @assert all(grid_coarsening .> 0) && 1 ∉ grid_coarsening end - return MultipleShooting(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening) + return MultipleShooting{S}(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening) end for order in (2, 3, 4, 5, 6) diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index f26f9924..d3ba2619 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -1,4 +1,18 @@ -function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), +function __solve(prob::BVProblem, _alg::MultipleShooting{true}; odesolve_kwargs = (;), + nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...) + # For TwoPointBVPs there is nothing to do. Forward to general multiple shooting + prob.problem_type isa TwoPointBVProblem && + return __solve_internal(prob, _alg; kwargs...) + + # Extract the time-points used in BC + _prob = ODEProblem{isinplace(prob)}(prob.f, prob.u0, prob.tspan, prob.p) +end + +function __solve(prob::BVProblem, _alg::MultipleShooting{false}; kwargs...) + return __solve_internal(prob, _alg; kwargs...) +end + +function __solve_internal(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...) @unpack f, tspan = prob From 9d7a87a6381261cd8b98a2eca8e02b7ceffea38c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Nov 2023 15:16:19 -0400 Subject: [PATCH 4/6] Clean up rebase --- Project.toml | 2 +- README.md | 7 ++++--- src/BoundaryValueDiffEq.jl | 25 ------------------------- src/solve/multiple_shooting.jl | 3 ++- 4 files changed, 7 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index 83ec55e6..eefb3c31 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ Aqua = "0.7" ArrayInterface = "7" BandedMatrices = "1" ConcreteStructs = "0.2" -DiffEqBase = "6.135" +DiffEqBase = "6.138" ForwardDiff = "0.10" LinearAlgebra = "1.9" LinearSolve = "2" diff --git a/README.md b/README.md index e158c4c2..659d1db6 100644 --- a/README.md +++ b/README.md @@ -48,12 +48,13 @@ Precompilation can be controlled via `Preferences.jl` - `PrecompileMIRK` -- Precompile the MIRK2 - MIRK6 algorithms (default: `true`). - `PrecompileShooting` -- Precompile the single shooting algorithms (default: `true`). This is triggered when `OrdinaryDiffEq` is loaded. - `PrecompileMultipleShooting` -- Precompile the multiple shooting algorithms (default: `true`). This is triggered when `OrdinaryDiffEq` is loaded. -<<<<<<< HEAD + <<<<<<< HEAD - `PrecompileMIRKNLLS` -- Precompile the MIRK2 - MIRK6 algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). - `PrecompileShootingNLLS` -- Precompile the single shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded. - `PrecompileMultipleShootingNLLS` -- Precompile the multiple shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded. -======= ->>>>>>> 5d53c01 (Add documentation about the solvers) + ======= + +> > > > > > > 5d53c01 (Add documentation about the solvers) To set these preferences before loading the package, do the following (replacing `PrecompileShooting` with the preference you want to set, or pass in multiple pairs to set them together): diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index 5d51a7db..ef62dc20 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -22,31 +22,6 @@ end @reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase -# TODO: Upstream -# For BVPs we want to propagate even a function u0 -function DiffEqBase.get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs) - if haskey(kwargs, :u0) - u0 = kwargs[:u0] - else - u0 = prob.u0 - end - - isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - - _u0 = DiffEqBase.handle_distribution_u0(u0) - - if isinplace(prob) && (_u0 isa Number || _u0 isa DiffEqBase.SArray) - throw(DiffEqBase.IncompatibleInitialConditionError()) - end - - if _u0 isa Tuple - throw(DiffEqBase.TupleStateError()) - end - - return _u0 -end -# End of Upstream - include("types.jl") include("utils.jl") include("algorithms.jl") diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index d3ba2619..8b275f5f 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -2,7 +2,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting{true}; odesolve_kwargs nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...) # For TwoPointBVPs there is nothing to do. Forward to general multiple shooting prob.problem_type isa TwoPointBVProblem && - return __solve_internal(prob, _alg; kwargs...) + return __solve_internal(prob, _alg; odesolve_kwargs, nlsolve_kwargs, ensemblealg, + verbose, kwargs...) # Extract the time-points used in BC _prob = ODEProblem{isinplace(prob)}(prob.f, prob.u0, prob.tspan, prob.p) From af6382eb294912ec53a6494b4cfdffb633fe34b5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Nov 2023 19:20:10 -0400 Subject: [PATCH 5/6] Finish an initial prototype --- README.md | 3 - src/algorithms.jl | 32 ++++-- src/solve/multiple_shooting.jl | 177 +++++++++++++++++++++++++++++++-- src/utils.jl | 42 ++++++++ 4 files changed, 234 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 659d1db6..be5693b3 100644 --- a/README.md +++ b/README.md @@ -52,9 +52,6 @@ Precompilation can be controlled via `Preferences.jl` - `PrecompileMIRKNLLS` -- Precompile the MIRK2 - MIRK6 algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). - `PrecompileShootingNLLS` -- Precompile the single shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded. - `PrecompileMultipleShootingNLLS` -- Precompile the multiple shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded. - ======= - -> > > > > > > 5d53c01 (Add documentation about the solvers) To set these preferences before loading the package, do the following (replacing `PrecompileShooting` with the preference you want to set, or pass in multiple pairs to set them together): diff --git a/src/algorithms.jl b/src/algorithms.jl index 0fbe46ab..01226802 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -63,7 +63,7 @@ end """ MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(), grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(), - static_auto_nodes::Val = Val(false)) + auto_static_nodes::Val = Val(false)) Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP. Significantly more stable than Single Shooting. @@ -98,9 +98,13 @@ Significantly more stable than Single Shooting. - `Function`: Takes the current number of shooting points and returns the next number of shooting points. For example, if `nshoots = 10` and `grid_coarsening = n -> n ÷ 2`, then the grid will be coarsened to `[5, 2]`. - - `static_auto_nodes`: Automatically detect the timepoints used in the boundary condition + +## Experimental Features + + - `auto_static_nodes`: Automatically detect the timepoints used in the boundary condition and use a faster version of the algorithm! This particular keyword argument should be - considered experimental and should be used with care! + considered experimental and should be used with care! (Note that we ignore + `grid_coarsening` if this is set to `Val(true)`. We plan to support this in the future.) !!! note For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm` @@ -125,13 +129,23 @@ function update_nshoots(alg::MultipleShooting, nshoots::Int) alg.grid_coarsening) end +function __without_static_nodes(ms::MultipleShooting{S}) where {S} + return MultipleShooting{false}(ms.ode_alg, ms.nlsolve, ms.jac_alg, ms.nshoots, + ms.grid_coarsening) +end + function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(), - grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(), - static_auto_nodes::Val{S} = Val(false)) where {S} - @assert grid_coarsening isa Bool || grid_coarsening isa Function || - grid_coarsening isa AbstractVector{<:Integer} || - grid_coarsening isa NTuple{N, <:Integer} where {N} - @assert S isa Bool + grid_coarsening = missing, jac_alg = BVPJacobianAlgorithm(), + auto_static_nodes::Val{S} = Val(false)) where {S} + @assert S isa Bool "`auto_static_nodes` must be either `Val(true)` or `Val(false)`." + if S + @assert grid_coarsening === missing||(grid_coarsening isa Bool && !grid_coarsening) "`auto_static_nodes` doesn't support grid_coarsening." + else + grid_coarsening === missing && (grid_coarsening = false) + @assert grid_coarsening isa Bool || grid_coarsening isa Function || + grid_coarsening isa AbstractVector{<:Integer} || + grid_coarsening isa NTuple{N, <:Integer} where {N} + end grid_coarsening isa Tuple && (grid_coarsening = Vector(grid_coarsening...)) if grid_coarsening isa AbstractVector sort!(grid_coarsening; rev = true) diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index 8b275f5f..a7db9293 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -2,11 +2,78 @@ function __solve(prob::BVProblem, _alg::MultipleShooting{true}; odesolve_kwargs nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...) # For TwoPointBVPs there is nothing to do. Forward to general multiple shooting prob.problem_type isa TwoPointBVProblem && - return __solve_internal(prob, _alg; odesolve_kwargs, nlsolve_kwargs, ensemblealg, - verbose, kwargs...) + return __solve_internal(prob, __without_static_nodes(_alg); odesolve_kwargs, + nlsolve_kwargs, ensemblealg, verbose, kwargs...) + + ig, T, N, Nig, u0 = __extract_problem_details(prob; dt = 0.1) + + if _unwrap_val(ig) && prob.u0 isa AbstractVector + if verbose + @warn "Static Nodes for Multiple-Shooting is not supported when Vector of \ + initial guesses are provided. Falling back to using the generic method!" + end + return __solve_internal(prob, __without_static_nodes(_alg); odesolve_kwargs, + nlsolve_kwargs, ensemblealg, verbose, kwargs...) + end + + has_initial_guess = _unwrap_val(ig) + + bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0) + iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0) # Extract the time-points used in BC - _prob = ODEProblem{isinplace(prob)}(prob.f, prob.u0, prob.tspan, prob.p) + _prob = ODEProblem{iip}(prob.f, prob.u0, prob.tspan, prob.p) + _fake_ode_sol = __construct_fake_ode_solution(_prob, _alg.ode_alg) + if iip + bc(bcresid_prototype, _fake_ode_sol, prob.p, _fake_ode_sol.sol.t) + else + bc(_fake_ode_sol, prob.p, _fake_ode_sol.sol.t) + end + __finalize_nodes!(_fake_ode_sol) + + __alg = concretize_jacobian_algorithm(_alg, prob) + alg = if has_initial_guess && Nig != __alg.nshoots + verbose && + @warn "Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(Nig)`" + update_nshoots(__alg, Nig) + else + __alg + end + nshoots = alg.nshoots + M = length(bcresid_prototype) + + internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true) + + function solve_internal_odes!(resid_nodes::T1, us::T2, p::T3, cur_nshoot::Int, + nodes::T4, odecache::C) where {T1, T2, T3, T4, C} + return __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot, + odecache, nodes, u0_size, N, ensemblealg) + end + + ode_cache_loss_fn = __multiple_shooting_init_odecache(ensemblealg, prob, + alg.ode_alg, u0, nshoots; internal_ode_kwargs...) + + nodes = typeof(first(tspan))[] + u_at_nodes = __multiple_shooting_initialize!(nodes, prob, alg, ig, nshoots, + ode_cache_loss_fn; kwargs..., verbose, odesolve_kwargs..., + static_nodes = _fake_ode_sol.nodes) + + __solve_nlproblem!(prob.problem_type, alg, bcresid_prototype, u_at_nodes, nodes, + nshoots, M, N, prod(resid_size), solve_internal_odes!, bc, prob, prob.f, + u0_size, u0, ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; verbose, + kwargs..., nlsolve_kwargs...) + + if prob.problem_type isa TwoPointBVProblem + diffmode_shooting = __get_non_sparse_ad(alg.jac_alg.diffmode) + else + diffmode_shooting = __get_non_sparse_ad(alg.jac_alg.bc_diffmode) + end + shooting_alg = Shooting(alg.ode_alg, alg.nlsolve, + BVPJacobianAlgorithm(diffmode_shooting)) + + single_shooting_prob = remake(prob; u0 = reshape(@view(u_at_nodes[1:N]), u0_size)) + return __solve(single_shooting_prob, shooting_alg; odesolve_kwargs, nlsolve_kwargs, + verbose, kwargs...) end function __solve(prob::BVProblem, _alg::MultipleShooting{false}; kwargs...) @@ -145,10 +212,71 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_ return nothing end -function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_prototype, - u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int, resid_len::Int, - solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0, ode_cache_loss_fn, - ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S} +function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting{true}, + bcresid_prototype, u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int, + resid_len::Int, solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0, + ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S} + if __any_sparse_ad(alg.jac_alg) + J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type, + bcresid_prototype, u0, N, cur_nshoot) + end + resid_prototype = vcat(bcresid_prototype, similar(u_at_nodes, cur_nshoot * N)) + + __resid_nodes = resid_prototype[(end - cur_nshoot * N + 1):end] + resid_nodes = __maybe_allocate_diffcache(__resid_nodes, + pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode) + + loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot, + nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob.tspan, + alg.ode_alg, u0, ode_cache_loss_fn) + + # ODE Part + sd_ode = alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ? + __sparsity_detection_alg(J_proto) : NoSparsityDetection() + ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode, + nothing, similar(u_at_nodes, cur_nshoot * N), u_at_nodes) + ode_cache_ode_jac_fn = __multiple_shooting_init_jacobian_odecache(ensemblealg, prob, + ode_jac_cache, alg.jac_alg.nonbc_diffmode, alg.ode_alg, cur_nshoot, u0; + internal_ode_kwargs...) + + # BC Part + sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ? + SymbolicsSparsityDetection() : NoSparsityDetection() + bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode, + sd_bc, nothing, similar(bcresid_prototype), u_at_nodes) + ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache(ensemblealg, prob, + bc_jac_cache, alg.jac_alg.bc_diffmode, alg.ode_alg, cur_nshoot, u0; + internal_ode_kwargs...) + + jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache)) + + # Define the functions now + ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes, + ode_cache_ode_jac_fn) + bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc_static_node!(du, u, prob.p, + cur_nshoot, nodes, + prob, solve_internal_odes!, N, f, bc, u0_size, prob.tspan, alg.ode_alg, u0, + ode_cache_bc_jac_fn) + + jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p, + similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache, + ode_fn, bc_fn, alg, N, M) + + loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn, + jac_prototype) + + # NOTE: u_at_nodes is updated inplace + nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!, + u_at_nodes, prob.p) + __solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true) + + return nothing +end + +function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting{false}, + bcresid_prototype, u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int, + resid_len::Int, solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0, + ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S} if __any_sparse_ad(alg.jac_alg) J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type, bcresid_prototype, u0, N, cur_nshoot) @@ -348,6 +476,29 @@ end return nothing end +@views function __multiple_shooting_mpoint_loss_bc_static_node!(resid_bc, us, p, + cur_nshoots::Int, nodes, prob, solve_internal_odes!::S, N, f::F, bc::BC, u0_size, + tspan, ode_alg, u0, ode_cache) where {S, F, BC} + iip = isinplace(prob) + + # NOTE: We placed the nodes at the points `bc` is evaluated so we don't need to + # recompute the solution + _ts = nodes + _us = [reshape(us[((i - 1) * prod(u0_size) + 1):(i * prod(u0_size))], u0_size) + for i in eachindex(_ts)] + + odeprob = ODEProblem{iip}(f, u0, tspan, p) + total_solution = SciMLBase.build_solution(odeprob, ode_alg, _ts, _us) + + if iip + eval_bc_residual!(resid_bc, StandardBVProblem(), bc, total_solution, p) + else + resid_bc .= eval_bc_residual(StandardBVProblem(), bc, total_solution, p) + end + + return nothing +end + @views function __multiple_shooting_mpoint_loss!(resid, us, p, cur_nshoots::Int, nodes, prob, solve_internal_odes!::S, resid_len, N, f::F, bc::BC, u0_size, tspan, ode_alg, u0, ode_cache) where {S, F, BC} @@ -390,12 +541,22 @@ end # No initial guess @views function __multiple_shooting_initialize!(nodes, prob, alg::MultipleShooting, - ::Val{false}, nshoots::Int, odecache_; verbose, kwargs...) + ::Val{false}, nshoots::Int, odecache_; verbose, static_nodes = nothing, kwargs...) @unpack f, u0, tspan, p = prob @unpack ode_alg = alg resize!(nodes, nshoots + 1) nodes .= range(tspan[1], tspan[2]; length = nshoots + 1) + + if static_nodes !== nothing + idx = 1 + for snode in static_nodes + sidx = searchsortedfirst(nodes[idx:end], snode) + nodes[idx + sidx - 1] = snode + idx = sidx + 1 + end + end + N = length(u0) # Ensures type stability in case the parameters are dual numbers diff --git a/src/utils.jl b/src/utils.jl index c380f15e..9da06b62 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -249,3 +249,45 @@ function __restructure_sol(sol::Vector{<:AbstractArray}, u_size) end # TODO: Add dispatch for a ODESolution Type as well + +# Fake ODE Solution to capture calls to the solution object +@concrete struct __FakeODESolution2 + sol + nodes +end + +__FakeODESolutionXXX = __FakeODESolution2 + +function __construct_fake_ode_solution(prob::ODEProblem, alg) + nodes = Vector{promote_type(typeof(prob.tspan[1]), typeof(prob.tspan[2]))}() + return __FakeODESolutionXXX(SciMLBase.build_solution(prob, alg, + [prob.tspan[1], prob.tspan[2]], [prob.u0, prob.u0]), nodes) +end + +function __finalize_nodes!(sol::__FakeODESolutionXXX) + sort!(sol.nodes) + unique!(sol.nodes) + return sol +end + +function (s::__FakeODESolutionXXX)(t::T, args...; kwargs...) where {T <: Number} + push!(s.nodes, t) + return s.sol(t, args...; kwargs...) +end + +function (s::__FakeODESolutionXXX)(t::T, args...; kwargs...) where {T <: AbstractVector} + append!(s.nodes, t) + return s.sol(t, args...; kwargs...) +end + +function Base.getindex(::__FakeODESolutionXXX, args...) + throw(ArgumentError("`static_auto_nodes = Val(true)` doesn't support indexing into \ + the solution object. Please rewrite your code to call the \ + solution object with the time points you want to evaluate at \ + or use `static_auto_nodes = Val(false)`")) +end + +function Base.show(io::IO, sol::__FakeODESolutionXXX) + print(io, "ODESolution evaluated @ nodes: $(sol.nodes)") + return +end From 59827dc9d4b2957373d0d83e74ea1a6bed4aa6d7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 5 Nov 2023 12:45:55 -0500 Subject: [PATCH 6/6] Update README.md Co-authored-by: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> --- .gitignore | 3 ++- README.md | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 559b95aa..4d4f66e4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ wip /.benchmarkci /benchmark/*.json *.json -*.json.tmp \ No newline at end of file +*.json.tmp +*.pdf diff --git a/README.md b/README.md index be5693b3..b240e092 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,6 @@ Precompilation can be controlled via `Preferences.jl` - `PrecompileMIRK` -- Precompile the MIRK2 - MIRK6 algorithms (default: `true`). - `PrecompileShooting` -- Precompile the single shooting algorithms (default: `true`). This is triggered when `OrdinaryDiffEq` is loaded. - `PrecompileMultipleShooting` -- Precompile the multiple shooting algorithms (default: `true`). This is triggered when `OrdinaryDiffEq` is loaded. - <<<<<<< HEAD - `PrecompileMIRKNLLS` -- Precompile the MIRK2 - MIRK6 algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). - `PrecompileShootingNLLS` -- Precompile the single shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded. - `PrecompileMultipleShootingNLLS` -- Precompile the multiple shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded.