diff --git a/rayon-core/Cargo.toml b/rayon-core/Cargo.toml index 09cfbafbb..e17d7e422 100644 --- a/rayon-core/Cargo.toml +++ b/rayon-core/Cargo.toml @@ -19,10 +19,12 @@ num_cpus = "1.2" lazy_static = "1" crossbeam-deque = "0.6.3" crossbeam-queue = "0.1.2" +crossbeam-utils = "0.6.5" [dev-dependencies] rand = "0.6" rand_xorshift = "0.1" +scoped-tls = "1.0" [target.'cfg(unix)'.dev-dependencies] libc = "0.2" @@ -49,3 +51,7 @@ path = "tests/scope_join.rs" [[test]] name = "simple_panic" path = "tests/simple_panic.rs" + +[[test]] +name = "scoped_threadpool" +path = "tests/scoped_threadpool.rs" diff --git a/rayon-core/src/lib.rs b/rayon-core/src/lib.rs index f1b7fda6c..0eb8026a6 100644 --- a/rayon-core/src/lib.rs +++ b/rayon-core/src/lib.rs @@ -34,6 +34,7 @@ use std::str::FromStr; extern crate crossbeam_deque; extern crate crossbeam_queue; +extern crate crossbeam_utils; #[cfg(any(debug_assertions, rayon_unstable))] #[macro_use] extern crate lazy_static; @@ -46,6 +47,8 @@ extern crate rand_xorshift; #[macro_use] mod log; +#[macro_use] +mod private; mod job; mod join; @@ -64,6 +67,7 @@ mod test; #[cfg(rayon_unstable)] pub mod internal; pub use join::{join, join_context}; +pub use registry::ThreadBuilder; pub use scope::{scope, Scope}; pub use scope::{scope_fifo, ScopeFifo}; pub use spawn::{spawn, spawn_fifo}; @@ -71,6 +75,8 @@ pub use thread_pool::current_thread_has_pending_tasks; pub use thread_pool::current_thread_index; pub use thread_pool::ThreadPool; +use registry::{CustomSpawn, DefaultSpawn, ThreadSpawn}; + /// Returns the number of threads in the current registry. If this /// code is executing within a Rayon thread-pool, then this will be /// the number of threads for the thread-pool of the current @@ -123,8 +129,7 @@ enum ErrorKind { /// /// [`ThreadPool`]: struct.ThreadPool.html /// [`build_global()`]: struct.ThreadPoolBuilder.html#method.build_global -#[derive(Default)] -pub struct ThreadPoolBuilder { +pub struct ThreadPoolBuilder { /// The number of threads in the rayon thread pool. /// If zero will use the RAYON_NUM_THREADS environment variable. /// If RAYON_NUM_THREADS is invalid or zero will use the default. @@ -146,6 +151,9 @@ pub struct ThreadPoolBuilder { /// Closure invoked on worker thread exit. exit_handler: Option>, + /// Closure invoked to spawn threads. + spawn_handler: S, + /// If false, worker threads will execute spawned jobs in a /// "depth-first" fashion. If true, they will do a "breadth-first" /// fashion. Depth-first is the default. @@ -174,12 +182,35 @@ type StartHandler = Fn(usize) + Send + Sync; /// Note that this same closure may be invoked multiple times in parallel. type ExitHandler = Fn(usize) + Send + Sync; +// NB: We can't `#[derive(Default)]` because `S` is left ambiguous. +impl Default for ThreadPoolBuilder { + fn default() -> Self { + ThreadPoolBuilder { + num_threads: 0, + panic_handler: None, + get_thread_name: None, + stack_size: None, + start_handler: None, + exit_handler: None, + spawn_handler: DefaultSpawn, + breadth_first: false, + } + } +} + impl ThreadPoolBuilder { /// Creates and returns a valid rayon thread pool builder, but does not initialize it. - pub fn new() -> ThreadPoolBuilder { - ThreadPoolBuilder::default() + pub fn new() -> Self { + Self::default() } +} +/// Note: the `S: ThreadSpawn` constraint is an internal implementation detail for the +/// default spawn and those set by [`spawn_handler`](#method.spawn_handler). +impl ThreadPoolBuilder +where + S: ThreadSpawn, +{ /// Create a new `ThreadPool` initialized using this configuration. pub fn build(self) -> Result { ThreadPool::build(self) @@ -207,6 +238,154 @@ impl ThreadPoolBuilder { registry.wait_until_primed(); Ok(()) } +} + +impl ThreadPoolBuilder { + /// Create a scoped `ThreadPool` initialized using this configuration. + /// + /// This is a convenience function for building a pool using [`crossbeam::scope`] + /// to spawn threads in a [`spawn_handler`](#method.spawn_handler). + /// The threads in this pool will start by calling `wrapper`, which should + /// do initialization and continue by calling `ThreadBuilder::run()`. + /// + /// [`crossbeam::scope`]: https://docs.rs/crossbeam/0.7/crossbeam/fn.scope.html + /// + /// # Examples + /// + /// A scoped pool may be useful in combination with scoped thread-local variables. + /// + /// ``` + /// #[macro_use] + /// extern crate scoped_tls; + /// # use rayon_core as rayon; + /// + /// scoped_thread_local!(static POOL_DATA: Vec); + /// + /// fn main() -> Result<(), rayon::ThreadPoolBuildError> { + /// let pool_data = vec![1, 2, 3]; + /// + /// // We haven't assigned any TLS data yet. + /// assert!(!POOL_DATA.is_set()); + /// + /// rayon::ThreadPoolBuilder::new() + /// .build_scoped( + /// // Borrow `pool_data` in TLS for each thread. + /// |thread| POOL_DATA.set(&pool_data, || thread.run()), + /// // Do some work that needs the TLS data. + /// |pool| pool.install(|| assert!(POOL_DATA.is_set())), + /// )?; + /// + /// // Once we've returned, `pool_data` is no longer borrowed. + /// drop(pool_data); + /// Ok(()) + /// } + /// ``` + pub fn build_scoped(self, wrapper: W, with_pool: F) -> Result + where + W: Fn(ThreadBuilder) + Sync, // expected to call `run()` + F: FnOnce(&ThreadPool) -> R, + { + let result = crossbeam_utils::thread::scope(|scope| { + let wrapper = &wrapper; + let pool = self + .spawn_handler(|thread| { + let mut builder = scope.builder(); + if let Some(name) = thread.name() { + builder = builder.name(name.to_string()); + } + if let Some(size) = thread.stack_size() { + builder = builder.stack_size(size); + } + builder.spawn(move |_| wrapper(thread))?; + Ok(()) + }) + .build()?; + Ok(with_pool(&pool)) + }); + + match result { + Ok(result) => result, + Err(err) => unwind::resume_unwinding(err), + } + } +} + +impl ThreadPoolBuilder { + /// Set a custom function for spawning threads. + /// + /// Note that the threads will not exit until after the pool is dropped. It + /// is up to the caller to wait for thread termination if that is important + /// for any invariants. For instance, threads created in [`crossbeam::scope`] + /// will be joined before that scope returns, and this will block indefinitely + /// if the pool is leaked. Furthermore, the global thread pool doesn't terminate + /// until the entire process exits! + /// + /// [`crossbeam::scope`]: https://docs.rs/crossbeam/0.7/crossbeam/fn.scope.html + /// + /// # Examples + /// + /// A minimal spawn handler just needs to call `run()` from an independent thread. + /// + /// ``` + /// # use rayon_core as rayon; + /// fn main() -> Result<(), rayon::ThreadPoolBuildError> { + /// let pool = rayon::ThreadPoolBuilder::new() + /// .spawn_handler(|thread| { + /// std::thread::spawn(|| thread.run()); + /// Ok(()) + /// }) + /// .build()?; + /// + /// pool.install(|| println!("Hello from my custom thread!")); + /// Ok(()) + /// } + /// ``` + /// + /// The default spawn handler sets the name and stack size if given, and propagates + /// any errors from the thread builder. + /// + /// ``` + /// # use rayon_core as rayon; + /// fn main() -> Result<(), rayon::ThreadPoolBuildError> { + /// let pool = rayon::ThreadPoolBuilder::new() + /// .spawn_handler(|thread| { + /// let mut b = std::thread::Builder::new(); + /// if let Some(name) = thread.name() { + /// b = b.name(name.to_owned()); + /// } + /// if let Some(stack_size) = thread.stack_size() { + /// b = b.stack_size(stack_size); + /// } + /// b.spawn(|| thread.run())?; + /// Ok(()) + /// }) + /// .build()?; + /// + /// pool.install(|| println!("Hello from my fully custom thread!")); + /// Ok(()) + /// } + /// ``` + pub fn spawn_handler(self, spawn: F) -> ThreadPoolBuilder> + where + F: FnMut(ThreadBuilder) -> io::Result<()>, + { + ThreadPoolBuilder { + spawn_handler: CustomSpawn::new(spawn), + // ..self + num_threads: self.num_threads, + panic_handler: self.panic_handler, + get_thread_name: self.get_thread_name, + stack_size: self.stack_size, + start_handler: self.start_handler, + exit_handler: self.exit_handler, + breadth_first: self.breadth_first, + } + } + + /// Returns a reference to the current spawn handler. + fn get_spawn_handler(&mut self) -> &mut S { + &mut self.spawn_handler + } /// Get the number of threads that will be used for the thread /// pool. See `num_threads()` for more information. @@ -276,7 +455,7 @@ impl ThreadPoolBuilder { /// replacement of the now deprecated `RAYON_RS_NUM_CPUS` environment /// variable. If both variables are specified, `RAYON_NUM_THREADS` will /// be prefered. - pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolBuilder { + pub fn num_threads(mut self, num_threads: usize) -> Self { self.num_threads = num_threads; self } @@ -300,7 +479,7 @@ impl ThreadPoolBuilder { /// If the panic handler itself panics, this will abort the /// process. To prevent this, wrap the body of your panic handler /// in a call to `std::panic::catch_unwind()`. - pub fn panic_handler(mut self, panic_handler: H) -> ThreadPoolBuilder + pub fn panic_handler(mut self, panic_handler: H) -> Self where H: Fn(Box) + Send + Sync + 'static, { @@ -368,7 +547,7 @@ impl ThreadPoolBuilder { /// Note that this same closure may be invoked multiple times in parallel. /// If this closure panics, the panic will be passed to the panic handler. /// If that handler returns, then startup will continue normally. - pub fn start_handler(mut self, start_handler: H) -> ThreadPoolBuilder + pub fn start_handler(mut self, start_handler: H) -> Self where H: Fn(usize) + Send + Sync + 'static, { @@ -387,7 +566,7 @@ impl ThreadPoolBuilder { /// Note that this same closure may be invoked multiple times in parallel. /// If this closure panics, the panic will be passed to the panic handler. /// If that handler returns, then the thread will exit normally. - pub fn exit_handler(mut self, exit_handler: H) -> ThreadPoolBuilder + pub fn exit_handler(mut self, exit_handler: H) -> Self where H: Fn(usize) + Send + Sync + 'static, { @@ -503,7 +682,7 @@ pub fn initialize(config: Configuration) -> Result<(), Box> { config.into_builder().build_global().map_err(Box::from) } -impl fmt::Debug for ThreadPoolBuilder { +impl fmt::Debug for ThreadPoolBuilder { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let ThreadPoolBuilder { ref num_threads, @@ -512,6 +691,7 @@ impl fmt::Debug for ThreadPoolBuilder { ref stack_size, ref start_handler, ref exit_handler, + spawn_handler: _, ref breadth_first, } = *self; diff --git a/rayon-core/src/private.rs b/rayon-core/src/private.rs new file mode 100644 index 000000000..5d084ff14 --- /dev/null +++ b/rayon-core/src/private.rs @@ -0,0 +1,26 @@ +//! The public parts of this private module are used to create traits +//! that cannot be implemented outside of our own crate. This way we +//! can feel free to extend those traits without worrying about it +//! being a breaking change for other implementations. + +/// If this type is pub but not publicly reachable, third parties +/// can't name it and can't implement traits using it. +#[allow(missing_debug_implementations)] +pub struct PrivateMarker; + +macro_rules! private_decl { + () => { + /// This trait is private; this method exists to make it + /// impossible to implement outside the crate. + #[doc(hidden)] + fn __rayon_private__(&self) -> ::private::PrivateMarker; + } +} + +macro_rules! private_impl { + () => { + fn __rayon_private__(&self) -> ::private::PrivateMarker { + ::private::PrivateMarker + } + } +} diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index 0691e3395..76567c370 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -11,7 +11,9 @@ use sleep::Sleep; use std::any::Any; use std::cell::Cell; use std::collections::hash_map::DefaultHasher; +use std::fmt; use std::hash::Hasher; +use std::io; use std::mem; use std::ptr; #[allow(deprecated)] @@ -24,6 +26,113 @@ use unwind; use util::leak; use {ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder}; +/// Thread builder used for customization via +/// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler). +pub struct ThreadBuilder { + name: Option, + stack_size: Option, + worker: Worker, + registry: Arc, + index: usize, +} + +impl ThreadBuilder { + /// Get the index of this thread in the pool, within `0..num_threads`. + pub fn index(&self) -> usize { + self.index + } + + /// Get the string that was specified by `ThreadPoolBuilder::name()`. + pub fn name(&self) -> Option<&str> { + self.name.as_ref().map(String::as_str) + } + + /// Get the value that was specified by `ThreadPoolBuilder::stack_size()`. + pub fn stack_size(&self) -> Option { + self.stack_size + } + + /// Execute the main loop for this thread. This will not return until the + /// thread pool is dropped. + pub fn run(self) { + unsafe { main_loop(self.worker, self.registry, self.index) } + } +} + +impl fmt::Debug for ThreadBuilder { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("ThreadBuilder") + .field("pool", &self.registry.id()) + .field("index", &self.index) + .field("name", &self.name) + .field("stack_size", &self.stack_size) + .finish() + } +} + +/// Generalized trait for spawning a thread in the `Registry`. +/// +/// This trait is pub-in-private -- E0445 forces us to make it public, +/// but we don't actually want to expose these details in the API. +pub trait ThreadSpawn { + private_decl! {} + + /// Spawn a thread with the `ThreadBuilder` parameters, and then + /// call `ThreadBuilder::run()`. + fn spawn(&mut self, ThreadBuilder) -> io::Result<()>; +} + +/// Spawns a thread in the "normal" way with `std::thread::Builder`. +/// +/// This type is pub-in-private -- E0445 forces us to make it public, +/// but we don't actually want to expose these details in the API. +#[derive(Debug, Default)] +pub struct DefaultSpawn; + +impl ThreadSpawn for DefaultSpawn { + private_impl! {} + + fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> { + let mut b = thread::Builder::new(); + if let Some(name) = thread.name() { + b = b.name(name.to_owned()); + } + if let Some(stack_size) = thread.stack_size() { + b = b.stack_size(stack_size); + } + b.spawn(|| thread.run())?; + Ok(()) + } +} + +/// Spawns a thread with a user's custom callback. +/// +/// This type is pub-in-private -- E0445 forces us to make it public, +/// but we don't actually want to expose these details in the API. +#[derive(Debug)] +pub struct CustomSpawn(F); + +impl CustomSpawn +where + F: FnMut(ThreadBuilder) -> io::Result<()>, +{ + pub(super) fn new(spawn: F) -> Self { + CustomSpawn(spawn) + } +} + +impl ThreadSpawn for CustomSpawn +where + F: FnMut(ThreadBuilder) -> io::Result<()>, +{ + private_impl! {} + + #[inline] + fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> { + (self.0)(thread) + } +} + pub(super) struct Registry { thread_infos: Vec, sleep: Sleep, @@ -58,39 +167,41 @@ static THE_REGISTRY_SET: Once = ONCE_INIT; /// initialization has not already occurred, use the default /// configuration. fn global_registry() -> &'static Arc { - THE_REGISTRY_SET.call_once(|| unsafe { init_registry(ThreadPoolBuilder::new()).unwrap() }); - unsafe { THE_REGISTRY.expect("The global thread pool has not been initialized.") } + set_global_registry(|| Registry::new(ThreadPoolBuilder::new())) + .or_else(|err| unsafe { THE_REGISTRY.ok_or(err) }) + .expect("The global thread pool has not been initialized.") } /// Starts the worker threads (if that has not already happened) with /// the given builder. -pub(super) fn init_global_registry( - builder: ThreadPoolBuilder, -) -> Result<&'static Registry, ThreadPoolBuildError> { - let mut called = false; - let mut init_result = Ok(());; - THE_REGISTRY_SET.call_once(|| unsafe { - init_result = init_registry(builder); - called = true; - }); - if called { - init_result?; - Ok(&**global_registry()) - } else { - Err(ThreadPoolBuildError::new( - ErrorKind::GlobalPoolAlreadyInitialized, - )) - } +pub(super) fn init_global_registry( + builder: ThreadPoolBuilder, +) -> Result<&'static Arc, ThreadPoolBuildError> +where + S: ThreadSpawn, +{ + set_global_registry(|| Registry::new(builder)) } -/// Initializes the global registry with the given builder. -/// Meant to be called from within the `THE_REGISTRY_SET` once -/// function. Declared `unsafe` because it writes to `THE_REGISTRY` in -/// an unsynchronized fashion. -unsafe fn init_registry(builder: ThreadPoolBuilder) -> Result<(), ThreadPoolBuildError> { - let registry = Registry::new(builder)?; - THE_REGISTRY = Some(leak(registry)); - Ok(()) +/// Starts the worker threads (if that has not already happened) +/// by creating a registry with the given callback. +fn set_global_registry(registry: F) -> Result<&'static Arc, ThreadPoolBuildError> +where + F: FnOnce() -> Result, ThreadPoolBuildError>, +{ + let mut result = Err(ThreadPoolBuildError::new( + ErrorKind::GlobalPoolAlreadyInitialized, + )); + THE_REGISTRY_SET.call_once(|| { + result = registry().map(|registry| { + let registry = leak(registry); + unsafe { + THE_REGISTRY = Some(registry); + } + registry + }); + }); + result } struct Terminator<'a>(&'a Arc); @@ -102,7 +213,12 @@ impl<'a> Drop for Terminator<'a> { } impl Registry { - pub(super) fn new(mut builder: ThreadPoolBuilder) -> Result, ThreadPoolBuildError> { + pub(super) fn new( + mut builder: ThreadPoolBuilder, + ) -> Result, ThreadPoolBuildError> + where + S: ThreadSpawn, + { let n_threads = builder.get_num_threads(); let breadth_first = builder.get_breadth_first(); @@ -130,15 +246,14 @@ impl Registry { let t1000 = Terminator(®istry); for (index, worker) in workers.into_iter().enumerate() { - let registry = registry.clone(); - let mut b = thread::Builder::new(); - if let Some(name) = builder.get_thread_name(index) { - b = b.name(name); - } - if let Some(stack_size) = builder.get_stack_size() { - b = b.stack_size(stack_size); - } - if let Err(e) = b.spawn(move || unsafe { main_loop(worker, registry, index) }) { + let thread = ThreadBuilder { + name: builder.get_thread_name(index), + stack_size: builder.get_stack_size(), + registry: registry.clone(), + worker, + index, + }; + if let Err(e) = builder.get_spawn_handler().spawn(thread) { return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e))); } } @@ -498,6 +613,16 @@ thread_local! { static WORKER_THREAD_STATE: Cell<*const WorkerThread> = Cell::new(ptr::null()); } +impl Drop for WorkerThread { + fn drop(&mut self) { + // Undo `set_current` + WORKER_THREAD_STATE.with(|t| { + assert!(t.get().eq(&(self as *const _))); + t.set(ptr::null()); + }); + } +} + impl WorkerThread { /// Gets the `WorkerThread` index for the current thread; returns /// NULL if this is not a worker thread. This pointer is valid @@ -656,14 +781,14 @@ impl WorkerThread { /// //////////////////////////////////////////////////////////////////////// unsafe fn main_loop(worker: Worker, registry: Arc, index: usize) { - let worker_thread = WorkerThread { + let worker_thread = &WorkerThread { worker, fifo: JobFifo::new(), index, rng: XorShift64Star::new(), registry: registry.clone(), }; - WorkerThread::set_current(&worker_thread); + WorkerThread::set_current(worker_thread); // let registry know we are ready to do work registry.thread_infos[index].primed.set(); diff --git a/rayon-core/src/test.rs b/rayon-core/src/test.rs index ab15b9abd..2ba27e0e8 100644 --- a/rayon-core/src/test.rs +++ b/rayon-core/src/test.rs @@ -155,3 +155,41 @@ fn configuration() { .build() .unwrap(); } + +#[test] +fn default_pool() { + ThreadPoolBuilder::default().build().unwrap(); +} + +/// Test that custom spawned threads get their `WorkerThread` cleared once +/// the pool is done with them, allowing them to be used with rayon again +/// later. e.g. WebAssembly want to have their own pool of available threads. +#[test] +fn cleared_current_thread() -> Result<(), ThreadPoolBuildError> { + let n_threads = 5; + let mut handles = vec![]; + let pool = ThreadPoolBuilder::new() + .num_threads(n_threads) + .spawn_handler(|thread| { + let handle = std::thread::spawn(move || { + thread.run(); + + // Afterward, the current thread shouldn't be set anymore. + assert_eq!(crate::current_thread_index(), None); + }); + handles.push(handle); + Ok(()) + }) + .build()?; + assert_eq!(handles.len(), n_threads); + + pool.install(|| assert!(crate::current_thread_index().is_some())); + drop(pool); + + // Wait for all threads to make their assertions and exit + for handle in handles { + handle.join().unwrap(); + } + + Ok(()) +} diff --git a/rayon-core/src/thread_pool/mod.rs b/rayon-core/src/thread_pool/mod.rs index 3834a1dbc..8a1875343 100644 --- a/rayon-core/src/thread_pool/mod.rs +++ b/rayon-core/src/thread_pool/mod.rs @@ -4,7 +4,7 @@ //! [`ThreadPool`]: struct.ThreadPool.html use join; -use registry::{Registry, WorkerThread}; +use registry::{Registry, ThreadSpawn, WorkerThread}; use spawn; use std::error::Error; use std::fmt; @@ -61,7 +61,12 @@ impl ThreadPool { Self::build(configuration.into_builder()).map_err(Box::from) } - pub(super) fn build(builder: ThreadPoolBuilder) -> Result { + pub(super) fn build( + builder: ThreadPoolBuilder, + ) -> Result + where + S: ThreadSpawn, + { let registry = Registry::new(builder)?; Ok(ThreadPool { registry }) } diff --git a/rayon-core/tests/scoped_threadpool.rs b/rayon-core/tests/scoped_threadpool.rs new file mode 100644 index 000000000..f1abdea2e --- /dev/null +++ b/rayon-core/tests/scoped_threadpool.rs @@ -0,0 +1,102 @@ +extern crate crossbeam_utils; +extern crate rayon_core; + +#[macro_use] +extern crate scoped_tls; + +use crossbeam_utils::thread; +use rayon_core::ThreadPoolBuilder; + +#[derive(PartialEq, Eq, Debug)] +struct Local(i32); + +scoped_thread_local!(static LOCAL: Local); + +#[test] +fn missing_scoped_tls() { + LOCAL.set(&Local(42), || { + let pool = ThreadPoolBuilder::new() + .build() + .expect("thread pool created"); + + // `LOCAL` is not set in the pool. + pool.install(|| { + assert!(!LOCAL.is_set()); + }); + }); +} + +#[test] +fn spawn_scoped_tls_threadpool() { + LOCAL.set(&Local(42), || { + LOCAL.with(|x| { + thread::scope(|scope| { + let pool = ThreadPoolBuilder::new() + .spawn_handler(move |thread| { + scope + .builder() + .spawn(move |_| { + // Borrow the same local value in the thread pool. + LOCAL.set(x, || thread.run()) + }) + .map(|_| ()) + }) + .build() + .expect("thread pool created"); + + // The pool matches our local value. + pool.install(|| { + assert!(LOCAL.is_set()); + LOCAL.with(|y| { + assert_eq!(x, y); + }); + }); + + // If we change our local value, the pool is not affected. + LOCAL.set(&Local(-1), || { + pool.install(|| { + assert!(LOCAL.is_set()); + LOCAL.with(|y| { + assert_eq!(x, y); + }); + }); + }); + }) + .expect("scope threads ok"); + // `thread::scope` will wait for the threads to exit before returning. + }); + }); +} + +#[test] +fn build_scoped_tls_threadpool() { + LOCAL.set(&Local(42), || { + LOCAL.with(|x| { + ThreadPoolBuilder::new() + .build_scoped( + move |thread| LOCAL.set(x, || thread.run()), + |pool| { + // The pool matches our local value. + pool.install(|| { + assert!(LOCAL.is_set()); + LOCAL.with(|y| { + assert_eq!(x, y); + }); + }); + + // If we change our local value, the pool is not affected. + LOCAL.set(&Local(-1), || { + pool.install(|| { + assert!(LOCAL.is_set()); + LOCAL.with(|y| { + assert_eq!(x, y); + }); + }); + }); + }, + ) + .expect("thread pool created"); + // Internally, `crossbeam::scope` will wait for the threads to exit before returning. + }); + }); +} diff --git a/src/lib.rs b/src/lib.rs index 0346b728c..eefb9d6ba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,6 +118,7 @@ mod par_either; mod compile_fail; pub use rayon_core::FnContext; +pub use rayon_core::ThreadBuilder; pub use rayon_core::ThreadPool; pub use rayon_core::ThreadPoolBuildError; pub use rayon_core::ThreadPoolBuilder;