From 004e0db60e9e394de4cd0833f6b9742aaf73bf61 Mon Sep 17 00:00:00 2001 From: Pierre Paleo Date: Tue, 11 Jan 2022 09:17:53 +0100 Subject: [PATCH 1/5] Remove unsed line --- src/silx/math/fft/npfft.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/silx/math/fft/npfft.py b/src/silx/math/fft/npfft.py index 20351de88b..2379a56c0d 100644 --- a/src/silx/math/fft/npfft.py +++ b/src/silx/math/fft/npfft.py @@ -59,7 +59,6 @@ def __init__( if normalize != "ortho": self.normalize = None self.set_fft_functions() - #~ self.allocate_arrays() # not needed for this backend self.compute_plans() From 4fcd4274062ab0a5ffde0d336dbc0896791f2518 Mon Sep 17 00:00:00 2001 From: Pierre Paleo Date: Tue, 11 Jan 2022 09:18:01 +0100 Subject: [PATCH 2/5] Fix wrong allocation --- src/silx/math/fft/npfft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/silx/math/fft/npfft.py b/src/silx/math/fft/npfft.py index 2379a56c0d..e87e8ce330 100644 --- a/src/silx/math/fft/npfft.py +++ b/src/silx/math/fft/npfft.py @@ -79,7 +79,7 @@ def set_fft_functions(self): def _allocate(self, shape, dtype): - return np.zeros(self.queue, shape, dtype=dtype) + return np.zeros(shape, dtype=dtype) def compute_plans(self): From 7bb1b6bf20c04242f24f0f331793edd7d8375404 Mon Sep 17 00:00:00 2001 From: Pierre Paleo Date: Tue, 11 Jan 2022 09:30:50 +0100 Subject: [PATCH 3/5] Use python3-like super() --- src/silx/math/fft/clfft.py | 2 +- src/silx/math/fft/cufft.py | 2 +- src/silx/math/fft/fftw.py | 2 +- src/silx/math/fft/npfft.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/silx/math/fft/clfft.py b/src/silx/math/fft/clfft.py index dad8ec15a7..f664a08e40 100644 --- a/src/silx/math/fft/clfft.py +++ b/src/silx/math/fft/clfft.py @@ -74,7 +74,7 @@ def __init__( if not(__have_clfft__) or not(__have_clfft__): raise ImportError("Please install pyopencl and gpyfft >= %s to use the OpenCL back-end" % __required_gpyfft_version__) - super(CLFFT, self).__init__( + super().__init__( shape=shape, dtype=dtype, template=template, diff --git a/src/silx/math/fft/cufft.py b/src/silx/math/fft/cufft.py index 848f3e64f6..082aefb010 100644 --- a/src/silx/math/fft/cufft.py +++ b/src/silx/math/fft/cufft.py @@ -61,7 +61,7 @@ def __init__( if not(__have_cufft__) or not(__have_cufft__): raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end") - super(CUFFT, self).__init__( + super().__init__( shape=shape, dtype=dtype, template=template, diff --git a/src/silx/math/fft/fftw.py b/src/silx/math/fft/fftw.py index ff6966c435..19c27b7d9a 100644 --- a/src/silx/math/fft/fftw.py +++ b/src/silx/math/fft/fftw.py @@ -66,7 +66,7 @@ def __init__( ): if not(__have_fftw__): raise ImportError("Please install pyfftw >= %s to use the FFTW back-end" % __required_pyfftw_version__) - super(FFTW, self).__init__( + super().__init__( shape=shape, dtype=dtype, template=template, diff --git a/src/silx/math/fft/npfft.py b/src/silx/math/fft/npfft.py index e87e8ce330..e010cf5085 100644 --- a/src/silx/math/fft/npfft.py +++ b/src/silx/math/fft/npfft.py @@ -42,7 +42,7 @@ def __init__( axes=None, normalize="rescale", ): - super(NPFFT, self).__init__( + super().__init__( shape=shape, dtype=dtype, template=template, From e3c7a688d11b2a6c33c9896719786704eb334e42 Mon Sep 17 00:00:00 2001 From: Pierre Paleo Date: Tue, 11 Jan 2022 09:31:16 +0100 Subject: [PATCH 4/5] Use late imports to create contexts only when needed --- src/silx/math/fft/fft.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/silx/math/fft/fft.py b/src/silx/math/fft/fft.py index eb0d73b9e9..ac38e964bc 100644 --- a/src/silx/math/fft/fft.py +++ b/src/silx/math/fft/fft.py @@ -24,9 +24,7 @@ # # ###########################################################################*/ from .fftw import FFTW -from .clfft import CLFFT from .npfft import NPFFT -from .cufft import CUFFT def FFT( @@ -71,20 +69,23 @@ def FFT( :param str backend: FFT Backend to use. Value can be "numpy", "fftw", "opencl", "cuda". """ - backends = { - "numpy": NPFFT, - "np": NPFFT, - "fftw": FFTW, - "opencl": CLFFT, - "clfft": CLFFT, - "cuda": CUFFT, - "cufft": CUFFT, - } - + backends = ["numpy", "fftw", "opencl", "cuda"] backend = backend.lower() - if backend not in backends: + if backend in ["numpy", "np"]: + fft_cls = NPFFT + elif backend == "fftw": + fft_cls = FFTW + elif backend in ["opencl", "clfft"]: + # Late import for creating context only if needed + from .clfft import CLFFT + fft_cls = CLFFT + elif backend in ["cuda", "cufft"]: + # Late import for creating context only if needed + from .cufft import CUFFT + fft_cls = CUFFT + else: raise ValueError("Unknown backend %s, available are %s" % (backend, backends)) - F = backends[backend]( + F = fft_cls( shape=shape, dtype=dtype, template=template, From 27a5770b3e2caff64c39abe7195eedc7db786d55 Mon Sep 17 00:00:00 2001 From: Pierre Paleo Date: Tue, 11 Jan 2022 10:33:58 +0100 Subject: [PATCH 5/5] test_fft: fix cuda context creation for Cuda >= 11 --- src/silx/math/fft/test/test_fft.py | 36 ++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/silx/math/fft/test/test_fft.py b/src/silx/math/fft/test/test_fft.py index 19becb8dab..6c6d72256f 100644 --- a/src/silx/math/fft/test/test_fft.py +++ b/src/silx/math/fft/test/test_fft.py @@ -40,9 +40,41 @@ from silx.math.fft.cufft import __have_cufft__ from silx.math.fft.fftw import __have_fftw__ +if __have_cufft__: + import atexit + import pycuda.driver as cuda + from pycuda.tools import clear_context_caches -logger = logging.getLogger(__name__) +def get_cuda_context(device_id=None, cleanup_at_exit=True): + """ + Create or get a CUDA context. + """ + current_ctx = cuda.Context.get_current() + # If a context already exists, use this one + # TODO what if the device used is different from device_id ? + if current_ctx is not None: + return current_ctx + # Otherwise create a new context + cuda.init() + if device_id is None: + device_id = 0 + # Use the Context obtained by retaining the device's primary context, + # which is the one used by the CUDA runtime API (ex. scikit-cuda). + # Unlike Context.make_context(), the newly-created context is not made current. + context = cuda.Device(device_id).retain_primary_context() + context.push() + # Register a clean-up function at exit + def _finish_up(context): + if context is not None: + context.pop() + context = None + clear_context_caches() + if cleanup_at_exit: + atexit.register(_finish_up, context) + return context + +logger = logging.getLogger(__name__) class TransformInfos(object): def __init__(self): @@ -113,7 +145,7 @@ def calc_mae(arr1, arr2): @unittest.skipIf(not __have_cufft__, "cuda back-end requires pycuda and scikit-cuda") def test_cuda(self): - import pycuda.autoinit + get_cuda_context() # Error is higher when using cuda. fast_math mode ? self.tol[np.dtype("float32")] *= 2