Skip to content

Commit

Permalink
Make API globals thread safe using atomics
Browse files Browse the repository at this point in the history
While the GIL is held when the API pointer is updated, this can still race with
other threads checking the current value of the API pointer (without holding the
GIL) and should therefore using atomics.

The loads and stores are performed using acquire-release semantics as we want to
dereference the pointer and hence any stores to the referenced memory need to be
visible to us.

The get function should also be unsafe as the offset it uses cannot be verified
which might create an invalid pointer invoking undefined behaviour as per the
contract of pointer::offset.

Finally, the initialization code is moved into a separate cold function to
improve code locality for the fast path.
  • Loading branch information
adamreichold committed Jan 6, 2022
1 parent 615d5c3 commit 5f4815e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
34 changes: 21 additions & 13 deletions src/npyffi/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
use libc::FILE;
use pyo3::ffi::{self, PyObject, PyTypeObject};
use std::os::raw::*;
use std::{cell::Cell, ptr};
use std::ptr::null_mut;
use std::sync::atomic::{AtomicPtr, Ordering};

use crate::npyffi::*;

Expand All @@ -12,7 +13,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API";
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
/// pointer to [Numpy Array API](https://numpy.org/doc/stable/reference/c-api/array.html).
///
/// You can acceess raw c APIs via this variable and its Deref implementation.
/// You can acceess raw C APIs via this variable.
///
/// See [PyArrayAPI](struct.PyArrayAPI.html) for what methods you can use via this variable.
///
Expand All @@ -31,28 +32,35 @@ pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI::new();

/// See [PY_ARRAY_API] for more.
pub struct PyArrayAPI {
api: Cell<*const *const c_void>,
api: AtomicPtr<*const c_void>,
}

impl PyArrayAPI {
const fn new() -> Self {
Self {
api: Cell::new(ptr::null_mut()),
api: AtomicPtr::new(null_mut()),
}
}
fn get(&self, offset: isize) -> *const *const c_void {
if self.api.get().is_null() {
Python::with_gil(|py| {
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.set(api);
});
#[cold]
fn init(&self) -> *const *const c_void {
Python::with_gil(|py| {
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
if api.is_null() {
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.store(api as *mut _, Ordering::Release);
}
api
})
}
unsafe fn get(&self, offset: isize) -> *const *const c_void {
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
if api.is_null() {
api = self.init();
}
unsafe { self.api.get().offset(offset) }
api.offset(offset)
}
}

unsafe impl Sync for PyArrayAPI {}

impl PyArrayAPI {
impl_api![0; PyArray_GetNDArrayCVersion() -> c_uint];
impl_api![40; PyArray_SetNumericOps(dict: *mut PyObject) -> c_int];
Expand Down
32 changes: 20 additions & 12 deletions src/npyffi/ufunc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Low-Level binding for [UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html)
use std::os::raw::*;
use std::{cell::Cell, ptr};
use std::ptr::null_mut;
use std::sync::atomic::{AtomicPtr, Ordering};

use pyo3::ffi::PyObject;
use pyo3::Python;
Expand All @@ -18,28 +19,35 @@ const CAPSULE_NAME: &str = "_UFUNC_API";
pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new();

pub struct PyUFuncAPI {
api: Cell<*const *const c_void>,
api: AtomicPtr<*const c_void>,
}

impl PyUFuncAPI {
const fn new() -> Self {
Self {
api: Cell::new(ptr::null_mut()),
api: AtomicPtr::new(null_mut()),
}
}
fn get(&self, offset: isize) -> *const *const c_void {
if self.api.get().is_null() {
Python::with_gil(|py| {
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.set(api);
});
#[cold]
fn init(&self) -> *const *const c_void {
Python::with_gil(|py| {
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
if api.is_null() {
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.store(api as *mut _, Ordering::Release);
}
api
})
}
unsafe fn get(&self, offset: isize) -> *const *const c_void {
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
if api.is_null() {
api = self.init();
}
unsafe { self.api.get().offset(offset) }
api.offset(offset)
}
}

unsafe impl Sync for PyUFuncAPI {}

impl PyUFuncAPI {
impl_api![1; PyUFunc_FromFuncAndData(func: *mut PyUFuncGenericFunction, data: *mut *mut c_void, types: *mut c_char, ntypes: c_int, nin: c_int, nout: c_int, identity: c_int, name: *const c_char, doc: *const c_char, unused: c_int) -> *mut PyObject];
impl_api![2; PyUFunc_RegisterLoopForType(ufunc: *mut PyUFuncObject, usertype: c_int, function: PyUFuncGenericFunction, arg_types: *mut c_int, data: *mut c_void) -> c_int];
Expand Down

0 comments on commit 5f4815e

Please sign in to comment.