Skip to content

Commit

Permalink
Expose Context all the way up
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Snaps <[email protected]>
alexsnaps committed Dec 3, 2024
1 parent 345be88 commit e0ae4f6
Showing 14 changed files with 192 additions and 196 deletions.
6 changes: 4 additions & 2 deletions limitador-server/src/envoy_rls/server.rs
Original file line number Diff line number Diff line change
@@ -109,18 +109,20 @@ impl RateLimitService for MyRateLimiter {
req.hits_addend
};

let ctx = (&values).into();

let rate_limited_resp = match &*self.limiter {
Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update(
&namespace,
&values,
&ctx,
u64::from(hits_addend),
self.rate_limit_headers != RateLimitHeaders::None,
),
Limiter::Async(limiter) => {
limiter
.check_rate_limited_and_update(
&namespace,
&values,
&ctx,
u64::from(hits_addend),
self.rate_limit_headers != RateLimitHeaders::None,
)
20 changes: 9 additions & 11 deletions limitador-server/src/http_api/server.rs
Original file line number Diff line number Diff line change
@@ -122,9 +122,10 @@ async fn check(
response_headers: _,
} = request.into_inner();
let namespace = namespace.into();
let ctx = (&values).into();
let is_rate_limited_result = match state.get_ref().limiter() {
Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &values, delta),
Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &values, delta).await,
Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &ctx, delta),
Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &ctx, delta).await,
};

match is_rate_limited_result {
@@ -152,9 +153,10 @@ async fn report(
response_headers: _,
} = request.into_inner();
let namespace = namespace.into();
let ctx = (&values).into();
let update_counters_result = match data.get_ref().limiter() {
Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &values, delta),
Limiter::Async(limiter) => limiter.update_counters(&namespace, &values, delta).await,
Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &ctx, delta),
Limiter::Async(limiter) => limiter.update_counters(&namespace, &ctx, delta).await,
};

