diff --git a/limitador-server/src/envoy_rls/server.rs b/limitador-server/src/envoy_rls/server.rs index 8f252ce5..d9925a68 100644 --- a/limitador-server/src/envoy_rls/server.rs +++ b/limitador-server/src/envoy_rls/server.rs @@ -109,10 +109,12 @@ impl RateLimitService for MyRateLimiter { req.hits_addend }; + let ctx = (&values).into(); + let rate_limited_resp = match &*self.limiter { Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update( &namespace, - &values, + &ctx, u64::from(hits_addend), self.rate_limit_headers != RateLimitHeaders::None, ), @@ -120,7 +122,7 @@ impl RateLimitService for MyRateLimiter { limiter .check_rate_limited_and_update( &namespace, - &values, + &ctx, u64::from(hits_addend), self.rate_limit_headers != RateLimitHeaders::None, ) diff --git a/limitador-server/src/http_api/server.rs b/limitador-server/src/http_api/server.rs index d436e2a0..ddbd0d64 100644 --- a/limitador-server/src/http_api/server.rs +++ b/limitador-server/src/http_api/server.rs @@ -122,9 +122,10 @@ async fn check( response_headers: _, } = request.into_inner(); let namespace = namespace.into(); + let ctx = (&values).into(); let is_rate_limited_result = match state.get_ref().limiter() { - Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &values, delta), - Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &values, delta).await, + Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &ctx, delta), + Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &ctx, delta).await, }; match is_rate_limited_result { @@ -152,9 +153,10 @@ async fn report( response_headers: _, } = request.into_inner(); let namespace = namespace.into(); + let ctx = (&values).into(); let update_counters_result = match data.get_ref().limiter() { - Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &values, delta), - Limiter::Async(limiter) => limiter.update_counters(&namespace, &values, delta).await, + Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &ctx, delta), + Limiter::Async(limiter) => limiter.update_counters(&namespace, &ctx, delta).await, }; match update_counters_result { @@ -176,22 +178,18 @@ async fn check_and_report( response_headers, } = request.into_inner(); let namespace = namespace.into(); + let ctx = (&values).into(); let rate_limit_data = data.get_ref(); let rate_limited_and_update_result = match rate_limit_data.limiter() { Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update( &namespace, - &values, + &ctx, delta, response_headers.is_some(), ), Limiter::Async(limiter) => { limiter - .check_rate_limited_and_update( - &namespace, - &values, - delta, - response_headers.is_some(), - ) + .check_rate_limited_and_update(&namespace, &ctx, delta, response_headers.is_some()) .await } }; diff --git a/limitador/benches/bench.rs b/limitador/benches/bench.rs index b6c0cf70..68a360b9 100644 --- a/limitador/benches/bench.rs +++ b/limitador/benches/bench.rs @@ -329,7 +329,7 @@ fn bench_is_rate_limited( rate_limiter .is_rate_limited( ¶ms.namespace.to_owned().into(), - ¶ms.values, + &(¶ms.values).into(), params.delta, ) .unwrap(), @@ -357,7 +357,7 @@ fn async_bench_is_rate_limited( rate_limiter .is_rate_limited( ¶ms.namespace.to_owned().into(), - ¶ms.values, + &(¶ms.values).into(), params.delta, ) .await @@ -383,7 +383,7 @@ fn bench_update_counters( rate_limiter .update_counters( ¶ms.namespace.to_owned().into(), - ¶ms.values, + &(¶ms.values).into(), params.delta, ) .unwrap(); @@ -410,7 +410,7 @@ fn async_bench_update_counters( rate_limiter .update_counters( ¶ms.namespace.to_owned().into(), - ¶ms.values, + &(¶ms.values).into(), params.delta, ) .await @@ -437,7 +437,7 @@ fn bench_check_rate_limited_and_update( rate_limiter .check_rate_limited_and_update( ¶ms.namespace.to_owned().into(), - ¶ms.values, + &(¶ms.values).into(), params.delta, false, ) @@ -467,7 +467,7 @@ fn async_bench_check_rate_limited_and_update( rate_limiter .check_rate_limited_and_update( ¶ms.namespace.to_owned().into(), - ¶ms.values, + &(¶ms.values).into(), params.delta, false, ) diff --git a/limitador/src/counter.rs b/limitador/src/counter.rs index 788bf8a7..68453a8e 100644 --- a/limitador/src/counter.rs +++ b/limitador/src/counter.rs @@ -1,4 +1,4 @@ -use crate::limit::{Limit, Namespace}; +use crate::limit::{Context, Limit, Namespace}; use crate::LimitadorResult; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; @@ -17,15 +17,9 @@ pub struct Counter { } impl Counter { - pub fn new>>( - limit: L, - set_variables: HashMap, - ) -> LimitadorResult> { + pub fn new>>(limit: L, ctx: &Context) -> LimitadorResult> { let limit = limit.into(); - let mut vars = set_variables; - vars.retain(|var, _| limit.has_variable(var)); - - let variables = limit.resolve_variables(vars)?; + let variables = limit.resolve_variables(ctx)?; match variables { None => Ok(None), Some(variables) => Ok(Some(Self { @@ -159,11 +153,9 @@ mod tests { Vec::default(), [var.try_into().expect("failed parsing!")], ); - let counter = Counter::new( - limit, - HashMap::from([("ts".to_string(), "2019-10-12T13:20:50.52Z".to_string())]), - ) - .expect("failed creating counter"); + let map = HashMap::from([("ts".to_string(), "2019-10-12T13:20:50.52Z".to_string())]); + let ctx = (&map).into(); + let counter = Counter::new(limit, &ctx).expect("failed creating counter"); assert_eq!( counter.unwrap().set_variables.get(var), Some("13".to_string()).as_ref() diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index d5016a18..6902028c 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -94,7 +94,7 @@ //! //! ``` //! use limitador::RateLimiter; -//! use limitador::limit::Limit; +//! use limitador::limit::{Limit, Context}; //! use std::collections::HashMap; //! //! let mut rate_limiter = RateLimiter::new(1000); @@ -116,22 +116,23 @@ //! //! // Check if we can report //! let namespace = "my_namespace".into(); -//! assert!(!rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap()); +//! let ctx = &values_to_report.into(); +//! assert!(!rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap()); //! //! // Report -//! rate_limiter.update_counters(&namespace, &values_to_report, 1).unwrap(); +//! rate_limiter.update_counters(&namespace, &ctx, 1).unwrap(); //! //! // Check and report again -//! assert!(!rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap()); -//! rate_limiter.update_counters(&namespace, &values_to_report, 1).unwrap(); +//! assert!(!rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap()); +//! rate_limiter.update_counters(&namespace, &ctx, 1).unwrap(); //! //! // We've already reported 2, so reporting another one should not be allowed -//! assert!(rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap()); +//! assert!(rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap()); //! //! // You can also check and report if not limited in a single call. It's useful //! // for example, when calling Limitador from a proxy. Instead of doing 2 //! // separate calls, we can issue just one: -//! rate_limiter.check_rate_limited_and_update(&namespace, &values_to_report, 1, false).unwrap(); +//! rate_limiter.check_rate_limited_and_update(&namespace, &ctx, 1, false).unwrap(); //! ``` //! //! # Async @@ -194,7 +195,7 @@ use crate::counter::Counter; use crate::errors::LimitadorError; -use crate::limit::{Limit, Namespace}; +use crate::limit::{Context, Limit, Namespace}; use crate::storage::in_memory::InMemoryStorage; use crate::storage::{AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage}; use std::collections::{HashMap, HashSet}; @@ -358,7 +359,7 @@ impl RateLimiter { pub fn is_rate_limited( &self, namespace: &Namespace, - values: &HashMap, + values: &Context, delta: u64, ) -> LimitadorResult { let counters = self.counters_that_apply(namespace, values)?; @@ -380,10 +381,10 @@ impl RateLimiter { pub fn update_counters( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context, delta: u64, ) -> LimitadorResult<()> { - let counters = self.counters_that_apply(namespace, values)?; + let counters = self.counters_that_apply(namespace, ctx)?; counters .iter() @@ -394,11 +395,11 @@ impl RateLimiter { pub fn check_rate_limited_and_update( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context, delta: u64, load_counters: bool, ) -> LimitadorResult { - let mut counters = self.counters_that_apply(namespace, values)?; + let mut counters = self.counters_that_apply(namespace, ctx)?; if counters.is_empty() { return Ok(CheckResult { @@ -476,14 +477,13 @@ impl RateLimiter { fn counters_that_apply( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context, ) -> LimitadorResult> { let limits = self.storage.get_limits(namespace); - let ctx = values.into(); limits .iter() - .filter(|lim| lim.applies(&ctx)) - .filter_map(|lim| match Counter::new(Arc::clone(lim), values.clone()) { + .filter(|lim| lim.applies(ctx)) + .filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) { Ok(None) => None, Ok(Some(c)) => Some(Ok(c)), Err(e) => Some(Err(e)), @@ -533,10 +533,10 @@ impl AsyncRateLimiter { pub async fn is_rate_limited( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context<'_>, delta: u64, ) -> LimitadorResult { - let counters = self.counters_that_apply(namespace, values).await?; + let counters = self.counters_that_apply(namespace, ctx).await?; for counter in counters { match self.storage.is_within_limits(&counter, delta).await { @@ -554,10 +554,10 @@ impl AsyncRateLimiter { pub async fn update_counters( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context<'_>, delta: u64, ) -> LimitadorResult<()> { - let counters = self.counters_that_apply(namespace, values).await?; + let counters = self.counters_that_apply(namespace, ctx).await?; for counter in counters { self.storage.update_counter(&counter, delta).await? @@ -569,12 +569,12 @@ impl AsyncRateLimiter { pub async fn check_rate_limited_and_update( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context<'_>, delta: u64, load_counters: bool, ) -> LimitadorResult { // the above where-clause is needed in order to call unwrap(). - let mut counters = self.counters_that_apply(namespace, values).await?; + let mut counters = self.counters_that_apply(namespace, ctx).await?; if counters.is_empty() { return Ok(CheckResult { @@ -657,14 +657,13 @@ impl AsyncRateLimiter { async fn counters_that_apply( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context<'_>, ) -> LimitadorResult> { let limits = self.storage.get_limits(namespace); - let ctx = values.into(); limits .iter() - .filter(|lim| lim.applies(&ctx)) - .filter_map(|lim| match Counter::new(Arc::clone(lim), values.clone()) { + .filter(|lim| lim.applies(ctx)) + .filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) { Ok(None) => None, Ok(Some(c)) => Some(Ok(c)), Err(e) => Some(Err(e)), @@ -696,9 +695,8 @@ fn classify_limits_by_namespace( #[cfg(test)] mod test { - use crate::limit::{Expression, Limit}; + use crate::limit::{Context, Expression, Limit}; use crate::RateLimiter; - use std::collections::HashMap; #[test] fn properly_updates_existing_limits() { @@ -713,7 +711,7 @@ mod test { assert_eq!(limits.iter().next().unwrap().max_value(), 42); let r = rl - .check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true) + .check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true) .unwrap(); assert_eq!(r.counters.first().unwrap().max_value(), 42); @@ -727,7 +725,7 @@ mod test { assert_eq!(limits.iter().next().unwrap().max_value(), 50); let r = rl - .check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true) + .check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true) .unwrap(); assert_eq!(r.counters.first().unwrap().max_value(), 50); } diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index 6653cb07..d6102234 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -1,14 +1,13 @@ -use cel::Context; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; -use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashSet}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; mod cel; +pub use cel::{Context, Expression, Predicate}; pub use cel::{EvaluationError, ParseError}; -pub use cel::{Expression, Predicate}; #[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord, Serialize, Deserialize)] pub struct Namespace(String); @@ -133,13 +132,12 @@ impl Limit { pub fn resolve_variables( &self, - vars: HashMap, + ctx: &Context, ) -> Result>, EvaluationError> { - let ctx = Context::new(String::default(), &vars); let mut map = BTreeMap::new(); for variable in &self.variables { let name = variable.source().into(); - match variable.eval(&ctx)? { + match variable.eval(ctx)? { None => return Ok(None), Some(value) => { map.insert(name, value); @@ -230,6 +228,7 @@ mod tests { use super::*; use crate::counter::Counter; use std::cmp::Ordering::Equal; + use std::collections::HashMap; #[test] fn limit_can_have_an_optional_name() { @@ -466,7 +465,7 @@ mod tests { let ctx = Context::new(String::default(), &map); assert!(limit.applies(&ctx)); assert_eq!( - Counter::new(limit, map) + Counter::new(limit, &ctx) .expect("failed") .unwrap() .set_variables() diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index 356aaaaf..0d758c27 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -249,12 +249,11 @@ mod tests { vec!["req_method == 'GET'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ); - let counter = Counter::new( - limit, - HashMap::from([("app_id".to_string(), "foo".to_string())]), - ) - .unwrap() - .expect("must have a counter"); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); + let counter = Counter::new(limit, &ctx) + .unwrap() + .expect("must have a counter"); let tmp = TempDir::new().expect("We should have a dir!"); { diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index aee13fa2..1c8f23c7 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -9,7 +9,7 @@ use tokio::sync::mpsc::Sender; use tracing::debug; use crate::counter::Counter; -use crate::limit::Limit; +use crate::limit::{Context, Limit}; use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::CounterUpdate; use crate::storage::distributed::grpc::{Broker, CounterEntry}; @@ -47,7 +47,7 @@ impl CounterStorage for CrInMemoryStorage { let key = encode_limit_to_key(limit); limits.entry(key.clone()).or_insert(Arc::new(CounterEntry { key, - counter: Counter::new(limit.clone(), HashMap::default()) + counter: Counter::new(limit.clone(), &Context::default()) .expect("counter creation can't fail! no vars to resolve!") .expect("must have a counter"), value: CrCounterValue::new( @@ -336,15 +336,14 @@ fn encode_counter_to_key(counter: &Counter) -> Vec { fn encode_limit_to_key(limit: &Limit) -> Vec { // fixme this is broken! - let counter = Counter::new( - limit.clone(), - limit - .variables() - .into_iter() - .map(|k| (k, "".to_string())) - .collect(), - ) - .expect("counter creation can't fail! faked vars!") - .expect("must have a counter"); + let vars: HashMap = limit + .variables() + .into_iter() + .map(|k| (k, "".to_string())) + .collect(); + let ctx = (&vars).into(); + let counter = Counter::new(limit.clone(), &ctx) + .expect("counter creation can't fail! faked vars!") + .expect("must have a counter"); key_for_counter_v2(&counter) } diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index 3b273d9c..6d741f65 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -1,5 +1,5 @@ use crate::counter::Counter; -use crate::limit::{Limit, Namespace}; +use crate::limit::{Context, Limit, Namespace}; use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::{Authorization, CounterStorage, StorageErr}; use moka::sync::Cache; @@ -212,7 +212,7 @@ impl InMemoryStorage { if limit.namespace() == namespace { res.insert( // todo fixme - Counter::new(limit.clone(), HashMap::default()) + Counter::new(limit.clone(), &Context::default()) .unwrap() .unwrap(), counter.clone(), @@ -269,18 +269,14 @@ mod tests { vec!["req_method == 'GET'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ); - let counter_1 = Counter::new( - limit_1, - HashMap::from([("app_id".to_string(), "foo".to_string())]), - ) - .expect("counter creation failed!") - .expect("Should have a counter"); - let counter_2 = Counter::new( - limit_2, - HashMap::from([("app_id".to_string(), "foo".to_string())]), - ) - .expect("counter creation failed!") - .expect("Should have a counter"); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); + let counter_1 = Counter::new(limit_1, &ctx) + .expect("counter creation failed!") + .expect("Should have a counter"); + let counter_2 = Counter::new(limit_2, &ctx) + .expect("counter creation failed!") + .expect("Should have a counter"); storage.update_counter(&counter_1, 1).unwrap(); storage.update_counter(&counter_2, 1).unwrap(); diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 58506328..247eb2f6 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -153,12 +153,11 @@ mod tests { vec!["req_method == 'GET'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ); - let counter = Counter::new( - limit.clone(), - HashMap::from([("app_id".to_string(), "foo".to_string())]), - ) - .expect("counter creation failed!") - .expect("must have a counter"); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); + let counter = Counter::new(limit.clone(), &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let raw = key_for_counter(&counter); assert_eq!(counter, partial_counter_from_counter_key(&raw)); } @@ -173,12 +172,11 @@ mod tests { vec!["req_method == 'GET'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ); - let counter = Counter::new( - limit.clone(), - HashMap::from([("app_id".to_string(), "foo".to_string())]), - ) - .expect("counter creation failed!") - .expect("must have a counter"); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); + let counter = Counter::new(limit.clone(), &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let mut other = counter.clone(); other.set_remaining(123); other.set_expires_in(Duration::from_millis(456)); @@ -389,7 +387,8 @@ pub mod bin { vars.insert("role".to_string(), "admin".to_string()); vars.insert("app_id".to_string(), "123".to_string()); vars.insert("wat".to_string(), "dunno".to_string()); - let counter = Counter::new(limit.clone(), vars) + let ctx = (&vars).into(); + let counter = Counter::new(limit.clone(), &ctx) .expect("counter creation failed!") .expect("must have a counter"); @@ -412,7 +411,8 @@ pub mod bin { ); let mut variables = HashMap::default(); variables.insert("app_id".to_string(), "123".to_string()); - let counter = Counter::new(limit.clone(), variables) + let ctx = (&variables).into(); + let counter = Counter::new(limit.clone(), &ctx) .expect("counter creation failed!") .expect("must have a counter"); let raw = key_for_counter(&counter); @@ -429,12 +429,11 @@ pub mod bin { vec!["req_method == 'GET'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ); - let counter = Counter::new( - limit, - HashMap::from([("app_id".to_string(), "foo".to_string())]), - ) - .expect("counter creation failed!") - .expect("must have a counter"); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); + let counter = Counter::new(limit, &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let serialized_counter = key_for_counter(&counter); let prefix = prefix_for_namespace(namespace); @@ -460,20 +459,16 @@ pub mod bin { vec!["app_id".try_into().expect("failed parsing!")], ); - let counter_with_id = Counter::new( - limit_with_id, - HashMap::from([("app_id".to_string(), "foo".to_string())]), - ) - .expect("counter creation failed!") - .expect("must have a counter"); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); + let counter_with_id = Counter::new(limit_with_id, &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let serialized_with_id_counter = key_for_counter(&counter_with_id); - let counter_without_id = Counter::new( - limit_without_id, - HashMap::from([("app_id".to_string(), "foo".to_string())]), - ) - .expect("counter creation failed!") - .expect("must have a counter"); + let counter_without_id = Counter::new(limit_without_id, &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let serialized_without_id_counter = key_for_counter(&counter_without_id); // the original key_for_counter continues to encode kinda big diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 03cdc900..fb05300e 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -676,6 +676,7 @@ mod tests { if let Some(overrides) = other_values { values.extend(overrides); } + let ctx = (&values).into(); Counter::new( Limit::new( "test_namespace", @@ -684,7 +685,7 @@ mod tests { vec!["req_method == 'POST'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ), - values, + &ctx, ) .expect("failed creating counter") .expect("Should have a counter") diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 70aaac75..0cc0e99e 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -439,6 +439,8 @@ mod tests { const LOCAL_INCREMENTS: u64 = 2; let mut counters_and_deltas = HashMap::new(); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); let counter = Counter::new( Limit::new( "test_namespace", @@ -447,7 +449,7 @@ mod tests { vec!["req_method == 'GET'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ), - HashMap::from([("app_id".to_string(), "foo".to_string())]), + &ctx, ) .expect("counter creation failed!") .expect("must have a counter"); @@ -503,6 +505,8 @@ mod tests { #[tokio::test] async fn flush_batcher_and_update_counters_test() { + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); let counter = Counter::new( Limit::new( "test_namespace", @@ -511,7 +515,7 @@ mod tests { vec!["req_method == 'POST'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ), - HashMap::from([("app_id".to_string(), "foo".to_string())]), + &ctx, ) .expect("counter creation failed!") .expect("must have a counter"); @@ -564,6 +568,8 @@ mod tests { #[tokio::test] async fn flush_batcher_reverts_on_err() { + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = (&map).into(); let counter = Counter::new( Limit::new( "test_namespace", @@ -572,7 +578,7 @@ mod tests { vec!["req_method == 'POST'".try_into().expect("failed parsing!")], vec!["app_id".try_into().expect("failed parsing!")], ), - HashMap::from([("app_id".to_string(), "foo".to_string())]), + &ctx, ) .expect("counter creation failed!") .expect("must have a counter"); diff --git a/limitador/tests/helpers/tests_limiter.rs b/limitador/tests/helpers/tests_limiter.rs index 2bae0c3e..89e631b1 100644 --- a/limitador/tests/helpers/tests_limiter.rs +++ b/limitador/tests/helpers/tests_limiter.rs @@ -1,8 +1,8 @@ use limitador::counter::Counter; use limitador::errors::LimitadorError; -use limitador::limit::{Limit, Namespace}; +use limitador::limit::{Context, Limit, Namespace}; use limitador::{AsyncRateLimiter, CheckResult, RateLimiter}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; // This exposes a struct that wraps both implementations of the rate limiter, // the blocking and the async one. This allows us to avoid duplications in the @@ -70,17 +70,15 @@ impl TestsLimiter { pub async fn is_rate_limited( &self, namespace: &str, - values: &HashMap, + ctx: &Context<'_>, delta: u64, ) -> Result { match &self.limiter_impl { LimiterImpl::Blocking(limiter) => { - limiter.is_rate_limited(&namespace.into(), values, delta) + limiter.is_rate_limited(&namespace.into(), ctx, delta) } LimiterImpl::Async(limiter) => { - limiter - .is_rate_limited(&namespace.into(), values, delta) - .await + limiter.is_rate_limited(&namespace.into(), ctx, delta).await } } } @@ -88,17 +86,15 @@ impl TestsLimiter { pub async fn update_counters( &self, namespace: &str, - values: &HashMap, + ctx: &Context<'_>, delta: u64, ) -> Result<(), LimitadorError> { match &self.limiter_impl { LimiterImpl::Blocking(limiter) => { - limiter.update_counters(&namespace.into(), values, delta) + limiter.update_counters(&namespace.into(), ctx, delta) } LimiterImpl::Async(limiter) => { - limiter - .update_counters(&namespace.into(), values, delta) - .await + limiter.update_counters(&namespace.into(), ctx, delta).await } } } @@ -106,20 +102,17 @@ impl TestsLimiter { pub async fn check_rate_limited_and_update( &self, namespace: &str, - values: &HashMap, + ctx: &Context<'_>, delta: u64, load_counters: bool, ) -> Result { match &self.limiter_impl { - LimiterImpl::Blocking(limiter) => limiter.check_rate_limited_and_update( - &namespace.into(), - values, - delta, - load_counters, - ), + LimiterImpl::Blocking(limiter) => { + limiter.check_rate_limited_and_update(&namespace.into(), ctx, delta, load_counters) + } LimiterImpl::Async(limiter) => { limiter - .check_rate_limited_and_update(&namespace.into(), values, delta, load_counters) + .check_rate_limited_and_update(&namespace.into(), ctx, delta, load_counters) .await } } diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index 7a71372c..4f8f2d4a 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -380,7 +380,7 @@ mod test { values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &(&values).into(), 1) .await .unwrap(); @@ -473,7 +473,7 @@ mod test { values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &(&values).into(), 1) .await .unwrap(); @@ -506,22 +506,23 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); for i in 0..max_hits { assert!( !rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); } assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -543,22 +544,23 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); for i in 0..max_hits { assert!( !rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); } assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -590,32 +592,34 @@ mod test { let mut get_values: HashMap = HashMap::new(); get_values.insert("req_method".to_string(), "GET".to_string()); get_values.insert("app_id".to_string(), "test_app_id".to_string()); + let get_ctx = (&get_values).into(); let mut post_values: HashMap = HashMap::new(); post_values.insert("req_method".to_string(), "POST".to_string()); post_values.insert("app_id".to_string(), "test_app_id".to_string()); + let post_ctx = (&post_values).into(); for i in 0..max_hits { assert!( !rate_limiter - .is_rate_limited(namespace, &get_values, 1) + .is_rate_limited(namespace, &get_ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); assert!( !rate_limiter - .is_rate_limited(namespace, &post_values, 1) + .is_rate_limited(namespace, &post_ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .check_rate_limited_and_update(namespace, &get_values, 1, false) + .check_rate_limited_and_update(namespace, &get_ctx, 1, false) .await .unwrap(); rate_limiter - .check_rate_limited_and_update(namespace, &post_values, 1, false) + .check_rate_limited_and_update(namespace, &post_ctx, 1, false) .await .unwrap(); } @@ -624,11 +628,11 @@ mod test { tokio::time::sleep(Duration::from_millis(40)).await; assert!(rate_limiter - .is_rate_limited(namespace, &get_values, 1) + .is_rate_limited(namespace, &get_ctx, 1) .await .unwrap()); assert!(!rate_limiter - .is_rate_limited(namespace, &post_values, 1) + .is_rate_limited(namespace, &post_ctx, 1) .await .unwrap()); } @@ -648,21 +652,22 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); // Report 5 hits twice. The limit is 10, so the first limited call should be // the third one. for _ in 0..2 { assert!(!rate_limiter - .is_rate_limited(namespace, &values, 5) + .is_rate_limited(namespace, &ctx, 5) .await .unwrap()); rate_limiter - .update_counters(namespace, &values, 5) + .update_counters(namespace, &ctx, 5) .await .unwrap(); } assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -683,9 +688,10 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); assert!(rate_limiter - .is_rate_limited(namespace, &values, max + 1) + .is_rate_limited(namespace, &ctx, max + 1) .await .unwrap()) } @@ -706,6 +712,7 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); for i in 0..max_hits { // Add an extra value that does not apply to the limit on each @@ -714,18 +721,18 @@ mod test { assert!( !rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); } assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -735,9 +742,10 @@ mod test { ) { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); + let ctx = (&values).into(); assert!(!rate_limiter - .is_rate_limited("test_namespace", &values, 1) + .is_rate_limited("test_namespace", &ctx, 1) .await .unwrap()); } @@ -761,9 +769,10 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "POST".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); assert!(!rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -783,9 +792,10 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -807,11 +817,12 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); for _ in 0..max_hits { assert!( !rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, false) + .check_rate_limited_and_update(namespace, &ctx, 1, false) .await .unwrap() .limited @@ -820,7 +831,7 @@ mod test { assert!( rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, false) + .check_rate_limited_and_update(namespace, &ctx, 1, false) .await .unwrap() .limited @@ -844,10 +855,11 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); for hit in 0..max_hits { let result = rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, true) + .check_rate_limited_and_update(namespace, &ctx, 1, true) .await .unwrap(); assert!(!result.limited); @@ -862,7 +874,7 @@ mod test { } let result = rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, true) + .check_rate_limited_and_update(namespace, &ctx, 1, true) .await .unwrap(); assert!(result.limited); @@ -895,10 +907,11 @@ mod test { values.insert("app_id".to_string(), "test_app_id".to_string()); // Does not match the limit defined values.insert("req_method".to_string(), "POST".to_string()); + let ctx = (&values).into(); assert!( !rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, false) + .check_rate_limited_and_update(namespace, &ctx, 1, false) .await .unwrap() .limited @@ -922,10 +935,11 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); assert!( rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, false) + .check_rate_limited_and_update(namespace, &ctx, 1, false) .await .unwrap() .limited @@ -951,14 +965,15 @@ mod test { let mut values = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); + let ctx = (&values).into(); rate_limiter - .update_counters(namespace, &values, hits_app_1) + .update_counters(namespace, &ctx, hits_app_1) .await .unwrap(); values.insert("app_id".to_string(), "2".to_string()); rate_limiter - .update_counters(namespace, &values, hits_app_2) + .update_counters(namespace, &ctx, hits_app_2) .await .unwrap(); @@ -1029,8 +1044,9 @@ mod test { let mut values = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); + let ctx = (&values).into(); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); @@ -1093,8 +1109,9 @@ mod test { let mut values = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); + let ctx = (&values).into(); rate_limiter - .update_counters(namespace, &values, hits_to_report) + .update_counters(namespace, &ctx, hits_to_report) .await .unwrap(); @@ -1245,19 +1262,20 @@ mod test { let mut values: HashMap = HashMap::new(); values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = (&values).into(); for i in 0..max_hits { // Alternate between the two rate limiters let rate_limiter = rate_limiters.get((i % 2) as usize).unwrap(); assert!( !rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); } @@ -1269,7 +1287,7 @@ mod test { || async { let rate_limiter = rate_limiters.first().unwrap(); rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap() }