Skip to content

Commit

Permalink
Test speculative fix/workaround for kinsol segfault
Browse files Browse the repository at this point in the history
See #25

This rearranges the kinsol wrapper to a simpler form
  • Loading branch information
sjdaines committed Aug 29, 2022
1 parent d376214 commit 26d5311
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions src/Kinsol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,31 @@ function kin_create(
psolvefun = nothing,
jvfun = nothing,
)
# use the user_data field to pass a function
# see: https://github.com/JuliaLang/julia/issues/2554
userfun = UserFunctionAndData(f, psetupfun, psolvefun, jvfun, userdata)

return _kin_create(userfun, y0; linear_solver=linear_solver, jac_upper=jac_upper, jac_lower=jac_lower, krylov_dim=krylov_dim)
end

function _kin_create(
userfun::T, y0::Vector{Float64};
linear_solver,
jac_upper,
jac_lower,
krylov_dim,
) where {T}

mem_ptr = Sundials.KINCreate()
(mem_ptr == C_NULL) && error("Failed to allocate KINSOL solver object")
kmem = Sundials.Handle(mem_ptr)

handles = []

push!(handles, userfun) # TODO prevent userfun from being garbage collected ?

# use the user_data field to pass a function
# see: https://github.com/JuliaLang/julia/issues/2554
userfun = UserFunctionAndData(f, psetupfun, psolvefun, jvfun, userdata)
# push!(handles, userfun) # TODO prevent userfun from being garbage collected ?
function getkinsolfun(userfun::T) where {T}
@cfunction(kinsolfun, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ref{T}))
end
function getpsetupfun(userfun::T) where {T}
@cfunction(kinprecsetup, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{T}))
end
function getpsolvefun(userfun::T) where {T}
@cfunction(kinprecsolve, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{T}))
end
function getkinjactimesvec(userfun::T) where {T}
@cfunction(kinjactimesvec, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cint}, Ref{T}))
end

flag = Sundials.@checkflag Sundials.KINInit(kmem, getkinsolfun(userfun), Sundials.NVector(y0)) true
c_kinsolfun = @cfunction(kinsolfun, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ref{T}))
flag = Sundials.@checkflag Sundials.KINInit(kmem, c_kinsolfun, Sundials.NVector(y0)) true

if linear_solver == :Dense
A = Sundials.SUNDenseMatrix(length(y0), length(y0))
Expand All @@ -158,7 +158,7 @@ function kin_create(
push!(handles, Sundials.LinSolHandle(LS, Sundials.Band()))
elseif linear_solver == :FGMRES
A = nothing
prec_side = isnothing(psolvefun) ? 0 : 2 # right preconditioning only
prec_side = isnothing(userfun.psolve) ? 0 : 2 # right preconditioning only
LS = Sundials.SUNLinSol_SPFGMR(y0, prec_side, krylov_dim)
push!(handles, Sundials.LinSolHandle(LS, Sundials.SPFGMR()))
end
Expand All @@ -168,17 +168,20 @@ function kin_create(
flag = Sundials.@checkflag Sundials.KINSetLinearSolver(kmem, LS, A === nothing ? C_NULL : A) true
# flag = Sundials.@checkflag Sundials.KINDlsSetLinearSolver(kmem, LS, A === nothing ? C_NULL : A) true

if !isnothing(psolvefun)
if !isnothing(userfun.psolve)
c_kinprecsetup = isnothing(userfun.psetup) ? C_NULL : @cfunction(kinprecsetup, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{T}))
c_kinprecsolve = @cfunction(kinprecsolve, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{T}))

flag = Sundials.@checkflag Sundials.KINSetPreconditioner(kmem,
psetupfun === nothing ? C_NULL : getpsetupfun(userfun),
getpsolvefun(userfun)) true
c_kinprecsetup,
c_kinprecsolve) true
end

if !isnothing(jvfun)
flag = Sundials.@checkflag Sundials.KINSetJacTimesVecFn(kmem, getkinjactimesvec(userfun)) true
if !isnothing(userfun.jv)
c_kinjactimesvec = @cfunction(kinjactimesvec, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cint}, Ref{T}))
flag = Sundials.@checkflag Sundials.KINSetJacTimesVecFn(kmem, c_kinjactimesvec) true
end


return (;kmem, handles)
end

Expand Down

0 comments on commit 26d5311

Please sign in to comment.