diff --git a/.dockerignore b/.dockerignore index 8b2bbc53..3ed62899 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,3 +3,6 @@ Dockerfile *.swp .dockerignore .git + +# OSX files +.DS_Store \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7ba900ac..57e5d6c5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ # These are backup files generated by rustfmt **/*.rs.bk + +# OSX files +.DS_Store \ No newline at end of file diff --git a/doc/server/configuration.md b/doc/server/configuration.md index d4f5025e..4e0b01ce 100644 --- a/doc/server/configuration.md +++ b/doc/server/configuration.md @@ -12,15 +12,36 @@ ARGS: The limit file to use OPTIONS: - -b, --rls-ip The IP to listen on for RLS [default: 0.0.0.0] - -p, --rls-port The port to listen on for RLS [default: 8081] - -B, --http-ip The IP to listen on for HTTP [default: 0.0.0.0] - -P, --http-port The port to listen on for HTTP [default: 8080] - -l, --limit-name-in-labels Include the Limit Name in prometheus label - -v Sets the level of verbosity - --validate Validates the LIMITS_FILE and exits - -h, --help Print help information - -V, --version Print version information + -b, --rls-ip + The IP to listen on for RLS [default: 0.0.0.0] + + -p, --rls-port + The port to listen on for RLS [default: 8081] + + -B, --http-ip + The IP to listen on for HTTP [default: 0.0.0.0] + + -P, --http-port + The port to listen on for HTTP [default: 8080] + + -l, --limit-name-in-labels + Include the Limit Name in prometheus label + + -v + Sets the level of verbosity + + --validate + Validates the LIMITS_FILE and exits + + -H, --rate-limit-headers + Enables rate limit response headers [default: NONE] [possible values: NONE, + DRAFT_VERSION_03] + + -h, --help + Print help information + + -V, --version + Print version information STORAGES: memory Counters are held in Limitador (ephemeral) @@ -319,3 +340,13 @@ require Redis. - Optional. By default, Limitador stores the limits in memory and does not require Infinispan. - Format: `URL`, in the format of `http://username:password@127.0.0.1:11222`. + + +#### `RATE_LIMIT_HEADERS` + +- Enables rate limit response headers. Only supported by the RLS server. +- Optional. Defaults to `"NONE"`. +- Must be one of: + - `"NONE"` - Does not add any additional headers to the http response. + - `"DRAFT_VERSION_03"`. Adds response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html + diff --git a/limitador-server/README.md b/limitador-server/README.md index d7ffc754..62102b18 100644 --- a/limitador-server/README.md +++ b/limitador-server/README.md @@ -22,15 +22,36 @@ ARGS: The limit file to use OPTIONS: - -b, --rls-ip The IP to listen on for RLS [default: 0.0.0.0] - -p, --rls-port The port to listen on for RLS [default: 8081] - -B, --http-ip The IP to listen on for HTTP [default: 0.0.0.0] - -P, --http-port The port to listen on for HTTP [default: 8080] - -l, --limit-name-in-labels Include the Limit Name in prometheus label - -v Sets the level of verbosity - --validate Validates the LIMITS_FILE and exits - -h, --help Print help information - -V, --version Print version information + -b, --rls-ip + The IP to listen on for RLS [default: 0.0.0.0] + + -p, --rls-port + The port to listen on for RLS [default: 8081] + + -B, --http-ip + The IP to listen on for HTTP [default: 0.0.0.0] + + -P, --http-port + The port to listen on for HTTP [default: 8080] + + -l, --limit-name-in-labels + Include the Limit Name in prometheus label + + -v + Sets the level of verbosity + + --validate + Validates the LIMITS_FILE and exits + + -H, --rate-limit-headers + Enables rate limit response headers [default: NONE] [possible values: NONE, + DRAFT_VERSION_03] + + -h, --help + Print help information + + -V, --version + Print version information STORAGES: memory Counters are held in Limitador (ephemeral) diff --git a/limitador-server/src/config.rs b/limitador-server/src/config.rs index e1355f72..db63fb4b 100644 --- a/limitador-server/src/config.rs +++ b/limitador-server/src/config.rs @@ -18,6 +18,7 @@ // HTTP_API_HOST: host // just to become HTTP_API_HOST:HTTP_API_PORT as &str // HTTP_API_PORT: port +use crate::envoy_rls::server::RateLimitHeaders; use log::LevelFilter; #[derive(Debug)] @@ -30,6 +31,7 @@ pub struct Configuration { http_port: u16, pub limit_name_in_labels: bool, pub log_level: Option, + pub rate_limit_headers: RateLimitHeaders, } impl Configuration { @@ -37,6 +39,7 @@ impl Configuration { pub const DEFAULT_HTTP_PORT: &'static str = "8080"; pub const DEFAULT_IP_BIND: &'static str = "0.0.0.0"; + #[allow(clippy::too_many_arguments)] pub fn with( storage: StorageConfiguration, limits_file: String, @@ -45,6 +48,7 @@ impl Configuration { http_host: String, http_port: u16, limit_name_in_labels: bool, + rate_limit_headers: RateLimitHeaders, ) -> Self { Self { limits_file, @@ -55,6 +59,7 @@ impl Configuration { http_port, limit_name_in_labels, log_level: None, + rate_limit_headers, } } @@ -79,6 +84,7 @@ impl Default for Configuration { http_port: 0, limit_name_in_labels: false, log_level: None, + rate_limit_headers: RateLimitHeaders::None, } } } diff --git a/limitador-server/src/envoy_rls/server.rs b/limitador-server/src/envoy_rls/server.rs index 30c6c638..3cf93275 100644 --- a/limitador-server/src/envoy_rls/server.rs +++ b/limitador-server/src/envoy_rls/server.rs @@ -1,3 +1,12 @@ +use std::cmp::Ordering; +use std::collections::HashMap; +use std::sync::Arc; + +use tonic::{transport, transport::Server, Request, Response, Status}; + +use limitador::counter::Counter; + +use crate::envoy_rls::server::envoy::config::core::v3::HeaderValue; use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_response::Code; use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_service_server::{ RateLimitService, RateLimitServiceServer, @@ -6,19 +15,26 @@ use crate::envoy_rls::server::envoy::service::ratelimit::v3::{ RateLimitRequest, RateLimitResponse, }; use crate::Limiter; -use std::collections::HashMap; -use std::sync::Arc; -use tonic::{transport, transport::Server, Request, Response, Status}; include!("envoy_types.rs"); +#[derive(PartialEq, Eq, Debug, Clone)] +pub enum RateLimitHeaders { + None, + DraftVersion03, +} + pub struct MyRateLimiter { limiter: Arc, + rate_limit_headers: RateLimitHeaders, } impl MyRateLimiter { - pub fn new(limiter: Arc) -> Self { - Self { limiter } + pub fn new(limiter: Arc, rate_limit_headers: RateLimitHeaders) -> Self { + Self { + limiter, + rate_limit_headers, + } } } @@ -62,45 +78,54 @@ impl RateLimitService for MyRateLimiter { req.hits_addend }; - let is_rate_limited_res = match &*self.limiter { - Limiter::Blocking(limiter) => { - limiter.check_rate_limited_and_update(&namespace, &values, i64::from(hits_addend)) - } + let rate_limited_resp = match &*self.limiter { + Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update( + &namespace, + &values, + i64::from(hits_addend), + self.rate_limit_headers != RateLimitHeaders::None, + ), Limiter::Async(limiter) => { limiter - .check_rate_limited_and_update(&namespace, &values, i64::from(hits_addend)) + .check_rate_limited_and_update( + &namespace, + &values, + i64::from(hits_addend), + self.rate_limit_headers != RateLimitHeaders::None, + ) .await } }; - let resp_code = match is_rate_limited_res { - Ok(rate_limited) => { - if rate_limited { - Code::OverLimit - } else { - Code::Ok - } - } - Err(e) => { - // In this case we could return "Code::Unknown" but that's not - // very helpful. When envoy receives "Unknown" it simply lets - // the request pass and this cannot be configured using the - // "failure_mode_deny" attribute, so it's equivalent to - // returning "Code::Ok". That's why we return an "unavailable" - // error here. What envoy does after receiving that kind of - // error can be configured with "failure_mode_deny". The only - // errors that can happen here have to do with connecting to the - // limits storage, which should be temporary. - error!("Error: {:?}", e); - return Err(Status::unavailable("Service unavailable")); - } + if let Err(e) = rate_limited_resp { + // In this case we could return "Code::Unknown" but that's not + // very helpful. When envoy receives "Unknown" it simply lets + // the request pass and this cannot be configured using the + // "failure_mode_deny" attribute, so it's equivalent to + // returning "Code::Ok". That's why we return an "unavailable" + // error here. What envoy does after receiving that kind of + // error can be configured with "failure_mode_deny". The only + // errors that can happen here have to do with connecting to the + // limits storage, which should be temporary. + error!("Error: {:?}", e); + return Err(Status::unavailable("Service unavailable")); + } + + let mut rate_limited_resp = rate_limited_resp.unwrap(); + let resp_code = if rate_limited_resp.limited { + Code::OverLimit + } else { + Code::Ok }; let reply = RateLimitResponse { overall_code: resp_code.into(), statuses: vec![], request_headers_to_add: vec![], - response_headers_to_add: vec![], + response_headers_to_add: to_response_header( + &self.rate_limit_headers, + &mut rate_limited_resp.counters, + ), raw_body: vec![], dynamic_metadata: None, quota: None, @@ -110,11 +135,71 @@ impl RateLimitService for MyRateLimiter { } } +pub fn to_response_header( + rate_limit_headers: &RateLimitHeaders, + counters: &mut Vec, +) -> Vec { + let mut headers = Vec::new(); + match rate_limit_headers { + RateLimitHeaders::None => {} + + // creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html + RateLimitHeaders::DraftVersion03 => { + // sort by the limit remaining.. + counters.sort_by(|a, b| { + let a_remaining = a.remaining().unwrap_or(a.max_value()); + let b_remaining = b.remaining().unwrap_or(b.max_value()); + if a_remaining - b_remaining < 0 { + Ordering::Less + } else { + Ordering::Greater + } + }); + + let mut all_limits_text = String::with_capacity(20 * counters.len()); + counters.iter_mut().for_each(|counter| { + all_limits_text.push_str( + format!(", {};w={}", counter.max_value(), counter.seconds()).as_str(), + ); + if let Some(name) = counter.limit().name() { + all_limits_text + .push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str()); + } + }); + + if let Some(counter) = counters.first() { + headers.push(HeaderValue { + key: "X-RateLimit-Limit".to_string(), + value: format!("{}{}", counter.max_value(), all_limits_text), + }); + + let mut remaining = counter.remaining().unwrap_or(counter.max_value()); + if remaining < 0 { + remaining = 0 + } + headers.push(HeaderValue { + key: "X-RateLimit-Remaining".to_string(), + value: format!("{}", remaining), + }); + + if let Some(duration) = counter.expires_in() { + headers.push(HeaderValue { + key: "X-RateLimit-Reset".to_string(), + value: format!("{}", duration.as_secs()), + }); + } + } + } + }; + headers +} + pub async fn run_envoy_rls_server( address: String, limiter: Arc, + rate_limit_headers: RateLimitHeaders, ) -> Result<(), transport::Error> { - let rate_limiter = MyRateLimiter::new(limiter); + let rate_limiter = MyRateLimiter::new(limiter, rate_limit_headers); let svc = RateLimitServiceServer::new(rate_limiter); Server::builder() @@ -125,13 +210,23 @@ pub async fn run_envoy_rls_server( #[cfg(test)] mod tests { - use super::*; + use tonic::IntoRequest; + + use limitador::limit::Limit; + use limitador::RateLimiter; + use crate::envoy_rls::server::envoy::extensions::common::ratelimit::v3::rate_limit_descriptor::Entry; use crate::envoy_rls::server::envoy::extensions::common::ratelimit::v3::RateLimitDescriptor; use crate::Configuration; - use limitador::limit::Limit; - use limitador::RateLimiter; - use tonic::IntoRequest; + + use super::*; + + fn header_value(key: &str, value: &str) -> HeaderValue { + HeaderValue { + key: key.to_string(), + value: value.to_string(), + } + } // All these tests use the in-memory storage implementation to simplify. We // know that some storage implementations like the Redis one trade @@ -154,7 +249,10 @@ mod tests { let limiter = RateLimiter::default(); limiter.add_limit(limit); - let rate_limiter = MyRateLimiter::new(Arc::new(Limiter::Blocking(limiter))); + let rate_limiter = MyRateLimiter::new( + Arc::new(Limiter::Blocking(limiter)), + RateLimitHeaders::DraftVersion03, + ); let req = RateLimitRequest { domain: namespace.to_string(), @@ -177,33 +275,42 @@ mod tests { // There's a limit of 1, so the first request should return "OK" and the // second "OverLimit". + let response = rate_limiter + .should_rate_limit(req.clone().into_request()) + .await + .unwrap() + .into_inner(); + assert_eq!(response.overall_code, i32::from(Code::Ok)); assert_eq!( - rate_limiter - .should_rate_limit(req.clone().into_request()) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::Ok) + response.response_headers_to_add, + vec![ + header_value("X-RateLimit-Limit", "1, 1;w=60"), + header_value("X-RateLimit-Remaining", "0"), + ], ); + let response = rate_limiter + .should_rate_limit(req.clone().into_request()) + .await + .unwrap() + .into_inner(); + assert_eq!(response.overall_code, i32::from(Code::OverLimit)); assert_eq!( - rate_limiter - .should_rate_limit(req.clone().into_request()) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::OverLimit) + response.response_headers_to_add, + vec![ + header_value("X-RateLimit-Limit", "1, 1;w=60"), + header_value("X-RateLimit-Remaining", "0"), + ], ); } #[tokio::test] async fn test_returns_ok_when_no_limits_apply() { // No limits saved - let rate_limiter = MyRateLimiter::new(Arc::new( - Limiter::new(Configuration::default()).await.unwrap(), - )); + let rate_limiter = MyRateLimiter::new( + Arc::new(Limiter::new(Configuration::default()).await.unwrap()), + RateLimitHeaders::DraftVersion03, + ); let req = RateLimitRequest { domain: "test_namespace".to_string(), @@ -218,22 +325,22 @@ mod tests { } .into_request(); - assert_eq!( - rate_limiter - .should_rate_limit(req) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::Ok) - ); + let response = rate_limiter + .should_rate_limit(req) + .await + .unwrap() + .into_inner(); + + assert_eq!(response.overall_code, i32::from(Code::Ok)); + assert_eq!(response.response_headers_to_add, vec![],); } #[tokio::test] async fn test_returns_unknown_when_domain_is_empty() { - let rate_limiter = MyRateLimiter::new(Arc::new( - Limiter::new(Configuration::default()).await.unwrap(), - )); + let rate_limiter = MyRateLimiter::new( + Arc::new(Limiter::new(Configuration::default()).await.unwrap()), + RateLimitHeaders::DraftVersion03, + ); let req = RateLimitRequest { domain: "".to_string(), @@ -248,15 +355,13 @@ mod tests { } .into_request(); - assert_eq!( - rate_limiter - .should_rate_limit(req) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::Unknown) - ); + let response = rate_limiter + .should_rate_limit(req) + .await + .unwrap() + .into_inner(); + assert_eq!(response.overall_code, i32::from(Code::Unknown)); + assert_eq!(response.response_headers_to_add, vec![],); } #[tokio::test] @@ -274,7 +379,10 @@ mod tests { limiter.add_limit(limit); }); - let rate_limiter = MyRateLimiter::new(Arc::new(Limiter::Blocking(limiter))); + let rate_limiter = MyRateLimiter::new( + Arc::new(Limiter::Blocking(limiter)), + RateLimitHeaders::DraftVersion03, + ); let req = RateLimitRequest { domain: namespace.to_string(), @@ -305,14 +413,19 @@ mod tests { hits_addend: 1, }; + let response = rate_limiter + .should_rate_limit(req.clone().into_request()) + .await + .unwrap() + .into_inner(); + + assert_eq!(response.overall_code, i32::from(Code::OverLimit)); assert_eq!( - rate_limiter - .should_rate_limit(req.clone().into_request()) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::OverLimit) + response.response_headers_to_add, + vec![ + header_value("X-RateLimit-Limit", "0, 0;w=60, 10;w=60"), + header_value("X-RateLimit-Remaining", "0"), + ], ); } @@ -324,7 +437,10 @@ mod tests { let limiter = RateLimiter::default(); limiter.add_limit(limit); - let rate_limiter = MyRateLimiter::new(Arc::new(Limiter::Blocking(limiter))); + let rate_limiter = MyRateLimiter::new( + Arc::new(Limiter::Blocking(limiter)), + RateLimitHeaders::DraftVersion03, + ); let req = RateLimitRequest { domain: namespace.to_string(), @@ -347,24 +463,32 @@ mod tests { // There's a limit of 10, "hits_addend" is 6, so the first request // should return "Ok" and the second "OverLimit". + let response = rate_limiter + .should_rate_limit(req.clone().into_request()) + .await + .unwrap() + .into_inner(); + assert_eq!(response.overall_code, i32::from(Code::Ok)); assert_eq!( - rate_limiter - .should_rate_limit(req.clone().into_request()) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::Ok) + response.response_headers_to_add, + vec![ + header_value("X-RateLimit-Limit", "10, 10;w=60"), + header_value("X-RateLimit-Remaining", "4"), + ], ); + let response = rate_limiter + .should_rate_limit(req.clone().into_request()) + .await + .unwrap() + .into_inner(); + assert_eq!(response.overall_code, i32::from(Code::OverLimit)); assert_eq!( - rate_limiter - .should_rate_limit(req.clone().into_request()) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::OverLimit) + response.response_headers_to_add, + vec![ + header_value("X-RateLimit-Limit", "10, 10;w=60"), + header_value("X-RateLimit-Remaining", "0"), + ], ); } @@ -378,7 +502,10 @@ mod tests { let limiter = RateLimiter::default(); limiter.add_limit(limit); - let rate_limiter = MyRateLimiter::new(Arc::new(Limiter::Blocking(limiter))); + let rate_limiter = MyRateLimiter::new( + Arc::new(Limiter::Blocking(limiter)), + RateLimitHeaders::DraftVersion03, + ); let req = RateLimitRequest { domain: namespace.to_string(), @@ -401,24 +528,32 @@ mod tests { // There's a limit of 1, and hits_addend is converted to 1, so the first // request should return "OK" and the second "OverLimit". + let response = rate_limiter + .should_rate_limit(req.clone().into_request()) + .await + .unwrap() + .into_inner(); + assert_eq!(response.overall_code, i32::from(Code::Ok)); assert_eq!( - rate_limiter - .should_rate_limit(req.clone().into_request()) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::Ok) + response.response_headers_to_add, + vec![ + header_value("X-RateLimit-Limit", "1, 1;w=60"), + header_value("X-RateLimit-Remaining", "0"), + ], ); + let response = rate_limiter + .should_rate_limit(req.clone().into_request()) + .await + .unwrap() + .into_inner(); + assert_eq!(response.overall_code, i32::from(Code::OverLimit)); assert_eq!( - rate_limiter - .should_rate_limit(req.clone().into_request()) - .await - .unwrap() - .into_inner() - .overall_code, - i32::from(Code::OverLimit) + response.response_headers_to_add, + vec![ + header_value("X-RateLimit-Limit", "1, 1;w=60"), + header_value("X-RateLimit-Remaining", "0"), + ], ); } } diff --git a/limitador-server/src/http_api/server.rs b/limitador-server/src/http_api/server.rs index eadcf755..aad0c4e8 100644 --- a/limitador-server/src/http_api/server.rs +++ b/limitador-server/src/http_api/server.rs @@ -152,18 +152,18 @@ async fn check_and_report( let namespace = namespace.into(); let rate_limited_and_update_result = match data.get_ref().as_ref() { Limiter::Blocking(limiter) => { - limiter.check_rate_limited_and_update(&namespace, &values, delta) + limiter.check_rate_limited_and_update(&namespace, &values, delta, false) } Limiter::Async(limiter) => { limiter - .check_rate_limited_and_update(&namespace, &values, delta) + .check_rate_limited_and_update(&namespace, &values, delta, false) .await } }; match rate_limited_and_update_result { Ok(is_rate_limited) => { - if is_rate_limited { + if is_rate_limited.limited { Err(ErrorResponse::TooManyRequests) } else { Ok(Json(())) diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index d0c1cdf7..14235bae 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -9,7 +9,7 @@ use crate::config::InfinispanStorageConfiguration; use crate::config::{ Configuration, RedisStorageCacheConfiguration, RedisStorageConfiguration, StorageConfiguration, }; -use crate::envoy_rls::server::run_envoy_rls_server; +use crate::envoy_rls::server::{run_envoy_rls_server, RateLimitHeaders}; use crate::http_api::server::run_http_server; use clap::{App, Arg, SubCommand}; use env_logger::Builder; @@ -276,6 +276,7 @@ async fn main() -> Result<(), Box> { let limit_file = config.limits_file.clone(); let envoy_rls_address = config.rlp_address(); let http_api_address = config.http_address(); + let rate_limit_headers = config.rate_limit_headers.clone(); let rate_limiter: Arc = match Limiter::new(config).await { Ok(limiter) => Arc::new(limiter), @@ -363,6 +364,7 @@ async fn main() -> Result<(), Box> { tokio::spawn(run_envoy_rls_server( envoy_rls_address.to_string(), rate_limiter.clone(), + rate_limit_headers, )); info!("HTTP server starting on {}", http_api_address); @@ -417,6 +419,8 @@ fn create_config() -> (Configuration, String) { let infinispan_consistency_default = env::var("INFINISPAN_COUNTERS_CONSISTENCY") .unwrap_or_else(|_| DEFAULT_INFINISPAN_CONSISTENCY.to_string()); + let rate_limit_headers_default = env::var("RATE_LIMIT_HEADERS").unwrap_or("NONE".to_string()); + // wire args based of defaults let limit_arg = Arg::with_name("LIMITS_FILE") .help("The limit file to use") @@ -498,6 +502,18 @@ fn create_config() -> (Configuration, String) { .display_order(7) .help("Validates the LIMITS_FILE and exits"), ) + .arg( + Arg::with_name("rate_limit_headers") + .long("rate-limit-headers") + .short('H') + .display_order(8) + .default_value(&rate_limit_headers_default) + .value_parser(clap::builder::PossibleValuesParser::new([ + "NONE", + "DRAFT_VERSION_03", + ])) + .help("Enables rate limit response headers"), + ) .subcommand( SubCommand::with_name("memory") .display_order(1) @@ -660,6 +676,12 @@ fn create_config() -> (Configuration, String) { _ => unreachable!("Some storage wasn't configured!"), }; + let rate_limit_headers = match matches.value_of("rate_limit_headers").unwrap() { + "NONE" => RateLimitHeaders::None, + "DRAFT_VERSION_03" => RateLimitHeaders::DraftVersion03, + _ => unreachable!("invalid --rate-limit-headers value"), + }; + let mut config = Configuration::with( storage, limits_file.to_string(), @@ -669,6 +691,7 @@ fn create_config() -> (Configuration, String) { matches.value_of("http_port").unwrap().parse().unwrap(), matches.value_of("limit_name_in_labels").is_some() || env_option_is_enabled("LIMIT_NAME_IN_PROMETHEUS_LABELS"), + rate_limit_headers, ); config.log_level = match matches.occurrences_of("v") { diff --git a/limitador/benches/bench.rs b/limitador/benches/bench.rs index 57105eac..436768e5 100644 --- a/limitador/benches/bench.rs +++ b/limitador/benches/bench.rs @@ -185,6 +185,7 @@ fn bench_check_rate_limited_and_update( ¶ms.namespace.to_owned().into(), ¶ms.values, params.delta, + false, ) .unwrap(), ) diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index 10751a5e..641b615c 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -129,7 +129,7 @@ //! // You can also check and report if not limited in a single call. It's useful //! // for example, when calling Limitador from a proxy. Instead of doing 2 //! // separate calls, we can issue just one: -//! rate_limiter.check_rate_limited_and_update(&namespace, &values_to_report, 1).unwrap(); +//! rate_limiter.check_rate_limited_and_update(&namespace, &values_to_report, 1, false).unwrap(); //! ``` //! //! # Async @@ -224,6 +224,17 @@ pub struct RateLimiterBuilder { prometheus_limit_name_labels_enabled: bool, } +pub struct CheckResult { + pub limited: bool, + pub counters: Vec, +} + +impl From for bool { + fn from(value: CheckResult) -> Self { + value.limited + } +} + impl RateLimiterBuilder { pub fn new() -> Self { Self { @@ -375,27 +386,43 @@ impl RateLimiter { namespace: &Namespace, values: &HashMap, delta: i64, - ) -> Result { - let counters = self.counters_that_apply(namespace, values)?; + load_counters: bool, + ) -> Result { + let mut counters = self.counters_that_apply(namespace, values)?; if counters.is_empty() { self.prometheus_metrics.incr_authorized_calls(namespace); - return Ok(false); + return Ok(CheckResult { + limited: false, + counters, + }); } let check_result = self .storage - .check_and_update(counters.into_iter().collect(), delta)?; + .check_and_update(&mut counters, delta, load_counters)?; + + let counters = if load_counters { + counters + } else { + Vec::default() + }; match check_result { Authorization::Ok => { self.prometheus_metrics.incr_authorized_calls(namespace); - Ok(false) + Ok(CheckResult { + limited: false, + counters, + }) } Authorization::Limited(name) => { self.prometheus_metrics .incr_limited_calls(namespace, name.as_deref()); - Ok(true) + Ok(CheckResult { + limited: true, + counters, + }) } } } @@ -551,29 +578,46 @@ impl AsyncRateLimiter { namespace: &Namespace, values: &HashMap, delta: i64, - ) -> Result { + load_counters: bool, + ) -> Result { // the above where-clause is needed in order to call unwrap(). - let counters = self.counters_that_apply(namespace, values).await?; + let mut counters = self.counters_that_apply(namespace, values).await?; if counters.is_empty() { self.prometheus_metrics.incr_authorized_calls(namespace); - return Ok(false); + return Ok(CheckResult { + limited: false, + counters, + }); } let check_result = self .storage - .check_and_update(counters.into_iter().collect(), delta) + .check_and_update(&mut counters, delta, load_counters) .await?; + let counters = if load_counters { + counters + } else { + Vec::default() + }; + match check_result { Authorization::Ok => { self.prometheus_metrics.incr_authorized_calls(namespace); - Ok(false) + + Ok(CheckResult { + limited: false, + counters, + }) } Authorization::Limited(name) => { self.prometheus_metrics .incr_limited_calls(namespace, name.as_deref()); - Ok(true) + Ok(CheckResult { + limited: true, + counters, + }) } } } diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index de2a96f3..86176a06 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -31,22 +31,40 @@ impl CounterStorage for InMemoryStorage { fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result { // This makes the operator of check + update atomic let mut stored_counters = self.counters.write().unwrap(); - for counter in counters.iter() { - if !Self::counter_is_within_limits(counter, stored_counters.get(counter), delta) { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); + if load_counters { + let mut first_limited = None; + for counter in counters.iter_mut() { + let remaining = + *stored_counters.get(counter).unwrap_or(&counter.max_value()) - delta; + counter.set_remaining(remaining); + if first_limited.is_none() && remaining < 0 { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )) + } + } + if let Some(l) = first_limited { + return Ok(l); + } + } else { + for counter in counters.iter() { + if !Self::counter_is_within_limits(counter, stored_counters.get(counter), delta) { + return Ok(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } } } for counter in counters { - self.insert_or_update_counter(&mut stored_counters, &counter, delta) + self.insert_or_update_counter(&mut stored_counters, counter, delta) } Ok(Authorization::Ok) diff --git a/limitador/src/storage/infinispan/infinispan_storage.rs b/limitador/src/storage/infinispan/infinispan_storage.rs index 33d30ebe..0626dee8 100644 --- a/limitador/src/storage/infinispan/infinispan_storage.rs +++ b/limitador/src/storage/infinispan/infinispan_storage.rs @@ -67,20 +67,42 @@ impl AsyncCounterStorage for InfinispanStorage { async fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result { - for counter in counters.iter() { - if !self.is_within_limits(counter, delta).await? { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); + if load_counters { + let mut first_limited = None; + for counter in counters.iter_mut() { + let counter_key = key_for_counter(counter); + let counter_val = + counters::get_value(&self.infinispan, &self.cache_name, &counter_key).await?; + + let remaining = counter_val.unwrap_or(counter.max_value()) - delta; + counter.set_remaining(remaining); + + if first_limited.is_none() && remaining < 0 { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )) + } + } + if let Some(l) = first_limited { + return Ok(l); + } + } else { + for counter in counters.iter() { + if !self.is_within_limits(counter, delta).await? { + return Ok(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } } } // Update only if all are withing limits for counter in counters { - self.update_counter(&counter, delta).await? + self.update_counter(counter, delta).await? } Ok(Authorization::Ok) diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index cc3fd3c0..e9751e79 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -121,10 +121,12 @@ impl Storage { pub fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result { - self.counters.check_and_update(counters, delta) + self.counters + .check_and_update(counters, delta, load_counters) } pub fn get_counters(&self, namespace: &Namespace) -> Result, StorageErr> { @@ -239,10 +241,13 @@ impl AsyncStorage { pub async fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result { - self.counters.check_and_update(counters, delta).await + self.counters + .check_and_update(counters, delta, load_counters) + .await } pub async fn get_counters( @@ -264,8 +269,9 @@ pub trait CounterStorage: Sync + Send { fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr>; fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result; fn get_counters(&self, limits: &HashSet) -> Result, StorageErr>; fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr>; @@ -278,8 +284,9 @@ pub trait AsyncCounterStorage: Sync + Send { async fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr>; async fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result; async fn get_counters(&self, limits: HashSet) -> Result, StorageErr>; async fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr>; diff --git a/limitador/src/storage/redis/mod.rs b/limitador/src/storage/redis/mod.rs index b9d89e69..2f692712 100644 --- a/limitador/src/storage/redis/mod.rs +++ b/limitador/src/storage/redis/mod.rs @@ -1,4 +1,5 @@ use ::redis::RedisError; +use std::time::Duration; mod batcher; mod counters_cache; @@ -12,7 +13,8 @@ pub const DEFAULT_MAX_CACHED_COUNTERS: usize = 10000; pub const DEFAULT_MAX_TTL_CACHED_COUNTERS_SEC: u64 = 5; pub const DEFAULT_TTL_RATIO_CACHED_COUNTERS: u64 = 10; -use crate::storage::StorageErr; +use crate::counter::Counter; +use crate::storage::{Authorization, StorageErr}; pub use redis_async::AsyncRedisStorage; pub use redis_cached::CachedRedisStorage; pub use redis_cached::CachedRedisStorageBuilder; @@ -29,3 +31,35 @@ impl From<::r2d2::Error> for StorageErr { Self { msg: e.to_string() } } } + +pub fn is_limited( + counters: &mut [Counter], + delta: i64, + script_res: Vec>, +) -> Option { + let mut counter_vals: Vec> = vec![]; + let mut counter_ttls_msecs: Vec> = vec![]; + + for val_ttl_pair in script_res.chunks(2) { + counter_vals.push(val_ttl_pair[0]); + counter_ttls_msecs.push(val_ttl_pair[1]); + } + + let mut first_limited = None; + for (i, counter) in counters.iter_mut().enumerate() { + let remaining = counter_vals[i].unwrap_or(counter.max_value()) - delta; + counter.set_remaining(remaining); + let expires_in = Duration::from_secs( + counter_ttls_msecs[i] + .map(|x| x as u64) + .unwrap_or(counter.seconds()), + ); + counter.set_expires_in(expires_in); + if first_limited.is_none() && remaining < 0 { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )) + } + } + first_limited +} diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index 029c0ff1..8163e4d4 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -5,7 +5,8 @@ use self::redis::ConnectionInfo; use crate::counter::Counter; use crate::limit::Limit; use crate::storage::keys::*; -use crate::storage::redis::scripts::SCRIPT_UPDATE_COUNTER; +use crate::storage::redis::is_limited; +use crate::storage::redis::scripts::{SCRIPT_UPDATE_COUNTER, VALUES_AND_TTLS}; use crate::storage::{AsyncCounterStorage, Authorization, StorageErr}; use async_trait::async_trait; use redis::{AsyncCommands, RedisError}; @@ -58,40 +59,44 @@ impl AsyncCounterStorage for AsyncRedisStorage { async fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result { let mut con = self.conn_manager.clone(); - let counter_keys: Vec = counters.iter().map(key_for_counter).collect(); - let counter_vals: Vec> = redis::cmd("MGET") - .arg(counter_keys) - .query_async(&mut con) - .await?; + if load_counters { + let script = redis::Script::new(VALUES_AND_TTLS); + let mut script_invocation = script.prepare_invoke(); - for (i, counter) in counters.iter().enumerate() { - match counter_vals[i] { - Some(val) => { - if val - delta < 0 { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } - } - None => { - if counter.max_value() - delta < 0 { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } + for counter_key in counter_keys { + script_invocation.key(counter_key); + } + + let script_res: Vec> = script_invocation.invoke_async(&mut con).await?; + if let Some(res) = is_limited(counters, delta, script_res) { + return Ok(res); + } + } else { + let counter_vals: Vec> = redis::cmd("MGET") + .arg(counter_keys) + .query_async(&mut con) + .await?; + + for (i, counter) in counters.iter().enumerate() { + let remaining = counter_vals[i].unwrap_or(counter.max_value()) - delta; + if remaining < 0 { + return Ok(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); } } } // TODO: this can be optimized by using pipelines with multiple updates for counter in counters { - self.update_counter(&counter, delta).await? + self.update_counter(counter, delta).await? } Ok(Authorization::Ok) diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 361c7483..34e05f3c 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -65,23 +65,34 @@ impl AsyncCounterStorage for CachedRedisStorage { // This function trades accuracy for speed. async fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result { let mut con = self.redis_conn_manager.clone(); - let mut not_cached: Vec<&Counter> = vec![]; + let mut not_cached: Vec<&mut Counter> = vec![]; + let mut first_limited = None; // Check cached counters { let cached_counters = self.cached_counters.lock().await; - for counter in counters.iter() { + for counter in counters.iter_mut() { match cached_counters.get(counter) { Some(val) => { - if val - delta < 0 { - return Ok(Authorization::Limited( + if first_limited.is_none() && val - delta < 0 { + let a = Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), - )); + ); + if !load_counters { + return Ok(a); + } + first_limited = Some(a); + } + if load_counters { + counter.set_remaining(val); + // todo: how do we get the ttl for this entry? + // counter.set_expires_in(Duration::from_secs(counter.seconds())); } } None => { @@ -107,36 +118,31 @@ impl AsyncCounterStorage for CachedRedisStorage { { let mut cached_counters = self.cached_counters.lock().await; - for (i, &counter) in not_cached.iter().enumerate() { + for (i, counter) in not_cached.iter_mut().enumerate() { cached_counters.insert( counter.clone(), counter_vals[i], counter_ttls_msecs[i], ttl_margin, ); - } - } - - for (i, counter) in not_cached.into_iter().enumerate() { - match counter_vals[i] { - Some(val) => { - if val - delta < 0 { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } + let remaining = counter_vals[i].unwrap_or(counter.max_value()) - delta; + if first_limited.is_none() && remaining < 0 { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); } - None => { - if counter.max_value() - delta < 0 { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } + if load_counters { + counter.set_remaining(remaining); + counter.set_expires_in(Duration::from_millis(counter_ttls_msecs[i] as u64)); } } } } + if let Some(l) = first_limited { + return Ok(l); + } + // Update cached values { let mut cached_counters = self.cached_counters.lock().await; @@ -148,12 +154,12 @@ impl AsyncCounterStorage for CachedRedisStorage { // Batch or update depending on configuration if self.batching_is_enabled { let batcher = self.batcher_counter_updates.lock().await; - for counter in counters { - batcher.add_counter(&counter, delta).await + for counter in counters.iter() { + batcher.add_counter(counter, delta).await } } else { - for counter in counters { - self.update_counter(&counter, delta).await? + for counter in counters.iter() { + self.update_counter(counter, delta).await? } } @@ -233,7 +239,7 @@ impl CachedRedisStorage { } async fn values_with_ttls( - counters: &[&Counter], + counters: &[&mut Counter], redis_con: &mut ConnectionManager, ) -> Result<(Vec>, Vec), StorageErr> { let counter_keys: Vec = counters diff --git a/limitador/src/storage/redis/redis_sync.rs b/limitador/src/storage/redis/redis_sync.rs index 9d62c608..ce0de530 100644 --- a/limitador/src/storage/redis/redis_sync.rs +++ b/limitador/src/storage/redis/redis_sync.rs @@ -4,7 +4,8 @@ use self::redis::{Commands, ConnectionInfo, ConnectionLike, IntoConnectionInfo, use crate::counter::Counter; use crate::limit::Limit; use crate::storage::keys::*; -use crate::storage::redis::scripts::SCRIPT_UPDATE_COUNTER; +use crate::storage::redis::is_limited; +use crate::storage::redis::scripts::{SCRIPT_UPDATE_COUNTER, VALUES_AND_TTLS}; use crate::storage::{Authorization, CounterStorage, StorageErr}; use r2d2::{ManageConnection, Pool}; use std::collections::HashSet; @@ -47,31 +48,34 @@ impl CounterStorage for RedisStorage { fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result { let mut con = self.conn_pool.get()?; - let counter_keys: Vec = counters.iter().map(key_for_counter).collect(); - let counter_vals: Vec> = - redis::cmd("MGET").arg(counter_keys).query(&mut *con)?; - - for (i, counter) in counters.iter().enumerate() { - match counter_vals[i] { - Some(val) => { - if val - delta < 0 { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } - } - None => { - if counter.max_value() - delta < 0 { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } + if load_counters { + let script = redis::Script::new(VALUES_AND_TTLS); + let mut script_invocation = script.prepare_invoke(); + for counter_key in counter_keys { + script_invocation.key(counter_key); + } + let script_res: Vec> = script_invocation.invoke(&mut *con)?; + + if let Some(res) = is_limited(counters, delta, script_res) { + return Ok(res); + } + } else { + let counter_vals: Vec> = + redis::cmd("MGET").arg(counter_keys).query(&mut *con)?; + + for (i, counter) in counters.iter().enumerate() { + let remaining = counter_vals[i].unwrap_or(counter.max_value()) - delta; + if remaining < 0 { + return Ok(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); } } } diff --git a/limitador/src/storage/wasm.rs b/limitador/src/storage/wasm.rs index d6c15c80..1d4237c1 100644 --- a/limitador/src/storage/wasm.rs +++ b/limitador/src/storage/wasm.rs @@ -93,22 +93,53 @@ impl CounterStorage for WasmStorage { fn check_and_update( &self, - counters: HashSet, + counters: &mut Vec, delta: i64, + load_counters: bool, ) -> Result { // This makes the operator of check + update atomic let mut stored_counters = self.counters.write().unwrap(); - for counter in counters.iter() { - if !self.counter_is_within_limits(counter, stored_counters.get(counter), delta) { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); + if load_counters { + let mut first_limited = None; + for counter in counters.iter_mut() { + let (remaining, expires_in) = match stored_counters.get(counter) { + Some(entry) => ( + entry.value - delta, + entry + .expires_at + .duration_since(self.clock.get_current_time()) + .unwrap_or(Duration::from_secs(0)), + ), + None => ( + counter.max_value() - delta, + Duration::from_secs(counter.seconds()), + ), + }; + counter.set_remaining(remaining); + counter.set_expires_in(expires_in); + + if first_limited.is_none() && remaining < 0 { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )) + } + } + if let Some(l) = first_limited { + return Ok(l); + } + } else { + for counter in counters.iter() { + if !self.counter_is_within_limits(counter, stored_counters.get(counter), delta) { + return Ok(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } } } - for counter in counters { - self.insert_or_update_counter(&mut stored_counters, &counter, delta) + for counter in counters.iter() { + self.insert_or_update_counter(&mut stored_counters, counter, delta) } Ok(Authorization::Ok) diff --git a/limitador/tests/helpers/tests_limiter.rs b/limitador/tests/helpers/tests_limiter.rs index f8cc298b..62be9933 100644 --- a/limitador/tests/helpers/tests_limiter.rs +++ b/limitador/tests/helpers/tests_limiter.rs @@ -111,14 +111,15 @@ impl TestsLimiter { ) -> Result { match &self.limiter_impl { LimiterImpl::Blocking(limiter) => { - limiter.check_rate_limited_and_update(&namespace.into(), values, delta) + limiter.check_rate_limited_and_update(&namespace.into(), values, delta, false) } LimiterImpl::Async(limiter) => { limiter - .check_rate_limited_and_update(&namespace.into(), values, delta) + .check_rate_limited_and_update(&namespace.into(), values, delta, false) .await } } + .map(|cr| cr.into()) } pub async fn get_counters(&self, namespace: &str) -> Result, LimitadorError> {