Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boundary conditions should always use solution object #260

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ function simplependulum!(du, u, p, t)
du[2] = -9.81 * sin(θ)
end
function bc!(residual, u, p, t)
residual[1] = u[:, end ÷ 2][1] + pi / 2
residual[2] = u[:, end][1] - pi / 2
residual[1] = u(pi / 4)[1] + pi / 2
residual[2] = u(pi / 2)[1] - pi / 2
end
prob = BVProblem(simplependulum!, bc!, [pi / 2, pi / 2], tspan)
sol = solve(prob, MIRK4(), dt = 0.05)
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqCore/src/BoundaryValueDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type, fast_scalar_indexing
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: ForwardDiff, pickchunksize
import ForwardDiff: ForwardDiff, pickchunksize, Dual
import Logging
using NonlinearSolveFirstOrder: NonlinearSolvePolyAlgorithm
import LineSearch: BackTracking
Expand All @@ -28,6 +28,7 @@ include("algorithms.jl")
include("alg_utils.jl")
include("default_nlsolve.jl")
include("sparse_jacobians.jl")
include("misc_utils.jl")

function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
Expand Down
7 changes: 7 additions & 0 deletions lib/BoundaryValueDiffEqCore/src/misc_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Intermidiate solution evaluation
@concrete struct EvalSol{iip}
u
t
alg
k_discrete
end
2 changes: 0 additions & 2 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,5 +368,3 @@ end
end

@inline (f::__Fix3{F})(a, b) where {F} = f.f(a, b, f.x)

# convert every vector of vector to AbstractVectorOfArray, especially if them come from get_tmp of PreallocationTools.jl
16 changes: 8 additions & 8 deletions lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__internal_nlsolve_problem, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__has_initial_guess, __initial_guess_length, EvalSol,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix
Expand All @@ -33,7 +33,7 @@ import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scal
import ConcreteStructs: @concrete
import DiffEqBase: solve
import FastClosures: @closure
import ForwardDiff: ForwardDiff, pickchunksize
import ForwardDiff: ForwardDiff, pickchunksize, Dual
import Logging
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
Expand All @@ -58,11 +58,11 @@ include("sparse_jacobians.jl")
f1 = (u, p, t) -> [u[2], 0]

function bc1!(residual, u, p, t)
residual[1] = u[:, 1][1] - 5
residual[2] = u[:, end][1]
residual[1] = u(0.0)[1] - 5
residual[2] = u(5.0)[1]
end

bc1 = (u, p, t) -> [u[:, 1][1] - 5, u[:, end][1]]
bc1 = (u, p, t) -> [u(0.0)[1] - 5, u(5.0)[1]]

bc1_a! = (residual, ua, p) -> (residual[1] = ua[1] - 5)
bc1_b! = (residual, ub, p) -> (residual[1] = ub[1])
Expand Down Expand Up @@ -103,14 +103,14 @@ include("sparse_jacobians.jl")
f1_nlls = (u, p, t) -> [u[2], -u[1]]

bc1_nlls! = (resid, sol, p, t) -> begin
solₜ₁ = sol[:, 1]
solₜ₂ = sol[:, end]
solₜ₁ = sol(0.0)
solₜ₂ = sol(100.0)
resid[1] = solₜ₁[1]
resid[2] = solₜ₂[1] - 1
resid[3] = solₜ₂[2] + 1.729109
return nothing
end
bc1_nlls = (sol, p, t) -> [sol[:, 1][1], sol[:, end][1] - 1, sol[:, end][2] + 1.729109]
bc1_nlls = (sol, p, t) -> [sol(0.0)[1], sol(100.0)[1] - 1, sol(100.0)[2] + 1.729109]

bc1_nlls_a! = (resid, ua, p) -> (resid[1] = ua[1])
bc1_nlls_b! = (resid, ub, p) -> (resid[1] = ub[1] - 1;
Expand Down
71 changes: 61 additions & 10 deletions lib/BoundaryValueDiffEqMIRK/src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ end
@inline function interpolation(tvals, id::MIRKInterpolation, idxs, deriv::D,
p, continuity::Symbol = :left) where {D}
(; t, u, cache) = id
(; mesh, mesh_dt) = cache
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)

