diff --git a/limitador/src/storage/redis/batcher.rs b/limitador/src/storage/redis/batcher.rs deleted file mode 100644 index e5496334..00000000 --- a/limitador/src/storage/redis/batcher.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::counter::Counter; -use crate::storage::redis::AsyncRedisStorage; -use crate::storage::AsyncCounterStorage; -use std::collections::HashMap; -use tokio::sync::Mutex; - -pub struct Batcher { - accumulated_counter_updates: Mutex>, - redis_storage: AsyncRedisStorage, -} - -impl Batcher { - pub fn new(redis_storage: AsyncRedisStorage) -> Self { - Self { - accumulated_counter_updates: Mutex::new(HashMap::new()), - redis_storage, - } - } - - pub async fn add_counter(&self, counter: &Counter, delta: i64) { - let mut accumulated_counter_updates = self.accumulated_counter_updates.lock().await; - - match accumulated_counter_updates.get_mut(counter) { - Some(val) => { - *val += delta; - } - None => { - accumulated_counter_updates.insert(counter.clone(), delta); - } - } - } - - pub async fn flush(&self) { - let mut accumulated_counter_updates = self.accumulated_counter_updates.lock().await; - - for (counter, delta) in accumulated_counter_updates.iter() { - self.redis_storage - .update_counter(counter, *delta) - .await - .unwrap(); - } - accumulated_counter_updates.clear(); - } -} diff --git a/limitador/src/storage/redis/mod.rs b/limitador/src/storage/redis/mod.rs index 1945e3e6..1eec4626 100644 --- a/limitador/src/storage/redis/mod.rs +++ b/limitador/src/storage/redis/mod.rs @@ -1,7 +1,6 @@ use ::redis::RedisError; use std::time::Duration; -mod batcher; mod counters_cache; mod redis_async; mod redis_cached; diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 2f6e1fc1..fd543aee 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -2,7 +2,6 @@ use crate::counter::Counter; use crate::limit::Limit; use crate::prometheus_metrics::CounterAccess; use crate::storage::keys::*; -use crate::storage::redis::batcher::Batcher; use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; use crate::storage::redis::redis_async::AsyncRedisStorage; use crate::storage::redis::scripts::VALUES_AND_TTLS; @@ -14,11 +13,10 @@ use crate::storage::{AsyncCounterStorage, Authorization, StorageErr}; use async_trait::async_trait; use redis::aio::ConnectionManager; use redis::{ConnectionInfo, RedisError}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; -use tokio::sync::Mutex; // This is just a first version. // @@ -40,7 +38,7 @@ use tokio::sync::Mutex; pub struct CachedRedisStorage { cached_counters: Mutex, - batcher_counter_updates: Arc>, + batcher_counter_updates: Arc>>, async_redis_storage: AsyncRedisStorage, redis_conn_manager: ConnectionManager, batching_is_enabled: bool, @@ -81,7 +79,7 @@ impl AsyncCounterStorage for CachedRedisStorage { // Check cached counters { - let cached_counters = self.cached_counters.lock().await; + let cached_counters = self.cached_counters.lock().unwrap(); for counter in counters.iter_mut() { match cached_counters.get(counter) { Some(val) => { @@ -122,7 +120,7 @@ impl AsyncCounterStorage for CachedRedisStorage { Duration::from_millis((Instant::now() - time_start_get_ttl).as_millis() as u64); { - let mut cached_counters = self.cached_counters.lock().await; + let mut cached_counters = self.cached_counters.lock().unwrap(); for (i, counter) in not_cached.iter_mut().enumerate() { cached_counters.insert( counter.clone(), @@ -150,7 +148,7 @@ impl AsyncCounterStorage for CachedRedisStorage { // Update cached values { - let mut cached_counters = self.cached_counters.lock().await; + let mut cached_counters = self.cached_counters.lock().unwrap(); for counter in counters.iter() { cached_counters.decrease_by(counter, delta); } @@ -158,9 +156,16 @@ impl AsyncCounterStorage for CachedRedisStorage { // Batch or update depending on configuration if self.batching_is_enabled { - let batcher = self.batcher_counter_updates.lock().await; + let mut batcher = self.batcher_counter_updates.lock().unwrap(); for counter in counters.iter() { - batcher.add_counter(counter, delta).await + match batcher.get_mut(counter) { + Some(val) => { + *val += delta; + } + None => { + batcher.insert(counter.clone(), delta); + } + } } } else { for counter in counters.iter() { @@ -216,17 +221,21 @@ impl CachedRedisStorage { let async_redis_storage = AsyncRedisStorage::new_with_conn_manager(redis_conn_manager.clone()); - let batcher = Arc::new(Mutex::new(Batcher::new(async_redis_storage.clone()))); + let storage = async_redis_storage.clone(); + let batcher = Arc::new(Mutex::new(Default::default())); if let Some(flushing_period) = flushing_period { let batcher_flusher = batcher.clone(); + let mut interval = tokio::time::interval(flushing_period); tokio::spawn(async move { loop { - let time_start = Instant::now(); - batcher_flusher.lock().await.flush().await; - let sleep_time = flushing_period - .checked_sub(time_start.elapsed()) - .unwrap_or_else(|| Duration::from_secs(0)); - tokio::time::sleep(sleep_time).await; + let counters = { + let mut batch = batcher_flusher.lock().unwrap(); + std::mem::take(&mut *batch) + }; + for (counter, delta) in counters { + storage.update_counter(&counter, delta).await.unwrap(); + } + interval.tick().await; } }); }