From 8c48bff3d28311ec19b46c0990b8dd7ba4597ea5 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Thu, 25 Apr 2024 08:50:26 -0400 Subject: [PATCH 01/10] Support pending writes within CachedCounterValue --- limitador/src/storage/redis/counters_cache.rs | 45 ++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 2f2b22ca..a367a575 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -5,11 +5,13 @@ use crate::storage::redis::{ DEFAULT_TTL_RATIO_CACHED_COUNTERS, }; use moka::sync::Cache; +use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; pub struct CachedCounterValue { value: AtomicExpiringValue, + initial_value: AtomicI64, expiry: AtomicExpiryTime, } @@ -24,6 +26,7 @@ impl CachedCounterValue { let now = SystemTime::now(); Self { value: AtomicExpiringValue::new(value, now + Duration::from_secs(counter.seconds())), + initial_value: AtomicI64::new(value), expiry: AtomicExpiryTime::from_now(ttl), } } @@ -39,8 +42,46 @@ impl CachedCounterValue { } pub fn delta(&self, counter: &Counter, delta: i64) -> i64 { - self.value - .update(delta, counter.seconds(), SystemTime::now()) + let value = self + .value + .update(delta, counter.seconds(), SystemTime::now()); + if value == delta { + // new window, invalidate initial value + self.initial_value.store(0, Ordering::SeqCst); + } + value + } + + #[allow(dead_code)] + pub fn pending_writes(&self) -> Result { + let start = self.initial_value.load(Ordering::SeqCst); + let value = self.value.value_at(SystemTime::now()); + let offset = if start == 0 { + value + } else { + let writes = value - start; + if writes > 0 { + writes + } else { + value + } + }; + match self + .initial_value + .compare_exchange(start, value, Ordering::SeqCst, Ordering::SeqCst) + { + Ok(_) => Ok(offset), + Err(newer) => { + if newer == 0 { + // We got expired in the meantime, this fresh value can wait the next iteration + Ok(0) + } else { + // Concurrent call to this method? + // We could support that with a CAS loop in the future if needed + Err(()) + } + } + } } pub fn hits(&self, _: &Counter) -> i64 { From df3b5569addcd57e1ce095ba7392d98971defc9d Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Thu, 25 Apr 2024 15:07:50 -0400 Subject: [PATCH 02/10] Use the same CachedCounterValue in both batcher and cache --- limitador/src/storage/redis/counters_cache.rs | 30 ++++++--- limitador/src/storage/redis/redis_cached.rs | 63 ++++++++----------- 2 files changed, 50 insertions(+), 43 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index a367a575..7ea1ec10 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -5,8 +5,9 @@ use crate::storage::redis::{ DEFAULT_TTL_RATIO_CACHED_COUNTERS, }; use moka::sync::Cache; +use std::collections::HashMap; use std::sync::atomic::{AtomicI64, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, MutexGuard}; use std::time::{Duration, SystemTime}; pub struct CachedCounterValue { @@ -37,6 +38,7 @@ impl CachedCounterValue { pub fn set_from_authority(&self, counter: &Counter, value: i64, expiry: Duration) { let time_window = Duration::from_secs(counter.seconds()); + self.initial_value.store(value, Ordering::SeqCst); self.value.set(value, time_window); self.expiry.update(expiry); } @@ -52,7 +54,6 @@ impl CachedCounterValue { value } - #[allow(dead_code)] pub fn pending_writes(&self) -> Result { let start = self.initial_value.load(Ordering::SeqCst); let value = self.value.value_at(SystemTime::now()); @@ -178,10 +179,25 @@ impl CountersCache { )) } - pub fn increase_by(&self, counter: &Counter, delta: i64) { - if let Some(val) = self.cache.get(counter) { - val.delta(counter, delta); - }; + pub fn increase_by( + &self, + counter: &Counter, + delta: i64, + batcher: Option<&mut MutexGuard>>>, + ) { + let val = self.cache.get_with_by_ref(counter, || { + Arc::new( + // this TTL is wrong, it needs to be the cache's TTL, not the time window of our limit + // todo fix when introducing the Batcher type! + CachedCounterValue::from(counter, 0, Duration::from_secs(counter.seconds())), + ) + }); + val.delta(counter, delta); + if let Some(batcher) = batcher { + if batcher.get_mut(counter).is_none() { + batcher.insert(counter.clone(), val.clone()); + } + } } fn ttl_from_redis_ttl( @@ -367,7 +383,7 @@ mod tests { Duration::from_secs(0), SystemTime::now(), ); - cache.increase_by(&counter, increase_by); + cache.increase_by(&counter, increase_by, None); assert_eq!( cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 3f9e3da6..364fb7b8 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -1,8 +1,9 @@ use crate::counter::Counter; use crate::limit::Limit; -use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::keys::*; -use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; +use crate::storage::redis::counters_cache::{ + CachedCounterValue, CountersCache, CountersCacheBuilder, +}; use crate::storage::redis::redis_async::AsyncRedisStorage; use crate::storage::redis::scripts::{BATCH_UPDATE_COUNTERS, VALUES_AND_TTLS}; use crate::storage::redis::{ @@ -40,7 +41,7 @@ use tracing::{debug_span, error, warn, Instrument}; pub struct CachedRedisStorage { cached_counters: Arc, - batcher_counter_updates: Arc>>, + batcher_counter_updates: Arc>>>, async_redis_storage: AsyncRedisStorage, redis_conn_manager: ConnectionManager, partitioned: Arc, @@ -150,28 +151,10 @@ impl AsyncCounterStorage for CachedRedisStorage { } // Update cached values - for counter in counters.iter() { - self.cached_counters.increase_by(counter, delta); - } - - // Batch or update depending on configuration let mut batcher = self.batcher_counter_updates.lock().unwrap(); - let now = SystemTime::now(); for counter in counters.iter() { - match batcher.get_mut(counter) { - Some(val) => { - val.update(delta, counter.seconds(), now); - } - None => { - batcher.insert( - counter.clone(), - AtomicExpiringValue::new( - delta, - now + Duration::from_secs(counter.seconds()), - ), - ); - } - } + self.cached_counters + .increase_by(counter, delta, Some(&mut batcher)); } Ok(Authorization::Ok) @@ -237,7 +220,7 @@ impl CachedRedisStorage { let partitioned = Arc::new(AtomicBool::new(false)); let async_redis_storage = AsyncRedisStorage::new_with_conn_manager(redis_conn_manager.clone()); - let batcher: Arc>> = + let batcher: Arc>>> = Arc::new(Mutex::new(Default::default())); { @@ -398,15 +381,14 @@ impl CachedRedisStorageBuilder { async fn update_counters( redis_conn: &mut C, - counters_and_deltas: HashMap, + counters_and_deltas: HashMap>, ) -> Result, StorageErr> { let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); let mut script_invocation = redis_script.prepare_invoke(); let mut res: Vec<(Counter, i64, i64)> = Vec::new(); - let now = SystemTime::now(); for (counter, delta) in counters_and_deltas { - let delta = delta.value_at(now); + let delta = delta.pending_writes().expect("State machine is wrong!"); if delta > 0 { script_invocation.key(key_for_counter(&counter)); script_invocation.key(key_for_counters_of_limit(counter.limit())); @@ -439,7 +421,7 @@ async fn update_counters( async fn flush_batcher_and_update_counters( mut redis_conn: C, - batcher: Arc>>, + batcher: Arc>>>, storage_is_alive: bool, cached_counters: Arc, partitioned: Arc, @@ -487,9 +469,10 @@ async fn flush_batcher_and_update_counters( mod tests { use crate::counter::Counter; use crate::limit::Limit; - use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::keys::{key_for_counter, key_for_counters_of_limit}; - use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; + use crate::storage::redis::counters_cache::{ + CachedCounterValue, CountersCache, CountersCacheBuilder, + }; use crate::storage::redis::redis_cached::{flush_batcher_and_update_counters, update_counters}; use crate::storage::redis::CachedRedisStorage; use redis::{ErrorKind, Value}; @@ -529,10 +512,14 @@ mod tests { Default::default(), ); - let expiring_value = - AtomicExpiringValue::new(1, SystemTime::now() + Duration::from_secs(60)); - - counters_and_deltas.insert(counter.clone(), expiring_value); + counters_and_deltas.insert( + counter.clone(), + Arc::new(CachedCounterValue::from( + &counter, + 1, + Duration::from_secs(60), + )), + ); let mock_response = Value::Bulk(vec![Value::Int(10), Value::Int(60)]); @@ -589,10 +576,14 @@ mod tests { let mut batched_counters = HashMap::new(); batched_counters.insert( counter.clone(), - AtomicExpiringValue::new(2, SystemTime::now() + Duration::from_secs(60)), + Arc::new(CachedCounterValue::from( + &counter, + 2, + Duration::from_secs(60), + )), ); - let batcher: Arc>> = + let batcher: Arc>>> = Arc::new(Mutex::new(batched_counters)); let cache = CountersCacheBuilder::new().build(); cache.insert( From f3a7f077cb0129b52902549485fab15d2c38231c Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Thu, 25 Apr 2024 15:27:43 -0400 Subject: [PATCH 03/10] Added Batcher type back --- limitador/src/storage/redis/counters_cache.rs | 54 ++++++++++++++----- limitador/src/storage/redis/redis_cached.rs | 41 ++++---------- 2 files changed, 50 insertions(+), 45 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 7ea1ec10..3074fe01 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -7,7 +7,7 @@ use crate::storage::redis::{ use moka::sync::Cache; use std::collections::HashMap; use std::sync::atomic::{AtomicI64, Ordering}; -use std::sync::{Arc, MutexGuard}; +use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; pub struct CachedCounterValue { @@ -16,10 +16,42 @@ pub struct CachedCounterValue { expiry: AtomicExpiryTime, } +pub struct Batcher { + updates: Mutex>>, +} + +impl Batcher { + fn new() -> Self { + Self { + updates: Mutex::new(Default::default()), + } + } + + pub fn is_empty(&self) -> bool { + self.updates.lock().unwrap().is_empty() + } + + pub fn consume_all(&self) -> HashMap> { + let mut batch = self.updates.lock().unwrap(); + std::mem::take(&mut *batch) + } + + pub fn add(&self, counter: Counter, value: Arc) { + self.updates.lock().unwrap().entry(counter).or_insert(value); + } +} + +impl Default for Batcher { + fn default() -> Self { + Self::new() + } +} + pub struct CountersCache { max_ttl_cached_counters: Duration, pub ttl_ratio_cached_counters: u64, cache: Cache>, + batcher: Batcher, } impl CachedCounterValue { @@ -137,6 +169,7 @@ impl CountersCacheBuilder { max_ttl_cached_counters: self.max_ttl_cached_counters, ttl_ratio_cached_counters: self.ttl_ratio_cached_counters, cache: Cache::new(self.max_cached_counters as u64), + batcher: Default::default(), } } } @@ -146,6 +179,10 @@ impl CountersCache { self.cache.get(counter) } + pub fn batcher(&self) -> &Batcher { + &self.batcher + } + pub fn insert( &self, counter: Counter, @@ -179,12 +216,7 @@ impl CountersCache { )) } - pub fn increase_by( - &self, - counter: &Counter, - delta: i64, - batcher: Option<&mut MutexGuard>>>, - ) { + pub fn increase_by(&self, counter: &Counter, delta: i64) { let val = self.cache.get_with_by_ref(counter, || { Arc::new( // this TTL is wrong, it needs to be the cache's TTL, not the time window of our limit @@ -193,11 +225,7 @@ impl CountersCache { ) }); val.delta(counter, delta); - if let Some(batcher) = batcher { - if batcher.get_mut(counter).is_none() { - batcher.insert(counter.clone(), val.clone()); - } - } + self.batcher.add(counter.clone(), val.clone()); } fn ttl_from_redis_ttl( @@ -383,7 +411,7 @@ mod tests { Duration::from_secs(0), SystemTime::now(), ); - cache.increase_by(&counter, increase_by, None); + cache.increase_by(&counter, increase_by); assert_eq!( cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 364fb7b8..a7aba7cc 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -17,7 +17,7 @@ use redis::{ConnectionInfo, RedisError}; use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; use tracing::{debug_span, error, warn, Instrument}; @@ -41,7 +41,6 @@ use tracing::{debug_span, error, warn, Instrument}; pub struct CachedRedisStorage { cached_counters: Arc, - batcher_counter_updates: Arc>>>, async_redis_storage: AsyncRedisStorage, redis_conn_manager: ConnectionManager, partitioned: Arc, @@ -151,10 +150,8 @@ impl AsyncCounterStorage for CachedRedisStorage { } // Update cached values - let mut batcher = self.batcher_counter_updates.lock().unwrap(); for counter in counters.iter() { - self.cached_counters - .increase_by(counter, delta, Some(&mut batcher)); + self.cached_counters.increase_by(counter, delta); } Ok(Authorization::Ok) @@ -220,21 +217,17 @@ impl CachedRedisStorage { let partitioned = Arc::new(AtomicBool::new(false)); let async_redis_storage = AsyncRedisStorage::new_with_conn_manager(redis_conn_manager.clone()); - let batcher: Arc>>> = - Arc::new(Mutex::new(Default::default())); { let storage = async_redis_storage.clone(); let counters_cache_clone = counters_cache.clone(); let conn = redis_conn_manager.clone(); let p = Arc::clone(&partitioned); - let batcher_flusher = batcher.clone(); let mut interval = tokio::time::interval(flushing_period); tokio::spawn(async move { loop { flush_batcher_and_update_counters( conn.clone(), - batcher_flusher.clone(), storage.is_alive().await, counters_cache_clone.clone(), p.clone(), @@ -247,7 +240,6 @@ impl CachedRedisStorage { Ok(Self { cached_counters: counters_cache, - batcher_counter_updates: batcher, redis_conn_manager, async_redis_storage, partitioned, @@ -421,21 +413,16 @@ async fn update_counters( async fn flush_batcher_and_update_counters( mut redis_conn: C, - batcher: Arc>>>, storage_is_alive: bool, cached_counters: Arc, partitioned: Arc, ) { if partitioned.load(Ordering::Acquire) || !storage_is_alive { - let batch = batcher.lock().unwrap(); - if !batch.is_empty() { + if !cached_counters.batcher().is_empty() { flip_partitioned(&partitioned, false); } } else { - let counters = { - let mut batch = batcher.lock().unwrap(); - std::mem::take(&mut *batch) - }; + let counters = cached_counters.batcher().consume_all(); let time_start_update_counters = Instant::now(); @@ -479,7 +466,7 @@ mod tests { use redis_test::{MockCmd, MockRedisConnection}; use std::collections::HashMap; use std::sync::atomic::AtomicBool; - use std::sync::{Arc, Mutex}; + use std::sync::Arc; use std::time::{Duration, SystemTime}; #[tokio::test] @@ -573,8 +560,8 @@ mod tests { Ok(mock_response.clone()), )]); - let mut batched_counters = HashMap::new(); - batched_counters.insert( + let cache = CountersCacheBuilder::new().build(); + cache.batcher().add( counter.clone(), Arc::new(CachedCounterValue::from( &counter, @@ -582,10 +569,6 @@ mod tests { Duration::from_secs(60), )), ); - - let batcher: Arc>>> = - Arc::new(Mutex::new(batched_counters)); - let cache = CountersCacheBuilder::new().build(); cache.insert( counter.clone(), Some(1), @@ -600,14 +583,8 @@ mod tests { assert_eq!(c.hits(&counter), 1); } - flush_batcher_and_update_counters( - mock_client, - batcher, - true, - cached_counters.clone(), - partitioned, - ) - .await; + flush_batcher_and_update_counters(mock_client, true, cached_counters.clone(), partitioned) + .await; if let Some(c) = cached_counters.get(&counter) { assert_eq!(c.hits(&counter), 8); From 620c0215782c95cc7877553e74f668148b5dc643 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Thu, 25 Apr 2024 16:38:18 -0400 Subject: [PATCH 04/10] Async flush either periodically or on batch size being reached --- limitador/src/storage/redis/counters_cache.rs | 41 +++++++++++++++---- limitador/src/storage/redis/redis_cached.rs | 6 +-- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 3074fe01..ba3ba348 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -9,6 +9,9 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; +use tokio::select; +use tokio::sync::Notify; +use tokio::time::interval; pub struct CachedCounterValue { value: AtomicExpiringValue, @@ -18,12 +21,16 @@ pub struct CachedCounterValue { pub struct Batcher { updates: Mutex>>, + notifier: Notify, + interval: Duration, } impl Batcher { - fn new() -> Self { + fn new(period: Duration) -> Self { Self { updates: Mutex::new(Default::default()), + notifier: Default::default(), + interval: period, } } @@ -31,6 +38,21 @@ impl Batcher { self.updates.lock().unwrap().is_empty() } + pub async fn consume(&self, min: usize) -> HashMap> { + let mut interval = interval(self.interval); + let mut ready = self.updates.lock().unwrap().len() >= min; + loop { + if ready { + return self.consume_all(); + } else { + ready = select! { + _ = self.notifier.notified() => self.updates.lock().unwrap().len() >= min, + _ = interval.tick() => true, + } + } + } + } + pub fn consume_all(&self) -> HashMap> { let mut batch = self.updates.lock().unwrap(); std::mem::take(&mut *batch) @@ -38,12 +60,13 @@ impl Batcher { pub fn add(&self, counter: Counter, value: Arc) { self.updates.lock().unwrap().entry(counter).or_insert(value); + self.notifier.notify_one(); } } impl Default for Batcher { fn default() -> Self { - Self::new() + Self::new(Duration::from_millis(100)) } } @@ -164,12 +187,12 @@ impl CountersCacheBuilder { self } - pub fn build(&self) -> CountersCache { + pub fn build(&self, period: Duration) -> CountersCache { CountersCache { max_ttl_cached_counters: self.max_ttl_cached_counters, ttl_ratio_cached_counters: self.ttl_ratio_cached_counters, cache: Cache::new(self.max_cached_counters as u64), - batcher: Default::default(), + batcher: Batcher::new(period), } } } @@ -294,7 +317,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(10), @@ -321,7 +344,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); assert!(cache.get(&counter).is_none()); } @@ -343,7 +366,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(current_value), @@ -374,7 +397,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), None, @@ -403,7 +426,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(current_val), diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index a7aba7cc..634c36ad 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -211,7 +211,7 @@ impl CachedRedisStorage { .max_cached_counters(max_cached_counters) .max_ttl_cached_counter(ttl_cached_counters) .ttl_ratio_cached_counter(ttl_ratio_cached_counters) - .build(); + .build(flushing_period); let counters_cache = Arc::new(cached_counters); let partitioned = Arc::new(AtomicBool::new(false)); @@ -422,7 +422,7 @@ async fn flush_batcher_and_update_counters( flip_partitioned(&partitioned, false); } } else { - let counters = cached_counters.batcher().consume_all(); + let counters = cached_counters.batcher().consume(1).await; let time_start_update_counters = Instant::now(); @@ -560,7 +560,7 @@ mod tests { Ok(mock_response.clone()), )]); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::from_millis(1)); cache.batcher().add( counter.clone(), Arc::new(CachedCounterValue::from( From 826b8a98a234fe63b495adf9ff10214682618897 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Fri, 26 Apr 2024 12:22:43 -0400 Subject: [PATCH 05/10] Do the priority dance --- Cargo.lock | 1 + limitador/Cargo.toml | 1 + limitador/src/storage/redis/counters_cache.rs | 89 +++++++++++++++---- limitador/src/storage/redis/redis_cached.rs | 11 +-- 4 files changed, 80 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 301842c1..7a553227 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1477,6 +1477,7 @@ dependencies = [ "base64 0.22.0", "cfg-if", "criterion", + "dashmap", "futures", "getrandom", "infinispan", diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index 2d6646c6..7f791fe8 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -22,6 +22,7 @@ lenient_conditions = [] [dependencies] moka = { version = "0.12", features = ["sync"] } +dashmap = "5.5.3" getrandom = { version = "0.2", features = ["js"] } serde = { version = "1", features = ["derive"] } postcard = { version = "1.0.4", features = ["use-std"] } diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index ba3ba348..daad510a 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -4,10 +4,12 @@ use crate::storage::redis::{ DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_MAX_TTL_CACHED_COUNTERS_SEC, DEFAULT_TTL_RATIO_CACHED_COUNTERS, }; +use dashmap::DashMap; use moka::sync::Cache; use std::collections::HashMap; -use std::sync::atomic::{AtomicI64, Ordering}; -use std::sync::{Arc, Mutex}; +use std::future::Future; +use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; +use std::sync::Arc; use std::time::{Duration, SystemTime}; use tokio::select; use tokio::sync::Notify; @@ -20,46 +22,91 @@ pub struct CachedCounterValue { } pub struct Batcher { - updates: Mutex>>, + updates: DashMap>, notifier: Notify, interval: Duration, + priority_flush: AtomicBool, } impl Batcher { fn new(period: Duration) -> Self { Self { - updates: Mutex::new(Default::default()), + updates: Default::default(), notifier: Default::default(), interval: period, + priority_flush: AtomicBool::new(false), } } pub fn is_empty(&self) -> bool { - self.updates.lock().unwrap().is_empty() + self.updates.is_empty() } - pub async fn consume(&self, min: usize) -> HashMap> { + pub async fn consume(&self, min: usize, consumer: F) -> O + where + F: FnOnce(HashMap>) -> Fut, + Fut: Future, + { let mut interval = interval(self.interval); - let mut ready = self.updates.lock().unwrap().len() >= min; + let mut ready = self.updates.len() >= min; loop { if ready { - return self.consume_all(); + let mut batch = Vec::with_capacity(min); + let mut probably_fake = Vec::with_capacity(min); + for entry in &self.updates { + if entry.value().value.ttl() < self.interval { + batch.push(entry.key().clone()); + if batch.len() == min { + break; + } + } + if entry.value().expiry.duration() == Duration::from_secs(entry.key().seconds()) + { + probably_fake.push(entry.key().clone()); + if probably_fake.len() == min { + break; + } + } + } + if let Some(remaining) = min.checked_sub(batch.len()) { + let take = probably_fake.into_iter().take(remaining); + batch.append(&mut take.collect()); + } + if let Some(remaining) = min.checked_sub(batch.len()) { + let take = self.updates.iter().take(remaining); + batch.append(&mut take.map(|e| e.key().clone()).collect()); + } + let mut result = HashMap::new(); + for counter in &batch { + let value = self.updates.get(counter).unwrap().clone(); + result.insert(counter.clone(), value); + } + let result = consumer(result).await; + for counter in &batch { + self.updates + .remove_if(counter, |_, v| v.no_pending_writes()); + } + return result; } else { ready = select! { - _ = self.notifier.notified() => self.updates.lock().unwrap().len() >= min, + _ = self.notifier.notified() => { + self.updates.len() >= min || + self.priority_flush + .compare_exchange(true, false, Ordering::Release, Ordering::Acquire) + .is_ok() + }, _ = interval.tick() => true, } } } } - pub fn consume_all(&self) -> HashMap> { - let mut batch = self.updates.lock().unwrap(); - std::mem::take(&mut *batch) - } - - pub fn add(&self, counter: Counter, value: Arc) { - self.updates.lock().unwrap().entry(counter).or_insert(value); + pub fn add(&self, counter: Counter, value: Arc, priority: bool) { + let priority = priority || value.value.ttl() < self.interval; + self.updates.entry(counter).or_insert(value); + if priority { + self.priority_flush.store(true, Ordering::Release); + } self.notifier.notify_one(); } } @@ -140,6 +187,12 @@ impl CachedCounterValue { } } + fn no_pending_writes(&self) -> bool { + let start = self.initial_value.load(Ordering::SeqCst); + let value = self.value.value_at(SystemTime::now()); + value - start == 0 + } + pub fn hits(&self, _: &Counter) -> i64 { self.value.value_at(SystemTime::now()) } @@ -240,7 +293,9 @@ impl CountersCache { } pub fn increase_by(&self, counter: &Counter, delta: i64) { + let mut priority = false; let val = self.cache.get_with_by_ref(counter, || { + priority = true; Arc::new( // this TTL is wrong, it needs to be the cache's TTL, not the time window of our limit // todo fix when introducing the Batcher type! @@ -248,7 +303,7 @@ impl CountersCache { ) }); val.delta(counter, delta); - self.batcher.add(counter.clone(), val.clone()); + self.batcher.add(counter.clone(), val.clone(), priority); } fn ttl_from_redis_ttl( diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 634c36ad..ce9068cd 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -422,11 +422,9 @@ async fn flush_batcher_and_update_counters( flip_partitioned(&partitioned, false); } } else { - let counters = cached_counters.batcher().consume(1).await; - - let time_start_update_counters = Instant::now(); - - let updated_counters = update_counters(&mut redis_conn, counters) + let updated_counters = cached_counters + .batcher() + .consume(1, |counters| update_counters(&mut redis_conn, counters)) .await .or_else(|err| { if err.is_transient() { @@ -438,6 +436,8 @@ async fn flush_batcher_and_update_counters( }) .expect("Unrecoverable Redis error!"); + let time_start_update_counters = Instant::now(); + for (counter, value, ttl) in updated_counters { cached_counters.insert( counter, @@ -568,6 +568,7 @@ mod tests { 2, Duration::from_secs(60), )), + false, ); cache.insert( counter.clone(), From 87f69a73051264f97fcd9e4ccded7cd749723a28 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Fri, 26 Apr 2024 16:54:49 -0400 Subject: [PATCH 06/10] no need for the outer interval --- limitador/src/storage/redis/redis_cached.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index ce9068cd..b1596b67 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -223,7 +223,6 @@ impl CachedRedisStorage { let counters_cache_clone = counters_cache.clone(); let conn = redis_conn_manager.clone(); let p = Arc::clone(&partitioned); - let mut interval = tokio::time::interval(flushing_period); tokio::spawn(async move { loop { flush_batcher_and_update_counters( @@ -233,7 +232,6 @@ impl CachedRedisStorage { p.clone(), ) .await; - interval.tick().await; } }); } From 9d2b2c966b52fb27119a736e9d82fdcc356ca48f Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Fri, 26 Apr 2024 17:47:11 -0400 Subject: [PATCH 07/10] Lookup the write behing queue on miss --- limitador/src/storage/redis/counters_cache.rs | 32 ++++++++++++++----- limitador/src/storage/redis/redis_cached.rs | 2 +- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index daad510a..79a1d2c3 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -252,7 +252,15 @@ impl CountersCacheBuilder { impl CountersCache { pub fn get(&self, counter: &Counter) -> Option> { - self.cache.get(counter) + let option = self.cache.get(counter); + if option.is_none() { + let from_queue = self.batcher.updates.get(counter); + if let Some(entry) = from_queue { + self.cache.insert(counter.clone(), entry.value().clone()); + return Some(entry.value().clone()) + } + } + option } pub fn batcher(&self) -> &Batcher { @@ -277,7 +285,11 @@ impl CountersCache { if let Some(ttl) = cache_ttl.checked_sub(ttl_margin) { if ttl > Duration::ZERO { let previous = self.cache.get_with(counter.clone(), || { - Arc::new(CachedCounterValue::from(&counter, counter_val, cache_ttl)) + if let Some(entry) = self.batcher.updates.get(&counter) { + entry.value().clone() + } else { + Arc::new(CachedCounterValue::from(&counter, counter_val, cache_ttl)) + } }); if previous.expired_at(now) || previous.value.value() < counter_val { previous.set_from_authority(&counter, counter_val, cache_ttl); @@ -295,12 +307,16 @@ impl CountersCache { pub fn increase_by(&self, counter: &Counter, delta: i64) { let mut priority = false; let val = self.cache.get_with_by_ref(counter, || { - priority = true; - Arc::new( - // this TTL is wrong, it needs to be the cache's TTL, not the time window of our limit - // todo fix when introducing the Batcher type! - CachedCounterValue::from(counter, 0, Duration::from_secs(counter.seconds())), - ) + if let Some(entry) = self.batcher.updates.get(&counter) { + entry.value().clone() + } else { + priority = true; + Arc::new( + // this TTL is wrong, it needs to be the cache's TTL, not the time window of our limit + // todo fix when introducing the Batcher type! + CachedCounterValue::from(counter, 0, Duration::from_secs(counter.seconds())), + ) + } }); val.delta(counter, delta); self.batcher.add(counter.clone(), val.clone(), priority); diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index b1596b67..02a425ca 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -579,7 +579,7 @@ mod tests { let partitioned = Arc::new(AtomicBool::new(false)); if let Some(c) = cached_counters.get(&counter) { - assert_eq!(c.hits(&counter), 1); + assert_eq!(c.hits(&counter), 2); } flush_batcher_and_update_counters(mock_client, true, cached_counters.clone(), partitioned) From 49dd4c4707cc4373d98b4742d9056b88424b3cab Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Mon, 29 Apr 2024 16:36:26 -0400 Subject: [PATCH 08/10] Fold priority within CachedCounterValue --- limitador/src/storage/redis/counters_cache.rs | 65 ++++++++++--------- limitador/src/storage/redis/redis_cached.rs | 5 +- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 79a1d2c3..925145f3 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -19,6 +19,7 @@ pub struct CachedCounterValue { value: AtomicExpiringValue, initial_value: AtomicI64, expiry: AtomicExpiryTime, + from_authority: AtomicBool, } pub struct Batcher { @@ -52,25 +53,13 @@ impl Batcher { loop { if ready { let mut batch = Vec::with_capacity(min); - let mut probably_fake = Vec::with_capacity(min); for entry in &self.updates { - if entry.value().value.ttl() < self.interval { + if entry.value().requires_fast_flush(&self.interval) { batch.push(entry.key().clone()); if batch.len() == min { break; } } - if entry.value().expiry.duration() == Duration::from_secs(entry.key().seconds()) - { - probably_fake.push(entry.key().clone()); - if probably_fake.len() == min { - break; - } - } - } - if let Some(remaining) = min.checked_sub(batch.len()) { - let take = probably_fake.into_iter().take(remaining); - batch.append(&mut take.collect()); } if let Some(remaining) = min.checked_sub(batch.len()) { let take = self.updates.iter().take(remaining); @@ -101,8 +90,8 @@ impl Batcher { } } - pub fn add(&self, counter: Counter, value: Arc, priority: bool) { - let priority = priority || value.value.ttl() < self.interval; + pub fn add(&self, counter: Counter, value: Arc) { + let priority = value.requires_fast_flush(&self.interval); self.updates.entry(counter).or_insert(value); if priority { self.priority_flush.store(true, Ordering::Release); @@ -125,12 +114,26 @@ pub struct CountersCache { } impl CachedCounterValue { - pub fn from(counter: &Counter, value: i64, ttl: Duration) -> Self { + pub fn from_authority(counter: &Counter, value: i64, ttl: Duration) -> Self { let now = SystemTime::now(); Self { value: AtomicExpiringValue::new(value, now + Duration::from_secs(counter.seconds())), initial_value: AtomicI64::new(value), expiry: AtomicExpiryTime::from_now(ttl), + from_authority: AtomicBool::new(true), + } + } + + pub fn load_from_authority_asap(counter: &Counter, temp_value: i64) -> Self { + let now = SystemTime::now(); + Self { + value: AtomicExpiringValue::new( + temp_value, + now + Duration::from_secs(counter.seconds()), + ), + initial_value: AtomicI64::new(temp_value), + expiry: AtomicExpiryTime::from_now(Duration::from_secs(counter.seconds())), + from_authority: AtomicBool::new(false), } } @@ -143,6 +146,7 @@ impl CachedCounterValue { self.initial_value.store(value, Ordering::SeqCst); self.value.set(value, time_window); self.expiry.update(expiry); + self.from_authority.store(true, Ordering::Release); } pub fn delta(&self, counter: &Counter, delta: i64) -> i64 { @@ -208,6 +212,10 @@ impl CachedCounterValue { pub fn to_next_window(&self) -> Duration { self.value.ttl() } + + pub fn requires_fast_flush(&self, within: &Duration) -> bool { + self.from_authority.load(Ordering::Acquire) || &self.value.ttl() <= within + } } pub struct CountersCacheBuilder { @@ -257,7 +265,7 @@ impl CountersCache { let from_queue = self.batcher.updates.get(counter); if let Some(entry) = from_queue { self.cache.insert(counter.clone(), entry.value().clone()); - return Some(entry.value().clone()) + return Some(entry.value().clone()); } } option @@ -288,38 +296,35 @@ impl CountersCache { if let Some(entry) = self.batcher.updates.get(&counter) { entry.value().clone() } else { - Arc::new(CachedCounterValue::from(&counter, counter_val, cache_ttl)) + Arc::new(CachedCounterValue::from_authority( + &counter, + counter_val, + ttl, + )) } }); if previous.expired_at(now) || previous.value.value() < counter_val { - previous.set_from_authority(&counter, counter_val, cache_ttl); + previous.set_from_authority(&counter, counter_val, ttl); } return previous; } } - Arc::new(CachedCounterValue::from( + Arc::new(CachedCounterValue::load_from_authority_asap( &counter, counter_val, - Duration::ZERO, )) } pub fn increase_by(&self, counter: &Counter, delta: i64) { - let mut priority = false; let val = self.cache.get_with_by_ref(counter, || { - if let Some(entry) = self.batcher.updates.get(&counter) { + if let Some(entry) = self.batcher.updates.get(counter) { entry.value().clone() } else { - priority = true; - Arc::new( - // this TTL is wrong, it needs to be the cache's TTL, not the time window of our limit - // todo fix when introducing the Batcher type! - CachedCounterValue::from(counter, 0, Duration::from_secs(counter.seconds())), - ) + Arc::new(CachedCounterValue::load_from_authority_asap(counter, 0)) } }); val.delta(counter, delta); - self.batcher.add(counter.clone(), val.clone(), priority); + self.batcher.add(counter.clone(), val.clone()); } fn ttl_from_redis_ttl( diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 02a425ca..69b43675 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -499,7 +499,7 @@ mod tests { counters_and_deltas.insert( counter.clone(), - Arc::new(CachedCounterValue::from( + Arc::new(CachedCounterValue::from_authority( &counter, 1, Duration::from_secs(60), @@ -561,12 +561,11 @@ mod tests { let cache = CountersCacheBuilder::new().build(Duration::from_millis(1)); cache.batcher().add( counter.clone(), - Arc::new(CachedCounterValue::from( + Arc::new(CachedCounterValue::from_authority( &counter, 2, Duration::from_secs(60), )), - false, ); cache.insert( counter.clone(), From 83ba6a7e79bc616c8bb7e74a8b1c5de0046bccf0 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Mon, 29 Apr 2024 17:11:52 -0400 Subject: [PATCH 09/10] Merge in case we face a race --- limitador/src/storage/redis/counters_cache.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 925145f3..c9af9091 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -4,6 +4,7 @@ use crate::storage::redis::{ DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_MAX_TTL_CACHED_COUNTERS_SEC, DEFAULT_TTL_RATIO_CACHED_COUNTERS, }; +use dashmap::mapref::entry::Entry; use dashmap::DashMap; use moka::sync::Cache; use std::collections::HashMap; @@ -92,7 +93,17 @@ impl Batcher { pub fn add(&self, counter: Counter, value: Arc) { let priority = value.requires_fast_flush(&self.interval); - self.updates.entry(counter).or_insert(value); + match self.updates.entry(counter.clone()) { + Entry::Occupied(needs_merge) => { + let arc = needs_merge.get(); + if !Arc::ptr_eq(arc, &value) { + arc.delta(&counter, value.pending_writes().unwrap()); + } + } + Entry::Vacant(miss) => { + miss.insert_entry(value); + } + }; if priority { self.priority_flush.store(true, Ordering::Release); } From 2b2230140db019596e5a3049939906fe65867418 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Mon, 29 Apr 2024 17:22:38 -0400 Subject: [PATCH 10/10] Fix #298: Fake 0 as the default value, and reload from authority ASAP --- limitador/src/storage/redis/redis_cached.rs | 93 ++------------------- 1 file changed, 5 insertions(+), 88 deletions(-) diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 69b43675..20dedcc0 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -5,7 +5,7 @@ use crate::storage::redis::counters_cache::{ CachedCounterValue, CountersCache, CountersCacheBuilder, }; use crate::storage::redis::redis_async::AsyncRedisStorage; -use crate::storage::redis::scripts::{BATCH_UPDATE_COUNTERS, VALUES_AND_TTLS}; +use crate::storage::redis::scripts::BATCH_UPDATE_COUNTERS; use crate::storage::redis::{ DEFAULT_FLUSHING_PERIOD_SEC, DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_MAX_TTL_CACHED_COUNTERS_SEC, DEFAULT_RESPONSE_TIMEOUT_MS, DEFAULT_TTL_RATIO_CACHED_COUNTERS, @@ -42,8 +42,6 @@ use tracing::{debug_span, error, warn, Instrument}; pub struct CachedRedisStorage { cached_counters: Arc, async_redis_storage: AsyncRedisStorage, - redis_conn_manager: ConnectionManager, - partitioned: Arc, } #[async_trait] @@ -102,37 +100,9 @@ impl AsyncCounterStorage for CachedRedisStorage { // Fetch non-cached counters, cache them, and check them if !not_cached.is_empty() { - let time_start_get_ttl = Instant::now(); - - let (counter_vals, counter_ttls_msecs) = if self.is_partitioned() { - self.fallback_vals_ttls(¬_cached) - } else { - self.values_with_ttls(¬_cached).await.or_else(|err| { - if err.is_transient() { - self.partitioned(true); - Ok(self.fallback_vals_ttls(¬_cached)) - } else { - Err(err) - } - })? - }; - - // Some time could have passed from the moment we got the TTL from Redis. - // This margin is not exact, because we don't know exactly the - // moment that Redis returned a particular TTL, but this - // approximation should be good enough. - let ttl_margin = - Duration::from_millis((Instant::now() - time_start_get_ttl).as_millis() as u64); - - for (i, counter) in not_cached.iter_mut().enumerate() { - let cached_value = self.cached_counters.insert( - counter.clone(), - counter_vals[i], - counter_ttls_msecs[i], - ttl_margin, - now, - ); - let remaining = cached_value.remaining(counter); + for counter in not_cached.iter_mut() { + let fake = CachedCounterValue::load_from_authority_asap(counter, 0); + let remaining = fake.remaining(counter); if first_limited.is_none() && remaining <= 0 { first_limited = Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), @@ -140,7 +110,7 @@ impl AsyncCounterStorage for CachedRedisStorage { } if load_counters { counter.set_remaining(remaining - delta); - counter.set_expires_in(cached_value.to_next_window()); + counter.set_expires_in(fake.to_next_window()); // todo: this is a plain lie! } } } @@ -238,62 +208,9 @@ impl CachedRedisStorage { Ok(Self { cached_counters: counters_cache, - redis_conn_manager, async_redis_storage, - partitioned, }) } - - fn is_partitioned(&self) -> bool { - self.partitioned.load(Ordering::Acquire) - } - - fn partitioned(&self, partition: bool) -> bool { - flip_partitioned(&self.partitioned, partition) - } - - fn fallback_vals_ttls(&self, counters: &Vec<&mut Counter>) -> (Vec>, Vec) { - let mut vals = Vec::with_capacity(counters.len()); - let mut ttls = Vec::with_capacity(counters.len()); - for counter in counters { - vals.push(Some(0i64)); - ttls.push(counter.limit().seconds() as i64 * 1000); - } - (vals, ttls) - } - - async fn values_with_ttls( - &self, - counters: &[&mut Counter], - ) -> Result<(Vec>, Vec), StorageErr> { - let mut redis_con = self.redis_conn_manager.clone(); - - let counter_keys: Vec = counters - .iter() - .map(|counter| key_for_counter(counter)) - .collect(); - - let script = redis::Script::new(VALUES_AND_TTLS); - let mut script_invocation = script.prepare_invoke(); - - for counter_key in counter_keys { - script_invocation.key(counter_key); - } - - let script_res: Vec> = script_invocation - .invoke_async::<_, _>(&mut redis_con) - .await?; - - let mut counter_vals: Vec> = vec![]; - let mut counter_ttls_msecs: Vec = vec![]; - - for val_ttl_pair in script_res.chunks(2) { - counter_vals.push(val_ttl_pair[0]); - counter_ttls_msecs.push(val_ttl_pair[1].unwrap()); - } - - Ok((counter_vals, counter_ttls_msecs)) - } } fn flip_partitioned(storage: &AtomicBool, partition: bool) -> bool {