diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index 8bcda255..0c35d32b 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -239,13 +239,13 @@ impl AsyncRedisStorage { pub(crate) async fn update_counters( redis_conn: &mut C, counters_and_deltas: HashMap, - ) -> Result, StorageErr> { + ) -> Result, StorageErr> { let span = trace_span!("datastore"); let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); let mut script_invocation = redis_script.prepare_invoke(); - let mut res: Vec<(Counter, i64)> = Vec::new(); + let mut res: Vec<(Counter, i64, i64)> = Vec::new(); let now = SystemTime::now(); for (counter, delta) in counters_and_deltas { let delta = delta.value_at(now); @@ -255,7 +255,7 @@ impl AsyncRedisStorage { script_invocation.arg(counter.seconds()); script_invocation.arg(delta); // We need to store the counter in the actual order we are sending it to the script - res.push((counter, 0)); + res.push((counter, 0, 0)); } } @@ -265,11 +265,14 @@ impl AsyncRedisStorage { .instrument(span) .await?; - // We need to update the values of the counters with the values returned by redis - for (i, (_, value)) in res.iter_mut().enumerate() { - if let Some(new_value) = script_res.get(i) { - *value = *new_value; - } + // We need to update the values and ttls returned by redis + let counters_range = 0..res.len(); + let script_res_range = (0..script_res.len()).step_by(2); + + for (i, j) in counters_range.zip(script_res_range) { + let (_, val, ttl) = &mut res[i]; + *val = script_res[j]; + *ttl = script_res[j + 1]; } Ok(res) @@ -323,11 +326,11 @@ mod tests { counters_and_deltas.insert(counter.clone(), expiring_value); - let mock_response = Value::Bulk(vec![Value::Int(10), Value::Int(20)]); + let mock_response = Value::Bulk(vec![Value::Int(10), Value::Int(60)]); let mut mock_client = MockRedisConnection::new(vec![MockCmd::new( redis::cmd("EVALSHA") - .arg("8ee7a63a239b1e196b6a557956da849c10ffefcf") + .arg("8fbdbae84f16e71bcaef347c46f8887564b01213") .arg("2") .arg(key_for_counter(&counter)) .arg(key_for_counters_of_limit(counter.limit())) @@ -341,11 +344,12 @@ mod tests { assert!(result.is_ok()); - let (c, v) = result.unwrap()[0].clone(); + let (c, v, t) = result.unwrap()[0].clone(); assert_eq!( "req.method == \"GET\"", c.limit().conditions().iter().collect::>()[0] ); assert_eq!(10, v); + assert_eq!(60, t); } } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 8381f5d8..a724fe89 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -399,6 +399,8 @@ async fn flush_batcher_and_update_counters( let mut conn = storage.conn_manager.clone(); + let time_start_update_counters = Instant::now(); + let updated_counters = AsyncRedisStorage::update_counters(&mut conn, counters) .await .or_else(|err| { @@ -411,13 +413,14 @@ async fn flush_batcher_and_update_counters( }) .expect("Unrecoverable Redis error!"); - for (counter, value) in updated_counters { - //TODO: Populate the right ttls + for (counter, value, ttl) in updated_counters { cached_counters.insert( counter, Option::from(value), - 0, - Duration::from_secs(0), + ttl, + Duration::from_millis( + (Instant::now() - time_start_update_counters).as_millis() as u64 + ), SystemTime::now(), ); } diff --git a/limitador/src/storage/redis/scripts.rs b/limitador/src/storage/redis/scripts.rs index f75b3adb..c81b2e89 100644 --- a/limitador/src/storage/redis/scripts.rs +++ b/limitador/src/storage/redis/scripts.rs @@ -23,6 +23,8 @@ pub const SCRIPT_UPDATE_COUNTER: &str = " // KEY[i+1]: Limit key // ARGV[i]: TTLs // ARGV[i+1]: Deltas +// This function returns a list with the values and TTLs for the updated counter_keys, +// the first position the counter value and the second the TTL pub const BATCH_UPDATE_COUNTERS: &str = " local res = {} for i = 1, #KEYS, 2 do @@ -38,6 +40,7 @@ pub const BATCH_UPDATE_COUNTERS: &str = " end table.insert(res, c) + table.insert(res, redis.call('pttl', counter_key)) end return res ";