diff --git a/src/FINUFFT.jl b/src/FINUFFT.jl index dcce07b..f0ebfa2 100644 --- a/src/FINUFFT.jl +++ b/src/FINUFFT.jl @@ -51,6 +51,11 @@ function __init__() include("cufinufft_simple.jl") determine_cuda_status() end + # use the same lock as FFTW to ensure a thread-safe planning stage + @require FFTW="7a1cc6ca-52ef-59f5-83cd-3a7055c09341" begin + using .FFTW + FINUFFT.finufftlock = FFTW.fftwlock + end end end # module diff --git a/src/errors.jl b/src/errors.jl index 9222f55..981e7fd 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -22,6 +22,7 @@ const ERR_METHOD_NOTVALID = 17 const ERR_BINSIZE_NOTVALID = 18 const ERR_INSUFFICIENT_SHMEM = 19 const ERR_NUM_NU_PTS_INVALID = 20 +const ERR_LOCK_FUNS_INVALID = 21 struct FINUFFTError <: Exception errno::Cint @@ -79,6 +80,8 @@ function check_ret(ret) msg = "GPU shmem too small for subprob/blockgather parameters" elseif ret==ERR_NUM_NU_PTS_INVALID msg = "invalid number of nonuniform points: nj or nk negative, or too big (see defs.h)" + elseif ret==ERR_LOCK_FUNS_INVALID + msg = "fftw_(un)lock functions should be both null or both set" else msg = "error of type unknown to Julia interface! Check FINUFFT documentation" end diff --git a/src/guru.jl b/src/guru.jl index a0bbaa0..aeb5caf 100644 --- a/src/guru.jl +++ b/src/guru.jl @@ -1,5 +1,7 @@ ### Guru Interfaces +finufftlock = ReentrantLock() + """ p = finufft_default_opts() p = finufft_default_opts(dtype=Float32) @@ -25,6 +27,13 @@ function finufft_default_opts(dtype::DataType=Float64) ) end + lock_c = @cfunction(x -> lock(unsafe_pointer_to_objref(x)), Cvoid, (Ptr{Cvoid},)) + unlock_c = @cfunction(x -> unlock(unsafe_pointer_to_objref(x)), Cvoid, (Ptr{Cvoid},)) + + opts.fftw_lock_fun = lock_c + opts.fftw_unlock_fun = unlock_c + opts.fftw_lock_data = pointer_from_objref(finufftlock) + return opts end diff --git a/src/types.jl b/src/types.jl index 46506fc..3cd06ec 100644 --- a/src/types.jl +++ b/src/types.jl @@ -96,6 +96,9 @@ mutable struct nufft_opts{T} maxbatchsize :: Cint spread_nthr_atomic :: Cint spread_max_sp_size :: Cint + fftw_lock_fun :: Ptr{Cvoid} + fftw_unlock_fun :: Ptr{Cvoid} + fftw_lock_data :: Ptr{Cvoid} nufft_opts{T}() where T <: finufftReal = new{T}() end # The above must match include/nufft_opts.h in FINUFFT.