Skip to content

Commit

Permalink
Simplify NVector wrapper
Browse files Browse the repository at this point in the history
NVector is now a mutable struct with finalizer.

(cf previous implementation with an immutable struct containing a
Ref{NVector} to which a finalizer was attached)
  • Loading branch information
sjdaines committed Jan 13, 2023
1 parent c027e5b commit 6a465ec
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions src/nvector_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,26 @@
NB: can be supplied to ccall when an N_Vector is required
(ie implements cconvert / unsafe_convert).
"""
struct NVector <: DenseVector{realtype}
ref_nv::Ref{N_Vector} # reference to N_Vector
mutable struct NVector <: DenseVector{realtype}
n_v::N_Vector # reference (C pointer) to N_Vector
v::Vector{realtype} # array that is referenced by N_Vector

function NVector(v::Vector{realtype})
# note that N_VMake_Serial() creates N_Vector doesn't own the data,
# so calling N_VDestroy_Serial() would not deallocate v
nv = new(Ref{N_Vector}(N_VMake_Serial(length(v), v)), v)
finalizer(release_handle, nv.ref_nv)
nv = new(N_VMake_Serial(length(v), v), v)
finalizer(release_handle, nv)
return nv
end

function NVector(nv::N_Vector)
function NVector(n_v::N_Vector)
# wrap N_Vector into NVector and get non-owning access to `nv` data
# via `v`, but don't register finalizer for `nv`
return new(Ref{N_Vector}(nv), asarray(nv))
return new(n_v, asarray(n_v))
end
end

release_handle(ref_nv::Ref{N_Vector}) = N_VDestroy_Serial(ref_nv[])
release_handle(nv::NVector) = N_VDestroy_Serial(nv.n_v)

Base.size(nv::NVector, d...) = size(nv.v, d...)
Base.stride(nv::NVector, d::Integer) = stride(nv.v, d)
Expand All @@ -41,7 +41,7 @@ Base.setindex!(nv::NVector, X, i::Real) = setindex!(nv.v, X, i)
Base.setindex!(nv::NVector, X, i::AbstractArray) = setindex!(nv.v, X, i)
Base.setindex!(nv::NVector, X, inds...) = setindex!(nv.v, X, inds...)

Base.pointer(nv::NVector) = Sundials.N_VGetArrayPointer_Serial(nv.ref_nv[])
Base.pointer(nv::NVector) = Sundials.N_VGetArrayPointer_Serial(nv.n_v)

##################################################################
#
Expand All @@ -65,23 +65,19 @@ Base.convert(::Type{Vector}, nv::NVector) = nv.v


"""
Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) -> nv::NVector
Base.unsafe_convert(::Type{N_Vector}, nv::NVector)
Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) -> NVector
Base.unsafe_convert(::Type{N_Vector}, nv::NVector) -> N_Vector
Convert NVector to N_Vector, for use by ccall
Convert v to N_Vector, for use by ccall
(NB: actually implemented to convert any Julia Vector, although only NVector is needed?)
This replaces incorrect use of Base.convert, which fails with Julia >= 1.8 if used with a temporary NVector
see https://discourse.julialang.org/t/how-to-keep-a-reference-for-c-structure-to-avoid-gc/9310/21
(NB: actually implemented to convert any Julia Vector v, although only v::NVector is needed?)
Conversion happens in two steps within ccall:
- cconvert to convert to temporary NVector, which is preserved from garbage collection
- unsafe_convert to get the pointer from the temporary NVector
- cconvert to convert to temporary NVector, which is preserved (by ccall) from garbage collection
- unsafe_convert to get the N_Vector pointer from the temporary NVector
"""
Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) = convert(NVector, v)
Base.unsafe_convert(::Type{N_Vector}, nv::NVector) = nv.ref_nv[]
Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) = convert(NVector, v) # will just return v if v is an NVector
Base.unsafe_convert(::Type{N_Vector}, nv::NVector) = nv.n_v


Base.similar(nv::NVector) = NVector(similar(nv.v))
Expand Down

0 comments on commit 6a465ec

Please sign in to comment.