diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index f9b5f278..c7a48ac6 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -9,6 +9,9 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; +use tokio::select; +use tokio::sync::Notify; +use tokio::time::interval; pub struct CachedCounterValue { value: AtomicExpiringValue, @@ -18,12 +21,16 @@ pub struct CachedCounterValue { pub struct Batcher { updates: Mutex>>, + notifier: Notify, + interval: Duration, } impl Batcher { - fn new() -> Self { + fn new(period: Duration) -> Self { Self { updates: Mutex::new(Default::default()), + notifier: Default::default(), + interval: period, } } @@ -31,6 +38,21 @@ impl Batcher { self.updates.lock().unwrap().is_empty() } + pub async fn consume(&self, min: usize) -> HashMap> { + let mut interval = interval(self.interval); + let mut ready = self.updates.lock().unwrap().len() >= min; + loop { + if ready { + return self.consume_all(); + } else { + ready = select! { + _ = self.notifier.notified() => self.updates.lock().unwrap().len() >= min, + _ = interval.tick() => true, + } + } + } + } + pub fn consume_all(&self) -> HashMap> { let mut batch = self.updates.lock().unwrap(); std::mem::take(&mut *batch) @@ -38,12 +60,13 @@ impl Batcher { pub fn add(&self, counter: Counter, value: Arc) { self.updates.lock().unwrap().entry(counter).or_insert(value); + self.notifier.notify_one(); } } impl Default for Batcher { fn default() -> Self { - Self::new() + Self::new(Duration::from_millis(100)) } } @@ -164,12 +187,12 @@ impl CountersCacheBuilder { self } - pub fn build(&self) -> CountersCache { + pub fn build(&self, period: Duration) -> CountersCache { CountersCache { max_ttl_cached_counters: self.max_ttl_cached_counters, ttl_ratio_cached_counters: self.ttl_ratio_cached_counters, cache: Cache::new(self.max_cached_counters as u64), - batcher: Default::default(), + batcher: Batcher::new(period), } } } @@ -294,7 +317,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(10), @@ -321,7 +344,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); assert!(cache.get(&counter).is_none()); } @@ -343,7 +366,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(current_value), @@ -374,7 +397,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), None, @@ -403,7 +426,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(current_val), diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index a7aba7cc..634c36ad 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -211,7 +211,7 @@ impl CachedRedisStorage { .max_cached_counters(max_cached_counters) .max_ttl_cached_counter(ttl_cached_counters) .ttl_ratio_cached_counter(ttl_ratio_cached_counters) - .build(); + .build(flushing_period); let counters_cache = Arc::new(cached_counters); let partitioned = Arc::new(AtomicBool::new(false)); @@ -422,7 +422,7 @@ async fn flush_batcher_and_update_counters( flip_partitioned(&partitioned, false); } } else { - let counters = cached_counters.batcher().consume_all(); + let counters = cached_counters.batcher().consume(1).await; let time_start_update_counters = Instant::now(); @@ -560,7 +560,7 @@ mod tests { Ok(mock_response.clone()), )]); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::from_millis(1)); cache.batcher().add( counter.clone(), Arc::new(CachedCounterValue::from(