Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DifferentiationInterface for the jacobian #258

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c47f611
Use DifferentiationInterface for the jacobian
ErikQQY Nov 21, 2024
ad04fc9
Fix incorrect dense choice
ErikQQY Nov 21, 2024
7c9ed85
Use SparseConnectivityTracer
ErikQQY Nov 21, 2024
309b809
FIRK and Ascher use SparseConnectivityTracer
ErikQQY Nov 21, 2024
cb29d91
Proper usage of SparseConnectivityTracer
ErikQQY Nov 22, 2024
e46534c
Fix incorrect diffmode usage
ErikQQY Nov 22, 2024
f712ebb
Done MIRKN
ErikQQY Nov 25, 2024
cd42ede
Done FIRK and single shooting
ErikQQY Nov 28, 2024
9cdaa8a
Merge branch 'master' into qqy/di
ErikQQY Nov 28, 2024
8524567
Remove SparseDiffTools everywhere
ErikQQY Nov 28, 2024
3af4f5e
Should reexport ADTypes
ErikQQY Nov 28, 2024
351b91e
Small tweaks
ErikQQY Nov 28, 2024
4dd08ce
Need to use OrdinaryDiffEq in test
ErikQQY Nov 28, 2024
907a106
Fix oop single shooting
ErikQQY Nov 29, 2024
286d635
And multiple shooting is done
ErikQQY Nov 29, 2024
24923f6
Dont forget multiple shooting for TwoPointBVProblem
ErikQQY Nov 29, 2024
c201ac8
Bump DI for shooting methods
ErikQQY Nov 30, 2024
1880155
Fix some CI complainings
ErikQQY Dec 5, 2024
b8b02ba
Merge branch 'master' of https://github.com/SciML/BoundaryValueDiffEq…
ErikQQY Dec 23, 2024
45cb9c6
Merge branch 'master' into qqy/di
ErikQQY Dec 23, 2024
0297d87
Using sparsity_pattern
ErikQQY Dec 24, 2024
162d14b
Fix some conflicts from merging
ErikQQY Dec 24, 2024
2f06695
Fix some incorrect utils in MIRKN
ErikQQY Dec 24, 2024
65a68e8
Fix some incorrect using in extension
ErikQQY Dec 25, 2024
e7670cc
Unify GreedyColoringAlgorithm usage
ErikQQY Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,19 @@ version = "5.12.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
BoundaryValueDiffEqCore = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
BoundaryValueDiffEqFIRK = "85d9eb09-370e-4000-bb32-543851f73618"
BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
BoundaryValueDiffEqMIRKN = "9255f1d6-53bf-473e-b6bd-23f1ff009da4"
BoundaryValueDiffEqShooting = "ed55bfe0-3725-4db6-871e-a1dc9f42a757"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
NonlinearSolveFirstOrder = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[weakdeps]
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
Expand All @@ -41,45 +25,29 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"

[compat]
ADTypes = "1.9"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.17"
BandedMatrices = "1.7.5"
BoundaryValueDiffEqAscher = "1"
BoundaryValueDiffEqCore = "1.1"
BoundaryValueDiffEqFIRK = "1.1"
BoundaryValueDiffEqMIRK = "1.1"
BoundaryValueDiffEqMIRKN = "1"
BoundaryValueDiffEqShooting = "1.1"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.158.3"
DiffEqDevTools = "2.44"
FastAlmostBandedMatrices = "0.1.4"
FastClosures = "0.3.2"
ForwardDiff = "0.10.38"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
JET = "0.9.12"
LineSearch = "0.1.4"
LinearAlgebra = "1.10"
LinearSolve = "2.36.2"
Logging = "1.10"
NonlinearSolveFirstOrder = "1"
ODEInterface = "0.5"
OrdinaryDiffEq = "6.90.1"
Pkg = "1.10.0"
PreallocationTools = "0.4.24"
PrecompileTools = "1.2"
Preferences = "1.4"
Random = "1.10"
ReTestItems = "1.23.1"
RecursiveArrayTools = "3.27.0"
Reexport = "1.2"
SciMLBase = "2.60.0"
Setfield = "1.1.1"
SparseArrays = "1.10"
SparseDiffTools = "2.23"
SciMLBase = "2.64.0"
StaticArrays = "1.9.8"
Test = "1.10"
julia = "1.10"
Expand Down
33 changes: 15 additions & 18 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
module BoundaryValueDiffEqODEInterfaceExt

