From 0b76dae5b96c061aa1e5c2239ef79b1f9c3283d5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 Oct 2023 16:46:13 -0400 Subject: [PATCH] MIRK supports NLS --- src/solve/mirk.jl | 18 ++++-- src/utils.jl | 6 +- test/mirk/nonlinear_least_squares.jl | 84 ++++++++++++++++++++++++++++ test/runtests.jl | 3 + test/shooting/shooting_tests.jl | 4 +- 5 files changed, 105 insertions(+), 10 deletions(-) create mode 100644 test/mirk/nonlinear_least_squares.jl diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 5cf2aa98..28072ccb 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -272,6 +272,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati N = length(cache.mesh) resid_bc = cache.bcresid_prototype + L = length(resid_bc) resid_collocation = similar(y, cache.M * (N - 1)) sd_bc = jac_alg.bc_diffmode isa AbstractSparseADType ? SymbolicsSparsityDetection() : @@ -292,24 +293,26 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati jac = if iip function jac_internal!(J, x, p) - sparse_jacobian!(@view(J[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc, + sparse_jacobian!(@view(J[1:L, :]), jac_alg.bc_diffmode, cache_bc, loss_bc, resid_bc, x) - sparse_jacobian!(@view(J[(cache.M + 1):end, :]), jac_alg.nonbc_diffmode, + sparse_jacobian!(@view(J[(L + 1):end, :]), jac_alg.nonbc_diffmode, cache_collocation, loss_collocation, resid_collocation, x) return J end else J_ = jac_prototype function jac_internal(x, p) - sparse_jacobian!(@view(J_[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc, + sparse_jacobian!(@view(J_[1:L, :]), jac_alg.bc_diffmode, cache_bc, loss_bc, x) - sparse_jacobian!(@view(J_[(cache.M + 1):end, :]), jac_alg.nonbc_diffmode, + sparse_jacobian!(@view(J_[(L + 1):end, :]), jac_alg.nonbc_diffmode, cache_collocation, loss_collocation, x) return J_ end end - return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p) + nlf = NonlinearFunction{iip}(loss; resid_prototype = vcat(resid_bc, resid_collocation), + jac, jac_prototype) + return (L == cache.M ? NonlinearProblem : NonlinearLeastSquaresProblem)(nlf, y, cache.p) end function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, @@ -320,6 +323,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati resid = vcat(cache.bcresid_prototype[1:prod(cache.resid_size[1])], similar(y, cache.M * (N - 1)), cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]) + L = length(cache.bcresid_prototype) sd = if jac_alg.diffmode isa AbstractSparseADType PrecomputedJacobianColorvec(__generate_sparse_jacobian_prototype(cache, @@ -345,5 +349,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati end end - return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p) + nlf = NonlinearFunction{iip}(loss; resid_prototype = copy(resid), jac, jac_prototype) + + return (L == cache.M ? NonlinearProblem : NonlinearLeastSquaresProblem)(nlf, y, cache.p) end diff --git a/src/utils.jl b/src/utils.jl index bd562daa..753d52db 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -101,11 +101,13 @@ end __append_similar!(::Nothing, n, _) = nothing +# NOTE: We use `last` since the `first` might not conform to the same structure. For eg, +# in the case of residuals function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _) N = n - length(x) N == 0 && return x N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) - append!(x, [similar(first(x)) for _ in 1:N]) + append!(x, [similar(last(x)) for _ in 1:N]) return x end @@ -114,7 +116,7 @@ function __append_similar!(x::AbstractVector{<:MaybeDiffCache}, n, M) N == 0 && return x N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) chunksize = pickchunksize(M * (N + length(x))) - append!(x, [__maybe_allocate_diffcache(first(x), chunksize) for _ in 1:N]) + append!(x, [__maybe_allocate_diffcache(last(x), chunksize) for _ in 1:N]) return x end diff --git a/test/mirk/nonlinear_least_squares.jl b/test/mirk/nonlinear_least_squares.jl new file mode 100644 index 00000000..cd1aef6c --- /dev/null +++ b/test/mirk/nonlinear_least_squares.jl @@ -0,0 +1,84 @@ +using BoundaryValueDiffEq, LinearAlgebra, Test + +@testset "Overconstrained BVP" begin + SOLVERS = [mirk(; + nlsolve = LevenbergMarquardt(; damping_initial = 1e-6, + α_geodesic = 0.9, b_uphill = 2.0)) for mirk in (MIRK4, MIRK5, MIRK6)] + + # OOP MP-BVP + f1(u, p, t) = [u[2], -u[1]] + + function bc1(sol, p, t) + solₜ₁ = sol[1] + solₜ₂ = sol[end] + return [solₜ₁[1], solₜ₂[1] - 1, solₜ₂[2] + 1.729109] + end + + tspan = (0.0, 100.0) + u0 = [0.0, 1.0] + + bvp1 = BVProblem(BVPFunction{false}(f1, bc1; bcresid_prototype = zeros(3)), u0, tspan) + + for solver in SOLVERS + @time sol = solve(bvp1, solver; verbose = false, dt = 1.0, abstol = 1e-3, + reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-2, reltol = 1e-2)) + @test norm(bc1(sol, nothing, tspan)) < 1e-2 + end + + # IIP MP-BVP + function f1!(du, u, p, t) + du[1] = u[2] + du[2] = -u[1] + return nothing + end + + function bc1!(resid, sol, p, t) + solₜ₁ = sol[1] + solₜ₂ = sol[end] + # We know that this overconstrained system has a solution + resid[1] = solₜ₁[1] + resid[2] = solₜ₂[1] - 1 + resid[3] = solₜ₂[2] + 1.729109 + return nothing + end + + bvp2 = BVProblem(BVPFunction{true}(f1!, bc1!; bcresid_prototype = zeros(3)), u0, tspan) + + for solver in SOLVERS + @time sol = solve(bvp2, solver; verbose = false, dt = 1.0, abstol = 1e-3, + reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3)) + resid_f = Array{Float64}(undef, 3) + bc1!(resid_f, sol, nothing, sol.t) + @test norm(resid_f) < 1e-2 + end + + # OOP TP-BVP + bc1a(ua, p) = [ua[1]] + bc1b(ub, p) = [ub[1] - 1, ub[2] + 1.729109] + + bvp3 = TwoPointBVProblem(BVPFunction{false}(f1, (bc1a, bc1b); twopoint = Val(true), + bcresid_prototype = (zeros(1), zeros(2))), u0, tspan) + + for solver in SOLVERS + @time sol = solve(bvp3, solver; verbose = false, dt = 1.0, abstol = 1e-3, + reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3)) + @test norm(vcat(bc1a(sol[1], nothing), bc1b(sol[end], nothing))) < 1e-2 + end + + # IIP TP-BVP + bc1a!(resid, ua, p) = (resid[1] = ua[1]) + bc1b!(resid, ub, p) = (resid[1] = ub[1] - 1; resid[2] = ub[2] + 1.729109) + + bvp4 = TwoPointBVProblem(BVPFunction{true}(f1!, (bc1a!, bc1b!); twopoint = Val(true), + bcresid_prototype = (zeros(1), zeros(2))), u0, tspan) + + for solver in SOLVERS + @time sol = solve(bvp3, solver; verbose = false, dt = 1.0, abstol = 1e-3, + reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3)) + resida = Array{Float64}(undef, 1) + residb = Array{Float64}(undef, 2) + bc1a!(resida, sol(0.0), nothing) + bc1b!(residb, sol(100.0), nothing) + @test norm(vcat(resida, residb)) < 1e-2 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e36c8280..b08cae50 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,9 @@ const GROUP = uppercase(get(ENV, "GROUP", "ALL")) @time @safetestset "Interpolation Tests" begin include("mirk/interpolation_test.jl") end + @time @safetestset "MIRK Nonlinear Least Squares Tests" begin + include("mirk/nonlinear_least_squares.jl") + end end end diff --git a/test/shooting/shooting_tests.jl b/test/shooting/shooting_tests.jl index 5630f95b..fcd8845c 100644 --- a/test/shooting/shooting_tests.jl +++ b/test/shooting/shooting_tests.jl @@ -150,7 +150,7 @@ end bc1a(ua, p) = [ua[1]] bc1b(ub, p) = [ub[1] - 1, ub[2] + 1.729109] - bvp3 = TwoPointBVProblem(BVPFunction{false}(f, (bc1a, bc1b); twopoint = Val(true), + bvp3 = TwoPointBVProblem(BVPFunction{false}(f1, (bc1a, bc1b); twopoint = Val(true), bcresid_prototype = (zeros(1), zeros(2))), u0, tspan) for solver in SOLVERS @@ -164,7 +164,7 @@ end bc1a!(resid, ua, p) = (resid[1] = ua[1]) bc1b!(resid, ub, p) = (resid[1] = ub[1] - 1; resid[2] = ub[2] + 1.729109) - bvp4 = TwoPointBVProblem(BVPFunction{true}(f!, (bc1a!, bc1b!); twopoint = Val(true), + bvp4 = TwoPointBVProblem(BVPFunction{true}(f1!, (bc1a!, bc1b!); twopoint = Val(true), bcresid_prototype = (zeros(1), zeros(2))), u0, tspan) for solver in SOLVERS