diff --git a/limitador-server/src/envoy_rls/server.rs b/limitador-server/src/envoy_rls/server.rs index 3511b202..7735f8f6 100644 --- a/limitador-server/src/envoy_rls/server.rs +++ b/limitador-server/src/envoy_rls/server.rs @@ -1,6 +1,5 @@ use opentelemetry::global; use opentelemetry::propagation::Extractor; -use std::cmp::Ordering; use std::collections::HashMap; use std::sync::Arc; @@ -100,7 +99,7 @@ impl RateLimitService for MyRateLimiter { Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update( &namespace, &values, - i64::from(hits_addend), + u64::from(hits_addend), self.rate_limit_headers != RateLimitHeaders::None, ), Limiter::Async(limiter) => { @@ -108,7 +107,7 @@ impl RateLimitService for MyRateLimiter { .check_rate_limited_and_update( &namespace, &values, - i64::from(hits_addend), + u64::from(hits_addend), self.rate_limit_headers != RateLimitHeaders::None, ) .await @@ -170,11 +169,7 @@ pub fn to_response_header( counters.sort_by(|a, b| { let a_remaining = a.remaining().unwrap_or(a.max_value()); let b_remaining = b.remaining().unwrap_or(b.max_value()); - if a_remaining - b_remaining < 0 { - Ordering::Less - } else { - Ordering::Greater - } + a_remaining.cmp(&b_remaining) }); let mut all_limits_text = String::with_capacity(20 * counters.len()); @@ -194,10 +189,7 @@ pub fn to_response_header( value: format!("{}{}", counter.max_value(), all_limits_text), }); - let mut remaining = counter.remaining().unwrap_or(counter.max_value()); - if remaining < 0 { - remaining = 0 - } + let remaining = counter.remaining().unwrap_or(counter.max_value()); headers.push(HeaderValue { key: "X-RateLimit-Remaining".to_string(), value: format!("{}", remaining), diff --git a/limitador-server/src/http_api/request_types.rs b/limitador-server/src/http_api/request_types.rs index 2ec9bab3..f8d7dc45 100644 --- a/limitador-server/src/http_api/request_types.rs +++ b/limitador-server/src/http_api/request_types.rs @@ -12,14 +12,14 @@ use std::collections::HashMap; pub struct CheckAndReportInfo { pub namespace: String, pub values: HashMap, - pub delta: i64, + pub delta: u64, pub response_headers: Option, } #[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Apiv2Schema)] pub struct Limit { namespace: String, - max_value: i64, + max_value: u64, seconds: u64, name: Option, conditions: Vec, @@ -61,7 +61,7 @@ impl From for LimitadorLimit { pub struct Counter { limit: Limit, set_variables: HashMap, - remaining: Option, + remaining: Option, expires_in_seconds: Option, } diff --git a/limitador-server/src/http_api/server.rs b/limitador-server/src/http_api/server.rs index 5407bc2c..ec0606b6 100644 --- a/limitador-server/src/http_api/server.rs +++ b/limitador-server/src/http_api/server.rs @@ -11,7 +11,6 @@ use paperclip::actix::{ // extension trait for actix_web::App and proc-macro attributes OpenApiExt, }; -use std::cmp::Ordering; use std::fmt; use std::sync::Arc; @@ -248,11 +247,7 @@ pub fn add_response_header( counters.sort_by(|a, b| { let a_remaining = a.remaining().unwrap_or(a.max_value()); let b_remaining = b.remaining().unwrap_or(b.max_value()); - if a_remaining - b_remaining < 0 { - Ordering::Less - } else { - Ordering::Greater - } + a_remaining.cmp(&b_remaining) }); let mut all_limits_text = String::with_capacity(20 * counters.len()); @@ -272,10 +267,7 @@ pub fn add_response_header( format!("{}{}", counter.max_value(), all_limits_text), )); - let mut remaining = counter.remaining().unwrap_or(counter.max_value()); - if remaining < 0 { - remaining = 0 - } + let remaining = counter.remaining().unwrap_or(counter.max_value()); resp.insert_header(( "X-RateLimit-Remaining".to_string(), format!("{}", remaining), @@ -581,7 +573,7 @@ mod tests { assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS); } - async fn create_test_limit(limiter: &Limiter, namespace: &str, max: i64) -> LimitadorLimit { + async fn create_test_limit(limiter: &Limiter, namespace: &str, max: u64) -> LimitadorLimit { // Create a limit let limit = LimitadorLimit::new( namespace, diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index 2b9bded1..8962d5ef 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -162,21 +162,16 @@ impl Limiter { Ok(f) => { let parsed_limits: Result, _> = serde_yaml::from_reader(f); match parsed_limits { - Ok(limits) => match find_first_negative_limit(&limits) { - None => { - match &self { - Self::Blocking(limiter) => limiter.configure_with(limits)?, - Self::Async(limiter) => limiter.configure_with(limits).await?, - } - if limitador::limit::check_deprecated_syntax_usages_and_reset() { - error!("You are using deprecated syntax for your conditions! See the migration guide https://docs.kuadrant.io/limitador/doc/migrations/conditions/") - } - Ok(()) + Ok(limits) => { + match &self { + Self::Blocking(limiter) => limiter.configure_with(limits)?, + Self::Async(limiter) => limiter.configure_with(limits).await?, + } + if limitador::limit::check_deprecated_syntax_usages_and_reset() { + error!("You are using deprecated syntax for your conditions! See the migration guide https://docs.kuadrant.io/limitador/doc/migrations/conditions/") } - Some(index) => Err(LimitadorServerError::ConfigFile(format!( - ".[{index}]: invalid value for `max_value`: positive integer expected" - ))), - }, + Ok(()) + } Err(e) => Err(LimitadorServerError::ConfigFile(format!( "Couldn't parse: {e}" ))), @@ -191,15 +186,6 @@ impl Limiter { } } -fn find_first_negative_limit(limits: &[Limit]) -> Option { - for (index, limit) in limits.iter().enumerate() { - if limit.max_value() < 0 { - return Some(index); - } - } - None -} - #[actix_rt::main] async fn main() -> Result<(), Box> { let config = { @@ -588,28 +574,23 @@ fn create_config() -> (Configuration, &'static str) { Ok(f) => { let parsed_limits: Result, _> = serde_yaml::from_reader(f); match parsed_limits { - Ok(limits) => match find_first_negative_limit(&limits) { - Some(index) => LimitadorServerError::ConfigFile(format!( - ".[{index}]: invalid value for `max_value`: positive integer expected" - )), - None => { - if limitador::limit::check_deprecated_syntax_usages_and_reset() { - eprintln!("Deprecated syntax for conditions corrected!\n") - } + Ok(limits) => { + if limitador::limit::check_deprecated_syntax_usages_and_reset() { + eprintln!("Deprecated syntax for conditions corrected!\n") + } - let output: Vec = - limits.iter().map(|l| l.into()).collect(); - match serde_yaml::to_string(&output) { - Ok(cfg) => { - println!("{cfg}"); - } - Err(err) => { - eprintln!("Config file is valid, but can't be output: {err}"); - } + let output: Vec = + limits.iter().map(|l| l.into()).collect(); + match serde_yaml::to_string(&output) { + Ok(cfg) => { + println!("{cfg}"); + } + Err(err) => { + eprintln!("Config file is valid, but can't be output: {err}"); } - process::exit(0); } - }, + process::exit(0); + } Err(e) => LimitadorServerError::ConfigFile(format!("Couldn't parse: {e}")), } } @@ -738,29 +719,3 @@ fn guess_cache_size() -> Option { fn leak(s: D) -> &'static str { return Box::leak(format!("{}", s).into_boxed_str()); } - -#[cfg(test)] -mod tests { - use crate::find_first_negative_limit; - use limitador::limit::Limit; - - #[test] - fn finds_negative_limits() { - let variables: [&str; 0] = []; - let mut limits: Vec = vec![ - Limit::new::<_, &str>("foo", 42, 10, [], variables), - Limit::new::<_, &str>("foo", -42, 10, [], variables), - ]; - - assert_eq!(find_first_negative_limit(&limits), Some(1)); - limits[0].set_max_value(-42); - assert_eq!(find_first_negative_limit(&limits), Some(0)); - limits[1].set_max_value(42); - assert_eq!(find_first_negative_limit(&limits), Some(0)); - limits[0].set_max_value(42); - assert_eq!(find_first_negative_limit(&limits), None); - - let nothing: [Limit; 0] = []; - assert_eq!(find_first_negative_limit(¬hing), None); - } -} diff --git a/limitador/benches/bench.rs b/limitador/benches/bench.rs index 1926efcf..5905a7af 100644 --- a/limitador/benches/bench.rs +++ b/limitador/benches/bench.rs @@ -56,7 +56,7 @@ const TEST_SCENARIOS: &[&TestScenario] = &[ struct TestCallParams { namespace: String, values: HashMap, - delta: i64, + delta: u64, } impl Display for TestScenario { @@ -280,7 +280,7 @@ fn generate_test_data( for limit_idx in 0..scenario.n_limits_per_ns { test_limits.push(Limit::new( namespace.clone(), - i64::MAX, + u64::MAX, ((limit_idx * 60) + 10) as u64, conditions.clone(), variables.clone(), diff --git a/limitador/src/counter.rs b/limitador/src/counter.rs index cc51c744..702d0c70 100644 --- a/limitador/src/counter.rs +++ b/limitador/src/counter.rs @@ -13,7 +13,7 @@ pub struct Counter { #[serde(serialize_with = "ordered_map")] set_variables: HashMap, - remaining: Option, + remaining: Option, expires_in: Option, } @@ -53,7 +53,7 @@ impl Counter { &self.limit } - pub fn max_value(&self) -> i64 { + pub fn max_value(&self) -> u64 { self.limit.max_value() } @@ -80,11 +80,11 @@ impl Counter { &self.set_variables } - pub fn remaining(&self) -> Option { + pub fn remaining(&self) -> Option { self.remaining } - pub fn set_remaining(&mut self, remaining: i64) { + pub fn set_remaining(&mut self, remaining: u64) { self.remaining = Some(remaining) } diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index 25010829..59f07a67 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -310,7 +310,7 @@ impl RateLimiter { &self, namespace: &Namespace, values: &HashMap, - delta: i64, + delta: u64, ) -> Result { let counters = self.counters_that_apply(namespace, values)?; @@ -332,7 +332,7 @@ impl RateLimiter { &self, namespace: &Namespace, values: &HashMap, - delta: i64, + delta: u64, ) -> Result<(), LimitadorError> { let counters = self.counters_that_apply(namespace, values)?; @@ -346,7 +346,7 @@ impl RateLimiter { &self, namespace: &Namespace, values: &HashMap, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { let mut counters = self.counters_that_apply(namespace, values)?; @@ -482,7 +482,7 @@ impl AsyncRateLimiter { &self, namespace: &Namespace, values: &HashMap, - delta: i64, + delta: u64, ) -> Result { let counters = self.counters_that_apply(namespace, values).await?; @@ -503,7 +503,7 @@ impl AsyncRateLimiter { &self, namespace: &Namespace, values: &HashMap, - delta: i64, + delta: u64, ) -> Result<(), LimitadorError> { let counters = self.counters_that_apply(namespace, values).await?; @@ -518,7 +518,7 @@ impl AsyncRateLimiter { &self, namespace: &Namespace, values: &HashMap, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { // the above where-clause is needed in order to call unwrap(). diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index d541007d..12adb7ff 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -53,7 +53,7 @@ impl From for Namespace { pub struct Limit { namespace: Namespace, #[serde(skip_serializing, default)] - max_value: i64, + max_value: u64, seconds: u64, #[serde(skip_serializing, default)] name: Option, @@ -308,7 +308,7 @@ where impl Limit { pub fn new, T: TryInto>( namespace: N, - max_value: i64, + max_value: u64, seconds: u64, conditions: impl IntoIterator, variables: impl IntoIterator>, @@ -335,7 +335,7 @@ impl Limit { &self.namespace } - pub fn max_value(&self) -> i64 { + pub fn max_value(&self) -> u64 { self.max_value } @@ -351,7 +351,7 @@ impl Limit { self.name = Some(name) } - pub fn set_max_value(&mut self, value: i64) { + pub fn set_max_value(&mut self, value: u64) { self.max_value = value; } diff --git a/limitador/src/storage/atomic_expiring_value.rs b/limitador/src/storage/atomic_expiring_value.rs index 465f7daf..f80eaa22 100644 --- a/limitador/src/storage/atomic_expiring_value.rs +++ b/limitador/src/storage/atomic_expiring_value.rs @@ -1,39 +1,39 @@ +use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; -use std::sync::atomic::{AtomicI64, AtomicU64}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[derive(Debug)] pub(crate) struct AtomicExpiringValue { - value: AtomicI64, + value: AtomicU64, expiry: AtomicExpiryTime, } impl AtomicExpiringValue { - pub fn new(value: i64, expiry: SystemTime) -> Self { + pub fn new(value: u64, expiry: SystemTime) -> Self { Self { - value: AtomicI64::new(value), + value: AtomicU64::new(value), expiry: AtomicExpiryTime::new(expiry), } } - pub fn value_at(&self, when: SystemTime) -> i64 { + pub fn value_at(&self, when: SystemTime) -> u64 { if self.expiry.expired_at(when) { return 0; } self.value.load(Ordering::SeqCst) } - pub fn value(&self) -> i64 { + pub fn value(&self) -> u64 { self.value_at(SystemTime::now()) } #[allow(dead_code)] - pub fn add_and_set_expiry(&self, delta: i64, expire_at: SystemTime) -> i64 { + pub fn add_and_set_expiry(&self, delta: u64, expire_at: SystemTime) -> u64 { self.expiry.update(expire_at); self.value.fetch_add(delta, Ordering::SeqCst) + delta } - pub fn update(&self, delta: i64, ttl: u64, when: SystemTime) -> i64 { + pub fn update(&self, delta: u64, ttl: u64, when: SystemTime) -> u64 { if self.expiry.update_if_expired(ttl, when) { self.value.store(delta, Ordering::SeqCst); return delta; @@ -115,7 +115,7 @@ impl Clone for AtomicExpiryTime { impl Default for AtomicExpiringValue { fn default() -> Self { AtomicExpiringValue { - value: AtomicI64::new(0), + value: AtomicU64::new(0), expiry: AtomicExpiryTime::new(UNIX_EPOCH), } } @@ -124,7 +124,7 @@ impl Default for AtomicExpiringValue { impl Clone for AtomicExpiringValue { fn clone(&self) -> Self { AtomicExpiringValue { - value: AtomicI64::new(self.value.load(Ordering::SeqCst)), + value: AtomicU64::new(self.value.load(Ordering::SeqCst)), expiry: self.expiry.clone(), } } @@ -187,7 +187,7 @@ mod tests { atomic_expiring_value.update(2, 1, now + Duration::from_secs(11)); }); }); - assert!([2i64, 3i64].contains(&atomic_expiring_value.value.load(Ordering::SeqCst))); + assert!([2u64, 3u64].contains(&atomic_expiring_value.value.load(Ordering::SeqCst))); } #[test] diff --git a/limitador/src/storage/disk/expiring_value.rs b/limitador/src/storage/disk/expiring_value.rs index b7c0cc81..948a85db 100644 --- a/limitador/src/storage/disk/expiring_value.rs +++ b/limitador/src/storage/disk/expiring_value.rs @@ -4,28 +4,28 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[derive(Clone, Debug)] pub(crate) struct ExpiringValue { - value: i64, + value: u64, expiry: SystemTime, } impl ExpiringValue { - pub fn new(value: i64, expiry: SystemTime) -> Self { + pub fn new(value: u64, expiry: SystemTime) -> Self { Self { value, expiry } } - pub fn value_at(&self, when: SystemTime) -> i64 { + pub fn value_at(&self, when: SystemTime) -> u64 { if self.expiry <= when { return 0; } self.value } - pub fn value(&self) -> i64 { + pub fn value(&self) -> u64 { self.value_at(SystemTime::now()) } #[must_use] - pub fn update(self, delta: i64, ttl: u64, now: SystemTime) -> Self { + pub fn update(self, delta: u64, ttl: u64, now: SystemTime) -> Self { let expiry = if self.expiry <= now { now + Duration::from_secs(ttl) } else { @@ -71,7 +71,7 @@ impl TryFrom<&[u8]> for ExpiringValue { let raw_val: [u8; 8] = raw[0..8].try_into()?; let raw_exp: [u8; 8] = raw[8..16].try_into()?; - let val = i64::from_be_bytes(raw_val); + let val = u64::from_be_bytes(raw_val); let exp = u64::from_be_bytes(raw_exp); Ok(Self { diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index a57a610d..4304af09 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -20,7 +20,7 @@ pub struct RocksDbStorage { impl CounterStorage for RocksDbStorage { #[tracing::instrument(skip_all)] - fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { + fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { let key = key_for_counter(counter); let value = self.insert_or_update(&key, counter, 0)?; Ok(counter.max_value() >= value.value() + delta) @@ -32,7 +32,7 @@ impl CounterStorage for RocksDbStorage { } #[tracing::instrument(skip_all)] - fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { let key = key_for_counter(counter); self.insert_or_update(&key, counter, delta)?; Ok(()) @@ -42,7 +42,7 @@ impl CounterStorage for RocksDbStorage { fn check_and_update( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { let mut keys: Vec> = Vec::with_capacity(counters.len()); @@ -66,7 +66,12 @@ impl CounterStorage for RocksDbStorage { if load_counters { counter.set_expires_in(ttl); - counter.set_remaining(counter.max_value() - val - delta); + counter.set_remaining( + counter + .max_value() + .checked_sub(val + delta) + .unwrap_or_default(), + ); } if counter.max_value() < val + delta { @@ -192,7 +197,7 @@ impl RocksDbStorage { &self, key: &[u8], counter: &Counter, - delta: i64, + delta: u64, ) -> Result { let now = SystemTime::now(); let entry = { diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index 35f2a681..10ad1f6a 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -18,7 +18,7 @@ pub struct InMemoryStorage { impl CounterStorage for InMemoryStorage { #[tracing::instrument(skip_all)] - fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { + fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { let limits_by_namespace = self.limits_for_namespace.read().unwrap(); let mut value = 0; @@ -50,7 +50,7 @@ impl CounterStorage for InMemoryStorage { } #[tracing::instrument(skip_all)] - fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); let now = SystemTime::now(); if counter.is_qualified() { @@ -97,7 +97,7 @@ impl CounterStorage for InMemoryStorage { fn check_and_update( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { let limits_by_namespace = self.limits_for_namespace.write().unwrap(); @@ -108,11 +108,11 @@ impl CounterStorage for InMemoryStorage { let now = SystemTime::now(); let mut process_counter = - |counter: &mut Counter, value: i64, delta: i64| -> Option { + |counter: &mut Counter, value: u64, delta: u64| -> Option { if load_counters { - let remaining = counter.max_value() - (value + delta); - counter.set_remaining(remaining); - if first_limited.is_none() && remaining < 0 { + let remaining = counter.max_value().checked_sub(value + delta); + counter.set_remaining(remaining.unwrap_or_default()); + if first_limited.is_none() && remaining.is_none() { first_limited = Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), )); @@ -278,7 +278,7 @@ impl InMemoryStorage { } } - fn counter_is_within_limits(counter: &Counter, current_val: Option<&i64>, delta: i64) -> bool { + fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool { match current_val { Some(current_val) => current_val + delta <= counter.max_value(), None => counter.max_value() >= delta, diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 61efd23e..7b8a3596 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -208,7 +208,7 @@ pub mod bin { .into_iter() .map(|(var, value)| (var.to_string(), value.to_string())) .collect(); - let limit = Limit::new(ns, i64::default(), seconds, conditions, map.keys()); + let limit = Limit::new(ns, u64::default(), seconds, conditions, map.keys()); Counter::new(limit, map) } diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index d00c14b7..4db70278 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -107,18 +107,18 @@ impl Storage { Ok(()) } - pub fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { + pub fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { self.counters.is_within_limits(counter, delta) } - pub fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + pub fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { self.counters.update_counter(counter, delta) } pub fn check_and_update( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { self.counters @@ -220,19 +220,19 @@ impl AsyncStorage { pub async fn is_within_limits( &self, counter: &Counter, - delta: i64, + delta: u64, ) -> Result { self.counters.is_within_limits(counter, delta).await } - pub async fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + pub async fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { self.counters.update_counter(counter, delta).await } pub async fn check_and_update<'a>( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { self.counters @@ -255,13 +255,13 @@ impl AsyncStorage { } pub trait CounterStorage: Sync + Send { - fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result; + fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result; fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr>; - fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr>; + fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr>; fn check_and_update( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result; fn get_counters(&self, limits: &HashSet) -> Result, StorageErr>; @@ -271,12 +271,12 @@ pub trait CounterStorage: Sync + Send { #[async_trait] pub trait AsyncCounterStorage: Sync + Send { - async fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result; - async fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr>; + async fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result; + async fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr>; async fn check_and_update<'a>( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result; async fn get_counters(&self, limits: HashSet) -> Result, StorageErr>; diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 800b3efe..a88ef332 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -7,7 +7,7 @@ use moka::sync::Cache; use std::collections::HashMap; use std::future::Future; use std::ops::Not; -use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; use tokio::select; @@ -16,39 +16,39 @@ use tokio::sync::{Notify, Semaphore}; #[derive(Debug)] pub struct CachedCounterValue { value: AtomicExpiringValue, - initial_value: AtomicI64, + initial_value: AtomicU64, from_authority: AtomicBool, } impl CachedCounterValue { - pub fn from_authority(counter: &Counter, value: i64) -> Self { + pub fn from_authority(counter: &Counter, value: u64) -> Self { let now = SystemTime::now(); Self { value: AtomicExpiringValue::new(value, now + Duration::from_secs(counter.seconds())), - initial_value: AtomicI64::new(value), + initial_value: AtomicU64::new(value), from_authority: AtomicBool::new(true), } } - pub fn load_from_authority_asap(counter: &Counter, temp_value: i64) -> Self { + pub fn load_from_authority_asap(counter: &Counter, temp_value: u64) -> Self { let now = SystemTime::now(); Self { value: AtomicExpiringValue::new( temp_value, now + Duration::from_secs(counter.seconds()), ), - initial_value: AtomicI64::new(0), + initial_value: AtomicU64::new(0), from_authority: AtomicBool::new(false), } } - pub fn add_from_authority(&self, delta: i64, expire_at: SystemTime) { + pub fn add_from_authority(&self, delta: u64, expire_at: SystemTime) { self.value.add_and_set_expiry(delta, expire_at); self.initial_value.fetch_add(delta, Ordering::SeqCst); self.from_authority.store(true, Ordering::Release); } - pub fn delta(&self, counter: &Counter, delta: i64) -> i64 { + pub fn delta(&self, counter: &Counter, delta: u64) -> u64 { let value = self .value .update(delta, counter.seconds(), SystemTime::now()); @@ -60,24 +60,20 @@ impl CachedCounterValue { value } - pub fn pending_writes(&self) -> Result { + pub fn pending_writes(&self) -> Result { self.pending_writes_and_value().map(|(writes, _)| writes) } - pub fn pending_writes_and_value(&self) -> Result<(i64, i64), ()> { + pub fn pending_writes_and_value(&self) -> Result<(u64, u64), ()> { let start = self.initial_value.load(Ordering::SeqCst); let value = self.value.value_at(SystemTime::now()); let offset = if start == 0 { value } else { - let writes = value - start; - if writes >= 0 { - writes - } else { - // self.value expired, is now less than the writes of the previous window - // which have not yet been reset... it'll be 0, so treat it as such. - value - } + let writes = value.checked_sub(start); + // self.value expired, is now less than the writes of the previous window + // which have not yet been reset... it'll be 0, so treat it as such. + writes.unwrap_or(value) }; match self .initial_value @@ -103,15 +99,15 @@ impl CachedCounterValue { value - start == 0 } - pub fn hits(&self, _: &Counter) -> i64 { + pub fn hits(&self, _: &Counter) -> u64 { self.value.value_at(SystemTime::now()) } - pub fn remaining(&self, counter: &Counter) -> i64 { + pub fn remaining(&self, counter: &Counter) -> u64 { counter.max_value() - self.hits(counter) } - pub fn is_limited(&self, counter: &Counter, delta: i64) -> bool { + pub fn is_limited(&self, counter: &Counter, delta: u64) -> bool { self.hits(counter) as i128 + delta as i128 > counter.max_value() as i128 } @@ -250,8 +246,8 @@ impl CountersCache { pub fn apply_remote_delta( &self, counter: Counter, - redis_val: i64, - remote_deltas: i64, + redis_val: u64, + remote_deltas: u64, redis_expiry: i64, ) -> Arc { if redis_expiry > 0 { @@ -279,7 +275,7 @@ impl CountersCache { )) } - pub async fn increase_by(&self, counter: &Counter, delta: i64) { + pub async fn increase_by(&self, counter: &Counter, delta: u64) { let val = self.cache.get_with_by_ref(counter, || { if let Some(entry) = self.batcher.updates.get(counter) { entry.value().clone() @@ -607,7 +603,7 @@ mod tests { ); } - fn test_counter(max_val: i64, other_values: Option>) -> Counter { + fn test_counter(max_val: u64, other_values: Option>) -> Counter { let mut values = HashMap::new(); values.insert("app_id".to_string(), "1".to_string()); if let Some(overrides) = other_values { diff --git a/limitador/src/storage/redis/mod.rs b/limitador/src/storage/redis/mod.rs index 180d8dd6..2b5c14e2 100644 --- a/limitador/src/storage/redis/mod.rs +++ b/limitador/src/storage/redis/mod.rs @@ -33,7 +33,7 @@ impl From for StorageErr { pub fn is_limited( counters: &mut [Counter], - delta: i64, + delta: u64, script_res: Vec>, ) -> Option { let mut counter_vals: Vec> = vec![]; @@ -47,8 +47,10 @@ pub fn is_limited( let mut first_limited = None; for (i, counter) in counters.iter_mut().enumerate() { // remaining = max - (curr_val + delta) - let remaining = counter.max_value() - (counter_vals[i].unwrap_or(0) + delta); - counter.set_remaining(remaining); + let remaining = counter + .max_value() + .checked_sub((counter_vals[i].unwrap_or(0) as u64) + delta); + counter.set_remaining(remaining.unwrap_or_default()); let expires_in = counter_ttls_msecs[i] .map(|x| { if x >= 0 { @@ -60,7 +62,7 @@ pub fn is_limited( .unwrap_or(Duration::from_secs(counter.seconds())); counter.set_expires_in(expires_in); - if first_limited.is_none() && remaining < 0 { + if first_limited.is_none() && remaining.is_none() { first_limited = Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), )) diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index 9718da5d..3a353b72 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -32,7 +32,7 @@ pub struct AsyncRedisStorage { #[async_trait] impl AsyncCounterStorage for AsyncRedisStorage { #[tracing::instrument(skip_all)] - async fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { + async fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { let mut con = self.conn_manager.clone(); match con @@ -40,13 +40,13 @@ impl AsyncCounterStorage for AsyncRedisStorage { .instrument(debug_span!("datastore")) .await? { - Some(val) => Ok(val + delta <= counter.max_value()), - None => Ok(counter.max_value() - delta >= 0), + Some(val) => Ok(u64::try_from(val).unwrap_or(0) + delta <= counter.max_value()), + None => Ok(counter.max_value().checked_sub(delta).is_some()), } } #[tracing::instrument(skip_all)] - async fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + async fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { let mut con = self.conn_manager.clone(); redis::Script::new(SCRIPT_UPDATE_COUNTER) @@ -65,7 +65,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { async fn check_and_update<'a>( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { let mut con = self.conn_manager.clone(); @@ -99,8 +99,10 @@ impl AsyncCounterStorage for AsyncRedisStorage { for (i, counter) in counters.iter().enumerate() { // remaining = max - (curr_val + delta) - let remaining = counter.max_value() - (counter_vals[i].unwrap_or(0) + delta); - if remaining < 0 { + let remaining = counter + .max_value() + .checked_sub(u64::try_from(counter_vals[i].unwrap_or(0)).unwrap_or(0) + delta); + if remaining.is_none() { return Ok(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), )); @@ -153,13 +155,13 @@ impl AsyncCounterStorage for AsyncRedisStorage { .await? }; if let Some(val) = option { - counter.set_remaining(limit.max_value() - val); - let ttl = { + counter.set_remaining(limit.max_value() - u64::try_from(val).unwrap_or(0)); + let ttl: i64 = { con.ttl(&counter_key) .instrument(debug_span!("datastore")) .await? }; - counter.set_expires_in(Duration::from_secs(ttl)); + counter.set_expires_in(Duration::from_secs(u64::try_from(ttl).unwrap_or(0))); res.insert(counter); } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index d3f094e6..38f5c831 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -48,14 +48,14 @@ pub struct CachedRedisStorage { #[async_trait] impl AsyncCounterStorage for CachedRedisStorage { #[tracing::instrument(skip_all)] - async fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { + async fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { self.async_redis_storage .is_within_limits(counter, delta) .await } #[tracing::instrument(skip_all)] - async fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + async fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { self.async_redis_storage .update_counter(counter, delta) .await @@ -69,7 +69,7 @@ impl AsyncCounterStorage for CachedRedisStorage { async fn check_and_update<'a>( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { let mut not_cached: Vec<&mut Counter> = vec![]; @@ -88,7 +88,11 @@ impl AsyncCounterStorage for CachedRedisStorage { first_limited = Some(a); } if load_counters { - counter.set_remaining(val.remaining(counter) - delta); + counter.set_remaining( + val.remaining(counter) + .checked_sub(delta) + .unwrap_or_default(), + ); counter.set_expires_in(val.to_next_window()); } } @@ -103,7 +107,7 @@ impl AsyncCounterStorage for CachedRedisStorage { for counter in not_cached.iter_mut() { let fake = CachedCounterValue::load_from_authority_asap(counter, 0); let remaining = fake.remaining(counter); - if first_limited.is_none() && remaining <= 0 { + if first_limited.is_none() && remaining == 0 { first_limited = Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), )); @@ -280,14 +284,14 @@ impl CachedRedisStorageBuilder { async fn update_counters( redis_conn: &mut C, counters_and_deltas: HashMap>, -) -> Result, StorageErr> { +) -> Result, StorageErr> { let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); let mut script_invocation = redis_script.prepare_invoke(); let res = if counters_and_deltas.is_empty() { Default::default() } else { - let mut res: Vec<(Counter, i64, i64, i64)> = Vec::with_capacity(counters_and_deltas.len()); + let mut res: Vec<(Counter, u64, u64, i64)> = Vec::with_capacity(counters_and_deltas.len()); for (counter, value) in counters_and_deltas { let (delta, last_value_from_redis) = value @@ -316,8 +320,10 @@ async fn update_counters( for (i, j) in counters_range.zip(script_res_range) { let (_, val, delta, expires_at) = &mut res[i]; - *val = script_res[j]; - *delta = script_res[j] - *delta; + *val = u64::try_from(script_res[j]).unwrap_or(0); + *delta = u64::try_from(script_res[j]) + .unwrap_or(0) + .saturating_sub(*delta); *expires_at = script_res[j + 1]; } res @@ -396,9 +402,9 @@ mod tests { #[tokio::test] async fn batch_update_counters() { - const NEW_VALUE_FROM_REDIS: i64 = 10; - const INITIAL_VALUE_FROM_REDIS: i64 = 1; - const LOCAL_INCREMENTS: i64 = 2; + const NEW_VALUE_FROM_REDIS: u64 = 10; + const INITIAL_VALUE_FROM_REDIS: u64 = 1; + const LOCAL_INCREMENTS: u64 = 2; let mut counters_and_deltas = HashMap::new(); let counter = Counter::new( @@ -424,7 +430,7 @@ mod tests { .duration_since(UNIX_EPOCH) .unwrap(); let mock_response = Value::Bulk(vec![ - Value::Int(NEW_VALUE_FROM_REDIS), + Value::Int(NEW_VALUE_FROM_REDIS as i64), Value::Int(one_sec_from_now.as_millis() as i64), ]); diff --git a/limitador/src/storage/redis/redis_sync.rs b/limitador/src/storage/redis/redis_sync.rs index cf7088f6..11113de1 100644 --- a/limitador/src/storage/redis/redis_sync.rs +++ b/limitador/src/storage/redis/redis_sync.rs @@ -24,12 +24,12 @@ pub struct RedisStorage { impl CounterStorage for RedisStorage { #[tracing::instrument(skip_all)] - fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { + fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { let mut con = self.conn_pool.get()?; match con.get::>(key_for_counter(counter))? { - Some(val) => Ok(val + delta <= counter.max_value()), - None => Ok(counter.max_value() - delta >= 0), + Some(val) => Ok(u64::try_from(val).unwrap_or(0) + delta <= counter.max_value()), + None => Ok(counter.max_value().checked_sub(delta).is_some()), } } @@ -39,7 +39,7 @@ impl CounterStorage for RedisStorage { } #[tracing::instrument(skip_all)] - fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { let mut con = self.conn_pool.get()?; redis::Script::new(SCRIPT_UPDATE_COUNTER) @@ -56,7 +56,7 @@ impl CounterStorage for RedisStorage { fn check_and_update( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { let mut con = self.conn_pool.get()?; @@ -80,8 +80,10 @@ impl CounterStorage for RedisStorage { for (i, counter) in counters.iter().enumerate() { // remaining = max - (curr_val + delta) - let remaining = counter.max_value() - (counter_vals[i].unwrap_or(0) + delta); - if remaining < 0 { + let remaining = counter + .max_value() + .checked_sub(u64::try_from(counter_vals[i].unwrap_or(0)).unwrap_or(0) + delta); + if remaining.is_none() { return Ok(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), )); @@ -124,7 +126,11 @@ impl CounterStorage for RedisStorage { // This does not cause any bugs, but consumes memory // unnecessarily. if let Some(val) = con.get::>(counter_key.clone())? { - counter.set_remaining(limit.max_value() - val); + counter.set_remaining( + limit + .max_value() + .saturating_sub(u64::try_from(val).unwrap_or(0)), + ); let ttl = con.ttl(&counter_key)?; counter.set_expires_in(Duration::from_secs(ttl)); diff --git a/limitador/tests/helpers/tests_limiter.rs b/limitador/tests/helpers/tests_limiter.rs index 7e77fdfc..b7bc4ca9 100644 --- a/limitador/tests/helpers/tests_limiter.rs +++ b/limitador/tests/helpers/tests_limiter.rs @@ -71,7 +71,7 @@ impl TestsLimiter { &self, namespace: &str, values: &HashMap, - delta: i64, + delta: u64, ) -> Result { match &self.limiter_impl { LimiterImpl::Blocking(limiter) => { @@ -89,7 +89,7 @@ impl TestsLimiter { &self, namespace: &str, values: &HashMap, - delta: i64, + delta: u64, ) -> Result<(), LimitadorError> { match &self.limiter_impl { LimiterImpl::Blocking(limiter) => { @@ -107,7 +107,7 @@ impl TestsLimiter { &self, namespace: &str, values: &HashMap, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { match &self.limiter_impl { diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index 276df09a..f14d8f95 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -748,7 +748,7 @@ mod test { if let Some(ttl) = counter.expires_in() { assert!(ttl.as_secs() <= 60); } - assert_eq!(counter.remaining().unwrap(), -1); + assert_eq!(counter.remaining().unwrap(), 0); } }