diff --git a/limitador/src/storage/redis/mod.rs b/limitador/src/storage/redis/mod.rs index 1945e3e6..49f052b1 100644 --- a/limitador/src/storage/redis/mod.rs +++ b/limitador/src/storage/redis/mod.rs @@ -47,7 +47,8 @@ pub fn is_limited( let mut first_limited = None; for (i, counter) in counters.iter_mut().enumerate() { - let remaining = counter_vals[i].unwrap_or(counter.max_value()) - delta; + // remaining = max - (curr_val + delta) + let remaining = counter.max_value() - (counter_vals[i].unwrap_or(0) + delta); counter.set_remaining(remaining); let expires_in = counter_ttls_msecs[i] .map(|x| { diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index bb94e933..23782254 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -44,7 +44,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { .instrument(span) .await? { - Some(val) => Ok(val - delta >= 0), + Some(val) => Ok(val + delta <= counter.max_value()), None => Ok(counter.max_value() - delta >= 0), } } @@ -58,7 +58,6 @@ impl AsyncCounterStorage for AsyncRedisStorage { redis::Script::new(SCRIPT_UPDATE_COUNTER) .key(key_for_counter(counter)) .key(key_for_counters_of_limit(counter.limit())) - .arg(counter.max_value()) .arg(counter.seconds()) .arg(delta) .invoke_async::<_, _>(&mut con) @@ -120,7 +119,8 @@ impl AsyncCounterStorage for AsyncRedisStorage { }; for (i, counter) in counters.iter().enumerate() { - let remaining = counter_vals[i].unwrap_or(counter.max_value()) - delta; + // remaining = max - (curr_val + delta) + let remaining = counter.max_value() - (counter_vals[i].unwrap_or(0) + delta); if remaining < 0 { return Ok(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), @@ -138,7 +138,6 @@ impl AsyncCounterStorage for AsyncRedisStorage { let result = redis::Script::new(SCRIPT_UPDATE_COUNTER) .key(key) .key(key_for_counters_of_limit(counter.limit())) - .arg(counter.max_value()) .arg(counter.seconds()) .arg(delta) .invoke_async::<_, _>(&mut con) @@ -187,7 +186,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { .await? }; if let Some(val) = option { - counter.set_remaining(val); + counter.set_remaining(limit.max_value() - val); let ttl = { let span = trace_span!("datastore"); async { con.ttl(&counter_key).await } diff --git a/limitador/src/storage/redis/redis_sync.rs b/limitador/src/storage/redis/redis_sync.rs index d6dee2ec..adc9ba3a 100644 --- a/limitador/src/storage/redis/redis_sync.rs +++ b/limitador/src/storage/redis/redis_sync.rs @@ -28,7 +28,7 @@ impl CounterStorage for RedisStorage { let mut con = self.conn_pool.get()?; match con.get::>(key_for_counter(counter))? { - Some(val) => Ok(val - delta >= 0), + Some(val) => Ok(val + delta <= counter.max_value()), None => Ok(counter.max_value() - delta >= 0), } } @@ -45,7 +45,6 @@ impl CounterStorage for RedisStorage { redis::Script::new(SCRIPT_UPDATE_COUNTER) .key(key_for_counter(counter)) .key(key_for_counters_of_limit(counter.limit())) - .arg(counter.max_value()) .arg(counter.seconds()) .arg(delta) .invoke(&mut *con)?; @@ -80,7 +79,8 @@ impl CounterStorage for RedisStorage { .query(&mut *con)?; for (i, counter) in counters.iter().enumerate() { - let remaining = counter_vals[i].unwrap_or(counter.max_value()) - delta; + // remaining = max - (curr_val + delta) + let remaining = counter.max_value() - (counter_vals[i].unwrap_or(0) + delta); if remaining < 0 { return Ok(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), @@ -95,7 +95,6 @@ impl CounterStorage for RedisStorage { redis::Script::new(SCRIPT_UPDATE_COUNTER) .key(key) .key(key_for_counters_of_limit(counter.limit())) - .arg(counter.max_value()) .arg(counter.seconds()) .arg(delta) .invoke(&mut *con)?; @@ -125,7 +124,7 @@ impl CounterStorage for RedisStorage { // This does not cause any bugs, but consumes memory // unnecessarily. if let Some(val) = con.get::>(counter_key.clone())? { - counter.set_remaining(val); + counter.set_remaining(limit.max_value() - val); let ttl = con.ttl(&counter_key)?; counter.set_expires_in(Duration::from_secs(ttl)); diff --git a/limitador/src/storage/redis/scripts.rs b/limitador/src/storage/redis/scripts.rs index a353d5ce..6c2432c9 100644 --- a/limitador/src/storage/redis/scripts.rs +++ b/limitador/src/storage/redis/scripts.rs @@ -9,15 +9,15 @@ // KEYS[1]: counter key // KEYS[2]: key that contains the counters that belong to the limit -// ARGV[1]: counter max val -// ARGV[2]: counter TTL -// ARGV[3]: delta +// ARGV[1]: counter TTL +// ARGV[2]: delta pub const SCRIPT_UPDATE_COUNTER: &str = " - local set_res = redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2], 'NX') - redis.call('incrby', KEYS[1], - ARGV[3]) - if set_res then - redis.call('sadd', KEYS[2], KEYS[1]) - end"; + local c = redis.call('incrby', KEYS[1], ARGV[2]) + if c == tonumber(ARGV[2]) then + redis.call('expire', KEYS[1], ARGV[1]) + redis.call('sadd', KEYS[2], KEYS[1]) + end + return c"; // KEYS: the function returns the value and TTL (in ms) for these keys // The first position of the list returned contains the value of KEYS[1], the