Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect destructors, then run them to release lock #31

Merged
merged 1 commit into from
May 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions src/thread_keys.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
//! Thread keys implementation for the standard library.

use spin::rwlock::RwLock;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicU32, Ordering};
use std::{mem, ptr};

use spin::rwlock::RwLock;

type Key = usize;
type Key = libc::pthread_t;
type Destructor = unsafe extern "C" fn(*mut libc::c_void);

static NEXT_KEY: AtomicUsize = AtomicUsize::new(1);
static NEXT_KEY: AtomicU32 = AtomicU32::new(1);

// This is a spin-lock RwLock which yields the thread every loop
static KEYS: RwLock<BTreeMap<Key, Option<Destructor>>, spin::Yield> = RwLock::new(BTreeMap::new());
Expand All @@ -16,31 +18,39 @@ static KEYS: RwLock<BTreeMap<Key, Option<Destructor>>, spin::Yield> = RwLock::ne
static mut LOCALS: BTreeMap<Key, *mut libc::c_void> = BTreeMap::new();

fn is_valid_key(key: Key) -> bool {
KEYS.read().contains_key(&(key as Key))
KEYS.read().contains_key(&key)
}

pub(crate) fn run_local_destructors() {
unsafe {
// We iterate all the thread-local keys set.
//
// When using `std` and the `thread_local!` macro there should be only one key registered here,
// which is the list of keys to destroy.
for (key, value) in LOCALS.iter() {
// We iterate all the thread-local keys set.
//
// When using `std` and the `thread_local!` macro there should be only one key registered here,
// which is the list of keys to destroy.
let dtors: Vec<_> = unsafe { LOCALS.iter_mut() }
.filter_map(|(key, value)| {
// We retrieve the destructor for a key from the static list.
if let Some(destructor) = KEYS.write().get_mut(key) {
// If the destructor is registered for a key, run it.
if let Some(d) = destructor {
// Hold the function pointer.
let dtor = *d;

// Set the destructor as null to avoid reentrancy.
*destructor = None;

dtor(*value);
}
if let Some(destructor) = KEYS
.write()
.get_mut(key)
// Removing the destructor from the list, so it's not called again.
.and_then(Option::take)
{
// And clearing the destructor arg as well
let arg = mem::replace(value, ptr::null_mut());
Some((destructor, arg))
} else {
None
}
})
// Collect destructors separately before running them, so that if any destructor would try
// to obtain KEYS' lock, it doesn't deadlock because we're already holding the lock here.
.collect();

for (destructor, value) in dtors {
unsafe {
destructor(value);
}
};
}
}

#[no_mangle]
Expand All @@ -51,14 +61,14 @@ pub unsafe extern "C" fn pthread_key_create(
let new_key = NEXT_KEY.fetch_add(1, Ordering::SeqCst);
KEYS.write().insert(new_key, destructor);

*key = new_key as libc::pthread_key_t;
*key = new_key;

0
}

#[no_mangle]
pub unsafe extern "C" fn pthread_key_delete(key: libc::pthread_key_t) -> libc::c_int {
match KEYS.write().remove(&(key as Key)) {
match KEYS.write().remove(&key) {
// We had a entry, so it was a valid key.
// It's officially undefined behavior if they use the key after this,
// so don't worry about cleaning up LOCALS, especially since we can't
Expand All @@ -71,11 +81,11 @@ pub unsafe extern "C" fn pthread_key_delete(key: libc::pthread_key_t) -> libc::c

#[no_mangle]
pub unsafe extern "C" fn pthread_getspecific(key: libc::pthread_key_t) -> *mut libc::c_void {
if let Some(&value) = LOCALS.get(&(key as Key)) {
value as _
if let Some(&value) = LOCALS.get(&key) {
value
} else {
// Note: we don't care if the key is invalid, we still return null
std::ptr::null_mut()
ptr::null_mut()
}
}

Expand All @@ -84,12 +94,11 @@ pub unsafe extern "C" fn pthread_setspecific(
key: libc::pthread_key_t,
value: *const libc::c_void,
) -> libc::c_int {
let key = key as Key;

if !is_valid_key(key) {
return libc::EINVAL;
}

LOCALS.insert(key, value as *mut _);
LOCALS.insert(key, value.cast_mut());

0
}