From 68011b5bcf5d00e21e6eb054706e29868e1cbcc5 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Thu, 18 Apr 2024 10:45:05 +0200 Subject: [PATCH] [wip] Testing, refactoring... --- limitador/src/storage/redis/redis_async.rs | 107 +++++++++----------- limitador/src/storage/redis/redis_cached.rs | 12 ++- limitador/src/storage/redis/scripts.rs | 3 +- 3 files changed, 59 insertions(+), 63 deletions(-) diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index bc96a45f..8bcda255 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -239,12 +239,13 @@ impl AsyncRedisStorage { pub(crate) async fn update_counters( redis_conn: &mut C, counters_and_deltas: HashMap, - ) -> Result, StorageErr> { + ) -> Result, StorageErr> { let span = trace_span!("datastore"); let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); let mut script_invocation = redis_script.prepare_invoke(); + let mut res: Vec<(Counter, i64)> = Vec::new(); let now = SystemTime::now(); for (counter, delta) in counters_and_deltas { let delta = delta.value_at(now); @@ -253,24 +254,39 @@ impl AsyncRedisStorage { script_invocation.key(key_for_counters_of_limit(counter.limit())); script_invocation.arg(counter.seconds()); script_invocation.arg(delta); + // We need to store the counter in the actual order we are sending it to the script + res.push((counter, 0)); } } - let script_res: Vec> = script_invocation - .invoke_async::<_, _>(redis_conn) + // The redis crate is not working with tables, thus the response will be a Vec of counter values + let script_res: Vec = script_invocation + .invoke_async(redis_conn) .instrument(span) .await?; - Ok(script_res.into_iter().flatten().collect()) + // We need to update the values of the counters with the values returned by redis + for (i, (_, value)) in res.iter_mut().enumerate() { + if let Some(new_value) = script_res.get(i) { + *value = *new_value; + } + } + + Ok(res) } } #[cfg(test)] mod tests { - use crate::storage::redis::AsyncRedisStorage; - use redis::ErrorKind; - use redis_test::{MockCmd, MockRedisConnection, IntoRedisValue}; + use crate::counter::Counter; + use crate::limit::Limit; + use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::keys::{key_for_counter, key_for_counters_of_limit}; + use crate::storage::redis::AsyncRedisStorage; + use redis::{ErrorKind, Value}; + use redis_test::{MockCmd, MockRedisConnection}; + use std::collections::HashMap; + use std::time::{Duration, SystemTime}; #[tokio::test] async fn errs_on_bad_url() { @@ -290,73 +306,46 @@ mod tests { #[tokio::test] async fn batch_update_counters() { - - let mut counters_and_deltas = std::collections::HashMap::new(); - let counter = crate::counter::Counter::new( - crate::limit::Limit::new( + let mut counters_and_deltas = HashMap::new(); + let counter = Counter::new( + Limit::new( "test_namespace", 10, 60, vec!["req.method == 'GET'"], vec!["app_id"], ), - std::collections::HashMap::new(), + Default::default(), ); - let expiring_value = crate::storage::atomic_expiring_value::AtomicExpiringValue::new( - 1, - std::time::SystemTime::now() + std::time::Duration::from_secs(60), - ); + let expiring_value = + AtomicExpiringValue::new(1, SystemTime::now() + Duration::from_secs(60)); - let counter_key = key_for_counter(&counter); - let key_for_counters_of_limit = key_for_counters_of_limit(counter.limit()); + counters_and_deltas.insert(counter.clone(), expiring_value); - counters_and_deltas.insert(counter, expiring_value); + let mock_response = Value::Bulk(vec![Value::Int(10), Value::Int(20)]); - let mock_response = format!( - "{{{{ {},1}}}}", - counter_key.clone(), - ); - - let mut mock_client = MockRedisConnection::new( - vec![ - MockCmd::new( - redis::cmd("EVALSHA") - .arg("13e042bb900a9a1104370208a300432bcdd45383") - .arg("2") - .arg(counter_key.clone()) - .arg(key_for_counters_of_limit.clone()) - .arg(60) - .arg(1), - Ok(IntoRedisValue::into_redis_value(mock_response)), - ), - MockCmd::new( - redis::cmd("incrby") - .arg(counter_key.clone()) - .arg(1), - Ok("1"), - ), - MockCmd::new( - redis::cmd("EXPIRE") - .arg(counter_key.clone()) - .arg(60), - Ok("1"), - ), - MockCmd::new( - redis::cmd("SADD") - .arg(key_for_counters_of_limit) - .arg(counter_key.clone()), - Ok("1"), - ), - ], - ); + let mut mock_client = MockRedisConnection::new(vec![MockCmd::new( + redis::cmd("EVALSHA") + .arg("8ee7a63a239b1e196b6a557956da849c10ffefcf") + .arg("2") + .arg(key_for_counter(&counter)) + .arg(key_for_counters_of_limit(counter.limit())) + .arg(60) + .arg(1), + Ok(mock_response.clone()), + )]); let result = AsyncRedisStorage::update_counters(&mut mock_client, counters_and_deltas).await; - - assert!(result.is_ok()); - //assert!(result.unwrap(), "{}", vec![("test_namespace:app_id:GET:1", 1)]); + + let (c, v) = result.unwrap()[0].clone(); + assert_eq!( + "req.method == \"GET\"", + c.limit().conditions().iter().collect::>()[0] + ); + assert_eq!(10, v); } } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 0ae03bba..8381f5d8 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -411,9 +411,15 @@ async fn flush_batcher_and_update_counters( }) .expect("Unrecoverable Redis error!"); - for (counter_key, value) in updated_counters { - let counter = partial_counter_from_counter_key(&counter_key); - cached_counters.increase_by(&counter, value); + for (counter, value) in updated_counters { + //TODO: Populate the right ttls + cached_counters.insert( + counter, + Option::from(value), + 0, + Duration::from_secs(0), + SystemTime::now(), + ); } } } diff --git a/limitador/src/storage/redis/scripts.rs b/limitador/src/storage/redis/scripts.rs index d7554772..f75b3adb 100644 --- a/limitador/src/storage/redis/scripts.rs +++ b/limitador/src/storage/redis/scripts.rs @@ -36,7 +36,8 @@ pub const BATCH_UPDATE_COUNTERS: &str = " redis.call('expire', counter_key, ttl) redis.call('sadd', limit_key, counter_key) end - table.insert(res, { counter_key, c }) + + table.insert(res, c) end return res ";