diff --git a/src/cached.rs b/src/cached.rs index 96d9795..6f36d5c 100644 --- a/src/cached.rs +++ b/src/cached.rs @@ -1,5 +1,5 @@ -use crate::runtime::{yield_now, Arc, Mutex}; -use crate::BatchFn; +use crate::runtime::{Arc, Mutex}; +use crate::{yield_fn, BatchFn, WaitForWorkFn}; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::hash::{BuildHasher, Hash}; @@ -72,7 +72,7 @@ where { state: Arc>>, load_fn: Arc>, - yield_count: usize, + wait_for_work_fn: Arc, max_batch_size: usize, } @@ -88,7 +88,7 @@ where state: self.state.clone(), max_batch_size: self.max_batch_size, load_fn: self.load_fn.clone(), - yield_count: self.yield_count, + wait_for_work_fn: self.wait_for_work_fn.clone(), } } } @@ -117,7 +117,7 @@ where state: Arc::new(Mutex::new(State::with_cache(cache))), load_fn: Arc::new(Mutex::new(load_fn)), max_batch_size: 200, - yield_count: 10, + wait_for_work_fn: Arc::new(yield_fn(10)), } } @@ -127,10 +127,17 @@ where } pub fn with_yield_count(mut self, yield_count: usize) -> Self { - self.yield_count = yield_count; + self.wait_for_work_fn = Arc::new(yield_fn(yield_count)); self } + /// Replaces the yielding for work behavior with an arbitrary future. Rather than yielding + /// the runtime repeatedly this will generate and `.await` a future of your choice. + /// ***This is incompatible with*** [`Self::with_yield_count()`]. + pub fn with_custom_wait_for_work(mut self, wait_for_work_fn: impl WaitForWorkFn) { + self.wait_for_work_fn = Arc::new(wait_for_work_fn); + } + pub fn max_batch_size(&self) -> usize { self.max_batch_size } @@ -141,7 +148,7 @@ where return Ok((*v).clone()); } - if state.pending.get(&key).is_none() { + if !state.pending.contains(&key) { state.pending.insert(key.clone()); if state.pending.len() >= self.max_batch_size { let keys = state.pending.drain().collect::>(); @@ -159,12 +166,7 @@ where } drop(state); - // yield for other load to append request - let mut i = 0; - while i < self.yield_count { - yield_now().await; - i += 1; - } + (self.wait_for_work_fn)().await; let mut state = self.state.lock().await; if let Some(v) = state.completed.get(&key) { @@ -200,7 +202,7 @@ where ret.insert(key, v); continue; } - if state.pending.get(&key).is_none() { + if !state.pending.contains(&key) { state.pending.insert(key.clone()); if state.pending.len() >= self.max_batch_size { let keys = state.pending.drain().collect::>(); @@ -216,12 +218,7 @@ where } drop(state); - // yield for other load to append request - let mut i = 0; - while i < self.yield_count { - yield_now().await; - i += 1; - } + (self.wait_for_work_fn)().await; if !rest.is_empty() { let mut state = self.state.lock().await; diff --git a/src/lib.rs b/src/lib.rs index e27423c..63a784d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,3 +4,27 @@ pub mod non_cached; mod runtime; pub use batch_fn::BatchFn; + +use std::{future::Future, pin::Pin}; + +/// A trait alias. Read as "a function which returns a pinned box containing a future" +pub trait WaitForWorkFn: + Fn() -> Pin + Send + Sync>> + Send + Sync + 'static +{ +} + +impl WaitForWorkFn for T where + T: Fn() -> Pin + Send + Sync>> + Send + Sync + 'static +{ +} + +pub(crate) fn yield_fn(count: usize) -> impl WaitForWorkFn { + move || { + Box::pin(async move { + // yield for other load to append request + for _ in 0..count { + runtime::yield_now().await; + } + }) + } +} diff --git a/src/non_cached.rs b/src/non_cached.rs index 6d332ca..c8fda33 100644 --- a/src/non_cached.rs +++ b/src/non_cached.rs @@ -1,5 +1,5 @@ -use crate::runtime::{yield_now, Arc, Mutex}; -use crate::BatchFn; +use crate::runtime::{Arc, Mutex}; +use crate::{yield_fn, BatchFn, WaitForWorkFn}; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::hash::Hash; @@ -37,7 +37,7 @@ where { state: Arc>>, load_fn: Arc>, - yield_count: usize, + wait_for_work_fn: Arc, max_batch_size: usize, } @@ -52,7 +52,7 @@ where state: self.state.clone(), load_fn: self.load_fn.clone(), max_batch_size: self.max_batch_size, - yield_count: self.yield_count, + wait_for_work_fn: self.wait_for_work_fn.clone(), } } } @@ -68,7 +68,7 @@ where state: Arc::new(Mutex::new(State::new())), load_fn: Arc::new(Mutex::new(load_fn)), max_batch_size: 200, - yield_count: 10, + wait_for_work_fn: Arc::new(yield_fn(10)), } } @@ -78,10 +78,17 @@ where } pub fn with_yield_count(mut self, yield_count: usize) -> Self { - self.yield_count = yield_count; + self.wait_for_work_fn = Arc::new(yield_fn(yield_count)); self } + /// Replaces the yielding for work behavior with an arbitrary future. Rather than yielding + /// the runtime repeatedly this will generate and `.await` a future of your choice. + /// ***This is incompatible with*** [`Self::with_yield_count()`]. + pub fn with_custom_wait_for_work(mut self, wait_for_work_fn: impl WaitForWorkFn) { + self.wait_for_work_fn = Arc::new(wait_for_work_fn); + } + pub fn max_batch_size(&self) -> usize { self.max_batch_size } @@ -122,16 +129,11 @@ where } drop(state); - // yield for other load to append request - let mut i = 0; - while i < self.yield_count { - yield_now().await; - i += 1; - } + (self.wait_for_work_fn)().await; let mut state = self.state.lock().await; - if state.completed.get(&request_id).is_none() { + if !state.completed.contains_key(&request_id) { let batch = state.pending.drain().collect::>(); if !batch.is_empty() { let keys: Vec = batch @@ -208,12 +210,7 @@ where drop(state); - // yield for other load to append request - let mut i = 0; - while i < self.yield_count { - yield_now().await; - i += 1; - } + (self.wait_for_work_fn)().await; let mut state = self.state.lock().await;