Skip to content

Commit

Permalink
fix FFTW crash after set_num_threads
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj committed Mar 21, 2017
1 parent 90cfb82 commit 3bdfeef
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
28 changes: 13 additions & 15 deletions base/fft/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,21 @@ end

# Threads

let initialized = false
global set_num_threads
function set_num_threads(nthreads::Integer)
if !initialized
# must re-initialize FFTW if any FFTW routines have been called
ccall((:fftw_cleanup,libfftw), Void, ())
ccall((:fftwf_cleanup,libfftwf), Void, ())
stat = ccall((:fftw_init_threads,libfftw), Int32, ())
statf = ccall((:fftwf_init_threads,libfftwf), Int32, ())
if stat == 0 || statf == 0
error("could not initialize FFTW threads")
end
initialized = true
const threads_initialized = Ref(false)
function set_num_threads(nthreads::Integer)
if !threads_initialized[]
# must forget wisdom if any FFTW routines have been called
# (don't call fftw_cleanup, since that would invalidate existing plans)
forget_wisdom()
stat = ccall((:fftw_init_threads,libfftw), Int32, ())
statf = ccall((:fftwf_init_threads,libfftwf), Int32, ())
if stat == 0 || statf == 0
error("could not initialize FFTW threads")
end
ccall((:fftw_plan_with_nthreads,libfftw), Void, (Int32,), nthreads)
ccall((:fftwf_plan_with_nthreads,libfftwf), Void, (Int32,), nthreads)
threads_initialized[] = true
end
ccall((:fftw_plan_with_nthreads,libfftw), Void, (Int32,), nthreads)
ccall((:fftwf_plan_with_nthreads,libfftwf), Void, (Int32,), nthreads)
end

# pointer type for fftw_plan (opaque pointer)
Expand Down
8 changes: 8 additions & 0 deletions test/fft.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license

# issue #19892
# (test this first to make sure it happens before set_num_threads)
let a = randn(10^5,1), p1 = plan_rfft(a)
FFTW.set_num_threads(2)
p2 = plan_rfft(a)
@test p1*a p2*a
end

# fft
a = rand(8) + im*rand(8)
@test norm(ifft(fft(a)) - a) < 1e-8
Expand Down

0 comments on commit 3bdfeef

Please sign in to comment.