Skip to content

Commit

Permalink
Temporarily use Cholesky
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 2, 2023
1 parent 24bf3c1 commit 7115ee8
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 20 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.135"
ForwardDiff = "0.10"
LinearSolve = "2"
NonlinearSolve = "2.5"
ODEInterface = "0.5"
OrdinaryDiffEq = "6"
Expand All @@ -58,6 +59,7 @@ julia = "1.9"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -66,4 +68,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "SafeTestsets", "ODEInterface", "Aqua"]
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "SafeTestsets", "ODEInterface", "Aqua", "LinearSolve"]
20 changes: 10 additions & 10 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,11 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
jac = if iip
(J, u, p) -> __mirk_mpoint_jacobian!(J, u, p, jac_alg.bc_diffmode,
jac_alg.nonbc_diffmode, cache_bc, cache_collocation, loss_bcₚ,
loss_collocationₚ, resid_bc, resid_collocation, cache.M)
loss_collocationₚ, resid_bc, resid_collocation, cache.M, L)
else
(u, p) -> __mirk_mpoint_jacobian(u, p, jac_prototype, jac_alg.bc_diffmode,
jac_alg.nonbc_diffmode, cache_bc, cache_collocation, loss_bcₚ,
loss_collocationₚ, cache.M)
loss_collocationₚ, cache.M, L)
end

nlf = NonlinearFunction{iip}(loss; resid_prototype = vcat(resid_bc, resid_collocation),
Expand All @@ -319,17 +319,17 @@ end

function __mirk_mpoint_jacobian!(J, x, p, bc_diffmode, nonbc_diffmode, bc_diffcache,
nonbc_diffcache, loss_bc::BC, loss_collocation::C, resid_bc, resid_collocation,
M::Int) where {BC, C}
sparse_jacobian!(@view(J[1:M, :]), bc_diffmode, bc_diffcache, loss_bc, resid_bc, x)
sparse_jacobian!(@view(J[(M + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
M::Int, L::Int) where {BC, C}
sparse_jacobian!(@view(J[1:L, :]), bc_diffmode, bc_diffcache, loss_bc, resid_bc, x)
sparse_jacobian!(@view(J[(L + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
loss_collocation, resid_collocation, x)
return nothing
end

function __mirk_mpoint_jacobian(x, p, J, bc_diffmode, nonbc_diffmode, bc_diffcache,
nonbc_diffcache, loss_bc::BC, loss_collocation::C, M::Int) where {BC, C}
sparse_jacobian!(@view(J[1:M, :]), bc_diffmode, bc_diffcache, loss_bc, x)
sparse_jacobian!(@view(J[(M + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
nonbc_diffcache, loss_bc::BC, loss_collocation::C, M::Int, L::Int) where {BC, C}
sparse_jacobian!(@view(J[1:L, :]), bc_diffmode, bc_diffcache, loss_bc, x)
sparse_jacobian!(@view(J[(L + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
loss_collocation, x)
return J
end
Expand All @@ -341,9 +341,9 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo

lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p))

resid = vcat(cache.bcresid_prototype[1:prod(cache.resid_size[1])],
resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
similar(y, cache.M * (N - 1)),
cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
L = length(cache.bcresid_prototype)

sd = if jac_alg.diffmode isa AbstractSparseADType
Expand Down
8 changes: 4 additions & 4 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_

jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
ode_fn, bc_fn, alg, N)
ode_fn, bc_fn, alg, N, M)

loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
jac_prototype)
Expand Down Expand Up @@ -186,9 +186,9 @@ end

function __multiple_shooting_mpoint_jacobian!(J, us, p, resid_bc, resid_nodes,
ode_jac_cache, bc_jac_cache, ode_fn::F1, bc_fn::F2, alg::MultipleShooting,
N::Int) where {F1, F2}
J_bc = @view(J[1:N, :])
J_c = @view(J[(N + 1):end, :])
N::Int, M::Int) where {F1, F2}
J_bc = @view(J[1:M, :])
J_c = @view(J[(M + 1):end, :])

sparse_jacobian!(J_c, alg.jac_alg.nonbc_diffmode, ode_jac_cache, ode_fn,
resid_nodes.du, us)
Expand Down
8 changes: 3 additions & 5 deletions test/shooting/shooting_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using BoundaryValueDiffEq, LinearAlgebra, OrdinaryDiffEq, Test
using BoundaryValueDiffEq, LinearAlgebra, LinearSolve, OrdinaryDiffEq, Test

@testset "Basic Shooting Tests" begin
SOLVERS = [Shooting(Tsit5()), MultipleShooting(10, Tsit5())]
Expand Down Expand Up @@ -83,12 +83,10 @@ end
@testset "Overconstrained BVP" begin
SOLVERS = [
Shooting(Tsit5();
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
α_geodesic = 0.9, b_uphill = 2.0)),
nlsolve = LevenbergMarquardt(; linsolve = CholeskyFactorization())),
Shooting(Tsit5(); nlsolve = GaussNewton()),
MultipleShooting(10, Tsit5();
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
α_geodesic = 0.9, b_uphill = 2.0)),
nlsolve = LevenbergMarquardt(; linsolve = CholeskyFactorization())),
MultipleShooting(10, Tsit5(); nlsolve = GaussNewton())]

# OOP MP-BVP
Expand Down

0 comments on commit 7115ee8

Please sign in to comment.