using BoundaryValueDiffEqCore, SciMLBase, ODEInterface, RecursiveArrayTools,
ConcreteStructs, Setfield, PreallocationTools

using BoundaryValueDiffEq: BVPM2, BVPSOL, COLNEW
import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, __extract_u0,
__initial_guess_length, __extract_mesh,
__flatten_initial_guess, __get_bcresid_prototype,
__has_initial_guess, __initial_guess
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
import ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS,
OPT_SOLMETHOD, OPT_RHS_CALLMODE, OPT_COLLOCATIONPTS, OPT_ADDGRIDPOINTS,
OPT_MAXSUBINTERVALS, RHS_CALL_INSITU, evalSolution
import ODEInterface: Bvpm2, bvpm2_init, bvpm2_solve, bvpm2_destroy, bvpm2_get_x
import ODEInterface: bvpsol
import ODEInterface: colnew

import FastClosures: @closure
import ForwardDiff
using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, __extract_u0,
__initial_guess_length, __extract_mesh,
__flatten_initial_guess, __get_bcresid_prototype,
__has_initial_guess, __initial_guess
using SciMLBase: SciMLBase, BVProblem
using ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS,
OPT_SOLMETHOD, OPT_RHS_CALLMODE, OPT_COLLOCATIONPTS, OPT_ADDGRIDPOINTS,
OPT_MAXSUBINTERVALS, RHS_CALL_INSITU, evalSolution
using ODEInterface: Bvpm2, bvpm2_init, bvpm2_solve, bvpm2_destroy, bvpm2_get_x
using ODEInterface: bvpsol
using ODEInterface: colnew

using FastClosures: @closure
using ForwardDiff: ForwardDiff

#------
# BVPM2
Expand Down
12 changes: 7 additions & 5 deletions lib/BoundaryValueDiffEqAscher/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BoundaryValueDiffEqAscher"
uuid = "7227322d-7511-4e07-9247-ad6ff830280e"
authors = ["Qingyu Qu <[email protected]>"]
version = "1.1.0"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -12,11 +12,11 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BoundaryValueDiffEqCore = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Expand All @@ -25,7 +25,8 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[compat]
ADTypes = "1.9"
Expand All @@ -37,14 +38,14 @@ BoundaryValueDiffEqCore = "1.1.0"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.158.3"
DiffEqDevTools = "2.44"
DifferentiationInterface = "0.6.24"
FastClosures = "0.3.2"
ForwardDiff = "0.10.38"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
JET = "0.9.12"
LinearAlgebra = "1.10"
LinearSolve = "2.36.2"
Logging = "1.10"
PreallocationTools = "0.4.24"
PrecompileTools = "1.2"
Preferences = "1.4"
Expand All @@ -55,7 +56,8 @@ Reexport = "1.2"
SciMLBase = "2.59.1"
Setfield = "1.1.1"
SparseArrays = "1.10"
SparseDiffTools = "2.23"
SparseConnectivityTracer = "0.6.9"
SparseMatrixColorings = "0.4.10"
StaticArrays = "1.9.8"
Test = "1.10"
julia = "1.10"
Expand Down
43 changes: 23 additions & 20 deletions lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
module BoundaryValueDiffEqAscher

