Skip to content

Commit

Permalink
Multiple Shooting supports NLS
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 19, 2023
1 parent 6540dc9 commit f78b75d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.94.2"
ForwardDiff = "0.10"
NonlinearSolve = "2"
NonlinearSolve = "2.2"
ODEInterface = "0.5"
PreallocationTools = "0.4"
RecursiveArrayTools = "2.38.10"
Reexport = "0.2, 1.0"
SciMLBase = "2.2"
SciMLBase = "2.4"
Setfield = "1"
SparseDiffTools = "2.6"
TruncatedStacktraces = "1"
Expand Down
10 changes: 7 additions & 3 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
if prob.problem_type isa TwoPointBVProblem
resida_len = prod(resid_size[1])
residb_len = prod(resid_size[2])
M = resida_len + residb_len
else
M = length(bcresid_prototype)
end

# We will use colored AD for this part!
Expand Down Expand Up @@ -134,8 +137,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
@views function jac_mp!(J::AbstractMatrix, us, p, resid_bc,
resid_nodes::MaybeDiffCache, ode_jac_cache, bc_jac_cache, ode_fn, bc_fn,
cur_nshoot, nodes)
J_bc = J[1:N, :]
J_c = J[(N + 1):end, :]
J_bc = J[1:M, :]
J_c = J[(M + 1):end, :]

sparse_jacobian!(J_c, alg.jac_alg.nonbc_diffmode, ode_jac_cache, ode_fn,
resid_nodes.du, us)
Expand Down Expand Up @@ -214,7 +217,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
end
loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
jac_prototype)
nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p)
nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
u_at_nodes, prob.p)
sol_nlsolve = __solve(nlprob, alg.nlsolve; verbose, kwargs..., nlsolve_kwargs...)
u_at_nodes = sol_nlsolve.u::typeof(u0)
end
Expand Down
8 changes: 6 additions & 2 deletions test/shooting/shooting_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ end
Shooting(Tsit5();
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
α_geodesic = 0.9, b_uphill = 2.0)),
Shooting(Tsit5(); nlsolve = GaussNewton())]
Shooting(Tsit5(); nlsolve = GaussNewton()),
MultipleShooting(10, Tsit5();
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
α_geodesic = 0.9, b_uphill = 2.0)),
MultipleShooting(10, Tsit5(); nlsolve = GaussNewton())]

# OOP MP-BVP
f1(u, p, t) = [u[2], -u[1]]
Expand Down Expand Up @@ -146,7 +150,7 @@ end
bc1a(ua, p) = [ua[1]]
bc1b(ub, p) = [ub[1] - 1, ub[2] + 1.729109]

bvp3 = TwoPointBVProblem(BVPFunction{false}(f, (bc_a, bc_b); twopoint = Val(true),
bvp3 = TwoPointBVProblem(BVPFunction{false}(f, (bc1a, bc1b); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)

for solver in SOLVERS
Expand Down

0 comments on commit f78b75d

Please sign in to comment.