Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for u0 initial guess function #133

Merged
merged 5 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

Expand All @@ -40,7 +41,7 @@ Aqua = "0.7"
ArrayInterface = "7"
BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.135"
DiffEqBase = "6.138"
ForwardDiff = "0.10"
LinearAlgebra = "1.9"
LinearSolve = "2"
Expand All @@ -56,6 +57,7 @@ SciMLBase = "2.5"
Setfield = "1"
SparseArrays = "1.9"
SparseDiffTools = "2.9"
Tricks = "0.1"
TruncatedStacktraces = "1"
UnPack = "1"
julia = "1.9"
Expand Down
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
@recompile_invalidations begin
using ADTypes, Adapt, BandedMatrices, DiffEqBase, ForwardDiff, LinearAlgebra,
NonlinearSolve, PreallocationTools, Preferences, RecursiveArrayTools, Reexport,
SciMLBase, Setfield, SparseArrays, SparseDiffTools
SciMLBase, Setfield, SparseArrays, SparseDiffTools, Tricks

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors,
Expand Down
13 changes: 8 additions & 5 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,19 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
abstol = 1e-3, adaptive = true, kwargs...)
@set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
iip = isinplace(prob)

_, T, M, n, X = __extract_problem_details(prob; dt, check_positive_dt = true)
# NOTE: Assumes the user provided initial guess is on a uniform mesh
mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
mesh_dt = diff(mesh)

chunksize = pickchunksize(M * (n + 1))

__alloc = x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg)

fᵢ_cache = __alloc(similar(X))
fᵢ₂_cache = vec(similar(X))

# NOTE: Assumes the user provided initial guess is on a uniform mesh
mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
mesh_dt = diff(mesh)

defect_threshold = T(0.1) # TODO: Allow user to specify these
MxNsub = 3000 # TODO: Allow user to specify these

Expand Down Expand Up @@ -100,7 +101,9 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
vecf, vecbc
end

return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob,
prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob

return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob_,
prob.problem_type, prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages,
resid₁_size, (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...))
Expand Down
25 changes: 15 additions & 10 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
ode_cache_loss_fn; kwargs..., verbose, odesolve_kwargs...)
else
u_at_nodes = __multiple_shooting_initialize!(nodes, u_at_nodes, prob, alg,
cur_nshoot, all_nshoots[i - 1], ig, ode_cache_loss_fn; kwargs..., verbose,
odesolve_kwargs...)
cur_nshoot, all_nshoots[i - 1], ig, ode_cache_loss_fn, u0; kwargs...,
verbose, odesolve_kwargs...)
end

if prob.problem_type isa TwoPointBVProblem
Expand Down Expand Up @@ -362,9 +362,13 @@ end
resize!(nodes, nshoots + 1)
nodes .= range(tspan[1], tspan[2]; length = nshoots + 1)

N = length(first(u0))
u_at_nodes = similar(first(u0), (nshoots + 1) * N)
recursive_flatten!(u_at_nodes, u0)
# NOTE: We don't check `u0 isa Function` since `u0` in-principle can be a callable
# struct
u0_ = u0 isa AbstractArray ? u0 : [__initial_guess(u0, prob.p, t) for t in nodes]

N = length(first(u0_))
u_at_nodes = similar(first(u0_), (nshoots + 1) * N)
recursive_flatten!(u_at_nodes, u0_)

return u_at_nodes
end
Expand Down Expand Up @@ -401,7 +405,8 @@ end
end
else
@warn "Initialization using odesolve failed. Initializing using 0s. It is \
recommended to provide an `initial_guess` in this case."
recommended to provide an initial guess function via \
`u0 = <function>(p, t)` or `u0 = <function>(t)` in this case."
fill!(u_at_nodes, 0)
end

Expand All @@ -410,16 +415,16 @@ end

# Grid coarsening
@views function __multiple_shooting_initialize!(nodes, u_at_nodes_prev, prob, alg,
nshoots, old_nshoots, ig, odecache_; kwargs...)
@unpack f, u0, tspan, p = prob
nshoots, old_nshoots, ig, odecache_, u0; kwargs...)
@unpack f, tspan, p = prob
prev_nodes = copy(nodes)
odecache = odecache_ isa Vector ? first(odecache_) : odecache_

resize!(nodes, nshoots + 1)
nodes .= range(tspan[1], tspan[2]; length = nshoots + 1)
N = _unwrap_val(ig) ? length(first(u0)) : length(u0)
N = length(u0)

u_at_nodes = similar(_unwrap_val(ig) ? first(u0) : u0, N + nshoots * N)
u_at_nodes = similar(u0, N + nshoots * N)
u_at_nodes[1:N] .= u_at_nodes_prev[1:N]
u_at_nodes[(end - N + 1):end] .= u_at_nodes_prev[(end - N + 1):end]

Expand Down
37 changes: 30 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,40 @@
t₀, t₁ = prob.tspan
return Val(false), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), prob.u0
end
function __extract_problem_details(prob, ::F; kwargs...) where {F <: Function}
throw(ArgumentError("passing `u0` as a function is not supported yet. Curently we only \
support AbstractArray or Vector of AbstractArrays as input! \
Use the latter format for passing in initial guess!"))
function __extract_problem_details(prob, f::F; dt = 0.0,
check_positive_dt::Bool = false) where {F <: Function}
# Problem passes in a initial guess function
check_positive_dt && dt ≤ 0 && throw(ArgumentError("dt must be positive"))
u0 = __initial_guess(f, prob.p, prob.tspan[1])
t₀, t₁ = prob.tspan
return Val(true), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), u0
end