using ADTypes
using AlmostBlockDiagonals
using BoundaryValueDiffEqCore
using ConcreteStructs
using FastClosures
using ForwardDiff
using ADTypes: ADTypes, AutoSparse, AutoForwardDiff
using AlmostBlockDiagonals: AlmostBlockDiagonals, IntermediateAlmostBlockDiagonal
using BoundaryValueDiffEqCore: BVPJacobianAlgorithm, __extract_problem_details,
concrete_jacobian_algorithm, __Fix3,
__concrete_nonlinearsolve_algorithm,
__internal_nlsolve_problem, BoundaryValueDiffEqAlgorithm,
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
__extract_mesh, get_dense_ad
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
using DifferentiationInterface: DifferentiationInterface, Constant
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using LinearAlgebra
using PreallocationTools
using RecursiveArrayTools
using Reexport
using SciMLBase
using Setfield
using PreallocationTools: PreallocationTools, DiffCache
using RecursiveArrayTools: VectorOfArray, recursivecopy
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractDiffEqInterpolation, StandardBVProblem, __solve,
_unwrap_val
using Setfield: @set!
using SparseConnectivityTracer: SparseConnectivityTracer
using SparseMatrixColorings: SparseMatrixColorings, GreedyColoringAlgorithm, LargestFirst

import BoundaryValueDiffEqCore: BVPJacobianAlgorithm, __extract_problem_details,
concrete_jacobian_algorithm, __Fix3,
__concrete_nonlinearsolve_algorithm,
BoundaryValueDiffEqAlgorithm, __sparse_jacobian_cache,
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
__extract_mesh
const DI = DifferentiationInterface

import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val

@reexport using ADTypes, DiffEqBase, BoundaryValueDiffEqCore, SparseDiffTools, SciMLBase
@reexport using ADTypes, BoundaryValueDiffEqCore, SciMLBase

include("types.jl")
include("utils.jl")
Expand Down
49 changes: 33 additions & 16 deletions lib/BoundaryValueDiffEqAscher/src/ascher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,31 +315,48 @@ function __construct_nlproblem(cache::AscherCache{iip, T}) where {iip, T}
else
@closure (z, p) -> @views Φ(cache, z, pt)
end

lz = reduce(vcat, cache.z)
sd = alg.jac_alg.diffmode isa AutoSparse ? SymbolicsSparsityDetection() :
NoSparsityDetection()
ad = alg.jac_alg.diffmode
lossₚ = (iip ? __Fix3 : Base.Fix2)(loss, cache.p)
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, lz)
jac_prototype = init_jacobian(jac_cache)
resid_prototype = zero(lz)
diffmode = if alg.jac_alg.diffmode isa AutoSparse
AutoSparse(get_dense_ad(alg.jac_alg.diffmode);
sparsity_detector = SparseConnectivityTracer.TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm(LargestFirst()))
else
alg.jac_alg.diffmode
end

jac_cache = if iip
DI.prepare_jacobian(loss, resid_prototype, diffmode, lz, Constant(cache.p))
else
DI.prepare_jacobian(loss, diffmode, lz, Constant(cache.p))
end

jac_prototype = if iip
DI.jacobian(loss, resid_prototype, jac_cache, diffmode, lz, Constant(cache.p))
else
DI.jacobian(loss, jac_cache, diffmode, lz, Constant(cache.p))
end

jac = if iip
@closure (J, u, p) -> __ascher_mpoint_jacobian!(J, u, ad, jac_cache, lossₚ, lz)
@closure (J, u, p) -> __ascher_mpoint_jacobian!(
J, u, diffmode, jac_cache, loss, lz, cache.p)
else
@closure (u, p) -> __ascher_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ)
@closure (u, p) -> __ascher_mpoint_jacobian(
jac_prototype, u, diffmode, jac_cache, loss, cache.p)
end
resid_prototype = zero(lz)
_nlf = NonlinearFunction{iip}(

nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p)
return nlprob
return __internal_nlsolve_problem(cache.prob, similar(lz), lz, nlf, lz, cache.p)
end

