diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index f7793194..bbfd8222 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -14,6 +14,7 @@ use async_trait::async_trait; use redis::aio::{ConnectionLike, ConnectionManager}; use redis::{ConnectionInfo, RedisError}; use std::collections::{HashMap, HashSet}; +use std::future::Future; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; @@ -252,7 +253,7 @@ impl CachedRedisStorage { flush_batcher_and_update_counters( conn.clone(), batcher_flusher.clone(), - storage.clone(), + storage.is_alive(), cacher_clone.clone(), p.clone(), ) @@ -428,15 +429,15 @@ async fn update_counters( Ok(res) } -async fn flush_batcher_and_update_counters( - mut conn: ConnectionManager, +async fn flush_batcher_and_update_counters( + mut redis_conn: C, batcher: Arc>>, - storage: AsyncRedisStorage, + storage_is_alive: impl Future, cached_counters: Arc, partitioned: Arc, ) { if partitioned.load(Ordering::Acquire) { - if storage.is_alive().await { + if storage_is_alive.await { warn!("Partition to Redis resolved!"); partitioned.store(false, Ordering::Release); } @@ -448,7 +449,7 @@ async fn flush_batcher_and_update_counters( let time_start_update_counters = Instant::now(); - let updated_counters = update_counters(&mut conn, counters) + let updated_counters = update_counters(&mut redis_conn, counters) .await .or_else(|err| { if err.is_transient() { @@ -480,11 +481,14 @@ mod tests { use crate::limit::Limit; use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::keys::{key_for_counter, key_for_counters_of_limit}; + use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; + use crate::storage::redis::redis_cached::{flush_batcher_and_update_counters, update_counters}; use crate::storage::redis::CachedRedisStorage; - use crate::storage::redis::redis_cached::update_counters; use redis::{ErrorKind, Value}; use redis_test::{MockCmd, MockRedisConnection}; use std::collections::HashMap; + use std::sync::atomic::AtomicBool; + use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; #[tokio::test] @@ -535,8 +539,7 @@ mod tests { Ok(mock_response.clone()), )]); - let result = - update_counters(&mut mock_client, counters_and_deltas).await; + let result = update_counters(&mut mock_client, counters_and_deltas).await; assert!(result.is_ok()); @@ -548,4 +551,71 @@ mod tests { assert_eq!(10, v); assert_eq!(60, t); } + + #[tokio::test] + async fn flush_batcher_and_update_counters_test() { + let counter = Counter::new( + Limit::new( + "test_namespace", + 10, + 60, + vec!["req.method == 'POST'"], + vec!["app_id"], + ), + Default::default(), + ); + + let mock_response = Value::Bulk(vec![Value::Int(8), Value::Int(60)]); + + let mock_client = MockRedisConnection::new(vec![MockCmd::new( + redis::cmd("EVALSHA") + .arg("1e87383cf7dba2bd0f9972ed73671274e6cbd5da") + .arg("2") + .arg(key_for_counter(&counter)) + .arg(key_for_counters_of_limit(counter.limit())) + .arg(60) + .arg(2), + Ok(mock_response.clone()), + )]); + + let mut batched_counters = HashMap::new(); + batched_counters.insert( + counter.clone(), + AtomicExpiringValue::new(2, SystemTime::now() + Duration::from_secs(60)), + ); + + let batcher: Arc>> = + Arc::new(Mutex::new(batched_counters)); + let cache = CountersCacheBuilder::new().build(); + cache.insert( + counter.clone(), + Some(1), + 10, + Duration::from_secs(0), + SystemTime::now(), + ); + let cached_counters: Arc = Arc::new(cache); + let partitioned = Arc::new(AtomicBool::new(false)); + + async fn future_true() -> bool { + true + } + + if let Some(c) = cached_counters.get(&counter) { + assert_eq!(c.hits(&counter), 1); + } + + flush_batcher_and_update_counters( + mock_client, + batcher, + future_true(), + cached_counters.clone(), + partitioned, + ) + .await; + + if let Some(c) = cached_counters.get(&counter) { + assert_eq!(c.hits(&counter), 8); + } + } }