From cb9aff653e8805faa5c45983db3b11ba76725365 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 Oct 2023 12:37:58 -0400 Subject: [PATCH] Multiple Shooting supports NLS --- Project.toml | 4 ++-- src/solve/multiple_shooting.jl | 12 +++++++----- test/shooting/shooting_tests.jl | 6 +++++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 3981b677..3e66a2b5 100644 --- a/Project.toml +++ b/Project.toml @@ -35,12 +35,12 @@ ArrayInterface = "7" ConcreteStructs = "0.2" DiffEqBase = "6.94.2" ForwardDiff = "0.10" -NonlinearSolve = "2" +NonlinearSolve = "2.2" ODEInterface = "0.5" PreallocationTools = "0.4" RecursiveArrayTools = "2.38.10" Reexport = "0.2, 1.0" -SciMLBase = "2.2" +SciMLBase = "2.4" Setfield = "1" SparseDiffTools = "2.6" TruncatedStacktraces = "1" diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index 07caed19..3be10245 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -6,6 +6,7 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), has_initial_guess = known(ig) bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0) + M = length(bcresid_prototype) iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0) __alg = concretize_jacobian_algorithm(_alg, prob) @@ -128,8 +129,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), # This is mostly a safety measure fill!(J, 0) - J_bc = J[1:N, :] - J_c = J[(N + 1):end, :] + J_bc = J[1:M, :] + J_c = J[(M + 1):end, :] sparse_jacobian!(J_c, alg.jac_alg.nonbc_diffmode, ode_jac_cache, ode_fn, resid_nodes.du, us) @@ -152,8 +153,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), # This is mostly a safety measure fill!(J, 0) - J_bc = J[1:N, :] - J_c = J[(N + 1):end, :] + J_bc = J[1:M, :] + J_c = J[(M + 1):end, :] sparse_jacobian!(J_c, alg.jac_alg.nonbc_diffmode, ode_jac_cache, ode_fn, resid_nodes.du, us) @@ -229,7 +230,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), loss_function! = NonlinearFunction{true}((args...) -> loss!(args..., cur_nshoot, nodes); resid_prototype, jac = jac_fn, jac_prototype) - nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p) + nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!, + u_at_nodes, prob.p) sol_nlsolve = __solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...) u_at_nodes = sol_nlsolve.u::typeof(u0) end diff --git a/test/shooting/shooting_tests.jl b/test/shooting/shooting_tests.jl index 589d944f..4c95cbf9 100644 --- a/test/shooting/shooting_tests.jl +++ b/test/shooting/shooting_tests.jl @@ -85,7 +85,11 @@ end Shooting(Tsit5(); nlsolve = LevenbergMarquardt(; damping_initial = 1e-6, α_geodesic = 0.9, b_uphill = 2.0)), - Shooting(Tsit5(); nlsolve = GaussNewton())] + Shooting(Tsit5(); nlsolve = GaussNewton()), + MultipleShooting(10, Tsit5(); + nlsolve = LevenbergMarquardt(; damping_initial = 1e-6, + α_geodesic = 0.9, b_uphill = 2.0)), + MultipleShooting(10, Tsit5(); nlsolve = GaussNewton())] # OOP MP-BVP f1(u, p, t) = [u[2], -u[1]]