Skip to content

Commit

Permalink
Support providing an optional id to limits/counters
Browse files Browse the repository at this point in the history
* add a key_for_counters_v2 function that uses the id as the key if set, otherwise uses the previous key encoding strategy.
* updated the distributed store to use key_for_counters_v2.  Since we can’t decode a partial counter from id based keys, we now also keep in memory the Counter in a counter field of the limits map.

Signed-off-by: Hiram Chirino <[email protected]>
  • Loading branch information
chirino committed Jul 2, 2024
1 parent 72c2b12 commit 70ebd0e
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 136 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- uses: abelfodil/protoc-action@v1
with:
protoc-version: '3.19.4'
- run: cargo check
- run: cargo check --all-features

test:
name: Test Suite
Expand Down
6 changes: 6 additions & 0 deletions limitador-server/src/http_api/request_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct CheckAndReportInfo {

#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Apiv2Schema)]
pub struct Limit {
id: Option<String>,
namespace: String,
max_value: u64,
seconds: u64,
Expand All @@ -29,6 +30,7 @@ pub struct Limit {
impl From<&LimitadorLimit> for Limit {
fn from(ll: &LimitadorLimit) -> Self {
Self {
id: ll.id().clone(),
namespace: ll.namespace().as_ref().to_string(),
max_value: ll.max_value(),
seconds: ll.seconds(),
Expand All @@ -49,6 +51,10 @@ impl From<Limit> for LimitadorLimit {
limit.variables,
);

if let Some(id) = limit.id {
limitador_limit.set_id(id);
}

if let Some(name) = limit.name {
limitador_limit.set_name(name)
}
Expand Down
2 changes: 1 addition & 1 deletion limitador/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ lenient_conditions = []
moka = { version = "0.12", features = ["sync"] }
dashmap = "5.5.3"
getrandom = { version = "0.2", features = ["js"] }
serde = { version = "1", features = ["derive"] }
serde = { version = "1", features = ["derive", "rc"] }
postcard = { version = "1.0.4", features = ["use-std"] }
serde_json = "1"
rmp-serde = "1.1.0"
Expand Down
4 changes: 4 additions & 0 deletions limitador/src/counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ impl Counter {
Duration::from_secs(self.limit.seconds())
}

pub fn id(&self) -> &Option<String> {
self.limit.id()
}

pub fn namespace(&self) -> &Namespace {
self.limit.namespace()
}
Expand Down
25 changes: 25 additions & 0 deletions limitador/src/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ impl From<String> for Namespace {

#[derive(Eq, Debug, Clone, Serialize, Deserialize)]
pub struct Limit {
#[serde(skip_serializing, default)]
id: Option<String>,
namespace: Namespace,
#[serde(skip_serializing, default)]
max_value: u64,
Expand Down Expand Up @@ -319,6 +321,7 @@ impl Limit {
{
// the above where-clause is needed in order to call unwrap().
Self {
id: None,
namespace: namespace.into(),
max_value,
seconds,
Expand All @@ -335,6 +338,14 @@ impl Limit {
&self.namespace
}

pub fn set_id(&mut self, value: String) {
self.id = Some(value);
}

pub fn id(&self) -> &Option<String> {
&self.id
}

pub fn max_value(&self) -> u64 {
self.max_value
}
Expand Down Expand Up @@ -998,4 +1009,18 @@ mod tests {
let result = serde_json::to_string(&condition).expect("Should serialize");
assert_eq!(result, r#""foobar == \"ok\"""#.to_string());
}

#[test]
fn limit_id() {
let mut limit = Limit::new(
"test_namespace",
10,
60,
vec!["req.method == 'GET'"],
vec!["app_id"],
);
limit.set_id("test_id".to_string());

assert_eq!(limit.id().clone(), Some("test_id".to_string()))
}
}
27 changes: 17 additions & 10 deletions limitador/src/storage/distributed/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::{error::Error, io::ErrorKind, pin::Pin};

use crate::counter::Counter;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{Permit, Sender};
use tokio::sync::{broadcast, mpsc, Notify, RwLock};
Expand Down Expand Up @@ -156,9 +157,10 @@ impl Session {
update = udpates_to_send.recv() => {
let update = update.map_err(|_| Status::unknown("broadcast error"))?;
// Multiple updates collapse into a single update for the same key
if !tx_updates_by_key.contains_key(&update.key) {
tx_updates_by_key.insert(update.key.clone(), update.value);
tx_updates_order.push(update.key);
let key = &update.key.clone();
if !tx_updates_by_key.contains_key(key) {
tx_updates_by_key.insert(key.clone(), update);
tx_updates_order.push(key.clone());
notifier.notify_one();
}
}
Expand All @@ -174,7 +176,7 @@ impl Session {

let key = tx_updates_order.remove(0);
let cr_counter_value = tx_updates_by_key.remove(&key).unwrap().clone();
let (expiry, values) = (*cr_counter_value).clone().into_inner();
let (expiry, values) = cr_counter_value.value.clone().into_inner();

// only send the update if it has not expired.
if expiry > SystemTime::now() {
Expand Down Expand Up @@ -437,19 +439,24 @@ type CounterUpdateFn = Pin<Box<dyn Fn(CounterUpdate) + Sync + Send>>;
#[derive(Clone, Debug)]
pub struct CounterEntry {
pub key: Vec<u8>,
pub value: Arc<CrCounterValue<String>>,
pub counter: Counter,
pub value: CrCounterValue<String>,
}

impl CounterEntry {
pub fn new(key: Vec<u8>, value: Arc<CrCounterValue<String>>) -> Self {
Self { key, value }
pub fn new(key: Vec<u8>, counter: Counter, value: CrCounterValue<String>) -> Self {
Self {
key,
counter,
value,
}
}
}

#[derive(Clone)]
struct BrokerState {
id: String,
publisher: broadcast::Sender<CounterEntry>,
publisher: broadcast::Sender<Arc<CounterEntry>>,
on_counter_update: Arc<CounterUpdateFn>,
on_re_sync: Arc<Sender<Sender<Option<CounterUpdate>>>>,
}
Expand All @@ -471,7 +478,7 @@ impl Broker {
on_re_sync: Sender<Sender<Option<CounterUpdate>>>,
) -> Broker {
let (tx, _) = broadcast::channel(16);
let publisher: broadcast::Sender<CounterEntry> = tx;
let publisher: broadcast::Sender<Arc<CounterEntry>> = tx;

Broker {
listen_address,
Expand All @@ -489,7 +496,7 @@ impl Broker {
}
}

pub fn publish(&self, counter_update: CounterEntry) {
pub fn publish(&self, counter_update: Arc<CounterEntry>) {
// ignore the send error, it just means there are no active subscribers
_ = self.broker_state.publisher.send(counter_update);
}
Expand Down
Loading

0 comments on commit 70ebd0e

Please sign in to comment.