From 5f4815ec5d2d1ea88c3b20c751e0d012a3b7b184 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Tue, 23 Nov 2021 19:12:55 +0100 Subject: [PATCH] Make API globals thread safe using atomics While the GIL is held when the API pointer is updated, this can still race with other threads checking the current value of the API pointer (without holding the GIL) and should therefore using atomics. The loads and stores are performed using acquire-release semantics as we want to dereference the pointer and hence any stores to the referenced memory need to be visible to us. The get function should also be unsafe as the offset it uses cannot be verified which might create an invalid pointer invoking undefined behaviour as per the contract of pointer::offset. Finally, the initialization code is moved into a separate cold function to improve code locality for the fast path. --- src/npyffi/array.rs | 34 +++++++++++++++++++++------------- src/npyffi/ufunc.rs | 32 ++++++++++++++++++++------------ 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/src/npyffi/array.rs b/src/npyffi/array.rs index 2c91a2174..6d2e18fda 100644 --- a/src/npyffi/array.rs +++ b/src/npyffi/array.rs @@ -2,7 +2,8 @@ use libc::FILE; use pyo3::ffi::{self, PyObject, PyTypeObject}; use std::os::raw::*; -use std::{cell::Cell, ptr}; +use std::ptr::null_mut; +use std::sync::atomic::{AtomicPtr, Ordering}; use crate::npyffi::*; @@ -12,7 +13,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API"; /// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html) /// pointer to [Numpy Array API](https://numpy.org/doc/stable/reference/c-api/array.html). /// -/// You can acceess raw c APIs via this variable and its Deref implementation. +/// You can acceess raw C APIs via this variable. /// /// See [PyArrayAPI](struct.PyArrayAPI.html) for what methods you can use via this variable. /// @@ -31,28 +32,35 @@ pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI::new(); /// See [PY_ARRAY_API] for more. pub struct PyArrayAPI { - api: Cell<*const *const c_void>, + api: AtomicPtr<*const c_void>, } impl PyArrayAPI { const fn new() -> Self { Self { - api: Cell::new(ptr::null_mut()), + api: AtomicPtr::new(null_mut()), } } - fn get(&self, offset: isize) -> *const *const c_void { - if self.api.get().is_null() { - Python::with_gil(|py| { - let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME); - self.api.set(api); - }); + #[cold] + fn init(&self) -> *const *const c_void { + Python::with_gil(|py| { + let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void; + if api.is_null() { + api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME); + self.api.store(api as *mut _, Ordering::Release); + } + api + }) + } + unsafe fn get(&self, offset: isize) -> *const *const c_void { + let mut api = self.api.load(Ordering::Acquire) as *const *const c_void; + if api.is_null() { + api = self.init(); } - unsafe { self.api.get().offset(offset) } + api.offset(offset) } } -unsafe impl Sync for PyArrayAPI {} - impl PyArrayAPI { impl_api![0; PyArray_GetNDArrayCVersion() -> c_uint]; impl_api![40; PyArray_SetNumericOps(dict: *mut PyObject) -> c_int]; diff --git a/src/npyffi/ufunc.rs b/src/npyffi/ufunc.rs index 73798a35c..306cacabd 100644 --- a/src/npyffi/ufunc.rs +++ b/src/npyffi/ufunc.rs @@ -1,7 +1,8 @@ //! Low-Level binding for [UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html) use std::os::raw::*; -use std::{cell::Cell, ptr}; +use std::ptr::null_mut; +use std::sync::atomic::{AtomicPtr, Ordering}; use pyo3::ffi::PyObject; use pyo3::Python; @@ -18,28 +19,35 @@ const CAPSULE_NAME: &str = "_UFUNC_API"; pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new(); pub struct PyUFuncAPI { - api: Cell<*const *const c_void>, + api: AtomicPtr<*const c_void>, } impl PyUFuncAPI { const fn new() -> Self { Self { - api: Cell::new(ptr::null_mut()), + api: AtomicPtr::new(null_mut()), } } - fn get(&self, offset: isize) -> *const *const c_void { - if self.api.get().is_null() { - Python::with_gil(|py| { - let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME); - self.api.set(api); - }); + #[cold] + fn init(&self) -> *const *const c_void { + Python::with_gil(|py| { + let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void; + if api.is_null() { + api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME); + self.api.store(api as *mut _, Ordering::Release); + } + api + }) + } + unsafe fn get(&self, offset: isize) -> *const *const c_void { + let mut api = self.api.load(Ordering::Acquire) as *const *const c_void; + if api.is_null() { + api = self.init(); } - unsafe { self.api.get().offset(offset) } + api.offset(offset) } } -unsafe impl Sync for PyUFuncAPI {} - impl PyUFuncAPI { impl_api![1; PyUFunc_FromFuncAndData(func: *mut PyUFuncGenericFunction, data: *mut *mut c_void, types: *mut c_char, ntypes: c_int, nin: c_int, nout: c_int, identity: c_int, name: *const c_char, doc: *const c_char, unused: c_int) -> *mut PyObject]; impl_api![2; PyUFunc_RegisterLoopForType(ufunc: *mut PyUFuncObject, usertype: c_int, function: PyUFuncGenericFunction, arg_types: *mut c_int, data: *mut c_void) -> c_int];