Skip to content

Commit

Permalink
Merge pull request #191 from Kuadrant/atomic-expiring-value
Browse files Browse the repository at this point in the history
AtomicExpiringValue
  • Loading branch information
didierofrivia authored Aug 1, 2023
2 parents 0dbf403 + de07e74 commit 5377654
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 57 deletions.
144 changes: 144 additions & 0 deletions limitador/src/storage/atomic_expiring_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicI64, AtomicU64};
use std::time::{Duration, SystemTime, UNIX_EPOCH};

#[derive(Debug)]
pub(crate) struct AtomicExpiringValue {
value: AtomicI64,
expiry: AtomicU64, // in microseconds
}

impl AtomicExpiringValue {
pub fn new(value: i64, expiry: SystemTime) -> Self {
let expiry = Self::get_duration_micros(expiry);
Self {
value: AtomicI64::new(value),
expiry: AtomicU64::new(expiry),
}
}

pub fn value_at(&self, when: SystemTime) -> i64 {
let when = Self::get_duration_micros(when);
let expiry = self.expiry.load(Ordering::SeqCst);
if expiry <= when {
return 0;
}
self.value.load(Ordering::SeqCst)
}

pub fn value(&self) -> i64 {
self.value_at(SystemTime::now())
}

pub fn update(&self, delta: i64, ttl: u64, when: SystemTime) -> i64 {
let ttl_micros = ttl * 1_000_000;
let when_micros = Self::get_duration_micros(when);

let expiry = self.expiry.load(Ordering::SeqCst);
if expiry <= when_micros {
let new_expiry = when_micros + ttl_micros;
if self
.expiry
.compare_exchange(expiry, new_expiry, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
self.value.store(delta, Ordering::SeqCst);
}
return delta;
}
self.value.fetch_add(delta, Ordering::SeqCst) + delta
}

pub fn ttl(&self) -> Duration {
let expiry =
SystemTime::UNIX_EPOCH + Duration::from_micros(self.expiry.load(Ordering::SeqCst));
expiry
.duration_since(SystemTime::now())
.unwrap_or(Duration::ZERO)
}

fn get_duration_micros(when: SystemTime) -> u64 {
when.duration_since(UNIX_EPOCH)
.expect("SystemTime before UNIX EPOCH!")
.as_micros() as u64
}
}

impl Default for AtomicExpiringValue {
fn default() -> Self {
AtomicExpiringValue {
value: AtomicI64::new(0),
expiry: AtomicU64::new(0),
}
}
}