Expand All @@ -34,7 +35,7 @@ end

for j in idx
z = similar(cache.fᵢ₂_cache)
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
interpolant!(z, id, cache, tvals[j], mesh, mesh_dt, deriv)
vals[j] = idxs !== nothing ? z[idxs] : z
end
return DiffEqArray(vals, tvals)
Expand All @@ -43,41 +44,91 @@ end
@inline function interpolation!(vals, tvals, id::MIRKInterpolation, idxs,
deriv::D, p, continuity::Symbol = :left) where {D}
(; t, cache) = id
(; mesh, mesh_dt) = cache
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)

for j in idx
z = similar(cache.fᵢ₂_cache)
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
z = similar(id.u[1])
interpolant!(z, id, cache, tvals[j], mesh, mesh_dt, deriv)
vals[j] = z
end
end

@inline function interpolation(tval::Number, id::MIRKInterpolation, idxs,
deriv::D, p, continuity::Symbol = :left) where {D}
z = similar(id.cache.fᵢ₂_cache)
interpolant!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv)
z = similar(id.u[1])
interpolant!(z, id, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv)
return idxs !== nothing ? z[idxs] : z
end

@inline function interpolant!(
z::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{0}})
z::AbstractArray, id, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{0}})
i = interval(mesh, t)
dt = mesh_dt[i]
τ = (t - mesh[i]) / dt
w, w′ = interp_weights(τ, cache.alg)
sum_stages!(z, cache, w, i)
sum_stages!(z, id, cache, w, i)
end

@inline function interpolant!(
dz::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{1}})
@inline function interpolant!(dz::AbstractArray, id::MIRKInterpolation,
cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{1}})
i = interval(mesh, t)
dt = mesh_dt[i]
τ = (t - mesh[i]) / dt
w, w′ = interp_weights(τ, cache.alg)
z = similar(dz)
sum_stages!(z, dz, cache, w, w′, i)
sum_stages!(z, dz, id, cache, w, w′, i)
end

function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
cache::MIRKCache, w, i::Int, dt = cache.mesh_dt[i])
(; stage, k_discrete, k_interp) = cache
(; s_star) = cache.ITU
z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
__maybe_matmul!(
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true)
z .= z .* dt .+ id.u[i]

return z
end

@views function sum_stages!(z, z′, id::MIRKInterpolation, cache::MIRKCache,
w, w′, i::Int, dt = cache.mesh_dt[i])
(; stage, k_discrete, k_interp) = cache
(; s_star) = cache.ITU

z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
__maybe_matmul!(
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true)
z′ .= zero(z′)
__maybe_matmul!(z′, k_discrete[i].du[:, 1:stage], w′[1:stage])
__maybe_matmul!(
z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true)
z .= z .* dt[1] .+ id.u[i]

return z, z′
end

@inline __build_interpolation(cache::MIRKCache, u::AbstractVector) = MIRKInterpolation(
cache.mesh, u, cache)

# Intermidiate solution for evaluating boundry conditions
# basically simplified version of the interpolation for MIRK
function (s::EvalSol)(tval::Number)
(; t, u, alg, k_discrete) = s
stage = alg_stage(alg)
# Quick handle for the case where tval is at the boundary
(tval == t[1]) && return first(u)
(tval == t[end]) && return last(u)
z = zero(last(u))
ii = interval(t, tval)
dt = t[ii + 1] - t[ii]
τ = (tval - t[ii]) / dt
w, _ = interp_weights(τ, alg)
__maybe_matmul!(z, k_discrete[ii].du[:, 1:stage], w[1:stage])
z .= z .* dt .+ u[ii]
return z
end
45 changes: 26 additions & 19 deletions lib/BoundaryValueDiffEqMIRK/src/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ end

