Skip to content

Commit

Permalink
Multiple Shooting supports NLS
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 13, 2023
1 parent 7a04a8b commit cb9aff6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ ArrayInterface = "7"
ConcreteStructs = "0.2"
DiffEqBase = "6.94.2"
ForwardDiff = "0.10"
NonlinearSolve = "2"
NonlinearSolve = "2.2"
ODEInterface = "0.5"
PreallocationTools = "0.4"
RecursiveArrayTools = "2.38.10"
Reexport = "0.2, 1.0"
SciMLBase = "2.2"
SciMLBase = "2.4"
Setfield = "1"
SparseDiffTools = "2.6"
TruncatedStacktraces = "1"
Expand Down
12 changes: 7 additions & 5 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
has_initial_guess = known(ig)

bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0)
M = length(bcresid_prototype)
iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0)

__alg = concretize_jacobian_algorithm(_alg, prob)
Expand Down Expand Up @@ -128,8 +129,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
# This is mostly a safety measure
fill!(J, 0)

J_bc = J[1:N, :]
J_c = J[(N + 1):end, :]
J_bc = J[1:M, :]
J_c = J[(M + 1):end, :]

sparse_jacobian!(J_c, alg.jac_alg.nonbc_diffmode, ode_jac_cache, ode_fn,
resid_nodes.du, us)
Expand All @@ -152,8 +153,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
# This is mostly a safety measure
fill!(J, 0)

J_bc = J[1:N, :]
J_c = J[(N + 1):end, :]
J_bc = J[1:M, :]
J_c = J[(M + 1):end, :]

sparse_jacobian!(J_c, alg.jac_alg.nonbc_diffmode, ode_jac_cache, ode_fn,
resid_nodes.du, us)
Expand Down Expand Up @@ -229,7 +230,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),

loss_function! = NonlinearFunction{true}((args...) -> loss!(args..., cur_nshoot,
nodes); resid_prototype, jac = jac_fn, jac_prototype)
nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p)
nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
u_at_nodes, prob.p)
sol_nlsolve = __solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)
u_at_nodes = sol_nlsolve.u::typeof(u0)
end
Expand Down
6 changes: 5 additions & 1 deletion test/shooting/shooting_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ end
Shooting(Tsit5();
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
α_geodesic = 0.9, b_uphill = 2.0)),
Shooting(Tsit5(); nlsolve = GaussNewton())]
Shooting(Tsit5(); nlsolve = GaussNewton()),
MultipleShooting(10, Tsit5();
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
α_geodesic = 0.9, b_uphill = 2.0)),
MultipleShooting(10, Tsit5(); nlsolve = GaussNewton())]

# OOP MP-BVP
f1(u, p, t) = [u[2], -u[1]]
Expand Down

0 comments on commit cb9aff6

Please sign in to comment.