Skip to content

Commit

Permalink
MIRK supports NLS
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 13, 2023
1 parent 239c471 commit 3c1f8f9
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 9 deletions.
18 changes: 12 additions & 6 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
N = length(cache.mesh)

resid_bc = cache.bcresid_prototype
L = length(resid_bc)
resid_collocation = similar(y, cache.M * (N - 1))

sd_bc = jac_alg.bc_diffmode isa AbstractSparseADType ? SymbolicsSparsityDetection() :
Expand All @@ -279,24 +280,26 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati

jac = if iip
function jac_internal!(J, x, p)
sparse_jacobian!(@view(J[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc,
sparse_jacobian!(@view(J[1:L, :]), jac_alg.bc_diffmode, cache_bc,
loss_bc, resid_bc, x)
sparse_jacobian!(@view(J[(cache.M + 1):end, :]), jac_alg.nonbc_diffmode,
sparse_jacobian!(@view(J[(L + 1):end, :]), jac_alg.nonbc_diffmode,
cache_collocation, loss_collocation, resid_collocation, x)
return J
end
else
J_ = jac_prototype
function jac_internal(x, p)
sparse_jacobian!(@view(J_[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc,
sparse_jacobian!(@view(J_[1:L, :]), jac_alg.bc_diffmode, cache_bc,
loss_bc, x)
sparse_jacobian!(@view(J_[(cache.M + 1):end, :]), jac_alg.nonbc_diffmode,
sparse_jacobian!(@view(J_[(L + 1):end, :]), jac_alg.nonbc_diffmode,
cache_collocation, loss_collocation, x)
return J_
end
end

return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p)
nlf = NonlinearFunction{iip}(loss; resid_prototype = vcat(resid_bc, resid_collocation),
jac, jac_prototype)
return (L == cache.M ? NonlinearProblem : NonlinearLeastSquaresProblem)(nlf, y, cache.p)
end

function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, loss,
Expand All @@ -305,6 +308,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
N = length(cache.mesh)

resid = ArrayPartition(cache.bcresid_prototype, similar(y, cache.M * (N - 1)))
L = length(cache.bcresid_prototype)

# TODO: We can splitup the computation here as well similar to the Multiple Shooting
# TODO: code. That way for the BC part the actual jacobian computation is even cheaper
Expand All @@ -331,5 +335,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
end
end

return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p)
nlf = NonlinearFunction{iip}(loss; resid_prototype = copy(resid), jac, jac_prototype)

return (L == cache.M ? NonlinearProblem : NonlinearLeastSquaresProblem)(nlf, y, cache.p)
end
2 changes: 1 addition & 1 deletion src/sparse_jacobians.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function __generate_sparse_jacobian_prototype(::MIRKCache, ::TwoPointBVProblem,
idx += 1
end

J = _sparse_like(Is, Js, y, M * N, M * N)
J = _sparse_like(Is, Js, y, M * (N - 1) + length(resida) + length(residb), M * N)

col_colorvec = Vector{Int}(undef, size(J, 2))
for i in eachindex(col_colorvec)
Expand Down
6 changes: 4 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ end

__append_similar!(::Nothing, n, _) = nothing

# NOTE: We use `last` since the `first` might not conform to the same structure. For eg,
# in the case of residuals
function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _)
N = n - length(x)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
append!(x, [similar(first(x)) for _ in 1:N])
append!(x, [similar(last(x)) for _ in 1:N])
return x
end

Expand All @@ -106,7 +108,7 @@ function __append_similar!(x::AbstractVector{<:MaybeDiffCache}, n, M)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
chunksize = pickchunksize(M * (N + length(x)))
append!(x, [__maybe_allocate_diffcache(first(x), chunksize) for _ in 1:N])
append!(x, [__maybe_allocate_diffcache(last(x), chunksize) for _ in 1:N])
return x
end

Expand Down
84 changes: 84 additions & 0 deletions test/mirk/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using BoundaryValueDiffEq, LinearAlgebra, Test

@testset "Overconstrained BVP" begin
SOLVERS = [mirk(;
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
α_geodesic = 0.9, b_uphill = 2.0)) for mirk in (MIRK4, MIRK5, MIRK6)]

# OOP MP-BVP
f1(u, p, t) = [u[2], -u[1]]

function bc1(sol, p, t)
solₜ₁ = sol[1]
solₜ₂ = sol[end]
return [solₜ₁[1], solₜ₂[1] - 1, solₜ₂[2] + 1.729109]
end

tspan = (0.0, 100.0)
u0 = [0.0, 1.0]

bvp1 = BVProblem(BVPFunction{false}(f1, bc1; bcresid_prototype = zeros(3)), u0, tspan)

for solver in SOLVERS
@time sol = solve(bvp1, solver; verbose = false, dt = 1.0, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-2, reltol = 1e-2))
@test norm(bc1(sol, nothing, tspan)) < 1e-2
end

# IIP MP-BVP
function f1!(du, u, p, t)
du[1] = u[2]
du[2] = -u[1]
return nothing
end

function bc1!(resid, sol, p, t)
solₜ₁ = sol[1]
solₜ₂ = sol[end]
# We know that this overconstrained system has a solution
resid[1] = solₜ₁[1]
resid[2] = solₜ₂[1] - 1
resid[3] = solₜ₂[2] + 1.729109
return nothing
end

bvp2 = BVProblem(BVPFunction{true}(f1!, bc1!; bcresid_prototype = zeros(3)), u0, tspan)

for solver in SOLVERS
@time sol = solve(bvp2, solver; verbose = false, dt = 1.0, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3))
resid_f = Array{Float64}(undef, 3)
bc1!(resid_f, sol, nothing, sol.t)
@test norm(resid_f) < 1e-2
end

# OOP TP-BVP
bc1a(ua, p) = [ua[1]]
bc1b(ub, p) = [ub[1] - 1, ub[2] + 1.729109]

bvp3 = TwoPointBVProblem(BVPFunction{false}(f1, (bc1a, bc1b); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)

for solver in SOLVERS
@time sol = solve(bvp3, solver; verbose = false, dt = 1.0, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3))
@test norm(vcat(bc1a(sol[1], nothing), bc1b(sol[end], nothing))) < 1e-2
end

# IIP TP-BVP
bc1a!(resid, ua, p) = (resid[1] = ua[1])
bc1b!(resid, ub, p) = (resid[1] = ub[1] - 1; resid[2] = ub[2] + 1.729109)

bvp4 = TwoPointBVProblem(BVPFunction{true}(f1!, (bc1a!, bc1b!); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)

for solver in SOLVERS
@time sol = solve(bvp3, solver; verbose = false, dt = 1.0, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3))
resida = Array{Float64}(undef, 1)
residb = Array{Float64}(undef, 2)
bc1a!(resida, sol(0.0), nothing)
bc1b!(residb, sol(100.0), nothing)
@test norm(vcat(resida, residb)) < 1e-2
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const GROUP = uppercase(get(ENV, "GROUP", "ALL"))
@time @safetestset "Interpolation Tests" begin
include("mirk/interpolation_test.jl")
end
@time @safetestset "MIRK Nonlinear Least Squares Tests" begin
include("mirk/nonlinear_least_squares.jl")
end
end
end

Expand Down

0 comments on commit 3c1f8f9

Please sign in to comment.