diff --git a/src/nvector_wrapper.jl b/src/nvector_wrapper.jl index 727e25a..3f46413 100644 --- a/src/nvector_wrapper.jl +++ b/src/nvector_wrapper.jl @@ -9,26 +9,26 @@ NB: can be supplied to ccall when an N_Vector is required (ie implements cconvert / unsafe_convert). """ -struct NVector <: DenseVector{realtype} - ref_nv::Ref{N_Vector} # reference to N_Vector +mutable struct NVector <: DenseVector{realtype} + n_v::N_Vector # reference (C pointer) to N_Vector v::Vector{realtype} # array that is referenced by N_Vector function NVector(v::Vector{realtype}) # note that N_VMake_Serial() creates N_Vector doesn't own the data, # so calling N_VDestroy_Serial() would not deallocate v - nv = new(Ref{N_Vector}(N_VMake_Serial(length(v), v)), v) - finalizer(release_handle, nv.ref_nv) + nv = new(N_VMake_Serial(length(v), v), v) + finalizer(release_handle, nv) return nv end - function NVector(nv::N_Vector) + function NVector(n_v::N_Vector) # wrap N_Vector into NVector and get non-owning access to `nv` data # via `v`, but don't register finalizer for `nv` - return new(Ref{N_Vector}(nv), asarray(nv)) + return new(n_v, asarray(n_v)) end end -release_handle(ref_nv::Ref{N_Vector}) = N_VDestroy_Serial(ref_nv[]) +release_handle(nv::NVector) = N_VDestroy_Serial(nv.n_v) Base.size(nv::NVector, d...) = size(nv.v, d...) Base.stride(nv::NVector, d::Integer) = stride(nv.v, d) @@ -41,7 +41,7 @@ Base.setindex!(nv::NVector, X, i::Real) = setindex!(nv.v, X, i) Base.setindex!(nv::NVector, X, i::AbstractArray) = setindex!(nv.v, X, i) Base.setindex!(nv::NVector, X, inds...) = setindex!(nv.v, X, inds...) -Base.pointer(nv::NVector) = Sundials.N_VGetArrayPointer_Serial(nv.ref_nv[]) +Base.pointer(nv::NVector) = Sundials.N_VGetArrayPointer_Serial(nv.n_v) ################################################################## # @@ -65,23 +65,19 @@ Base.convert(::Type{Vector}, nv::NVector) = nv.v """ - Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) -> nv::NVector - Base.unsafe_convert(::Type{N_Vector}, nv::NVector) + Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) -> NVector + Base.unsafe_convert(::Type{N_Vector}, nv::NVector) -> N_Vector -Convert NVector to N_Vector, for use by ccall +Convert v to N_Vector, for use by ccall -(NB: actually implemented to convert any Julia Vector, although only NVector is needed?) - -This replaces incorrect use of Base.convert, which fails with Julia >= 1.8 if used with a temporary NVector - -see https://discourse.julialang.org/t/how-to-keep-a-reference-for-c-structure-to-avoid-gc/9310/21 +(NB: actually implemented to convert any Julia Vector v, although only v::NVector is needed?) Conversion happens in two steps within ccall: - - cconvert to convert to temporary NVector, which is preserved from garbage collection - - unsafe_convert to get the pointer from the temporary NVector + - cconvert to convert to temporary NVector, which is preserved (by ccall) from garbage collection + - unsafe_convert to get the N_Vector pointer from the temporary NVector """ -Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) = convert(NVector, v) -Base.unsafe_convert(::Type{N_Vector}, nv::NVector) = nv.ref_nv[] +Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) = convert(NVector, v) # will just return v if v is an NVector +Base.unsafe_convert(::Type{N_Vector}, nv::NVector) = nv.n_v Base.similar(nv::NVector) = NVector(similar(nv.v))