function __perform_mirk_iteration(
cache::MIRKCache, abstol, adaptive::Bool; nlsolve_kwargs = (;), kwargs...)
nlprob = __construct_nlproblem(cache, vec(cache.y₀))
nlprob = __construct_nlproblem(cache, vec(cache.y₀), copy(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(
nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
Expand Down Expand Up @@ -206,9 +206,12 @@ function __perform_mirk_iteration(
end

# Constructing the Nonlinear Problem
function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {iip}
function __construct_nlproblem(
cache::MIRKCache{iip}, y::AbstractVector, y₀::AbstractVectorOfArray) where {iip}
pt = cache.problem_type

eval_sol = EvalSol{iip}(y₀.u, cache.mesh, cache.alg, cache.k_discrete)

loss_bc = if iip
@closure (du, u, p) -> __mirk_loss_bc!(
du, u, p, pt, cache.bc, cache.y, cache.mesh, cache)
Expand All @@ -226,66 +229,70 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {

loss = if iip
@closure (du, u, p) -> __mirk_loss!(
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache)
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, eval_sol)
else
@closure (u, p) -> __mirk_loss(u, p, cache.y, pt, cache.bc, cache.mesh, cache)
@closure (u, p) -> __mirk_loss(
u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol)
end

return __construct_nlproblem(cache, y, loss_bc, loss_collocation, loss, pt)
end

@views function __mirk_loss!(
resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, cache) where {BC}
@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC,
residual, mesh, cache, EvalSol) where {BC}
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
soly_ = VectorOfArray(y_)
eval_bc_residual!(resids[1], pt, bc!, soly_, p, mesh)
Φ!(resids[2:end], cache, y_, u, p)
EvalSol.u[1:end] .= y_
EvalSol.k_discrete[1:end] .= cache.k_discrete
eval_bc_residual!(resids[1], pt, bc!, EvalSol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
end

@views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
residual, mesh, cache) where {BC1, BC2}
residual, mesh, cache, _) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, p)
soly_ = VectorOfArray(y_)
resida = resids[1][1:prod(cache.resid_size[1])]
residb = resids[1][(prod(cache.resid_size[1]) + 1):end]
eval_bc_residual!((resida, residb), pt, bc!, soly_, p, mesh)
Φ!(resids[2:end], cache, y_, u, p)
recursive_flatten_twopoint!(resid, resids, cache.resid_size)
return nothing
end

@views function __mirk_loss(u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache) where {BC}
@views function __mirk_loss(
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, EvalSol) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resid_bc = eval_bc_residual(pt, bc, soly_, p, mesh)
resid_co = Φ(cache, y_, u, p)
EvalSol.u[1:end] .= y_
EvalSol.k_discrete[1:end] .= cache.k_discrete
resid_bc = eval_bc_residual(pt, bc, EvalSol, p, mesh)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
end

@views function __mirk_loss(
u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2}, mesh, cache) where {BC1, BC2}
@views function __mirk_loss(u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2},
mesh, cache, EvalSol) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
resid_co = Φ(cache, y_, u, p)
soly_ = VectorOfArray(y_)
resid_bca, resid_bcb = eval_bc_residual(pt, bc, soly_, p, mesh)
resid_co = Φ(cache, y_, u, p)
return vcat(resid_bca, mapreduce(vec, vcat, resid_co), resid_bcb)
end

@views function __mirk_loss_bc!(
resid, u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
soly_ = EvalSol{true}(y_, mesh, cache.alg, cache.k_discrete)
eval_bc_residual!(resid, pt, bc!, soly_, p, mesh)
return nothing
end

@views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
soly_ = EvalSol{false}(y_, mesh, cache.alg, cache.k_discrete)
return eval_bc_residual(pt, bc!, soly_, p, mesh)
end

Expand Down
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqMIRK/test/ensemble_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
end

function bc!(residual, u, p, t)
residual[1] = u[:, 1][1] - 1.0
residual[2] = u[:, end][1]
residual[1] = u(0.0)[1] - 1.0
residual[2] = u(1.0)[1]
end

prob_func(prob, i, repeat) = remake(prob, p = [rand()])
Expand Down
Loading
Loading