__initial_state_from_prob(prob::BVProblem, mesh) = __initial_state_from_prob(prob.u0, mesh)
__initial_state_from_prob(u0::AbstractArray, mesh) = [copy(vec(u0)) for _ in mesh]
function __initial_state_from_prob(u0::AbstractVector{<:AbstractVector}, _)
function __initial_guess(f::F, p::P, t::T) where {F, P, T}
if static_hasmethod(f, Tuple{P, T})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that in v1.10 this isn't needed, and it can be incorrect.

return f(p, t)
elseif static_hasmethod(f, Tuple{T})
Base.depwarn("initial guess function must take 2 inputs `(p, t)` instead of just \
`t`. The single argument version has been deprecated and will be \
removed in the next major release of SciMLBase.", :__initial_guess)
return f(t)
else
throw(ArgumentError("`initial_guess` must be a function of the form `f(p, t)`"))

Check warning on line 160 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L160

Added line #L160 was not covered by tests
end
end

function __initial_state_from_prob(prob::BVProblem, mesh)
return __initial_state_from_prob(prob, prob.u0, mesh)
end
function __initial_state_from_prob(::BVProblem, u0::AbstractArray, mesh)
return [copy(vec(u0)) for _ in mesh]
end
function __initial_state_from_prob(::BVProblem, u0::AbstractVector{<:AbstractVector}, _)
return [copy(vec(u)) for u in u0]
end
function __initial_state_from_prob(prob::BVProblem, f::F, mesh) where {F}
return [__initial_guess(f, prob.p, t) for t in mesh]
end

function __get_bcresid_prototype(prob::BVProblem, u)
return __get_bcresid_prototype(prob.problem_type, prob, u)
Expand Down
82 changes: 82 additions & 0 deletions test/misc/initial_guess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
using BoundaryValueDiffEq, OrdinaryDiffEq, Test, LinearAlgebra

@testset "Initial Guess" begin
# Problem taken from https://github.com/SciML/BoundaryValueDiffEq.jl/issues/117#issuecomment-1780981510
function affine_connection(a, Xc, Yc)
MR = 3.0
Mr = 2.0
Zc = similar(Xc)
θ = a[1]
sinθ, cosθ = sincos(θ)
Γ¹₂₂ = (MR + Mr * cosθ) * sinθ / Mr
Γ²₁₂ = -Mr * sinθ / (MR + Mr * cosθ)

Zc[1] = Xc[2] * Γ¹₂₂ * Yc[2]
Zc[2] = Γ²₁₂ * (Xc[1] * Yc[2] + Xc[2] * Yc[1])
return Zc
end

function chart_log_problem!(du, u, p, t)
mid = div(length(u), 2)
a = u[1:mid]
dx = u[(mid + 1):end]
ddx = -affine_connection(a, dx, dx)
du[1:mid] .= dx
du[(mid + 1):end] .= ddx
return du
end

function bc1!(residual, u, p, t)
a1, a2 = p[1:2], p[3:4]
mid = div(length(u[1]), 2)
residual[1:mid] = u[1][1:mid] - a1
residual[(mid + 1):end] = u[end][1:mid] - a2
return residual
end

function initial_guess_1(p, t)
a1, a2 = p[1:2], p[3:4]
return vcat(t * a1 + (1 - t) * a2, zero(a1))
end

function initial_guess_2(t)
a1, a2 = [0.5, -1.2], [-0.5, 0.3]
return vcat(t * a1 + (1 - t) * a2, zero(a1))
end

dt = 0.05
p = [0.5, -1.2, -0.5, 0.3]
tspan = (0.0, 1.0)

bvp1 = BVProblem(chart_log_problem!, bc1!, initial_guess_1, tspan, p)

algs = [Shooting(Tsit5()), MultipleShooting(10, Tsit5()), MIRK4(), MIRK5(), MIRK6()]

for alg in algs
if alg isa Shooting || alg isa MultipleShooting
sol = solve(bvp1, alg)
else
sol = solve(bvp1, alg; dt)
end
@test SciMLBase.successful_retcode(sol)
resid = zeros(4)
bc1!(resid, sol, p, sol.t)
@test norm(resid) < 1e-10
end

bvp2 = BVProblem(chart_log_problem!, bc1!, initial_guess_2, tspan, p)

for alg in algs
if alg isa Shooting || alg isa MultipleShooting
sol = solve(bvp2, alg)
@test_deprecated solve(bvp2, alg)
else
sol = solve(bvp2, alg; dt)
@test_deprecated solve(bvp2, alg; dt)
end
@test SciMLBase.successful_retcode(sol)
resid = zeros(4)
bc1!(resid, sol, p, sol.t)
@test norm(resid) < 1e-10
end
end
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ const GROUP = uppercase(get(ENV, "GROUP", "ALL"))
@time @safetestset "Non Vector Inputs" begin
include("misc/non_vector_inputs.jl")
end

@time @safetestset "Type Stability" begin
include("misc/type_stability.jl")
end

@time @safetestset "ODE Interface Tests" begin
include("misc/odeinterface_ex7.jl")
end

@time @safetestset "Initial Guess Function" begin
include("misc/initial_guess.jl")
end
@time @safetestset "Aqua: Quality Assurance" begin
include("misc/aqua.jl")
end
Expand Down
Loading