impl Clone for AtomicExpiringValue {
fn clone(&self) -> Self {
AtomicExpiringValue {
value: AtomicI64::new(self.value.load(Ordering::SeqCst)),
expiry: AtomicU64::new(self.expiry.load(Ordering::SeqCst)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::{Duration, SystemTime};

#[test]
fn returns_value_when_valid() {
let now = SystemTime::now();
let val = AtomicExpiringValue::new(42, now);
assert_eq!(val.value_at(now - Duration::from_secs(1)), 42);
}

#[test]
fn returns_default_when_expired() {
let now = SystemTime::now();
let val = AtomicExpiringValue::new(42, now - Duration::from_secs(1));
assert_eq!(val.value_at(now), 0);
}

#[test]
fn returns_default_on_expiry() {
let now = SystemTime::now();
let val = AtomicExpiringValue::new(42, now);
assert_eq!(val.value_at(now), 0);
}

#[test]
fn updates_when_valid() {
let now = SystemTime::now();
let val = AtomicExpiringValue::new(42, now + Duration::from_secs(1));
val.update(3, 10, now);
assert_eq!(val.value_at(now - Duration::from_secs(1)), 45);
}

#[test]
fn updates_when_expired() {
let now = SystemTime::now();
let val = AtomicExpiringValue::new(42, now);
assert_eq!(val.ttl(), Duration::ZERO);
val.update(3, 10, now);
assert_eq!(val.value_at(now - Duration::from_secs(1)), 3);
}

#[test]
fn test_overlapping_updates() {
let now = SystemTime::now();
let atomic_expiring_value = AtomicExpiringValue::new(42, now + Duration::from_secs(10));

thread::scope(|s| {
s.spawn(|| {
atomic_expiring_value.update(1, 1, now);
});
s.spawn(|| {
atomic_expiring_value.update(2, 1, now + Duration::from_secs(11));
});
});
assert!([2i64, 3i64].contains(&atomic_expiring_value.value.load(Ordering::SeqCst)));
}
}
11 changes: 0 additions & 11 deletions limitador/src/storage/expiring_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,6 @@ impl ExpiringValue {
Self { value, expiry }
}

pub fn update_mut(&mut self, delta: i64, ttl: u64, now: SystemTime) {
let expiry = if self.expiry <= now {
now + Duration::from_secs(ttl)
} else {
self.expiry
};

self.value = self.value_at(now) + delta;
self.expiry = expiry;
}

#[must_use]
pub fn merge(self, other: ExpiringValue, now: SystemTime) -> Self {
if self.expiry > now {
Expand Down
83 changes: 37 additions & 46 deletions limitador/src/storage/in_memory.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::counter::Counter;
use crate::limit::{Limit, Namespace};
use crate::storage::expiring_value::ExpiringValue;
use crate::storage::atomic_expiring_value::AtomicExpiringValue;
use crate::storage::{Authorization, CounterStorage, StorageErr};
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -53,7 +53,7 @@ impl PartialEq for CounterKey {
type NamespacedLimitCounters<T> = HashMap<Namespace, HashMap<Limit, HashMap<CounterKey, T>>>;

pub struct InMemoryStorage {
limits_for_namespace: RwLock<NamespacedLimitCounters<ExpiringValue>>,
limits_for_namespace: RwLock<NamespacedLimitCounters<AtomicExpiringValue>>,
}

impl CounterStorage for InMemoryStorage {
Expand Down Expand Up @@ -103,6 +103,8 @@ impl CounterStorage for InMemoryStorage {
) -> Result<Authorization, StorageErr> {
let mut limits_by_namespace = self.limits_for_namespace.write().unwrap();
let mut first_limited = None;
let mut counter_values_to_update: Vec<(&AtomicExpiringValue, u64)> = Vec::new();
let now = SystemTime::now();

let mut process_counter =
|counter: &mut Counter, value: i64, delta: i64| -> Option<Authorization> {
Expand All @@ -123,56 +125,42 @@ impl CounterStorage for InMemoryStorage {
None
};

for counter in counters.iter_mut() {
if counter.max_value() < delta {
if let Some(limited) = process_counter(counter, 0, delta) {
if !load_counters {
return Ok(limited);
}
}
continue;
}
// Normalize counters and values
for counter in counters.iter() {
limits_by_namespace
.entry(counter.limit().namespace().clone())
.or_insert_with(HashMap::new)
.entry(counter.limit().clone())
.or_insert_with(HashMap::new)
.entry(counter.into())
.or_insert_with(AtomicExpiringValue::default);
}

let value = Some(
limits_by_namespace
.get(counter.limit().namespace())
.and_then(|limits| limits.get(counter.limit()))
.and_then(|counters| counters.get(&counter.into()))
.map(|expiring_value| expiring_value.value())
.unwrap_or(0),
);
// Process counters
for counter in counters.iter_mut() {
let atomic_expiring_value: &AtomicExpiringValue = limits_by_namespace
.get(counter.limit().namespace())
.and_then(|limits| limits.get(counter.limit()))
.and_then(|counters| counters.get(&counter.into()))
.unwrap();

if let Some(limited) = process_counter(counter, value.unwrap(), delta) {
if let Some(limited) = process_counter(counter, atomic_expiring_value.value(), delta) {
if !load_counters {
return Ok(limited);
}
}

counter_values_to_update.push((atomic_expiring_value, counter.seconds()));
}

if let Some(limited) = first_limited {
return Ok(limited);
}

for counter in counters.iter_mut() {
let now = SystemTime::now();
match limits_by_namespace
.entry(counter.limit().namespace().clone())
.or_insert_with(HashMap::new)
.entry(counter.limit().clone())
.or_insert_with(HashMap::new)
.entry(counter.into())
{
Entry::Vacant(v) => {
v.insert(ExpiringValue::new(
delta,
now + Duration::from_secs(counter.seconds()),
));
}
Entry::Occupied(mut o) => {
o.get_mut().update_mut(delta, counter.seconds(), now);
}
}
}
// Update counters
counter_values_to_update.iter().for_each(|(v, ttl)| {
v.update(delta, *ttl, now);
});

Ok(Authorization::Ok)
}
Expand Down Expand Up @@ -224,8 +212,11 @@ impl InMemoryStorage {
}
}

fn counters_in_namespace(&self, namespace: &Namespace) -> HashMap<Counter, ExpiringValue> {
let mut res: HashMap<Counter, ExpiringValue> = HashMap::new();
fn counters_in_namespace(
&self,
namespace: &Namespace,
) -> HashMap<Counter, AtomicExpiringValue> {
let mut res: HashMap<Counter, AtomicExpiringValue> = HashMap::new();

if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) {
for (limit, values) in counters_by_limit {
Expand All @@ -251,20 +242,20 @@ impl InMemoryStorage {

fn insert_or_update_counter(
&self,
counters: &mut HashMap<CounterKey, ExpiringValue>,
counters: &mut HashMap<CounterKey, AtomicExpiringValue>,
counter: &Counter,
delta: i64,
) {
let now = SystemTime::now();
match counters.entry(counter.into()) {
Entry::Vacant(v) => {
v.insert(ExpiringValue::new(
v.insert(AtomicExpiringValue::new(
delta,
now + Duration::from_secs(counter.seconds()),
));
}
Entry::Occupied(mut o) => {
o.get_mut().update_mut(delta, counter.seconds(), now);
Entry::Occupied(o) => {
o.get().update(delta, counter.seconds(), now);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions limitador/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod redis;
#[cfg(feature = "infinispan_storage")]
pub mod infinispan;

mod atomic_expiring_value;
mod expiring_value;
mod keys;

Expand Down

0 comments on commit 5377654

Please sign in to comment.