Skip to content

Commit

Permalink
MIRK supports NLS
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 19, 2023
1 parent f78b75d commit 0b76dae
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 10 deletions.
18 changes: 12 additions & 6 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
N = length(cache.mesh)

resid_bc = cache.bcresid_prototype
L = length(resid_bc)
resid_collocation = similar(y, cache.M * (N - 1))

sd_bc = jac_alg.bc_diffmode isa AbstractSparseADType ? SymbolicsSparsityDetection() :
Expand All @@ -292,24 +293,26 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati

jac = if iip
function jac_internal!(J, x, p)
sparse_jacobian!(@view(J[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc,
sparse_jacobian!(@view(J[1:L, :]), jac_alg.bc_diffmode, cache_bc,
loss_bc, resid_bc, x)
sparse_jacobian!(@view(J[(cache.M + 1):end, :]), jac_alg.nonbc_diffmode,
sparse_jacobian!(@view(J[(L + 1):end, :]), jac_alg.nonbc_diffmode,
cache_collocation, loss_collocation, resid_collocation, x)
return J
end
else
J_ = jac_prototype
function jac_internal(x, p)
sparse_jacobian!(@view(J_[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc,
sparse_jacobian!(@view(J_[1:L, :]), jac_alg.bc_diffmode, cache_bc,
loss_bc, x)
sparse_jacobian!(@view(J_[(cache.M + 1):end, :]), jac_alg.nonbc_diffmode,
sparse_jacobian!(@view(J_[(L + 1):end, :]), jac_alg.nonbc_diffmode,
cache_collocation, loss_collocation, x)
return J_
end
end

return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p)
nlf = NonlinearFunction{iip}(loss; resid_prototype = vcat(resid_bc, resid_collocation),
jac, jac_prototype)
return (L == cache.M ? NonlinearProblem : NonlinearLeastSquaresProblem)(nlf, y, cache.p)
end

function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation,
Expand All @@ -320,6 +323,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
resid = vcat(cache.bcresid_prototype[1:prod(cache.resid_size[1])],
similar(y, cache.M * (N - 1)),
cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])
L = length(cache.bcresid_prototype)

sd = if jac_alg.diffmode isa AbstractSparseADType
PrecomputedJacobianColorvec(__generate_sparse_jacobian_prototype(cache,
Expand All @@ -345,5 +349,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
end
end

return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p)
nlf = NonlinearFunction{iip}(loss; resid_prototype = copy(resid), jac, jac_prototype)

return (L == cache.M ? NonlinearProblem : NonlinearLeastSquaresProblem)(nlf, y, cache.p)
end
6 changes: 4 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ end

__append_similar!(::Nothing, n, _) = nothing

# NOTE: We use `last` since the `first` might not conform to the same structure. For eg,
# in the case of residuals
function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _)
N = n - length(x)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
append!(x, [similar(first(x)) for _ in 1:N])
append!(x, [similar(last(x)) for _ in 1:N])
return x
end

Expand All @@ -114,7 +116,7 @@ function __append_similar!(x::AbstractVector{<:MaybeDiffCache}, n, M)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
chunksize = pickchunksize(M * (N + length(x)))
append!(x, [__maybe_allocate_diffcache(first(x), chunksize) for _ in 1:N])
append!(x, [__maybe_allocate_diffcache(last(x), chunksize) for _ in 1:N])
return x
end

Expand Down
84 changes: 84 additions & 0 deletions test/mirk/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using BoundaryValueDiffEq, LinearAlgebra, Test

@testset "Overconstrained BVP" begin
SOLVERS = [mirk(;
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
α_geodesic = 0.9, b_uphill = 2.0)) for mirk in (MIRK4, MIRK5, MIRK6)]

# OOP MP-BVP
f1(u, p, t) = [u[2], -u[1]]

function bc1(sol, p, t)
solₜ₁ = sol[1]
solₜ₂ = sol[end]
return [solₜ₁[1], solₜ₂[1] - 1, solₜ₂[2] + 1.729109]
end

tspan = (0.0, 100.0)
u0 = [0.0, 1.0]

bvp1 = BVProblem(BVPFunction{false}(f1, bc1; bcresid_prototype = zeros(3)), u0, tspan)

for solver in SOLVERS
@time sol = solve(bvp1, solver; verbose = false, dt = 1.0, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-2, reltol = 1e-2))
@test norm(bc1(sol, nothing, tspan)) < 1e-2
end

# IIP MP-BVP
function f1!(du, u, p, t)
du[1] = u[2]
du[2] = -u[1]
return nothing
end

function bc1!(resid, sol, p, t)
solₜ₁ = sol[1]
solₜ₂ = sol[end]
# We know that this overconstrained system has a solution
resid[1] = solₜ₁[1]
resid[2] = solₜ₂[1] - 1
resid[3] = solₜ₂[2] + 1.729109
return nothing
end

bvp2 = BVProblem(BVPFunction{true}(f1!, bc1!; bcresid_prototype = zeros(3)), u0, tspan)

for solver in SOLVERS
@time sol = solve(bvp2, solver; verbose = false, dt = 1.0, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3))
resid_f = Array{Float64}(undef, 3)
bc1!(resid_f, sol, nothing, sol.t)
@test norm(resid_f) < 1e-2
end

# OOP TP-BVP
bc1a(ua, p) = [ua[1]]
bc1b(ub, p) = [ub[1] - 1, ub[2] + 1.729109]

bvp3 = TwoPointBVProblem(BVPFunction{false}(f1, (bc1a, bc1b); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)

for solver in SOLVERS
@time sol = solve(bvp3, solver; verbose = false, dt = 1.0, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3))
@test norm(vcat(bc1a(sol[1], nothing), bc1b(sol[end], nothing))) < 1e-2
end

# IIP TP-BVP
bc1a!(resid, ua, p) = (resid[1] = ua[1])
bc1b!(resid, ub, p) = (resid[1] = ub[1] - 1; resid[2] = ub[2] + 1.729109)

bvp4 = TwoPointBVProblem(BVPFunction{true}(f1!, (bc1a!, bc1b!); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)

for solver in SOLVERS
@time sol = solve(bvp3, solver; verbose = false, dt = 1.0, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3))
resida = Array{Float64}(undef, 1)
residb = Array{Float64}(undef, 2)
bc1a!(resida, sol(0.0), nothing)
bc1b!(residb, sol(100.0), nothing)
@test norm(vcat(resida, residb)) < 1e-2
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const GROUP = uppercase(get(ENV, "GROUP", "ALL"))
@time @safetestset "Interpolation Tests" begin
include("mirk/interpolation_test.jl")
end
@time @safetestset "MIRK Nonlinear Least Squares Tests" begin
include("mirk/nonlinear_least_squares.jl")
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions test/shooting/shooting_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ end
bc1a(ua, p) = [ua[1]]
bc1b(ub, p) = [ub[1] - 1, ub[2] + 1.729109]

bvp3 = TwoPointBVProblem(BVPFunction{false}(f, (bc1a, bc1b); twopoint = Val(true),
bvp3 = TwoPointBVProblem(BVPFunction{false}(f1, (bc1a, bc1b); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)

for solver in SOLVERS
Expand All @@ -164,7 +164,7 @@ end
bc1a!(resid, ua, p) = (resid[1] = ua[1])
bc1b!(resid, ub, p) = (resid[1] = ub[1] - 1; resid[2] = ub[2] + 1.729109)

bvp4 = TwoPointBVProblem(BVPFunction{true}(f!, (bc1a!, bc1b!); twopoint = Val(true),
bvp4 = TwoPointBVProblem(BVPFunction{true}(f1!, (bc1a!, bc1b!); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)

for solver in SOLVERS
Expand Down

0 comments on commit 0b76dae

Please sign in to comment.