match update_counters_result {
@@ -176,22 +178,18 @@ async fn check_and_report(
response_headers,
} = request.into_inner();
let namespace = namespace.into();
let ctx = (&values).into();
let rate_limit_data = data.get_ref();
let rate_limited_and_update_result = match rate_limit_data.limiter() {
Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update(
&namespace,
&values,
&ctx,
delta,
response_headers.is_some(),
),
Limiter::Async(limiter) => {
limiter
.check_rate_limited_and_update(
&namespace,
&values,
delta,
response_headers.is_some(),
)
.check_rate_limited_and_update(&namespace, &ctx, delta, response_headers.is_some())
.await
}
};
12 changes: 6 additions & 6 deletions limitador/benches/bench.rs
Original file line number Diff line number Diff line change
@@ -329,7 +329,7 @@ fn bench_is_rate_limited(
rate_limiter
.is_rate_limited(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
)
.unwrap(),
@@ -357,7 +357,7 @@ fn async_bench_is_rate_limited<F>(
rate_limiter
.is_rate_limited(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
)
.await
@@ -383,7 +383,7 @@ fn bench_update_counters(
rate_limiter
.update_counters(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
)
.unwrap();
@@ -410,7 +410,7 @@ fn async_bench_update_counters<F>(
rate_limiter
.update_counters(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
)
.await
@@ -437,7 +437,7 @@ fn bench_check_rate_limited_and_update(
rate_limiter
.check_rate_limited_and_update(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
false,
)
@@ -467,7 +467,7 @@ fn async_bench_check_rate_limited_and_update<F>(
rate_limiter
.check_rate_limited_and_update(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
false,
)
20 changes: 6 additions & 14 deletions limitador/src/counter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::limit::{Limit, Namespace};
use crate::limit::{Context, Limit, Namespace};
use crate::LimitadorResult;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
@@ -17,15 +17,9 @@ pub struct Counter {
}

impl Counter {
pub fn new<L: Into<Arc<Limit>>>(
limit: L,
set_variables: HashMap<String, String>,
) -> LimitadorResult<Option<Self>> {
pub fn new<L: Into<Arc<Limit>>>(limit: L, ctx: &Context) -> LimitadorResult<Option<Self>> {
let limit = limit.into();
let mut vars = set_variables;
vars.retain(|var, _| limit.has_variable(var));

let variables = limit.resolve_variables(vars)?;
let variables = limit.resolve_variables(ctx)?;
match variables {
None => Ok(None),
Some(variables) => Ok(Some(Self {
@@ -159,11 +153,9 @@ mod tests {
Vec::default(),
[var.try_into().expect("failed parsing!")],
);
let counter = Counter::new(
limit,
HashMap::from([("ts".to_string(), "2019-10-12T13:20:50.52Z".to_string())]),
)
.expect("failed creating counter");
let map = HashMap::from([("ts".to_string(), "2019-10-12T13:20:50.52Z".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(limit, &ctx).expect("failed creating counter");
assert_eq!(
counter.unwrap().set_variables.get(var),
Some("13".to_string()).as_ref()
60 changes: 29 additions & 31 deletions limitador/src/lib.rs
Original file line number Diff line number Diff line change
@@ -94,7 +94,7 @@
//!
//! ```
//! use limitador::RateLimiter;
//! use limitador::limit::Limit;
//! use limitador::limit::{Limit, Context};
//! use std::collections::HashMap;
//!
//! let mut rate_limiter = RateLimiter::new(1000);
@@ -116,22 +116,23 @@
//!
//! // Check if we can report
//! let namespace = "my_namespace".into();
//! assert!(!rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap());
//! let ctx = &values_to_report.into();
//! assert!(!rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap());
//!
//! // Report
//! rate_limiter.update_counters(&namespace, &values_to_report, 1).unwrap();
//! rate_limiter.update_counters(&namespace, &ctx, 1).unwrap();
//!
//! // Check and report again
//! assert!(!rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap());
//! rate_limiter.update_counters(&namespace, &values_to_report, 1).unwrap();
//! assert!(!rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap());
//! rate_limiter.update_counters(&namespace, &ctx, 1).unwrap();
//!
//! // We've already reported 2, so reporting another one should not be allowed
//! assert!(rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap());
//! assert!(rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap());
//!
//! // You can also check and report if not limited in a single call. It's useful
//! // for example, when calling Limitador from a proxy. Instead of doing 2
//! // separate calls, we can issue just one:
//! rate_limiter.check_rate_limited_and_update(&namespace, &values_to_report, 1, false).unwrap();
//! rate_limiter.check_rate_limited_and_update(&namespace, &ctx, 1, false).unwrap();
//! ```
//!
//! # Async
@@ -194,7 +195,7 @@

use crate::counter::Counter;
use crate::errors::LimitadorError;
use crate::limit::{Limit, Namespace};
use crate::limit::{Context, Limit, Namespace};
use crate::storage::in_memory::InMemoryStorage;
use crate::storage::{AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage};
use std::collections::{HashMap, HashSet};
@@ -358,7 +359,7 @@ impl RateLimiter {
pub fn is_rate_limited(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
values: &Context,
delta: u64,
) -> LimitadorResult<bool> {
let counters = self.counters_that_apply(namespace, values)?;
@@ -380,10 +381,10 @@ impl RateLimiter {
pub fn update_counters(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context,
delta: u64,
) -> LimitadorResult<()> {
let counters = self.counters_that_apply(namespace, values)?;
let counters = self.counters_that_apply(namespace, ctx)?;

counters
.iter()
@@ -394,11 +395,11 @@ impl RateLimiter {
pub fn check_rate_limited_and_update(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context,
delta: u64,
load_counters: bool,
) -> LimitadorResult<CheckResult> {
let mut counters = self.counters_that_apply(namespace, values)?;
let mut counters = self.counters_that_apply(namespace, ctx)?;

if counters.is_empty() {
return Ok(CheckResult {
@@ -476,14 +477,13 @@ impl RateLimiter {
fn counters_that_apply(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context,
) -> LimitadorResult<Vec<Counter>> {
let limits = self.storage.get_limits(namespace);
let ctx = values.into();
limits
.iter()
.filter(|lim| lim.applies(&ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), values.clone()) {
.filter(|lim| lim.applies(ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) {
Ok(None) => None,
Ok(Some(c)) => Some(Ok(c)),
Err(e) => Some(Err(e)),
@@ -533,10 +533,10 @@ impl AsyncRateLimiter {
pub async fn is_rate_limited(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
) -> LimitadorResult<bool> {
let counters = self.counters_that_apply(namespace, values).await?;
let counters = self.counters_that_apply(namespace, ctx).await?;

for counter in counters {
match self.storage.is_within_limits(&counter, delta).await {
@@ -554,10 +554,10 @@ impl AsyncRateLimiter {
pub async fn update_counters(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
) -> LimitadorResult<()> {
let counters = self.counters_that_apply(namespace, values).await?;
let counters = self.counters_that_apply(namespace, ctx).await?;

for counter in counters {
self.storage.update_counter(&counter, delta).await?
@@ -569,12 +569,12 @@ impl AsyncRateLimiter {
pub async fn check_rate_limited_and_update(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
load_counters: bool,
) -> LimitadorResult<CheckResult> {
// the above where-clause is needed in order to call unwrap().
let mut counters = self.counters_that_apply(namespace, values).await?;
let mut counters = self.counters_that_apply(namespace, ctx).await?;

if counters.is_empty() {
return Ok(CheckResult {
@@ -657,14 +657,13 @@ impl AsyncRateLimiter {
async fn counters_that_apply(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context<'_>,
) -> LimitadorResult<Vec<Counter>> {
let limits = self.storage.get_limits(namespace);
let ctx = values.into();
limits
.iter()
.filter(|lim| lim.applies(&ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), values.clone()) {
.filter(|lim| lim.applies(ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) {
Ok(None) => None,
Ok(Some(c)) => Some(Ok(c)),
Err(e) => Some(Err(e)),
@@ -696,9 +695,8 @@ fn classify_limits_by_namespace(

#[cfg(test)]
mod test {
use crate::limit::{Expression, Limit};
use crate::limit::{Context, Expression, Limit};
use crate::RateLimiter;
use std::collections::HashMap;

#[test]
fn properly_updates_existing_limits() {
@@ -713,7 +711,7 @@ mod test {
assert_eq!(limits.iter().next().unwrap().max_value(), 42);

let r = rl
.check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true)
.check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true)
.unwrap();
assert_eq!(r.counters.first().unwrap().max_value(), 42);

@@ -727,7 +725,7 @@ mod test {
assert_eq!(limits.iter().next().unwrap().max_value(), 50);

let r = rl
.check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true)
.check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true)
.unwrap();
assert_eq!(r.counters.first().unwrap().max_value(), 50);
}
13 changes: 6 additions & 7 deletions limitador/src/limit.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use cel::Context;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::fmt::Debug;
use std::hash::{Hash, Hasher};

mod cel;

pub use cel::{Context, Expression, Predicate};
pub use cel::{EvaluationError, ParseError};
pub use cel::{Expression, Predicate};

#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Namespace(String);
@@ -133,13 +132,12 @@ impl Limit {

pub fn resolve_variables(
&self,
vars: HashMap<String, String>,
ctx: &Context,
) -> Result<Option<BTreeMap<String, String>>, EvaluationError> {
let ctx = Context::new(String::default(), &vars);
let mut map = BTreeMap::new();
for variable in &self.variables {
let name = variable.source().into();
match variable.eval(&ctx)? {
match variable.eval(ctx)? {
None => return Ok(None),
Some(value) => {
map.insert(name, value);
@@ -230,6 +228,7 @@ mod tests {
use super::*;
use crate::counter::Counter;
use std::cmp::Ordering::Equal;
use std::collections::HashMap;

#[test]
fn limit_can_have_an_optional_name() {
@@ -466,7 +465,7 @@ mod tests {
let ctx = Context::new(String::default(), &map);
assert!(limit.applies(&ctx));
assert_eq!(
Counter::new(limit, map)
Counter::new(limit, &ctx)
.expect("failed")
.unwrap()
.set_variables()
11 changes: 5 additions & 6 deletions limitador/src/storage/disk/rocksdb_storage.rs
Original file line number Diff line number Diff line change
@@ -249,12 +249,11 @@ mod tests {
vec!["req_method == 'GET'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
);
let counter = Counter::new(
limit,
HashMap::from([("app_id".to_string(), "foo".to_string())]),
)
.unwrap()
.expect("must have a counter");
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(limit, &ctx)
.unwrap()
.expect("must have a counter");

let tmp = TempDir::new().expect("We should have a dir!");
{
23 changes: 11 additions & 12 deletions limitador/src/storage/distributed/mod.rs
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ use tokio::sync::mpsc::Sender;
use tracing::debug;

use crate::counter::Counter;
use crate::limit::Limit;
use crate::limit::{Context, Limit};
use crate::storage::distributed::cr_counter_value::CrCounterValue;
use crate::storage::distributed::grpc::v1::CounterUpdate;
use crate::storage::distributed::grpc::{Broker, CounterEntry};
@@ -47,7 +47,7 @@ impl CounterStorage for CrInMemoryStorage {
let key = encode_limit_to_key(limit);
limits.entry(key.clone()).or_insert(Arc::new(CounterEntry {
key,
counter: Counter::new(limit.clone(), HashMap::default())
counter: Counter::new(limit.clone(), &Context::default())
.expect("counter creation can't fail! no vars to resolve!")
.expect("must have a counter"),
value: CrCounterValue::new(
@@ -336,15 +336,14 @@ fn encode_counter_to_key(counter: &Counter) -> Vec<u8> {

fn encode_limit_to_key(limit: &Limit) -> Vec<u8> {
// fixme this is broken!
let counter = Counter::new(
limit.clone(),
limit
.variables()
.into_iter()
.map(|k| (k, "".to_string()))
.collect(),
)
.expect("counter creation can't fail! faked vars!")
.expect("must have a counter");
let vars: HashMap<String, String> = limit
.variables()
.into_iter()
.map(|k| (k, "".to_string()))
.collect();
let ctx = (&vars).into();
let counter = Counter::new(limit.clone(), &ctx)
.expect("counter creation can't fail! faked vars!")
.expect("must have a counter");
key_for_counter_v2(&counter)
}
24 changes: 10 additions & 14 deletions limitador/src/storage/in_memory.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::counter::Counter;
use crate::limit::{Limit, Namespace};
use crate::limit::{Context, Limit, Namespace};
use crate::storage::atomic_expiring_value::AtomicExpiringValue;
use crate::storage::{Authorization, CounterStorage, StorageErr};
use moka::sync::Cache;
@@ -212,7 +212,7 @@ impl InMemoryStorage {
if limit.namespace() == namespace {
res.insert(
// todo fixme
Counter::new(limit.clone(), HashMap::default())
Counter::new(limit.clone(), &Context::default())
.unwrap()
.unwrap(),
counter.clone(),
@@ -269,18 +269,14 @@ mod tests {
vec!["req_method == 'GET'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
);
let counter_1 = Counter::new(
limit_1,
HashMap::from([("app_id".to_string(), "foo".to_string())]),
)
.expect("counter creation failed!")
.expect("Should have a counter");
let counter_2 = Counter::new(
limit_2,
HashMap::from([("app_id".to_string(), "foo".to_string())]),
)
.expect("counter creation failed!")
.expect("Should have a counter");
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter_1 = Counter::new(limit_1, &ctx)
.expect("counter creation failed!")
.expect("Should have a counter");
let counter_2 = Counter::new(limit_2, &ctx)
.expect("counter creation failed!")
.expect("Should have a counter");
storage.update_counter(&counter_1, 1).unwrap();
storage.update_counter(&counter_2, 1).unwrap();

59 changes: 27 additions & 32 deletions limitador/src/storage/keys.rs
Original file line number Diff line number Diff line change
@@ -153,12 +153,11 @@ mod tests {
vec!["req_method == 'GET'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
);
let counter = Counter::new(
limit.clone(),
HashMap::from([("app_id".to_string(), "foo".to_string())]),
)
.expect("counter creation failed!")
.expect("must have a counter");
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(limit.clone(), &ctx)
.expect("counter creation failed!")
.expect("must have a counter");
let raw = key_for_counter(&counter);
assert_eq!(counter, partial_counter_from_counter_key(&raw));
}
@@ -173,12 +172,11 @@ mod tests {
vec!["req_method == 'GET'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
);
let counter = Counter::new(
limit.clone(),
HashMap::from([("app_id".to_string(), "foo".to_string())]),
)
.expect("counter creation failed!")
.expect("must have a counter");
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(limit.clone(), &ctx)
.expect("counter creation failed!")
.expect("must have a counter");
let mut other = counter.clone();
other.set_remaining(123);
other.set_expires_in(Duration::from_millis(456));
@@ -389,7 +387,8 @@ pub mod bin {
vars.insert("role".to_string(), "admin".to_string());
vars.insert("app_id".to_string(), "123".to_string());
vars.insert("wat".to_string(), "dunno".to_string());
let counter = Counter::new(limit.clone(), vars)
let ctx = (&vars).into();
let counter = Counter::new(limit.clone(), &ctx)
.expect("counter creation failed!")
.expect("must have a counter");

@@ -412,7 +411,8 @@ pub mod bin {
);
let mut variables = HashMap::default();
variables.insert("app_id".to_string(), "123".to_string());
let counter = Counter::new(limit.clone(), variables)
let ctx = (&variables).into();
let counter = Counter::new(limit.clone(), &ctx)
.expect("counter creation failed!")
.expect("must have a counter");
let raw = key_for_counter(&counter);
@@ -429,12 +429,11 @@ pub mod bin {
vec!["req_method == 'GET'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
);
let counter = Counter::new(
limit,
HashMap::from([("app_id".to_string(), "foo".to_string())]),
)
.expect("counter creation failed!")
.expect("must have a counter");
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(limit, &ctx)
.expect("counter creation failed!")
.expect("must have a counter");
let serialized_counter = key_for_counter(&counter);

let prefix = prefix_for_namespace(namespace);
@@ -460,20 +459,16 @@ pub mod bin {
vec!["app_id".try_into().expect("failed parsing!")],
);

let counter_with_id = Counter::new(
limit_with_id,
HashMap::from([("app_id".to_string(), "foo".to_string())]),
)
.expect("counter creation failed!")
.expect("must have a counter");
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter_with_id = Counter::new(limit_with_id, &ctx)
.expect("counter creation failed!")
.expect("must have a counter");
let serialized_with_id_counter = key_for_counter(&counter_with_id);

let counter_without_id = Counter::new(
limit_without_id,
HashMap::from([("app_id".to_string(), "foo".to_string())]),
)
.expect("counter creation failed!")
.expect("must have a counter");
let counter_without_id = Counter::new(limit_without_id, &ctx)
.expect("counter creation failed!")
.expect("must have a counter");
let serialized_without_id_counter = key_for_counter(&counter_without_id);

// the original key_for_counter continues to encode kinda big
3 changes: 2 additions & 1 deletion limitador/src/storage/redis/counters_cache.rs
Original file line number Diff line number Diff line change
@@ -676,6 +676,7 @@ mod tests {
if let Some(overrides) = other_values {
values.extend(overrides);
}
let ctx = (&values).into();
Counter::new(
Limit::new(
"test_namespace",
@@ -684,7 +685,7 @@ mod tests {
vec!["req_method == 'POST'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
),
values,
&ctx,
)
.expect("failed creating counter")
.expect("Should have a counter")
12 changes: 9 additions & 3 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
@@ -439,6 +439,8 @@ mod tests {
const LOCAL_INCREMENTS: u64 = 2;

let mut counters_and_deltas = HashMap::new();
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(
Limit::new(
"test_namespace",
@@ -447,7 +449,7 @@ mod tests {
vec!["req_method == 'GET'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
),
HashMap::from([("app_id".to_string(), "foo".to_string())]),
&ctx,
)
.expect("counter creation failed!")
.expect("must have a counter");
@@ -503,6 +505,8 @@ mod tests {

#[tokio::test]
async fn flush_batcher_and_update_counters_test() {
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(
Limit::new(
"test_namespace",
@@ -511,7 +515,7 @@ mod tests {
vec!["req_method == 'POST'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
),
HashMap::from([("app_id".to_string(), "foo".to_string())]),
&ctx,
)
.expect("counter creation failed!")
.expect("must have a counter");
@@ -564,6 +568,8 @@ mod tests {

#[tokio::test]
async fn flush_batcher_reverts_on_err() {
let map = HashMap::from([("app_id".to_string(), "foo".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(
Limit::new(
"test_namespace",
@@ -572,7 +578,7 @@ mod tests {
vec!["req_method == 'POST'".try_into().expect("failed parsing!")],
vec!["app_id".try_into().expect("failed parsing!")],
),
HashMap::from([("app_id".to_string(), "foo".to_string())]),
&ctx,
)
.expect("counter creation failed!")
.expect("must have a counter");
33 changes: 13 additions & 20 deletions limitador/tests/helpers/tests_limiter.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use limitador::counter::Counter;
use limitador::errors::LimitadorError;
use limitador::limit::{Limit, Namespace};
use limitador::limit::{Context, Limit, Namespace};
use limitador::{AsyncRateLimiter, CheckResult, RateLimiter};
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;

// This exposes a struct that wraps both implementations of the rate limiter,
// the blocking and the async one. This allows us to avoid duplications in the
@@ -70,56 +70,49 @@ impl TestsLimiter {
pub async fn is_rate_limited(
&self,
namespace: &str,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
) -> Result<bool, LimitadorError> {
match &self.limiter_impl {
LimiterImpl::Blocking(limiter) => {
limiter.is_rate_limited(&namespace.into(), values, delta)
limiter.is_rate_limited(&namespace.into(), ctx, delta)
}
LimiterImpl::Async(limiter) => {
limiter
.is_rate_limited(&namespace.into(), values, delta)
.await
limiter.is_rate_limited(&namespace.into(), ctx, delta).await
}
}
}

pub async fn update_counters(
&self,
namespace: &str,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
) -> Result<(), LimitadorError> {
match &self.limiter_impl {
LimiterImpl::Blocking(limiter) => {
limiter.update_counters(&namespace.into(), values, delta)
limiter.update_counters(&namespace.into(), ctx, delta)
}
LimiterImpl::Async(limiter) => {
limiter
.update_counters(&namespace.into(), values, delta)
.await
limiter.update_counters(&namespace.into(), ctx, delta).await
}
}
}

pub async fn check_rate_limited_and_update(
&self,
namespace: &str,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
load_counters: bool,
) -> Result<CheckResult, LimitadorError> {
match &self.limiter_impl {
LimiterImpl::Blocking(limiter) => limiter.check_rate_limited_and_update(
&namespace.into(),
values,
delta,
load_counters,
),
LimiterImpl::Blocking(limiter) => {
limiter.check_rate_limited_and_update(&namespace.into(), ctx, delta, load_counters)
}
LimiterImpl::Async(limiter) => {
limiter
.check_rate_limited_and_update(&namespace.into(), values, delta, load_counters)
.check_rate_limited_and_update(&namespace.into(), ctx, delta, load_counters)
.await
}
}
92 changes: 55 additions & 37 deletions limitador/tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -380,7 +380,7 @@ mod test {
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "1".to_string());
rate_limiter
.update_counters(namespace, &values, 1)
.update_counters(namespace, &(&values).into(), 1)
.await
.unwrap();

@@ -473,7 +473,7 @@ mod test {
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "1".to_string());
rate_limiter
.update_counters(namespace, &values, 1)
.update_counters(namespace, &(&values).into(), 1)
.await
.unwrap();

@@ -506,22 +506,23 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

for i in 0..max_hits {
assert!(
!rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap(),
"Must not be limited after {i}"
);
rate_limiter
.update_counters(namespace, &values, 1)
.update_counters(namespace, &ctx, 1)
.await
.unwrap();
}
assert!(rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap());
}
@@ -543,22 +544,23 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

for i in 0..max_hits {
assert!(
!rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap(),
"Must not be limited after {i}"
);
rate_limiter
.update_counters(namespace, &values, 1)
.update_counters(namespace, &ctx, 1)
.await
.unwrap();
}
assert!(rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap());
}
@@ -590,32 +592,34 @@ mod test {
let mut get_values: HashMap<String, String> = HashMap::new();
get_values.insert("req_method".to_string(), "GET".to_string());
get_values.insert("app_id".to_string(), "test_app_id".to_string());
let get_ctx = (&get_values).into();

let mut post_values: HashMap<String, String> = HashMap::new();
post_values.insert("req_method".to_string(), "POST".to_string());
post_values.insert("app_id".to_string(), "test_app_id".to_string());
let post_ctx = (&post_values).into();

for i in 0..max_hits {
assert!(
!rate_limiter
.is_rate_limited(namespace, &get_values, 1)
.is_rate_limited(namespace, &get_ctx, 1)
.await
.unwrap(),
"Must not be limited after {i}"
);
assert!(
!rate_limiter
.is_rate_limited(namespace, &post_values, 1)
.is_rate_limited(namespace, &post_ctx, 1)
.await
.unwrap(),
"Must not be limited after {i}"
);
rate_limiter
.check_rate_limited_and_update(namespace, &get_values, 1, false)
.check_rate_limited_and_update(namespace, &get_ctx, 1, false)
.await
.unwrap();
rate_limiter
.check_rate_limited_and_update(namespace, &post_values, 1, false)
.check_rate_limited_and_update(namespace, &post_ctx, 1, false)
.await
.unwrap();
}
@@ -624,11 +628,11 @@ mod test {
tokio::time::sleep(Duration::from_millis(40)).await;

assert!(rate_limiter
.is_rate_limited(namespace, &get_values, 1)
.is_rate_limited(namespace, &get_ctx, 1)
.await
.unwrap());
assert!(!rate_limiter
.is_rate_limited(namespace, &post_values, 1)
.is_rate_limited(namespace, &post_ctx, 1)
.await
.unwrap());
}
@@ -648,21 +652,22 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

// Report 5 hits twice. The limit is 10, so the first limited call should be
// the third one.
for _ in 0..2 {
assert!(!rate_limiter
.is_rate_limited(namespace, &values, 5)
.is_rate_limited(namespace, &ctx, 5)
.await
.unwrap());
rate_limiter
.update_counters(namespace, &values, 5)
.update_counters(namespace, &ctx, 5)
.await
.unwrap();
}
assert!(rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap());
}
@@ -683,9 +688,10 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

assert!(rate_limiter
.is_rate_limited(namespace, &values, max + 1)
.is_rate_limited(namespace, &ctx, max + 1)
.await
.unwrap())
}
@@ -706,6 +712,7 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

for i in 0..max_hits {
// Add an extra value that does not apply to the limit on each
@@ -714,18 +721,18 @@ mod test {

assert!(
!rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap(),
"Must not be limited after {i}"
);
rate_limiter
.update_counters(namespace, &values, 1)
.update_counters(namespace, &ctx, 1)
.await
.unwrap();
}
assert!(rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap());
}
@@ -735,9 +742,10 @@ mod test {
) {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
let ctx = (&values).into();

assert!(!rate_limiter
.is_rate_limited("test_namespace", &values, 1)
.is_rate_limited("test_namespace", &ctx, 1)
.await
.unwrap());
}
@@ -761,9 +769,10 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "POST".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

assert!(!rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap());
}
@@ -783,9 +792,10 @@ mod test {

let mut values: HashMap<String, String> = HashMap::new();
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

assert!(rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap());
}
@@ -807,11 +817,12 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

for _ in 0..max_hits {
assert!(
!rate_limiter
.check_rate_limited_and_update(namespace, &values, 1, false)
.check_rate_limited_and_update(namespace, &ctx, 1, false)
.await
.unwrap()
.limited
@@ -820,7 +831,7 @@ mod test {

assert!(
rate_limiter
.check_rate_limited_and_update(namespace, &values, 1, false)
.check_rate_limited_and_update(namespace, &ctx, 1, false)
.await
.unwrap()
.limited
@@ -844,10 +855,11 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

for hit in 0..max_hits {
let result = rate_limiter
.check_rate_limited_and_update(namespace, &values, 1, true)
.check_rate_limited_and_update(namespace, &ctx, 1, true)
.await
.unwrap();
assert!(!result.limited);
@@ -862,7 +874,7 @@ mod test {
}

let result = rate_limiter
.check_rate_limited_and_update(namespace, &values, 1, true)
.check_rate_limited_and_update(namespace, &ctx, 1, true)
.await
.unwrap();
assert!(result.limited);
@@ -895,10 +907,11 @@ mod test {
values.insert("app_id".to_string(), "test_app_id".to_string());
// Does not match the limit defined
values.insert("req_method".to_string(), "POST".to_string());
let ctx = (&values).into();

assert!(
!rate_limiter
.check_rate_limited_and_update(namespace, &values, 1, false)
.check_rate_limited_and_update(namespace, &ctx, 1, false)
.await
.unwrap()
.limited
@@ -922,10 +935,11 @@ mod test {

let mut values: HashMap<String, String> = HashMap::new();
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

assert!(
rate_limiter
.check_rate_limited_and_update(namespace, &values, 1, false)
.check_rate_limited_and_update(namespace, &ctx, 1, false)
.await
.unwrap()
.limited
@@ -951,14 +965,15 @@ mod test {
let mut values = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "1".to_string());
let ctx = (&values).into();
rate_limiter
.update_counters(namespace, &values, hits_app_1)
.update_counters(namespace, &ctx, hits_app_1)
.await
.unwrap();

values.insert("app_id".to_string(), "2".to_string());
rate_limiter
.update_counters(namespace, &values, hits_app_2)
.update_counters(namespace, &ctx, hits_app_2)
.await
.unwrap();

@@ -1029,8 +1044,9 @@ mod test {
let mut values = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "1".to_string());
let ctx = (&values).into();
rate_limiter
.update_counters(namespace, &values, 1)
.update_counters(namespace, &ctx, 1)
.await
.unwrap();

@@ -1093,8 +1109,9 @@ mod test {
let mut values = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "1".to_string());
let ctx = (&values).into();
rate_limiter
.update_counters(namespace, &values, hits_to_report)
.update_counters(namespace, &ctx, hits_to_report)
.await
.unwrap();

@@ -1245,19 +1262,20 @@ mod test {
let mut values: HashMap<String, String> = HashMap::new();
values.insert("req_method".to_string(), "GET".to_string());
values.insert("app_id".to_string(), "test_app_id".to_string());
let ctx = (&values).into();

for i in 0..max_hits {
// Alternate between the two rate limiters
let rate_limiter = rate_limiters.get((i % 2) as usize).unwrap();
assert!(
!rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap(),
"Must not be limited after {i}"
);
rate_limiter
.update_counters(namespace, &values, 1)
.update_counters(namespace, &ctx, 1)
.await
.unwrap();
}
@@ -1269,7 +1287,7 @@ mod test {
|| async {
let rate_limiter = rate_limiters.first().unwrap();
rate_limiter
.is_rate_limited(namespace, &values, 1)
.is_rate_limited(namespace, &ctx, 1)
.await
.unwrap()
}

0 comments on commit e0ae4f6

Please sign in to comment.