From 12baf760f69a697998b9a000e767f8b638f189bf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 14:48:13 -0400 Subject: [PATCH] feat: support automatic sparsity detection for PETSc --- ext/NonlinearSolvePETScExt.jl | 55 ++++++++++++++++++++++++++--------- src/internal/helpers.jl | 7 +++-- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/ext/NonlinearSolvePETScExt.jl b/ext/NonlinearSolvePETScExt.jl index 94ae54f43..347ba8d1b 100644 --- a/ext/NonlinearSolvePETScExt.jl +++ b/ext/NonlinearSolvePETScExt.jl @@ -6,16 +6,19 @@ using NonlinearSolveBase: NonlinearSolveBase, get_tolerance using NonlinearSolve: NonlinearSolve, PETScSNES using PETSc: PETSc using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode +using SparseArrays: AbstractSparseMatrix function SciMLBase.__solve( prob::NonlinearProblem, alg::PETScSNES, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, show_trace::Val{ShT} = Val(false), kwargs...) where {ShT} + # XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/ termination_condition === nothing || error("`PETScSNES` does not support termination conditions!") _f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0) T = eltype(prob.u0) + @assert T āˆˆ PETSc.scalar_types if alg.petsclib === missing petsclibidx = findfirst(PETSc.petsclibs) do petsclib @@ -35,7 +38,10 @@ function SciMLBase.__solve( abstol = get_tolerance(abstol, T) reltol = get_tolerance(reltol, T) + nf = Ref{Int}(0) + f! = @closure (cfx, cx, user_ctx) -> begin + nf[] += 1 fx = cfx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cfx; read = false) : cfx x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx _f!(fx, x) @@ -49,25 +55,47 @@ function SciMLBase.__solve( alg.snes_options..., snes_monitor = ShT, snes_rtol = reltol, snes_atol = abstol, snes_max_it = maxiters) + PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0))) + if alg.autodiff === missing && prob.f.jac === nothing _jac! = nothing + njac = Ref{Int}(-1) else autodiff = alg.autodiff === missing ? nothing : alg.autodiff - _jac! = NonlinearSolve.__construct_extension_jac(prob, alg, u0, resid; autodiff) - end + _jac!, J_init = NonlinearSolve.__construct_extension_jac( + prob, alg, u0, resid; autodiff, initial_jacobian = Val(true)) - PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0))) + njac = Ref{Int}(0) - if _jac! !== nothing # XXX: Sparsity Handling??? - PJ = PETSc.MatSeqDense(zeros(T, length(resid), length(u0))) - jac! = @closure (cx, J, _, user_ctx) -> begin - x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx - _jac!(J, x) - Base.finalize(x) - PETSc.assemble(J) - return + if J_init isa AbstractSparseMatrix + PJ = PETSc.MatSeqAIJ(J_init) + jac! = @closure (cx, J, _, user_ctx) -> begin + njac[] += 1 + x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx + if J isa PETSc.AbstractMat + _jac!(user_ctx.jacobian, x) + copyto!(J, user_ctx.jacobian) + PETSc.assemble(J) + else + _jac!(J, x) + end + Base.finalize(x) + return + end + PETSc.setjacobian!(snes, jac!, PJ, PJ) + snes.user_ctx = (; jacobian = J_init) + else + PJ = PETSc.MatSeqDense(J_init) + jac! = @closure (cx, J, _, user_ctx) -> begin + njac[] += 1 + x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx + _jac!(J, x) + Base.finalize(x) + J isa PETSc.AbstractMat && PETSc.assemble(J) + return + end + PETSc.setjacobian!(snes, jac!, PJ, PJ) end - PETSc.setjacobian!(snes, jac!, PJ, PJ) end res = PETSc.solve!(u0, snes) @@ -79,7 +107,8 @@ function SciMLBase.__solve( objective = maximum(abs, resid) # XXX: Return Code from PETSc retcode = ifelse(objective ā‰¤ abstol, ReturnCode.Success, ReturnCode.Failure) - return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original = snes) + return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original = snes, + stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)) end end diff --git a/src/internal/helpers.jl b/src/internal/helpers.jl index 30e4596bd..6a154e0cc 100644 --- a/src/internal/helpers.jl +++ b/src/internal/helpers.jl @@ -109,7 +109,8 @@ function __construct_extension_f(prob::AbstractNonlinearProblem; alias_u0::Bool end function __construct_extension_jac(prob, alg, u0, fu; can_handle_oop::Val = False, - can_handle_scalar::Val = False, autodiff = nothing, kwargs...) + can_handle_scalar::Val = False, autodiff = nothing, initial_jacobian = False, + kwargs...) autodiff = select_jacobian_autodiff(prob, autodiff) Jā‚š = JacobianCache( @@ -120,7 +121,9 @@ function __construct_extension_jac(prob, alg, u0, fu; can_handle_oop::Val = Fals š‰ = (can_handle_oop === False && !isinplace(prob)) ? @closure((J, u)->copyto!(J, š“™(u))) : š“™ - return š‰ + initial_jacobian === False && return š‰ + + return š‰, Jā‚š(nothing) end function reinit_cache! end