function __ascher_mpoint_jacobian!(J, x, diffmode, diffcache, loss, resid)
sparse_jacobian!(J, diffmode, diffcache, loss, resid, x)
function __ascher_mpoint_jacobian!(J, x, diffmode, diffcache, loss, resid, p)
DI.jacobian!(loss, resid, J, diffcache, diffmode, x, Constant(p))
return nothing
end
function __ascher_mpoint_jacobian(J, x, diffmode, diffcache, loss)
sparse_jacobian!(J, diffmode, diffcache, loss, x)
function __ascher_mpoint_jacobian(J, x, diffmode, diffcache, loss, p)
DI.jacobian!(loss, J, diffcache, diffmode, x, Constant(p))
return J
end

Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqAscher/src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ function Φ(cache::AscherCache{iip, T}, z, pt::StandardBVProblem) where {iip, T}

@views gblock!(cache, h, g[i], izeta, w[i], v[i])

if i >= n
if i == n
izsave = izeta
# build equation for a side condition.
# other nonlinear case
Expand Down Expand Up @@ -848,6 +848,7 @@ function dmzsol!(cache::AscherCache, v, z, dmz)
for i in 1:n
for j in 1:ncomp
fact = __get_value(z[i][j])
println("fact: ", fact)
for l in 1:kdy
kk, jj = __locate_stage(l, ncy)
dmz[i][kk][jj] = dmz[i][kk][jj] + fact * v[i][l, j]
Expand Down
2 changes: 0 additions & 2 deletions lib/BoundaryValueDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
Expand All @@ -43,7 +42,6 @@ Reexport = "1.2"
SciMLBase = "2.59.1"
Setfield = "1"
SparseArrays = "1.10"
SparseDiffTools = "2.23"
Test = "1.10"
julia = "1.10"

Expand Down
41 changes: 20 additions & 21 deletions lib/BoundaryValueDiffEqCore/src/BoundaryValueDiffEqCore.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
module BoundaryValueDiffEqCore

using ADTypes, Adapt, ArrayInterface, ForwardDiff, LinearAlgebra, LineSearch,
NonlinearSolveFirstOrder, RecursiveArrayTools, Reexport, SciMLBase, Setfield,
SparseDiffTools

using PreallocationTools: PreallocationTools, DiffCache

# Special Matrix Types
using SparseArrays

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type, fast_scalar_indexing
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: ForwardDiff, pickchunksize
import Logging
using Adapt: adapt
using ADTypes: ADTypes, AbstractADType, AutoSparse, AutoForwardDiff, AutoFiniteDiff,
NoSparsityDetector, KnownJacobianSparsityDetector
using ArrayInterface: matrix_colors, parameterless_type, fast_scalar_indexing
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase, solve
using ForwardDiff: ForwardDiff, pickchunksize
using Logging
using NonlinearSolveFirstOrder: NonlinearSolvePolyAlgorithm
import LineSearch: BackTracking
import RecursiveArrayTools: VectorOfArray, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
using LinearAlgebra
using LineSearch: BackTracking
using PreallocationTools: PreallocationTools, DiffCache
using RecursiveArrayTools: AbstractVectorOfArray, VectorOfArray, DiffEqArray
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractBVProblem, AbstractDiffEqInterpolation,
StandardBVProblem, StandardSecondOrderBVProblem, __solve, _unwrap_val
using Setfield: @set!, @set
using SparseArrays: sparse

@reexport using ADTypes, NonlinearSolveFirstOrder, SparseDiffTools, SciMLBase
@reexport using NonlinearSolveFirstOrder, SciMLBase

include("types.jl")
include("utils.jl")
include("algorithms.jl")
include("alg_utils.jl")
include("default_nlsolve.jl")
include("sparse_jacobians.jl")

function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
function SciMLBase.__solve(
prob::AbstractBVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end
Expand Down
Loading
Loading