diff --git a/limitador-server/Cargo.toml b/limitador-server/Cargo.toml index f7b2c146..4acdb31d 100644 --- a/limitador-server/Cargo.toml +++ b/limitador-server/Cargo.toml @@ -12,6 +12,9 @@ documentation = "https://kuadrant.io/docs/limitador" readme = "README.md" edition = "2021" +[features] +distributed_storage = ["limitador/distributed_storage"] + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] diff --git a/limitador-server/build.rs b/limitador-server/build.rs index f7a0d811..9700c93b 100644 --- a/limitador-server/build.rs +++ b/limitador-server/build.rs @@ -6,6 +6,7 @@ use std::process::Command; fn main() -> Result<(), Box> { set_git_hash("LIMITADOR_GIT_HASH"); set_profile("LIMITADOR_PROFILE"); + set_features("LIMITADOR_FEATURES"); generate_protobuf() } @@ -31,6 +32,14 @@ fn set_profile(env: &str) { } } +fn set_features(env: &str) { + let mut features = vec![]; + if cfg!(feature = "distributed_storage") { + features.push("+distributed"); + } + println!("cargo:rustc-env={env}={features:?}"); +} + fn set_git_hash(env: &str) { let git_sha = Command::new("/usr/bin/git") .args(["rev-parse", "HEAD"]) diff --git a/limitador-server/examples/limits.yaml b/limitador-server/examples/limits.yaml index f0ea815b..afcb2b50 100644 --- a/limitador-server/examples/limits.yaml +++ b/limitador-server/examples/limits.yaml @@ -1,12 +1,10 @@ --- - namespace: test_namespace - max_value: 10 + max_value: 1000000 seconds: 60 conditions: - - "req.method == 'GET'" variables: - - user_id - namespace: test_namespace max_value: 5 diff --git a/limitador-server/src/config.rs b/limitador-server/src/config.rs index 949bf446..dc4ef59c 100644 --- a/limitador-server/src/config.rs +++ b/limitador-server/src/config.rs @@ -140,6 +140,8 @@ pub enum StorageConfiguration { InMemory(InMemoryStorageConfiguration), Disk(DiskStorageConfiguration), Redis(RedisStorageConfiguration), + #[cfg(feature = "distributed_storage")] + Distributed(DistributedStorageConfiguration), } #[derive(PartialEq, Eq, Debug)] @@ -147,6 +149,15 @@ pub struct InMemoryStorageConfiguration { pub cache_size: Option, } +#[derive(PartialEq, Eq, Debug)] +#[cfg(feature = "distributed_storage")] +pub struct DistributedStorageConfiguration { + pub name: String, + pub cache_size: Option, + pub local: String, + pub broadcast: String, +} + #[derive(PartialEq, Eq, Debug)] pub struct DiskStorageConfiguration { pub path: String, diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index 8962d5ef..a4ced111 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -5,6 +5,8 @@ extern crate log; extern crate clap; +#[cfg(feature = "distributed_storage")] +use crate::config::DistributedStorageConfiguration; use crate::config::{ Configuration, DiskStorageConfiguration, InMemoryStorageConfiguration, RedisStorageCacheConfiguration, RedisStorageConfiguration, StorageConfiguration, @@ -23,6 +25,8 @@ use limitador::storage::redis::{ AsyncRedisStorage, CachedRedisStorage, CachedRedisStorageBuilder, DEFAULT_BATCH_SIZE, DEFAULT_FLUSHING_PERIOD_SEC, DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_RESPONSE_TIMEOUT_MS, }; +#[cfg(feature = "distributed_storage")] +use limitador::storage::DistributedInMemoryStorage; use limitador::storage::{AsyncCounterStorage, AsyncStorage, Storage}; use limitador::{ storage, AsyncRateLimiter, AsyncRateLimiterBuilder, RateLimiter, RateLimiterBuilder, @@ -57,6 +61,7 @@ pub mod prometheus_metrics; const LIMITADOR_VERSION: &str = env!("CARGO_PKG_VERSION"); const LIMITADOR_PROFILE: &str = env!("LIMITADOR_PROFILE"); +const LIMITADOR_FEATURES: &str = env!("LIMITADOR_FEATURES"); const LIMITADOR_HEADER: &str = "Limitador Server"; #[derive(Error, Debug)] @@ -83,6 +88,8 @@ impl Limiter { let rate_limiter = match config.storage { StorageConfiguration::Redis(cfg) => Self::redis_limiter(cfg).await, StorageConfiguration::InMemory(cfg) => Self::in_memory_limiter(cfg), + #[cfg(feature = "distributed_storage")] + StorageConfiguration::Distributed(cfg) => Self::distributed_limiter(cfg), StorageConfiguration::Disk(cfg) => Self::disk_limiter(cfg), }; @@ -154,6 +161,20 @@ impl Limiter { Self::Blocking(rate_limiter_builder.build()) } + #[cfg(feature = "distributed_storage")] + fn distributed_limiter(cfg: DistributedStorageConfiguration) -> Self { + let storage = DistributedInMemoryStorage::new( + cfg.name, + cfg.cache_size.or_else(guess_cache_size).unwrap(), + cfg.local, + cfg.broadcast, + ); + let rate_limiter_builder = + RateLimiterBuilder::with_storage(Storage::with_counter_storage(Box::new(storage))); + + Self::Blocking(rate_limiter_builder.build()) + } + pub async fn load_limits_from_file>( &self, path: &P, @@ -350,12 +371,12 @@ async fn main() -> Result<(), Box> { fn create_config() -> (Configuration, &'static str) { let full_version: &'static str = formatcp!( - "v{} ({}) {}", + "v{} ({}) {} {}", LIMITADOR_VERSION, env!("LIMITADOR_GIT_HASH"), + LIMITADOR_FEATURES, LIMITADOR_PROFILE, ); - // wire args based of defaults let limit_arg = Arg::new("LIMITS_FILE") .action(ArgAction::Set) @@ -565,6 +586,43 @@ fn create_config() -> (Configuration, &'static str) { ), ); + #[cfg(feature = "distributed_storage")] + let cmdline = cmdline.subcommand( + Command::new("distributed") + .about("Replicates CRDT-based counters across multiple Limitador servers") + .display_order(5) + .arg( + Arg::new("NAME") + .action(ArgAction::Set) + .required(true) + .display_order(2) + .help("Unique name to identify this Limitador instance"), + ) + .arg( + Arg::new("LOCAL") + .action(ArgAction::Set) + .required(true) + .display_order(2) + .help("Local IP:PORT to send datagrams from"), + ) + .arg( + Arg::new("BROADCAST") + .action(ArgAction::Set) + .required(true) + .display_order(3) + .help("Broadcast IP:PORT to send datagrams to"), + ) + .arg( + Arg::new("CACHE_SIZE") + .long("cache") + .short('c') + .action(ArgAction::Set) + .value_parser(value_parser!(u64)) + .display_order(4) + .help("Sets the size of the cache for 'qualified counters'"), + ), + ); + let matches = cmdline.get_matches(); let limits_file = matches.get_one::("LIMITS_FILE").unwrap(); @@ -630,6 +688,15 @@ fn create_config() -> (Configuration, &'static str) { Some(("memory", sub)) => StorageConfiguration::InMemory(InMemoryStorageConfiguration { cache_size: sub.get_one::("CACHE_SIZE").copied(), }), + #[cfg(feature = "distributed_storage")] + Some(("distributed", sub)) => { + StorageConfiguration::Distributed(DistributedStorageConfiguration { + name: sub.get_one::("NAME").unwrap().to_owned(), + local: sub.get_one::("LOCAL").unwrap().to_owned(), + broadcast: sub.get_one::("BROADCAST").unwrap().to_owned(), + cache_size: sub.get_one::("CACHE_SIZE").copied(), + }) + } None => storage_config_from_env(), _ => unreachable!("Some storage wasn't configured!"), }; diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index 456b3309..8f0a681b 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -15,6 +15,7 @@ edition = "2021" [features] default = ["disk_storage", "redis_storage"] disk_storage = ["rocksdb"] +distributed_storage = [] redis_storage = ["redis", "r2d2", "tokio"] lenient_conditions = [] diff --git a/limitador/src/storage/atomic_expiring_value.rs b/limitador/src/storage/atomic_expiring_value.rs index f80eaa22..8b00c7bd 100644 --- a/limitador/src/storage/atomic_expiring_value.rs +++ b/limitador/src/storage/atomic_expiring_value.rs @@ -102,6 +102,47 @@ impl AtomicExpiryTime { } false } + + #[allow(dead_code)] + pub fn merge(&self, other: Self) { + let mut other = other; + loop { + let now = SystemTime::now(); + other = match self.merge_at(other, now) { + Ok(_) => return, + Err(other) => other, + }; + } + } + + pub fn merge_at(&self, other: Self, when: SystemTime) -> Result<(), Self> { + let other_exp = other.expiry.load(Ordering::SeqCst); + let expiry = self.expiry.load(Ordering::SeqCst); + if other_exp < expiry && other_exp > Self::since_epoch(when) { + // if our expiry changed, some thread observed the time window as elapsed... + // `other` can't be in the future anymore! Safely ignoring the failure scenario + return match self.expiry.compare_exchange( + expiry, + other_exp, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => Ok(()), + Err(_) => Err(other), + }; + } + Ok(()) + } + + #[allow(dead_code)] + pub fn into_inner(self) -> SystemTime { + self.expires_at() + } + + #[allow(dead_code)] + pub fn expires_at(&self) -> SystemTime { + SystemTime::UNIX_EPOCH + Duration::from_micros(self.expiry.load(Ordering::SeqCst)) + } } impl Clone for AtomicExpiryTime { @@ -130,6 +171,12 @@ impl Clone for AtomicExpiringValue { } } +impl From for AtomicExpiryTime { + fn from(value: SystemTime) -> Self { + AtomicExpiryTime::new(value) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs new file mode 100644 index 00000000..eb6fc1fb --- /dev/null +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -0,0 +1,287 @@ +use crate::storage::atomic_expiring_value::AtomicExpiryTime; +use std::collections::btree_map::Entry; +use std::collections::BTreeMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::RwLock; +use std::time::{Duration, SystemTime}; + +#[derive(Debug)] +pub struct CrCounterValue { + ourselves: A, + value: AtomicU64, + others: RwLock>, + expiry: AtomicExpiryTime, +} + +#[allow(dead_code)] +impl CrCounterValue { + pub fn new(actor: A, time_window: Duration) -> Self { + Self { + ourselves: actor, + value: Default::default(), + others: RwLock::default(), + expiry: AtomicExpiryTime::from_now(time_window), + } + } + + pub fn read(&self) -> u64 { + self.read_at(SystemTime::now()) + } + + pub fn read_at(&self, when: SystemTime) -> u64 { + if self.expiry.expired_at(when) { + 0 + } else { + let guard = self.others.read().unwrap(); + let others: u64 = guard.values().sum(); + others + self.value.load(Ordering::Relaxed) + } + } + + pub fn inc(&self, increment: u64, time_window: Duration) { + self.inc_at(increment, time_window, SystemTime::now()) + } + + pub fn inc_at(&self, increment: u64, time_window: Duration, when: SystemTime) { + if self.expiry.update_if_expired(time_window.as_secs(), when) { + self.value.store(increment, Ordering::SeqCst); + } else { + self.value.fetch_add(increment, Ordering::SeqCst); + } + } + + pub fn inc_actor(&self, actor: A, increment: u64, time_window: Duration) { + self.inc_actor_at(actor, increment, time_window, SystemTime::now()); + } + + pub fn inc_actor_at(&self, actor: A, increment: u64, time_window: Duration, when: SystemTime) { + if actor == self.ourselves { + self.inc_at(increment, time_window, when); + } else { + let mut guard = self.others.write().unwrap(); + if self + .expiry + .update_if_expired(time_window.as_micros() as u64, when) + { + guard.insert(actor, increment); + } else { + *guard.entry(actor).or_insert(0) += increment; + } + } + } + + pub fn merge(&self, other: Self) { + self.merge_at(other, SystemTime::now()); + } + + pub fn merge_at(&self, other: Self, when: SystemTime) { + let (expiry, other_values) = other.into_inner(); + if expiry > when { + let _ = self.expiry.merge_at(expiry.into(), when); + if self.expiry.expired_at(when) { + self.reset(expiry); + } + let ourselves = self.value.load(Ordering::SeqCst); + let mut others = self.others.write().unwrap(); + for (actor, other_value) in other_values { + if actor == self.ourselves { + if other_value > ourselves { + self.value + .fetch_add(other_value - ourselves, Ordering::SeqCst); + } + } else { + match others.entry(actor) { + Entry::Vacant(entry) => { + if other_value > 0 { + entry.insert(other_value); + } + } + Entry::Occupied(mut known) => { + let local = known.get_mut(); + if other_value > *local { + *local = other_value; + } + } + } + } + } + } + } + + pub fn ttl(&self) -> Duration { + self.expiry.duration() + } + + pub fn expiry(&self) -> SystemTime { + self.expiry.expires_at() + } + + pub fn into_inner(self) -> (SystemTime, BTreeMap) { + let Self { + ourselves, + value, + others, + expiry, + } = self; + let mut map = others.into_inner().unwrap(); + map.insert(ourselves, value.into_inner()); + (expiry.into_inner(), map) + } + + fn reset(&self, expiry: SystemTime) { + let mut guard = self.others.write().unwrap(); + self.expiry.update(expiry); + self.value.store(0, Ordering::SeqCst); + guard.clear() + } +} + +impl Clone for CrCounterValue { + fn clone(&self) -> Self { + Self { + ourselves: self.ourselves.clone(), + value: AtomicU64::new(self.value.load(Ordering::SeqCst)), + others: RwLock::new(self.others.read().unwrap().clone()), + expiry: self.expiry.clone(), + } + } +} + +impl From<(SystemTime, BTreeMap)> for CrCounterValue { + fn from(value: (SystemTime, BTreeMap)) -> Self { + Self { + ourselves: A::default(), + value: Default::default(), + others: RwLock::new(value.1), + expiry: value.0.into(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::storage::distributed::cr_counter_value::CrCounterValue; + use std::time::{Duration, SystemTime}; + + #[test] + fn local_increments_are_readable() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + a.inc(3, window); + assert_eq!(3, a.read()); + a.inc(2, window); + assert_eq!(5, a.read()); + } + + #[test] + fn local_increments_expire() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let now = SystemTime::now(); + a.inc_at(3, window, now); + assert_eq!(3, a.read()); + a.inc_at(2, window, now + window); + assert_eq!(2, a.read()); + } + + #[test] + fn other_increments_are_readable() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + a.inc_actor('B', 3, window); + assert_eq!(3, a.read()); + a.inc_actor('B', 2, window); + assert_eq!(5, a.read()); + } + + #[test] + fn other_increments_expire() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let now = SystemTime::now(); + a.inc_actor_at('B', 3, window, now); + assert_eq!(3, a.read()); + a.inc_actor_at('B', 2, window, now + window); + assert_eq!(2, a.read()); + } + + #[test] + fn merges() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + a.merge(b); + assert_eq!(a.read(), 5); + } + + #[test] + fn merges_symetric() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.merge(a); + assert_eq!(b.read(), 5); + } + + #[test] + fn merges_overrides_with_larger_value() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.inc_actor('A', 2, window); // older value! + b.merge(a); // merges the 3 + assert_eq!(b.read(), 5); + } + + #[test] + fn merges_ignore_lesser_values() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.inc_actor('A', 5, window); // newer value! + b.merge(a); // ignores the 3 and keeps its own 5 for a + assert_eq!(b.read(), 7); + } + + #[test] + fn merge_ignores_expired_sets() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', Duration::ZERO); + a.inc(3, Duration::ZERO); + let b = CrCounterValue::new('B', window); + b.inc(2, window); + b.merge(a); + assert_eq!(b.read(), 2); + } + + #[test] + fn merge_ignores_expired_sets_symmetric() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', Duration::ZERO); + a.inc(3, Duration::ZERO); + let b = CrCounterValue::new('B', window); + b.inc(2, window); + a.merge(b); + assert_eq!(a.read(), 2); + } + + #[test] + fn merge_uses_earliest_expiry() { + let later = Duration::from_secs(1); + let a = CrCounterValue::new('A', later); + let sooner = Duration::from_millis(200); + let b = CrCounterValue::new('B', sooner); + a.inc(3, later); + b.inc(2, later); + a.merge(b); + assert!(a.expiry.duration() < sooner); + } +} diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs new file mode 100644 index 00000000..5732a322 --- /dev/null +++ b/limitador/src/storage/distributed/mod.rs @@ -0,0 +1,438 @@ +use std::collections::hash_map::Entry; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::net::ToSocketAddrs; +use std::ops::Deref; +use std::sync::{Arc, RwLock}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use moka::sync::Cache; +use serde::{Deserialize, Serialize}; +use tokio::net::UdpSocket; +use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; + +use crate::counter::Counter; +use crate::limit::{Limit, Namespace}; +use crate::storage::distributed::cr_counter_value::CrCounterValue; +use crate::storage::{Authorization, CounterStorage, StorageErr}; + +mod cr_counter_value; + +type NamespacedLimitCounters = HashMap>; + +pub struct CrInMemoryStorage { + identifier: String, + sender: Sender, + limits_for_namespace: Arc>>>, + qualified_counters: Arc>>>, +} + +impl CounterStorage for CrInMemoryStorage { + #[tracing::instrument(skip_all)] + fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { + let limits_by_namespace = self.limits_for_namespace.read().unwrap(); + + let mut value = 0; + + if counter.is_qualified() { + if let Some(counter) = self.qualified_counters.get(counter) { + value = counter.read(); + } + } else if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) { + if let Some(counter) = limits.get(counter.limit()) { + value = counter.read(); + } + } + + Ok(counter.max_value() >= value + delta) + } + + #[tracing::instrument(skip_all)] + fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr> { + if limit.variables().is_empty() { + let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + limits_by_namespace + .entry(limit.namespace().clone()) + .or_default() + .entry(limit.clone()) + .or_insert(CrCounterValue::new( + self.identifier.clone(), + Duration::from_secs(limit.seconds()), + )); + } + Ok(()) + } + + #[tracing::instrument(skip_all)] + fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { + let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + let now = SystemTime::now(); + if counter.is_qualified() { + let value = match self.qualified_counters.get(counter) { + None => self.qualified_counters.get_with(counter.clone(), || { + Arc::new(CrCounterValue::new( + self.identifier.clone(), + Duration::from_secs(counter.seconds()), + )) + }), + Some(counter) => counter, + }; + self.increment_counter(counter.clone(), &value, delta, now); + } else { + match limits_by_namespace.entry(counter.limit().namespace().clone()) { + Entry::Vacant(v) => { + let mut limits = HashMap::new(); + let duration = Duration::from_secs(counter.seconds()); + let counter_val = CrCounterValue::new(self.identifier.clone(), duration); + self.increment_counter(counter.clone(), &counter_val, delta, now); + limits.insert(counter.limit().clone(), counter_val); + v.insert(limits); + } + Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) { + Entry::Vacant(v) => { + let duration = Duration::from_secs(counter.seconds()); + let counter_value = CrCounterValue::new(self.identifier.clone(), duration); + self.increment_counter(counter.clone(), &counter_value, delta, now); + v.insert(counter_value); + } + Entry::Occupied(o) => { + self.increment_counter(counter.clone(), o.get(), delta, now); + } + }, + } + } + Ok(()) + } + + #[tracing::instrument(skip_all)] + fn check_and_update( + &self, + counters: &mut Vec, + delta: u64, + load_counters: bool, + ) -> Result { + let limits_by_namespace = self.limits_for_namespace.write().unwrap(); + let mut first_limited = None; + let mut counter_values_to_update: Vec<(&CrCounterValue, Counter)> = Vec::new(); + let mut qualified_counter_values_to_updated: Vec<(Arc>, Counter)> = + Vec::new(); + let now = SystemTime::now(); + + let mut process_counter = + |counter: &mut Counter, value: u64, delta: u64| -> Option { + if load_counters { + let remaining = counter.max_value().checked_sub(value + delta); + counter.set_remaining(remaining.unwrap_or(0)); + if first_limited.is_none() && remaining.is_none() { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } + } + if !Self::counter_is_within_limits(counter, Some(&value), delta) { + return Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } + None + }; + + // Process simple counters + for counter in counters.iter_mut().filter(|c| !c.is_qualified()) { + let atomic_expiring_value: &CrCounterValue = limits_by_namespace + .get(counter.limit().namespace()) + .and_then(|limits| limits.get(counter.limit())) + .unwrap(); + + if let Some(limited) = process_counter(counter, atomic_expiring_value.read(), delta) { + if !load_counters { + return Ok(limited); + } + } + counter_values_to_update.push((atomic_expiring_value, counter.clone())); + } + + // Process qualified counters + for counter in counters.iter_mut().filter(|c| c.is_qualified()) { + let value = match self.qualified_counters.get(counter) { + None => self.qualified_counters.get_with(counter.clone(), || { + Arc::new(CrCounterValue::new( + self.identifier.clone(), + Duration::from_secs(counter.seconds()), + )) + }), + Some(counter) => counter, + }; + + if let Some(limited) = process_counter(counter, value.read(), delta) { + if !load_counters { + return Ok(limited); + } + } + + qualified_counter_values_to_updated.push((value, counter.clone())); + } + + if let Some(limited) = first_limited { + return Ok(limited); + } + + // Update counters + counter_values_to_update + .into_iter() + .for_each(|(v, counter)| { + self.increment_counter(counter, v, delta, now); + }); + qualified_counter_values_to_updated + .into_iter() + .for_each(|(v, counter)| { + self.increment_counter(counter, v.deref(), delta, now); + }); + + Ok(Authorization::Ok) + } + + #[tracing::instrument(skip_all)] + fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { + let mut res = HashSet::new(); + + let namespaces: HashSet<&Namespace> = limits.iter().map(Limit::namespace).collect(); + let limits_by_namespace = self.limits_for_namespace.read().unwrap(); + + for namespace in namespaces { + if let Some(limits) = limits_by_namespace.get(namespace) { + for limit in limits.keys() { + if limits.contains_key(limit) { + for (counter, expiring_value) in self.counters_in_namespace(namespace) { + let mut counter_with_val = counter.clone(); + counter_with_val.set_remaining( + counter_with_val.max_value() - expiring_value.read(), + ); + counter_with_val.set_expires_in(expiring_value.ttl()); + if counter_with_val.expires_in().unwrap() > Duration::ZERO { + res.insert(counter_with_val); + } + } + } + } + } + } + + for (counter, expiring_value) in self.qualified_counters.iter() { + if limits.contains(counter.limit()) { + let mut counter_with_val = counter.deref().clone(); + counter_with_val + .set_remaining(counter_with_val.max_value() - expiring_value.read()); + counter_with_val.set_expires_in(expiring_value.ttl()); + if counter_with_val.expires_in().unwrap() > Duration::ZERO { + res.insert(counter_with_val); + } + } + } + + Ok(res) + } + + #[tracing::instrument(skip_all)] + fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + for limit in limits { + self.delete_counters_of_limit(&limit); + } + Ok(()) + } + + #[tracing::instrument(skip_all)] + fn clear(&self) -> Result<(), StorageErr> { + self.limits_for_namespace.write().unwrap().clear(); + Ok(()) + } +} + +impl CrInMemoryStorage { + pub fn new(identifier: String, cache_size: u64, local: String, broadcast: String) -> Self { + let (sender, mut rx) = mpsc::channel(1000); + + let local = local.to_socket_addrs().unwrap().next().unwrap(); + let remote = broadcast.clone(); + tokio::spawn(async move { + let sock = UdpSocket::bind(local).await.unwrap(); + sock.set_broadcast(true).unwrap(); + sock.connect(remote).await.unwrap(); + loop { + let message: CounterValueMessage = rx.recv().await.unwrap(); + let buf = postcard::to_stdvec(&message).unwrap(); + match sock.send(&buf).await { + Ok(len) => { + if len != buf.len() { + println!("Couldn't send complete message!"); + } + } + Err(err) => println!("Couldn't send update: {:?}", err), + }; + } + }); + + let limits_for_namespace = Arc::new(RwLock::new(HashMap::< + Namespace, + HashMap>, + >::new())); + let qualified_counters: Arc>>> = + Arc::new(Cache::new(cache_size)); + + { + let limits_for_namespace = limits_for_namespace.clone(); + let qualified_counters = qualified_counters.clone(); + tokio::spawn(async move { + let sock = UdpSocket::bind(broadcast).await.unwrap(); + sock.set_broadcast(true).unwrap(); + let mut buf = [0; 1024]; + loop { + let (len, addr) = sock.recv_from(&mut buf).await.unwrap(); + if addr != local { + match postcard::from_bytes::(&buf[..len]) { + Ok(message) => { + let CounterValueMessage { + counter_key, + expiry, + values, + } = message; + let counter = >::into(counter_key); + if counter.is_qualified() { + if let Some(counter) = qualified_counters.get(&counter) { + counter.merge( + (UNIX_EPOCH + Duration::from_secs(expiry), values) + .into(), + ); + } + } else { + let counters = limits_for_namespace.read().unwrap(); + let limits = counters.get(counter.namespace()).unwrap(); + let value = limits.get(counter.limit()).unwrap(); + value.merge( + (UNIX_EPOCH + Duration::from_secs(expiry), values).into(), + ); + }; + } + Err(err) => { + println!("Error from {} bytes: {:?} \n{:?}", len, err, &buf[..len]) + } + } + } + } + }); + } + + Self { + identifier, + sender, + limits_for_namespace, + qualified_counters, + } + } + + fn counters_in_namespace( + &self, + namespace: &Namespace, + ) -> HashMap> { + let mut res: HashMap> = HashMap::new(); + + if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) { + for (limit, value) in counters_by_limit { + res.insert( + Counter::new(limit.clone(), HashMap::default()), + value.clone(), + ); + } + } + + for (counter, value) in self.qualified_counters.iter() { + if counter.namespace() == namespace { + res.insert(counter.deref().clone(), value.deref().clone()); + } + } + + res + } + + fn delete_counters_of_limit(&self, limit: &Limit) { + if let Some(counters_by_limit) = self + .limits_for_namespace + .write() + .unwrap() + .get_mut(limit.namespace()) + { + counters_by_limit.remove(limit); + } + } + + fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool { + match current_val { + Some(current_val) => current_val + delta <= counter.max_value(), + None => counter.max_value() >= delta, + } + } + + fn increment_counter( + &self, + key: Counter, + counter: &CrCounterValue, + delta: u64, + when: SystemTime, + ) { + counter.inc_at(delta, Duration::from_secs(key.seconds()), when); + let sender = self.sender.clone(); + let counter = counter.clone(); + tokio::spawn(async move { + let (expiry, values) = counter.into_inner(); + let message = CounterValueMessage { + counter_key: key.into(), + expiry: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), + values, + }; + sender.send(message).await + }); + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct CounterValueMessage { + counter_key: CounterKey, + expiry: u64, + values: BTreeMap, +} + +#[derive(Debug, Serialize, Deserialize)] +struct CounterKey { + namespace: Namespace, + seconds: u64, + conditions: HashSet, + variables: HashSet, + vars: HashMap, +} + +impl From for CounterKey { + fn from(value: Counter) -> Self { + Self { + namespace: value.namespace().clone(), + seconds: value.seconds(), + variables: value.limit().variables(), + conditions: value.limit().conditions(), + vars: value.set_variables().clone(), + } + } +} + +impl From for Counter { + fn from(value: CounterKey) -> Self { + Self::new( + Limit::new( + value.namespace, + 0, + value.seconds, + value.conditions, + value.vars.keys(), + ), + value.vars, + ) + } +} diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index 4db70278..22abd33a 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -8,8 +8,13 @@ use thiserror::Error; #[cfg(feature = "disk_storage")] pub mod disk; +#[cfg(feature = "distributed_storage")] +pub mod distributed; pub mod in_memory; +#[cfg(feature = "distributed_storage")] +pub use crate::storage::distributed::CrInMemoryStorage as DistributedInMemoryStorage; + #[cfg(feature = "redis_storage")] pub mod redis; diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index f14d8f95..2b1e9afe 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -13,6 +13,14 @@ macro_rules! test_with_all_storage_impls { $function(&mut TestsLimiter::new_from_blocking_impl(rate_limiter)).await; } + #[cfg(feature = "distributed_storage")] + #[tokio::test] + async fn [<$function _distributed_storage>]() { + let rate_limiter = + RateLimiter::new_with_storage(Box::new(CrInMemoryStorage::new("test_node".to_owned(), 10_000, "127.0.0.1:19876".to_owned(), "127.0.0.255:19876".to_owned()))); + $function(&mut TestsLimiter::new_from_blocking_impl(rate_limiter)).await; + } + #[tokio::test] async fn [<$function _disk_storage>]() { let dir = TempDir::new().expect("We should have a dir!"); @@ -89,6 +97,8 @@ mod test { use crate::helpers::tests_limiter::*; use limitador::limit::Limit; use limitador::storage::disk::{DiskStorage, OptimizeFor}; + #[cfg(feature = "distributed_storage")] + use limitador::storage::distributed::CrInMemoryStorage; use limitador::storage::in_memory::InMemoryStorage; use std::collections::{HashMap, HashSet}; use std::thread::sleep;