diff --git a/Cargo.lock b/Cargo.lock index ab170a96..3b5909bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -305,6 +305,15 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" +[[package]] +name = "ascii-canvas" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef1e3e699d84ab1b0911a1010c5c106aa34ae89aeac103be5ce0c3859db1e891" +dependencies = [ + "term", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -498,6 +507,21 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitflags" version = "1.3.2" @@ -595,6 +619,33 @@ dependencies = [ "shlex", ] +[[package]] +name = "cel-interpreter" +version = "0.9.0" +source = "git+https://github.com/clarkmcc/cel-rust?rev=5b02b08#5b02b0817ced05c7cfc1c72bab03bc97bbfa2dea" +dependencies = [ + "base64 0.22.1", + "cel-parser", + "chrono", + "nom", + "paste", + "regex", + "serde", + "serde_json", + "thiserror 1.0.69", +] + +[[package]] +name = "cel-parser" +version = "0.8.0" +source = "git+https://github.com/clarkmcc/cel-rust?rev=5b02b08#5b02b0817ced05c7cfc1c72bab03bc97bbfa2dea" +dependencies = [ + "lalrpop", + "lalrpop-util", + "regex", + "thiserror 1.0.69", +] + [[package]] name = "cexpr" version = "0.6.0" @@ -950,6 +1001,15 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" +[[package]] +name = "ena" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d248bdd43ce613d87415282f69b9bb99d947d290b10962dd6c56233312c2ad5" +dependencies = [ + "log", +] + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1286,6 +1346,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "http" version = "0.2.12" @@ -1727,6 +1796,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + [[package]] name = "kqueue" version = "1.0.8" @@ -1747,6 +1825,38 @@ dependencies = [ "libc", ] +[[package]] +name = "lalrpop" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06093b57658c723a21da679530e061a8c25340fa5a6f98e313b542268c7e2a1f" +dependencies = [ + "ascii-canvas", + "bit-set", + "ena", + "itertools 0.13.0", + "lalrpop-util", + "petgraph", + "pico-args", + "regex", + "regex-syntax 0.8.5", + "sha3", + "string_cache", + "term", + "unicode-xid", + "walkdir", +] + +[[package]] +name = "lalrpop-util" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feee752d43abd0f4807a921958ab4131f692a44d4d599733d4419c5d586176ce" +dependencies = [ + "regex-automata 0.4.9", + "rustversion", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -1825,6 +1935,8 @@ version = "0.8.0-dev" dependencies = [ "async-trait", "base64 0.22.1", + "cel-interpreter", + "cel-parser", "cfg-if", "criterion", "dashmap", @@ -2109,6 +2221,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + [[package]] name = "nom" version = "7.1.3" @@ -2507,6 +2625,21 @@ dependencies = [ "indexmap 2.6.0", ] +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pico-args" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315" + [[package]] name = "pin-project" version = "1.1.7" @@ -2607,6 +2740,12 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "prettyplease" version = "0.2.25" @@ -3146,6 +3285,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -3170,6 +3319,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "sketches-ddsketch" version = "0.2.2" @@ -3216,6 +3371,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared", + "precomputed-hash", +] + [[package]] name = "strsim" version = "0.11.1" @@ -3319,6 +3487,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "term" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4df4175de05129f31b80458c6df371a15e7fc3fd367272e6bf938e5c351c7ea0" +dependencies = [ + "home", + "windows-sys 0.52.0", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/doc/migrations/conditions.md b/doc/migrations/conditions.md index 29fd1a8b..ffd4bb49 100644 --- a/doc/migrations/conditions.md +++ b/doc/migrations/conditions.md @@ -23,22 +23,3 @@ case `foo` was the identifier of the variable, while `bar` was the value to eval after the operator `==` would be equally important. SO that `foo == bar` would test for a `foo ` variable being equal to ` bar` where the trailing whitespace after the identifier, and the one prefixing the value, would have been evaluated. - -## Server binary users - -The server still allows for the deprecated syntax, but warns about its usage. You can easily migrate your limits file, -using the following command: - -```commandline -limitador-server --validate old_limits.yaml > updated_limits.yaml -``` - -Which should output `Deprecated syntax for conditions corrected!` to `stderr` while `stdout` would be the limits using -the new syntax. It is recommended you manually verify the resulting `LIMITS_FILE`. - - -## Crate users - -A feature `lenient_conditions` has been added, which lets you use the syntax used in previous version of the crate. -The function `limitador::limit::check_deprecated_syntax_usages_and_reset()` lets you verify if the deprecated syntax -has been used as `limit::Limit`s are created with their condition strings using the deprecated syntax. diff --git a/doc/server/configuration.md b/doc/server/configuration.md index b3374a06..3e1bc519 100644 --- a/doc/server/configuration.md +++ b/doc/server/configuration.md @@ -94,7 +94,7 @@ Here is an example of such a limit definition: max_value: 10 seconds: 60 conditions: - - "req.method == 'GET'" + - "req_method == 'GET'" variables: - user_id ``` diff --git a/limitador-server/Cargo.toml b/limitador-server/Cargo.toml index 6ed286c9..5608e939 100644 --- a/limitador-server/Cargo.toml +++ b/limitador-server/Cargo.toml @@ -18,7 +18,7 @@ distributed_storage = ["limitador/distributed_storage"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -limitador = { path = "../limitador", features = ['lenient_conditions'] } +limitador = { path = "../limitador" } tokio = { version = "1", features = ["full"] } thiserror = "2" tonic = "0.12.3" diff --git a/limitador-server/examples/envoy.yaml b/limitador-server/examples/envoy.yaml index 6cb44a45..1a7d0eda 100644 --- a/limitador-server/examples/envoy.yaml +++ b/limitador-server/examples/envoy.yaml @@ -28,8 +28,8 @@ static_resources: rate_limits: - stage: 0 actions: - - {request_headers: {header_name: "userid", descriptor_key: "user_id"}} - - {request_headers: {header_name: ":method", descriptor_key: "req.method"}} + - { request_headers: { header_name: "userid", descriptor_key: "user_id" } } + - { request_headers: { header_name: ":method", descriptor_key: "descriptors[0]['method']" } } http_filters: - name: envoy.filters.http.ratelimit typed_config: diff --git a/limitador-server/examples/limits.yaml b/limitador-server/examples/limits.yaml index afcb2b50..c47aedf4 100644 --- a/limitador-server/examples/limits.yaml +++ b/limitador-server/examples/limits.yaml @@ -10,6 +10,6 @@ max_value: 5 seconds: 60 conditions: - - "req.method == 'POST'" + - "descriptors[0]['req.method'] == 'POST'" variables: - user_id diff --git a/limitador-server/sandbox/limits.yaml b/limitador-server/sandbox/limits.yaml index 68b6df8d..cb354bf2 100644 --- a/limitador-server/sandbox/limits.yaml +++ b/limitador-server/sandbox/limits.yaml @@ -3,20 +3,20 @@ max_value: 10 seconds: 60 conditions: - - "req.method == 'GET'" - - "req.path != '/json'" + - "descriptors[0]['req.method'] == 'GET'" + - "descriptors[0]['req.path'] != '/json'" variables: [] - namespace: test_namespace max_value: 5 seconds: 60 conditions: - - "req.method == 'POST'" - - "req.path != '/json'" + - "descriptors[0]['req.method'] == 'POST'" + - "descriptors[0]['req.path'] != '/json'" variables: [] - namespace: test_namespace max_value: 50000 seconds: 10 conditions: - - "req.method == 'GET'" - - "req.path == '/json'" + - "descriptors[0]['req.method'] == 'GET'" + - "descriptors[0]['req.path'] == '/json'" variables: [] diff --git a/limitador-server/src/envoy_rls/server.rs b/limitador-server/src/envoy_rls/server.rs index 4eef9151..0528f6b7 100644 --- a/limitador-server/src/envoy_rls/server.rs +++ b/limitador-server/src/envoy_rls/server.rs @@ -3,12 +3,6 @@ use opentelemetry::propagation::Extractor; use std::collections::HashMap; use std::sync::Arc; -use limitador::CheckResult; -use tonic::codegen::http::HeaderMap; -use tonic::{transport, transport::Server, Request, Response, Status}; -use tracing::Span; -use tracing_opentelemetry::OpenTelemetrySpanExt; - use crate::envoy_rls::server::envoy::config::core::v3::HeaderValue; use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_response::Code; use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_service_server::{ @@ -19,6 +13,12 @@ use crate::envoy_rls::server::envoy::service::ratelimit::v3::{ }; use crate::prometheus_metrics::PrometheusMetrics; use crate::Limiter; +use limitador::limit::Context; +use limitador::CheckResult; +use tonic::codegen::http::HeaderMap; +use tonic::{transport, transport::Server, Request, Response, Status}; +use tracing::Span; +use tracing_opentelemetry::OpenTelemetrySpanExt; include!("envoy_types.rs"); @@ -72,7 +72,7 @@ impl RateLimitService for MyRateLimiter { ) -> Result, Status> { debug!("Request received: {:?}", request); - let mut values: HashMap = HashMap::new(); + let mut values: Vec> = Vec::default(); let (metadata, _ext, req) = request.into_parts(); let namespace = req.domain; let rl_headers = RateLimitRequestHeaders::new(metadata.into_headers()); @@ -96,9 +96,11 @@ impl RateLimitService for MyRateLimiter { let namespace = namespace.into(); for descriptor in &req.descriptors { + let mut map = HashMap::default(); for entry in &descriptor.entries { - values.insert(entry.key.clone(), entry.value.clone()); + map.insert(entry.key.clone(), entry.value.clone()); } + values.push(map); } // "hits_addend" is optional according to the spec, and should default @@ -109,10 +111,13 @@ impl RateLimitService for MyRateLimiter { req.hits_addend }; + let mut ctx = Context::default(); + ctx.list_binding("descriptors".to_string(), values); + let rate_limited_resp = match &*self.limiter { Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update( &namespace, - &values, + &ctx, u64::from(hits_addend), self.rate_limit_headers != RateLimitHeaders::None, ), @@ -120,7 +125,7 @@ impl RateLimitService for MyRateLimiter { limiter .check_rate_limited_and_update( &namespace, - &values, + &ctx, u64::from(hits_addend), self.rate_limit_headers != RateLimitHeaders::None, ) @@ -253,10 +258,13 @@ mod tests { namespace, 1, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["descriptors[0]['req.method'] == 'GET'" + .try_into() + .expect("failed parsing!")], + vec!["descriptors[0]['app.id']" + .try_into() + .expect("failed parsing!")], + ); let limiter = RateLimiter::new(10_000); limiter.add_limit(limit); @@ -279,7 +287,7 @@ mod tests { value: "GET".to_string(), }, Entry { - key: "app_id".to_string(), + key: "app.id".to_string(), value: "1".to_string(), }, ], @@ -395,10 +403,29 @@ mod tests { let namespace = "test_namespace"; vec![ - Limit::new(namespace, 10, 60, vec!["x == '1'"], vec!["z"]) - .expect("This must be a valid limit!"), - Limit::new(namespace, 0, 60, vec!["x == '1'", "y == '2'"], vec!["z"]) - .expect("This must be a valid limit!"), + Limit::new( + namespace, + 10, + 60, + vec!["descriptors[0].x == '1'" + .try_into() + .expect("failed parsing!")], + vec!["descriptors[0].z".try_into().expect("failed parsing!")], + ), + Limit::new( + namespace, + 0, + 60, + vec![ + "descriptors[0].x == '1'" + .try_into() + .expect("failed parsing!"), + "descriptors[1].y == '2'" + .try_into() + .expect("failed parsing!"), + ], + vec!["descriptors[0].z".try_into().expect("failed parsing!")], + ), ] .into_iter() .for_each(|limit| { @@ -462,8 +489,15 @@ mod tests { #[tokio::test] async fn test_takes_into_account_the_hits_addend_param() { let namespace = "test_namespace"; - let limit = Limit::new(namespace, 10, 60, vec!["x == '1'"], vec!["y"]) - .expect("This must be a valid limit!"); + let limit = Limit::new( + namespace, + 10, + 60, + vec!["descriptors[0].x == '1'" + .try_into() + .expect("failed parsing!")], + vec!["descriptors[0].y".try_into().expect("failed parsing!")], + ); let limiter = RateLimiter::new(10_000); limiter.add_limit(limit); @@ -532,8 +566,15 @@ mod tests { // "hits_addend" is optional according to the spec, and should default // to 1, However, with the autogenerated structs it defaults to 0. let namespace = "test_namespace"; - let limit = Limit::new(namespace, 1, 60, vec!["x == '1'"], vec!["y"]) - .expect("This must be a valid limit!"); + let limit = Limit::new( + namespace, + 1, + 60, + vec!["descriptors[0].x == '1'" + .try_into() + .expect("failed parsing!")], + vec!["descriptors[0].y".try_into().expect("failed parsing!")], + ); let limiter = RateLimiter::new(10_000); limiter.add_limit(limit); diff --git a/limitador-server/src/http_api/request_types.rs b/limitador-server/src/http_api/request_types.rs index 1cae899a..6939ba38 100644 --- a/limitador-server/src/http_api/request_types.rs +++ b/limitador-server/src/http_api/request_types.rs @@ -1,6 +1,5 @@ use limitador::counter::Counter as LimitadorCounter; -use limitador::errors::LimitadorError; -use limitador::limit::Limit as LimitadorLimit; +use limitador::limit::{Expression, Limit as LimitadorLimit, ParseError, Predicate}; use paperclip::actix::Apiv2Schema; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; @@ -42,26 +41,31 @@ impl From<&LimitadorLimit> for Limit { } impl TryFrom for LimitadorLimit { - type Error = LimitadorError; + type Error = ParseError; fn try_from(limit: Limit) -> Result { + let conditions: Result, ParseError> = + limit.conditions.into_iter().map(|p| p.try_into()).collect(); + let variables: Result, ParseError> = + limit.variables.into_iter().map(|v| v.try_into()).collect(); + let mut limitador_limit = if let Some(id) = limit.id { Self::with_id( id, limit.namespace, limit.max_value, limit.seconds, - limit.conditions, - limit.variables, - )? + conditions?, + variables?, + ) } else { Self::new( limit.namespace, limit.max_value, limit.seconds, - limit.conditions, - limit.variables, - )? + conditions?, + variables?, + ) }; if let Some(name) = limit.name { diff --git a/limitador-server/src/http_api/server.rs b/limitador-server/src/http_api/server.rs index 25a1cd76..c9085502 100644 --- a/limitador-server/src/http_api/server.rs +++ b/limitador-server/src/http_api/server.rs @@ -3,6 +3,7 @@ use crate::prometheus_metrics::PrometheusMetrics; use crate::Limiter; use actix_web::{http::StatusCode, HttpResponse, HttpResponseBuilder, ResponseError}; use actix_web::{App, HttpServer}; +use limitador::limit::Context; use limitador::CheckResult; use paperclip::actix::{ api_v2_errors, @@ -122,9 +123,11 @@ async fn check( response_headers: _, } = request.into_inner(); let namespace = namespace.into(); + let mut ctx = Context::default(); + ctx.list_binding("descriptors".to_string(), vec![values]); let is_rate_limited_result = match state.get_ref().limiter() { - Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &values, delta), - Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &values, delta).await, + Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &ctx, delta), + Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &ctx, delta).await, }; match is_rate_limited_result { @@ -152,9 +155,11 @@ async fn report( response_headers: _, } = request.into_inner(); let namespace = namespace.into(); + let mut ctx = Context::default(); + ctx.list_binding("descriptors".to_string(), vec![values]); let update_counters_result = match data.get_ref().limiter() { - Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &values, delta), - Limiter::Async(limiter) => limiter.update_counters(&namespace, &values, delta).await, + Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &ctx, delta), + Limiter::Async(limiter) => limiter.update_counters(&namespace, &ctx, delta).await, }; match update_counters_result { @@ -176,22 +181,19 @@ async fn check_and_report( response_headers, } = request.into_inner(); let namespace = namespace.into(); + let mut ctx = Context::default(); + ctx.list_binding("descriptors".to_string(), vec![values]); let rate_limit_data = data.get_ref(); let rate_limited_and_update_result = match rate_limit_data.limiter() { Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update( &namespace, - &values, + &ctx, delta, response_headers.is_some(), ), Limiter::Async(limiter) => { limiter - .check_rate_limited_and_update( - &namespace, - &values, - delta, - response_headers.is_some(), - ) + .check_rate_limited_and_update(&namespace, &ctx, delta, response_headers.is_some()) .await } }; @@ -385,7 +387,7 @@ mod tests { // Prepare values to check let mut values = HashMap::new(); values.insert("req.method".into(), "GET".into()); - values.insert("app_id".into(), "1".into()); + values.insert("req.id".into(), "1".into()); let info = CheckAndReportInfo { namespace: namespace.into(), values, @@ -436,7 +438,7 @@ mod tests { // Prepare values to check let mut values = HashMap::new(); values.insert("req.method".into(), "GET".into()); - values.insert("app_id".into(), "1".into()); + values.insert("app.id".into(), "1".into()); let info = CheckAndReportInfo { namespace: namespace.into(), values, @@ -511,7 +513,7 @@ mod tests { // Prepare values to check let mut values = HashMap::new(); values.insert("req.method".into(), "GET".into()); - values.insert("app_id".into(), "1".into()); + values.insert("app.id".into(), "1".into()); let info = CheckAndReportInfo { namespace: namespace.into(), values, @@ -553,10 +555,13 @@ mod tests { namespace, max, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["descriptors[0]['req.method'] == 'GET'" + .try_into() + .expect("failed parsing!")], + vec!["descriptors[0]['app.id']" + .try_into() + .expect("failed parsing!")], + ); match &limiter { Limiter::Blocking(limiter) => limiter.add_limit(limit.clone()), diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index ec12e420..3be5c0ee 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -194,9 +194,6 @@ impl Limiter { Self::Blocking(limiter) => limiter.configure_with(limits)?, Self::Async(limiter) => limiter.configure_with(limits).await?, } - if limitador::limit::check_deprecated_syntax_usages_and_reset() { - error!("You are using deprecated syntax for your conditions! See the migration guide https://docs.kuadrant.io/limitador/doc/migrations/conditions/") - } Ok(()) } Err(e) => Err(LimitadorServerError::ConfigFile(format!( @@ -597,10 +594,6 @@ fn create_config() -> (Configuration, &'static str) { let parsed_limits: Result, _> = serde_yaml::from_reader(f); match parsed_limits { Ok(limits) => { - if limitador::limit::check_deprecated_syntax_usages_and_reset() { - eprintln!("Deprecated syntax for conditions corrected!\n") - } - let output: Vec = limits.iter().map(|l| l.into()).collect(); match serde_yaml::to_string(&output) { diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index d9240b5f..ed822b1d 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -17,7 +17,6 @@ default = ["disk_storage", "redis_storage"] disk_storage = ["rocksdb"] distributed_storage = ["tokio", "tokio-stream", "h2", "base64", "uuid", "tonic", "tonic-reflection", "prost", "prost-types"] redis_storage = ["redis", "r2d2", "tokio"] -lenient_conditions = [] [dependencies] moka = { version = "0.12", features = ["sync"] } @@ -54,6 +53,8 @@ tonic = { version = "0.12.3", optional = true } tonic-reflection = { version = "0.12.3", optional = true } prost = { version = "0.13.3", optional = true } prost-types = { version = "0.13.3", optional = true } +cel-interpreter = { git = "https://github.com/clarkmcc/cel-rust", rev = "5b02b08", features = ["json", "regex", "chrono"] } +cel-parser = { git = "https://github.com/clarkmcc/cel-rust", rev = "5b02b08" } [dev-dependencies] serial_test = "3.0" diff --git a/limitador/README.md b/limitador/README.md index fd8f6f95..305f9d45 100644 --- a/limitador/README.md +++ b/limitador/README.md @@ -11,5 +11,4 @@ For the complete documentation of the crate's API, please refer to [docs.rs](htt * `redis_storage`: support for using Redis as the data storage backend. * `disk_storage`: support for using RocksDB as a local disk storage backend. -* `lenient_conditions`: support for the deprecated syntax of `Condition`s * `default`: `redis_storage`. diff --git a/limitador/benches/bench.rs b/limitador/benches/bench.rs index 5dcb3677..fed3fec8 100644 --- a/limitador/benches/bench.rs +++ b/limitador/benches/bench.rs @@ -7,7 +7,7 @@ use criterion::{black_box, criterion_group, criterion_main, Bencher, BenchmarkId use rand::seq::SliceRandom; use rand::SeedableRng; -use limitador::limit::Limit; +use limitador::limit::{Context, Limit}; #[cfg(feature = "disk_storage")] use limitador::storage::disk::{DiskStorage, OptimizeFor}; #[cfg(feature = "distributed_storage")] @@ -89,9 +89,9 @@ const TEST_SCENARIOS: &[&TestScenario] = &[ }, ]; -struct TestCallParams { +struct TestCallParams<'a> { namespace: String, - values: HashMap, + ctx: Context<'a>, delta: u64, } @@ -329,7 +329,7 @@ fn bench_is_rate_limited( rate_limiter .is_rate_limited( ¶ms.namespace.to_owned().into(), - ¶ms.values, + ¶ms.ctx, params.delta, ) .unwrap(), @@ -357,7 +357,7 @@ fn async_bench_is_rate_limited( rate_limiter .is_rate_limited( ¶ms.namespace.to_owned().into(), - ¶ms.values, + ¶ms.ctx, params.delta, ) .await @@ -383,7 +383,7 @@ fn bench_update_counters( rate_limiter .update_counters( ¶ms.namespace.to_owned().into(), - ¶ms.values, + ¶ms.ctx, params.delta, ) .unwrap(); @@ -410,7 +410,7 @@ fn async_bench_update_counters( rate_limiter .update_counters( ¶ms.namespace.to_owned().into(), - ¶ms.values, + ¶ms.ctx, params.delta, ) .await @@ -437,7 +437,7 @@ fn bench_check_rate_limited_and_update( rate_limiter .check_rate_limited_and_update( ¶ms.namespace.to_owned().into(), - ¶ms.values, + ¶ms.ctx, params.delta, false, ) @@ -467,7 +467,7 @@ fn async_bench_check_rate_limited_and_update( rate_limiter .check_rate_limited_and_update( ¶ms.namespace.to_owned().into(), - ¶ms.values, + ¶ms.ctx, params.delta, false, ) @@ -529,14 +529,18 @@ fn generate_test_limits(scenario: &TestScenario) -> (Vec, Vec (Vec, Vec>>(limit: L, set_variables: HashMap) -> Self { - // TODO: check that all the variables defined in the limit are set. + pub fn new>>(limit: L, ctx: &Context) -> LimitadorResult> { + let limit = limit.into(); + let variables = limit.resolve_variables(ctx)?; + match variables { + None => Ok(None), + Some(variables) => Ok(Some(Self { + limit, + set_variables: variables, + remaining: None, + expires_in: None, + })), + } + } + pub(super) fn resolved_vars>>( + limit: L, + set_variables: HashMap, + ) -> LimitadorResult { let limit = limit.into(); let mut vars = set_variables; vars.retain(|var, _| limit.has_variable(var)); - Self { + Ok(Self { limit, set_variables: vars.into_iter().collect(), remaining: None, expires_in: None, - } + }) } #[cfg(any(feature = "redis_storage", feature = "disk_storage"))] @@ -120,3 +136,29 @@ impl PartialEq for Counter { self.limit == other.limit && self.set_variables == other.set_variables } } + +#[cfg(test)] +mod tests { + use crate::counter::Counter; + use crate::limit::Limit; + use std::collections::HashMap; + + #[test] + fn resolves_variables() { + let var = "timestamp(ts).getHours()"; + let limit = Limit::new( + "", + 10, + 60, + Vec::default(), + [var.try_into().expect("failed parsing!")], + ); + let map = HashMap::from([("ts".to_string(), "2019-10-12T13:20:50.52Z".to_string())]); + let ctx = map.into(); + let counter = Counter::new(limit, &ctx).expect("failed creating counter"); + assert_eq!( + counter.unwrap().set_variables.get(var), + Some("13".to_string()).as_ref() + ); + } +} diff --git a/limitador/src/errors.rs b/limitador/src/errors.rs index d590aafa..61070ed7 100644 --- a/limitador/src/errors.rs +++ b/limitador/src/errors.rs @@ -1,4 +1,5 @@ -use crate::limit::ConditionParsingError; +use crate::limit::EvaluationError; +use crate::limit::ParseError; use crate::storage::StorageErr; use std::convert::Infallible; use std::error::Error; @@ -7,7 +8,7 @@ use std::fmt::{Display, Formatter}; #[derive(Debug)] pub enum LimitadorError { StorageError(StorageErr), - InterpreterError(ConditionParsingError), + InterpreterError(EvaluationError), } impl Display for LimitadorError { @@ -38,13 +39,13 @@ impl From for LimitadorError { } } -impl From for LimitadorError { - fn from(err: ConditionParsingError) -> Self { +impl From for LimitadorError { + fn from(err: EvaluationError) -> Self { LimitadorError::InterpreterError(err) } } -impl From for LimitadorError { +impl From for ParseError { fn from(value: Infallible) -> Self { unreachable!("unexpected infallible value: {:?}", value) } diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index 69628052..6902028c 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -54,8 +54,8 @@ //! "my_namespace", //! 10, //! 60, -//! vec!["req.method == 'GET'"], -//! vec!["user_id"], +//! vec!["req_method == 'GET'".try_into().expect("failed parsing!")], +//! vec!["user_id".try_into().expect("failed parsing!")], //! ); //! ``` //! @@ -71,9 +71,9 @@ //! "my_namespace", //! 10, //! 60, -//! vec!["req.method == 'GET'"], -//! vec!["user_id"], -//! ).unwrap(); +//! vec!["req_method == 'GET'".try_into().expect("failed parsing!")], +//! vec!["user_id".try_into().expect("failed parsing!")], +//! ); //! let mut rate_limiter = RateLimiter::new(1000); //! //! // Add a limit @@ -94,7 +94,7 @@ //! //! ``` //! use limitador::RateLimiter; -//! use limitador::limit::Limit; +//! use limitador::limit::{Limit, Context}; //! use std::collections::HashMap; //! //! let mut rate_limiter = RateLimiter::new(1000); @@ -103,35 +103,36 @@ //! "my_namespace", //! 2, //! 60, -//! vec!["req.method == 'GET'"], -//! vec!["user_id"], -//! ).unwrap(); +//! vec!["req_method == 'GET'".try_into().expect("failed parsing!")], +//! vec!["user_id".try_into().expect("failed parsing!")], +//! ); //! rate_limiter.add_limit(limit); //! //! // We've defined a limit of 2. So we can report 2 times before being //! // rate-limited //! let mut values_to_report: HashMap = HashMap::new(); -//! values_to_report.insert("req.method".to_string(), "GET".to_string()); +//! values_to_report.insert("req_method".to_string(), "GET".to_string()); //! values_to_report.insert("user_id".to_string(), "1".to_string()); //! //! // Check if we can report //! let namespace = "my_namespace".into(); -//! assert!(!rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap()); +//! let ctx = &values_to_report.into(); +//! assert!(!rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap()); //! //! // Report -//! rate_limiter.update_counters(&namespace, &values_to_report, 1).unwrap(); +//! rate_limiter.update_counters(&namespace, &ctx, 1).unwrap(); //! //! // Check and report again -//! assert!(!rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap()); -//! rate_limiter.update_counters(&namespace, &values_to_report, 1).unwrap(); +//! assert!(!rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap()); +//! rate_limiter.update_counters(&namespace, &ctx, 1).unwrap(); //! //! // We've already reported 2, so reporting another one should not be allowed -//! assert!(rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap()); +//! assert!(rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap()); //! //! // You can also check and report if not limited in a single call. It's useful //! // for example, when calling Limitador from a proxy. Instead of doing 2 //! // separate calls, we can issue just one: -//! rate_limiter.check_rate_limited_and_update(&namespace, &values_to_report, 1, false).unwrap(); +//! rate_limiter.check_rate_limited_and_update(&namespace, &ctx, 1, false).unwrap(); //! ``` //! //! # Async @@ -167,9 +168,9 @@ //! "my_namespace", //! 10, //! 60, -//! vec!["req.method == 'GET'"], -//! vec!["user_id"], -//! ).unwrap(); +//! vec!["req_method == 'GET'".try_into().expect("failed parsing!")], +//! vec!["user_id".try_into().expect("failed parsing!")], +//! ); //! //! async { //! let rate_limiter = AsyncRateLimiter::new_with_storage( @@ -192,14 +193,13 @@ // TODO this needs review to reduce the bloat pulled in by dependencies #![allow(clippy::multiple_crate_versions)] -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; - use crate::counter::Counter; use crate::errors::LimitadorError; -use crate::limit::{Limit, Namespace}; +use crate::limit::{Context, Limit, Namespace}; use crate::storage::in_memory::InMemoryStorage; use crate::storage::{AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; #[macro_use] extern crate core; @@ -359,7 +359,7 @@ impl RateLimiter { pub fn is_rate_limited( &self, namespace: &Namespace, - values: &HashMap, + values: &Context, delta: u64, ) -> LimitadorResult { let counters = self.counters_that_apply(namespace, values)?; @@ -381,10 +381,10 @@ impl RateLimiter { pub fn update_counters( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context, delta: u64, ) -> LimitadorResult<()> { - let counters = self.counters_that_apply(namespace, values)?; + let counters = self.counters_that_apply(namespace, ctx)?; counters .iter() @@ -395,11 +395,11 @@ impl RateLimiter { pub fn check_rate_limited_and_update( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context, delta: u64, load_counters: bool, ) -> LimitadorResult { - let mut counters = self.counters_that_apply(namespace, values)?; + let mut counters = self.counters_that_apply(namespace, ctx)?; if counters.is_empty() { return Ok(CheckResult { @@ -477,17 +477,18 @@ impl RateLimiter { fn counters_that_apply( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context, ) -> LimitadorResult> { let limits = self.storage.get_limits(namespace); - - let counters = limits + limits .iter() - .filter(|lim| lim.applies(values)) - .map(|lim| Counter::new(Arc::clone(lim), values.clone())) - .collect(); - - Ok(counters) + .filter(|lim| lim.applies(ctx)) + .filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) { + Ok(None) => None, + Ok(Some(c)) => Some(Ok(c)), + Err(e) => Some(Err(e)), + }) + .collect() } } @@ -532,10 +533,10 @@ impl AsyncRateLimiter { pub async fn is_rate_limited( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context<'_>, delta: u64, ) -> LimitadorResult { - let counters = self.counters_that_apply(namespace, values).await?; + let counters = self.counters_that_apply(namespace, ctx).await?; for counter in counters { match self.storage.is_within_limits(&counter, delta).await { @@ -553,10 +554,10 @@ impl AsyncRateLimiter { pub async fn update_counters( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context<'_>, delta: u64, ) -> LimitadorResult<()> { - let counters = self.counters_that_apply(namespace, values).await?; + let counters = self.counters_that_apply(namespace, ctx).await?; for counter in counters { self.storage.update_counter(&counter, delta).await? @@ -568,12 +569,12 @@ impl AsyncRateLimiter { pub async fn check_rate_limited_and_update( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context<'_>, delta: u64, load_counters: bool, ) -> LimitadorResult { // the above where-clause is needed in order to call unwrap(). - let mut counters = self.counters_that_apply(namespace, values).await?; + let mut counters = self.counters_that_apply(namespace, ctx).await?; if counters.is_empty() { return Ok(CheckResult { @@ -656,17 +657,18 @@ impl AsyncRateLimiter { async fn counters_that_apply( &self, namespace: &Namespace, - values: &HashMap, + ctx: &Context<'_>, ) -> LimitadorResult> { let limits = self.storage.get_limits(namespace); - - let counters = limits + limits .iter() - .filter(|lim| lim.applies(values)) - .map(|lim| Counter::new(Arc::clone(lim), values.clone())) - .collect(); - - Ok(counters) + .filter(|lim| lim.applies(ctx)) + .filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) { + Ok(None) => None, + Ok(Some(c)) => Some(Ok(c)), + Err(e) => Some(Err(e)), + }) + .collect() } } @@ -693,23 +695,15 @@ fn classify_limits_by_namespace( #[cfg(test)] mod test { - use crate::limit::Limit; + use crate::limit::{Context, Expression, Limit}; use crate::RateLimiter; - use std::collections::HashMap; #[test] fn properly_updates_existing_limits() { let rl = RateLimiter::new(100); let namespace = "foo"; - let l = Limit::new::<_, String>( - namespace, - 42, - 100, - Vec::::default(), - Vec::::default(), - ) - .expect("This must be a valid limit!"); + let l = Limit::new(namespace, 42, 100, vec![], Vec::::default()); rl.add_limit(l.clone()); let limits = rl.get_limits(&namespace.into()); assert_eq!(limits.len(), 1); @@ -717,7 +711,7 @@ mod test { assert_eq!(limits.iter().next().unwrap().max_value(), 42); let r = rl - .check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true) + .check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true) .unwrap(); assert_eq!(r.counters.first().unwrap().max_value(), 42); @@ -731,7 +725,7 @@ mod test { assert_eq!(limits.iter().next().unwrap().max_value(), 50); let r = rl - .check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true) + .check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true) .unwrap(); assert_eq!(r.counters.first().unwrap().max_value(), 50); } diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index 94b95b1b..b667847b 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -1,34 +1,13 @@ -use crate::limit::conditions::{ErrorType, Literal, SyntaxError, Token, TokenType}; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; -use std::collections::{BTreeSet, HashMap, HashSet}; -use std::error::Error; -use std::fmt::{Debug, Display, Formatter}; +use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::fmt::Debug; use std::hash::{Hash, Hasher}; -#[cfg(feature = "lenient_conditions")] -mod deprecated { - use std::sync::atomic::{AtomicBool, Ordering}; +mod cel; - static DEPRECATED_SYNTAX: AtomicBool = AtomicBool::new(false); - - pub fn check_deprecated_syntax_usages_and_reset() -> bool { - match DEPRECATED_SYNTAX.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed) - { - Ok(previous) => previous, - Err(previous) => previous, - } - } - - pub fn deprecated_syntax_used() { - DEPRECATED_SYNTAX.fetch_or(true, Ordering::SeqCst); - } -} - -use crate::errors::LimitadorError; -use crate::LimitadorResult; -#[cfg(feature = "lenient_conditions")] -pub use deprecated::check_deprecated_syntax_usages_and_reset; +pub use cel::{Context, Expression, Predicate}; +pub use cel::{EvaluationError, ParseError}; #[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord, Serialize, Deserialize)] pub struct Namespace(String); @@ -64,285 +43,48 @@ pub struct Limit { // Need to sort to generate the same object when using the JSON as a key or // value in Redis. - conditions: BTreeSet, - variables: BTreeSet, -} - -#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)] -#[serde(try_from = "String", into = "String")] -pub struct Condition { - var_name: String, - predicate: Predicate, - operand: String, -} - -#[derive(Debug)] -pub struct ConditionParsingError { - error: SyntaxError, - pub tokens: Vec, - condition: String, -} - -impl Display for ConditionParsingError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{} of condition \"{}\"", self.error, self.condition) - } -} - -impl Error for ConditionParsingError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - Some(&self.error) - } -} - -impl TryFrom<&str> for Condition { - type Error = ConditionParsingError; - - fn try_from(value: &str) -> Result { - value.to_owned().try_into() - } -} - -impl TryFrom for Condition { - type Error = ConditionParsingError; - - fn try_from(value: String) -> Result { - match conditions::Scanner::scan(value.clone()) { - Ok(tokens) => match tokens.len().cmp(&(3_usize)) { - Ordering::Equal => { - match ( - &tokens[0].token_type, - &tokens[1].token_type, - &tokens[2].token_type, - ) { - ( - TokenType::Identifier, - TokenType::EqualEqual | TokenType::NotEqual, - TokenType::String, - ) => { - if let ( - Some(Literal::Identifier(var_name)), - Some(Literal::String(operand)), - ) = (&tokens[0].literal, &tokens[2].literal) - { - let predicate = match &tokens[1].token_type { - TokenType::EqualEqual => Predicate::Equal, - TokenType::NotEqual => Predicate::NotEqual, - _ => unreachable!(), - }; - Ok(Condition { - var_name: var_name.clone(), - predicate, - operand: operand.clone(), - }) - } else { - panic!( - "Unexpected state {tokens:?} returned from Scanner for: `{value}`" - ) - } - } - ( - TokenType::String, - TokenType::EqualEqual | TokenType::NotEqual, - TokenType::Identifier, - ) => { - if let ( - Some(Literal::String(operand)), - Some(Literal::Identifier(var_name)), - ) = (&tokens[0].literal, &tokens[2].literal) - { - let predicate = match &tokens[1].token_type { - TokenType::EqualEqual => Predicate::Equal, - TokenType::NotEqual => Predicate::NotEqual, - _ => unreachable!(), - }; - Ok(Condition { - var_name: var_name.clone(), - predicate, - operand: operand.clone(), - }) - } else { - panic!( - "Unexpected state {tokens:?} returned from Scanner for: `{value}`" - ) - } - } - #[cfg(feature = "lenient_conditions")] - (TokenType::Identifier, TokenType::EqualEqual, TokenType::Identifier) => { - if let ( - Some(Literal::Identifier(var_name)), - Some(Literal::Identifier(operand)), - ) = (&tokens[0].literal, &tokens[2].literal) - { - deprecated::deprecated_syntax_used(); - Ok(Condition { - var_name: var_name.clone(), - predicate: Predicate::Equal, - operand: operand.clone(), - }) - } else { - panic!( - "Unexpected state {tokens:?} returned from Scanner for: `{value}`" - ) - } - } - #[cfg(feature = "lenient_conditions")] - (TokenType::Identifier, TokenType::EqualEqual, TokenType::Number) => { - if let ( - Some(Literal::Identifier(var_name)), - Some(Literal::Number(operand)), - ) = (&tokens[0].literal, &tokens[2].literal) - { - deprecated::deprecated_syntax_used(); - Ok(Condition { - var_name: var_name.clone(), - predicate: Predicate::Equal, - operand: operand.to_string(), - }) - } else { - panic!( - "Unexpected state {tokens:?} returned from Scanner for: `{value}`" - ) - } - } - (t1, t2, _) => { - let faulty = match (t1, t2) { - ( - TokenType::Identifier | TokenType::String, - TokenType::EqualEqual | TokenType::NotEqual, - ) => 2, - (TokenType::Identifier | TokenType::String, _) => 1, - (_, _) => 0, - }; - Err(ConditionParsingError { - error: SyntaxError { - pos: tokens[faulty].pos, - error: ErrorType::UnexpectedToken(tokens[faulty].clone()), - }, - tokens, - condition: value, - }) - } - } - } - Ordering::Less => Err(ConditionParsingError { - error: SyntaxError { - pos: value.len(), - error: ErrorType::MissingToken, - }, - tokens, - condition: value, - }), - Ordering::Greater => Err(ConditionParsingError { - error: SyntaxError { - pos: tokens[3].pos, - error: ErrorType::UnexpectedToken(tokens[3].clone()), - }, - tokens, - condition: value, - }), - }, - Err(err) => Err(ConditionParsingError { - error: err, - tokens: Vec::new(), - condition: value, - }), - } - } -} - -impl From for String { - fn from(condition: Condition) -> Self { - let p = &condition.predicate; - let predicate: String = p.clone().into(); - let quotes = if condition.operand.contains('"') { - '\'' - } else { - '"' - }; - format!( - "{} {} {}{}{}", - condition.var_name, predicate, quotes, condition.operand, quotes - ) - } -} - -#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Hash)] -pub enum Predicate { - Equal, - NotEqual, -} - -impl Predicate { - fn test(&self, lhs: &str, rhs: &str) -> bool { - match self { - Predicate::Equal => lhs == rhs, - Predicate::NotEqual => lhs != rhs, - } - } -} - -impl From for String { - fn from(op: Predicate) -> Self { - match op { - Predicate::Equal => "==".to_string(), - Predicate::NotEqual => "!=".to_string(), - } - } + conditions: BTreeSet, + variables: BTreeSet, } impl Limit { - pub fn new, T: TryInto>( + pub fn new>( namespace: N, max_value: u64, seconds: u64, - conditions: impl IntoIterator, - variables: impl IntoIterator>, - ) -> LimitadorResult + conditions: impl IntoIterator, + variables: impl IntoIterator, + ) -> Self where - >::Error: core::fmt::Debug, - >::Error: core::fmt::Debug, - LimitadorError: From<>::Error>, + >::Error: Debug, { - // the above where-clause is needed in order to call unwrap(). - let conditions: Result, _> = - conditions.into_iter().map(|cond| cond.try_into()).collect(); - match conditions { - Ok(conditions) => Ok(Self { - id: None, - namespace: namespace.into(), - max_value, - seconds, - name: None, - conditions, - variables: variables.into_iter().map(|var| var.into()).collect(), - }), - Err(err) => Err(err.into()), + Self { + id: None, + namespace: namespace.into(), + max_value, + seconds, + name: None, + conditions: conditions.into_iter().collect(), + variables: variables.into_iter().collect(), } } - pub fn with_id, N: Into, T: TryInto>( + pub fn with_id, N: Into>( id: S, namespace: N, max_value: u64, seconds: u64, - conditions: impl IntoIterator, - variables: impl IntoIterator>, - ) -> LimitadorResult - where - LimitadorError: From<>::Error>, - { - match conditions.into_iter().map(|cond| cond.try_into()).collect() { - Ok(conditions) => Ok(Self { - id: Some(id.into()), - namespace: namespace.into(), - max_value, - seconds, - name: None, - conditions, - variables: variables.into_iter().map(|var| var.into()).collect(), - }), - Err(err) => Err(err.into()), + conditions: impl IntoIterator, + variables: impl IntoIterator, + ) -> Self { + Self { + id: Some(id.into()), + namespace: namespace.into(), + max_value, + seconds, + name: None, + conditions: conditions.into_iter().collect(), + variables: variables.into_iter().collect(), } } @@ -382,43 +124,64 @@ impl Limit { } pub fn variables(&self) -> HashSet { - self.variables.iter().map(|var| var.into()).collect() + self.variables + .iter() + .map(|var| var.source().into()) + .collect() + } + + pub fn resolve_variables( + &self, + ctx: &Context, + ) -> Result>, EvaluationError> { + let mut map = BTreeMap::new(); + for variable in &self.variables { + let name = variable.source().into(); + match variable.eval(ctx)? { + None => return Ok(None), + Some(value) => { + map.insert(name, value); + } + } + } + Ok(Some(map)) } #[cfg(feature = "disk_storage")] pub(crate) fn variables_for_key(&self) -> Vec<&str> { let mut variables = Vec::with_capacity(self.variables.len()); for var in &self.variables { - variables.push(var.as_str()); + variables.push(var.source()); } variables.sort(); variables } pub fn has_variable(&self, var: &str) -> bool { - self.variables.contains(var) + self.variables + .iter() + .flat_map(|v| v.variables()) + .any(|v| v.as_str() == var) } - pub fn applies(&self, values: &HashMap) -> bool { + pub fn applies(&self, ctx: &Context) -> bool { + let ctx = ctx.for_limit(self); let all_conditions_apply = self .conditions .iter() - .all(|cond| Self::condition_applies(cond, values)); + .all(|predicate| predicate.test(&ctx.for_limit(self)).unwrap()); - let all_vars_are_set = self.variables.iter().all(|var| values.contains_key(var)); + let all_vars_are_set = self.variables.iter().all(|var| { + ctx.has_variables( + &var.variables() + .iter() + .map(String::as_str) + .collect::>(), + ) + }); all_conditions_apply && all_vars_are_set } - - fn condition_applies(condition: &Condition, values: &HashMap) -> bool { - let left_operand = condition.var_name.as_str(); - let right_operand = condition.operand.as_str(); - - match values.get(left_operand) { - Some(val) => condition.predicate.test(val, right_operand), - None => false, - } - } } impl Hash for Limit { @@ -460,410 +223,22 @@ impl PartialEq for Limit { } } -mod conditions { - use std::error::Error; - use std::fmt::{Debug, Display, Formatter}; - use std::num::IntErrorKind; - - #[derive(Debug)] - pub struct SyntaxError { - pub pos: usize, - pub error: ErrorType, - } - - #[derive(Debug, Eq, PartialEq)] - pub enum ErrorType { - UnexpectedToken(Token), - MissingToken, - InvalidCharacter(char), - InvalidNumber, - UnclosedStringLiteral(char), - } - - impl Display for SyntaxError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match &self.error { - ErrorType::UnexpectedToken(token) => write!( - f, - "SyntaxError: Unexpected token `{}` at offset {}", - token, self.pos - ), - ErrorType::InvalidCharacter(char) => write!( - f, - "SyntaxError: Invalid character `{}` at offset {}", - char, self.pos - ), - ErrorType::InvalidNumber => { - write!(f, "SyntaxError: Invalid number at offset {}", self.pos) - } - ErrorType::MissingToken => { - write!(f, "SyntaxError: Expected token at offset {}", self.pos) - } - ErrorType::UnclosedStringLiteral(char) => { - write!(f, "SyntaxError: Missing closing `{}` for string literal starting at offset {}", char, self.pos) - } - } - } - } - - impl Error for SyntaxError {} - - #[derive(Clone, Eq, PartialEq, Debug)] - pub enum TokenType { - // Predicates - EqualEqual, - NotEqual, - - //Literals - Identifier, - String, - Number, - } - - #[derive(Clone, Eq, PartialEq, Debug)] - pub enum Literal { - Identifier(String), - String(String), - Number(i64), - } - - impl Display for Literal { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Literal::Identifier(id) => write!(f, "{id}"), - Literal::String(string) => write!(f, "'{string}'"), - Literal::Number(number) => write!(f, "{number}"), - } - } - } - - #[derive(Clone, Eq, PartialEq, Debug)] - pub struct Token { - pub token_type: TokenType, - pub literal: Option, - pub pos: usize, - } - - impl Display for Token { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.token_type { - TokenType::EqualEqual => write!(f, "Equality (==)"), - TokenType::NotEqual => write!(f, "Unequal (!=)"), - TokenType::Identifier => { - write!(f, "Identifier: {}", self.literal.as_ref().unwrap()) - } - TokenType::String => { - write!(f, "String literal: {}", self.literal.as_ref().unwrap()) - } - TokenType::Number => { - write!(f, "Number literal: {}", self.literal.as_ref().unwrap()) - } - } - } - } - - pub struct Scanner { - input: Vec, - pos: usize, - } - - impl Scanner { - pub fn scan(condition: String) -> Result, SyntaxError> { - let mut tokens: Vec = Vec::with_capacity(3); - let mut scanner = Scanner { - input: condition.chars().collect(), - pos: 0, - }; - while !scanner.done() { - match scanner.next_token() { - Ok(token) => { - if let Some(token) = token { - tokens.push(token) - } - } - Err(err) => { - return Err(err); - } - } - } - Ok(tokens) - } - - fn next_token(&mut self) -> Result, SyntaxError> { - let character = self.advance(); - match character { - '=' => { - if self.next_matches('=') { - Ok(Some(Token { - token_type: TokenType::EqualEqual, - literal: None, - pos: self.pos - 1, - })) - } else { - Err(SyntaxError { - pos: self.pos, - error: ErrorType::InvalidCharacter(self.input[self.pos - 1]), - }) - } - } - '!' => { - if self.next_matches('=') { - Ok(Some(Token { - token_type: TokenType::NotEqual, - literal: None, - pos: self.pos - 1, - })) - } else { - Err(SyntaxError { - pos: self.pos, - error: ErrorType::InvalidCharacter(self.input[self.pos - 1]), - }) - } - } - '"' | '\'' => self.scan_string(character).map(Some), - ' ' | '\n' | '\r' | '\t' => Ok(None), - _ => { - if character.is_alphabetic() { - self.scan_identifier().map(Some) - } else if character.is_numeric() { - self.scan_number().map(Some) - } else { - Err(SyntaxError { - pos: self.pos, - error: ErrorType::InvalidCharacter(character), - }) - } - } - } - } - - fn scan_identifier(&mut self) -> Result { - let start = self.pos; - while !self.done() && self.valid_id_char() { - self.advance(); - } - Ok(Token { - token_type: TokenType::Identifier, - literal: Some(Literal::Identifier( - self.input[start - 1..self.pos].iter().collect(), - )), - pos: start, - }) - } - - fn valid_id_char(&mut self) -> bool { - let char = self.input[self.pos]; - char.is_alphanumeric() || char == '.' || char == '_' - } - - fn scan_string(&mut self, until: char) -> Result { - let start = self.pos; - loop { - if self.done() { - return Err(SyntaxError { - pos: start, - error: ErrorType::UnclosedStringLiteral(until), - }); - } - if self.advance() == until { - return Ok(Token { - token_type: TokenType::String, - literal: Some(Literal::String( - self.input[start..self.pos - 1].iter().collect(), - )), - pos: start, - }); - } - } - } - - fn scan_number(&mut self) -> Result { - let start = self.pos; - while !self.done() && self.input[self.pos].is_numeric() { - self.advance(); - } - let number_str = self.input[start - 1..self.pos].iter().collect::(); - match number_str.parse::() { - Ok(number) => Ok(Token { - token_type: TokenType::Number, - literal: Some(Literal::Number(number)), - pos: start, - }), - Err(err) => { - let syntax_error = match err.kind() { - IntErrorKind::Empty => { - unreachable!("This means a bug in the scanner!") - } - IntErrorKind::Zero => { - unreachable!("We're parsing Numbers as i64, so 0 should always work!") - } - _ => SyntaxError { - pos: start, - error: ErrorType::InvalidNumber, - }, - }; - Err(syntax_error) - } - } - } - - fn advance(&mut self) -> char { - let char = self.input[self.pos]; - self.pos += 1; - char - } - - fn next_matches(&mut self, c: char) -> bool { - if self.done() || self.input[self.pos] != c { - return false; - } - - self.pos += 1; - true - } - - fn done(&self) -> bool { - self.pos >= self.input.len() - } - } - - #[cfg(test)] - mod tests { - use crate::limit::conditions::Literal::Identifier; - use crate::limit::conditions::{ErrorType, Literal, Scanner, Token, TokenType}; - - #[test] - fn test_scanner() { - let mut tokens = - Scanner::scan("foo=='bar '".to_owned()).expect("Should parse alright!"); - assert_eq!(tokens.len(), 3); - assert_eq!( - tokens[0], - Token { - token_type: TokenType::Identifier, - literal: Some(Identifier("foo".to_owned())), - pos: 1, - } - ); - assert_eq!( - tokens[1], - Token { - token_type: TokenType::EqualEqual, - literal: None, - pos: 4, - } - ); - assert_eq!( - tokens[2], - Token { - token_type: TokenType::String, - literal: Some(Literal::String("bar ".to_owned())), - pos: 6, - } - ); - - tokens[1].pos += 1; - tokens[2].pos += 2; - assert_eq!( - tokens, - Scanner::scan("foo == 'bar '".to_owned()).expect("Should parse alright!") - ); - - tokens[0].pos += 2; - tokens[1].pos += 2; - tokens[2].pos += 2; - assert_eq!( - tokens, - Scanner::scan(" foo == 'bar ' ".to_owned()).expect("Should parse alright!") - ); - - tokens[1].pos += 2; - tokens[2].pos += 4; - assert_eq!( - tokens, - Scanner::scan(" foo == 'bar ' ".to_owned()).expect("Should parse alright!") - ); - } - - #[test] - fn test_number_literal() { - let tokens = Scanner::scan("var == 42".to_owned()).expect("Should parse alright!"); - assert_eq!(tokens.len(), 3); - assert_eq!( - tokens[0], - Token { - token_type: TokenType::Identifier, - literal: Some(Identifier("var".to_owned())), - pos: 1, - } - ); - assert_eq!( - tokens[1], - Token { - token_type: TokenType::EqualEqual, - literal: None, - pos: 5, - } - ); - assert_eq!( - tokens[2], - Token { - token_type: TokenType::Number, - literal: Some(Literal::Number(42)), - pos: 8, - } - ); - } - - #[test] - fn test_charset() { - let tokens = - Scanner::scan(" 変数 == ' 💖 '".to_owned()).expect("Should parse alright!"); - assert_eq!(tokens.len(), 3); - assert_eq!( - tokens[0], - Token { - token_type: TokenType::Identifier, - literal: Some(Identifier("変数".to_owned())), - pos: 2, - } - ); - assert_eq!( - tokens[1], - Token { - token_type: TokenType::EqualEqual, - literal: None, - pos: 5, - } - ); - assert_eq!( - tokens[2], - Token { - token_type: TokenType::String, - literal: Some(Literal::String(" 💖 ".to_owned())), - pos: 8, - } - ); - } - - #[test] - fn unclosed_string_literal() { - let error = Scanner::scan("foo == 'ba".to_owned()).expect_err("Should fail!"); - assert_eq!(error.pos, 8); - assert_eq!(error.error, ErrorType::UnclosedStringLiteral('\'')); - } - } -} - #[cfg(test)] mod tests { use super::*; + use crate::counter::Counter; use std::cmp::Ordering::Equal; + use std::collections::HashMap; #[test] fn limit_can_have_an_optional_name() { - let mut limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]) - .expect("This must be a valid limit!"); + let mut limit = Limit::new( + "test_namespace", + 10, + 60, + vec!["x == \"5\"".try_into().expect("failed parsing!")], + vec!["y".try_into().expect("failed parsing!")], + ); assert!(limit.name.is_none()); let name = "Test Limit"; @@ -873,77 +248,71 @@ mod tests { #[test] fn limit_applies() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]) - .expect("This must be a valid limit!"); + let limit = Limit::new( + "test_namespace", + 10, + 60, + vec!["x == \"5\"".try_into().expect("failed parsing!")], + vec!["y".try_into().expect("failed parsing!")], + ); let mut values: HashMap = HashMap::new(); values.insert("x".into(), "5".into()); values.insert("y".into(), "1".into()); - assert!(limit.applies(&values)) + assert!(limit.applies(&values.into())) } #[test] fn limit_does_not_apply_when_cond_is_false() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]) - .expect("This must be a valid limit!"); - - let mut values: HashMap = HashMap::new(); - values.insert("x".into(), "1".into()); - values.insert("y".into(), "1".into()); - - assert!(!limit.applies(&values)) - } - - #[test] - #[cfg(feature = "lenient_conditions")] - fn limit_does_not_apply_when_cond_is_false_deprecated_style() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == 5"], vec!["y"]) - .expect("This must be a valid limit!"); + let limit = Limit::new( + "test_namespace", + 10, + 60, + vec!["x == \"5\"".try_into().expect("failed parsing!")], + vec!["y".try_into().expect("failed parsing!")], + ); let mut values: HashMap = HashMap::new(); values.insert("x".into(), "1".into()); values.insert("y".into(), "1".into()); - assert!(!limit.applies(&values)); - assert!(check_deprecated_syntax_usages_and_reset()); - assert!(!check_deprecated_syntax_usages_and_reset()); - - let limit = Limit::new("test_namespace", 10, 60, vec!["x == foobar"], vec!["y"]) - .expect("This must be a valid limit!"); - - let mut values: HashMap = HashMap::new(); - values.insert("x".into(), "foobar".into()); - values.insert("y".into(), "1".into()); - - assert!(limit.applies(&values)); - assert!(check_deprecated_syntax_usages_and_reset()); - assert!(!check_deprecated_syntax_usages_and_reset()); + assert!(!limit.applies(&values.into())) } #[test] fn limit_does_not_apply_when_cond_var_is_not_set() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]) - .expect("This must be a valid limit!"); + let limit = Limit::new( + "test_namespace", + 10, + 60, + vec!["x == \"5\"".try_into().expect("failed parsing!")], + vec!["y".try_into().expect("failed parsing!")], + ); // Notice that "x" is not set let mut values: HashMap = HashMap::new(); values.insert("a".into(), "1".into()); values.insert("y".into(), "1".into()); - assert!(!limit.applies(&values)) + assert!(!limit.applies(&values.into())) } #[test] fn limit_does_not_apply_when_var_not_set() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]) - .expect("This must be a valid limit!"); + let limit = Limit::new( + "test_namespace", + 10, + 60, + vec!["x == \"5\"".try_into().expect("failed parsing!")], + vec!["y".try_into().expect("failed parsing!")], + ); // Notice that "y" is not set let mut values: HashMap = HashMap::new(); values.insert("x".into(), "5".into()); - assert!(!limit.applies(&values)) + assert!(!limit.applies(&values.into())) } #[test] @@ -952,17 +321,19 @@ mod tests { "test_namespace", 10, 60, - vec!["x == \"5\"", "y == \"2\""], - vec!["z"], - ) - .expect("This must be a valid limit!"); + vec![ + "x == \"5\"".try_into().expect("failed parsing!"), + "y == \"2\"".try_into().expect("failed parsing!"), + ], + vec!["z".try_into().expect("failed parsing!")], + ); let mut values: HashMap = HashMap::new(); values.insert("x".into(), "5".into()); values.insert("y".into(), "2".into()); values.insert("z".into(), "1".into()); - assert!(limit.applies(&values)) + assert!(limit.applies(&values.into())) } #[test] @@ -971,82 +342,19 @@ mod tests { "test_namespace", 10, 60, - vec!["x == \"5\"", "y == \"2\""], - vec!["z"], - ) - .expect("This must be a valid limit!"); + vec![ + "x == \"5\"".try_into().expect("failed parsing!"), + "y == \"2\"".try_into().expect("failed parsing!"), + ], + vec!["z".try_into().expect("failed parsing!")], + ); let mut values: HashMap = HashMap::new(); values.insert("x".into(), "3".into()); values.insert("y".into(), "2".into()); values.insert("z".into(), "1".into()); - assert!(!limit.applies(&values)) - } - - #[test] - fn valid_condition_literal_parsing() { - let result: Condition = serde_json::from_str(r#""x == '5'""#).expect("Should deserialize"); - assert_eq!( - result, - Condition { - var_name: "x".to_string(), - predicate: Predicate::Equal, - operand: "5".to_string(), - } - ); - - let result: Condition = - serde_json::from_str(r#"" foobar=='ok' ""#).expect("Should deserialize"); - assert_eq!( - result, - Condition { - var_name: "foobar".to_string(), - predicate: Predicate::Equal, - operand: "ok".to_string(), - } - ); - - let result: Condition = - serde_json::from_str(r#"" foobar == 'ok' ""#).expect("Should deserialize"); - assert_eq!( - result, - Condition { - var_name: "foobar".to_string(), - predicate: Predicate::Equal, - operand: "ok".to_string(), - } - ); - } - - #[test] - #[cfg(not(feature = "lenient_conditions"))] - fn invalid_deprecated_condition_parsing() { - let _result = serde_json::from_str::(r#""x == 5""#) - .err() - .expect("Should fail!"); - } - - #[test] - fn invalid_condition_parsing() { - let result = serde_json::from_str::(r#""x != 5 && x > 12""#) - .expect_err("should fail parsing"); - assert_eq!( - result.to_string(), - "SyntaxError: Invalid character `&` at offset 8 of condition \"x != 5 && x > 12\"" - .to_string() - ); - } - - #[test] - fn condition_serialization() { - let condition = Condition { - var_name: "foobar".to_string(), - predicate: Predicate::Equal, - operand: "ok".to_string(), - }; - let result = serde_json::to_string(&condition).expect("Should serialize"); - assert_eq!(result, r#""foobar == \"ok\"""#.to_string()); + assert!(!limit.applies(&values.into())) } #[test] @@ -1056,10 +364,9 @@ mod tests { "test_namespace", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); assert_eq!(limit.id(), Some("test_id")) } @@ -1071,22 +378,99 @@ mod tests { "test_namespace", 42, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let mut limit2 = Limit::new( limit1.namespace.clone(), limit1.max_value + 10, limit1.seconds, - limit1.conditions.clone(), + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], limit1.variables.clone(), - ) - .expect("This must be a valid limit!"); + ); limit2.set_name("Who cares?".to_string()); assert_eq!(limit1.partial_cmp(&limit2), Some(Equal)); assert_eq!(limit1, limit2); } + + #[test] + fn resolves_variables() { + let limit = Limit::new( + "", + 10, + 60, + Vec::default(), + ["int(x) * 3".try_into().expect("failed parsing!")], + ); + assert!(limit.has_variable("x")); + } + + #[test] + fn conditions_have_limit_info() { + let mut limit = Limit::new( + "ns", + 42, + 10, + vec!["limit.name == 'named_limit'" + .try_into() + .expect("failed parsing!")], + Vec::default(), + ); + assert!(!limit.applies(&Context::default())); + + limit.set_name("named_limit".to_string()); + assert!(limit.applies(&Context::default())); + + let limit = Limit::with_id( + "my_id", + "ns", + 42, + 10, + vec![ + "limit.id == 'my_id'".try_into().expect("failed parsing!"), + "limit.name == null".try_into().expect("failed parsing!"), + ], + Vec::default(), + ); + assert!(limit.applies(&Context::default())); + + let limit = Limit::with_id( + "my_id", + "ns", + 42, + 10, + vec!["limit.id == 'other_id'" + .try_into() + .expect("failed parsing!")], + Vec::default(), + ); + assert!(!limit.applies(&Context::default())); + } + + #[test] + fn cel_limit_applies() { + let limit = Limit::new( + "ns", + 42, + 10, + vec!["foo.contains('bar')".try_into().expect("failed parsing!")], + vec!["bar.endsWith('baz')".try_into().expect("failed parsing!")], + ); + let map = HashMap::from([ + ("foo".to_string(), "nice bar!".to_string()), + ("bar".to_string(), "foo,baz".to_string()), + ]); + let ctx = map.into(); + assert!(limit.applies(&ctx)); + assert_eq!( + Counter::new(limit, &ctx) + .expect("failed") + .unwrap() + .set_variables() + .get("bar.endsWith('baz')"), + Some(&"true".to_string()) + ); + } } diff --git a/limitador/src/limit/cel.rs b/limitador/src/limit/cel.rs new file mode 100644 index 00000000..d44609d1 --- /dev/null +++ b/limitador/src/limit/cel.rs @@ -0,0 +1,448 @@ +use crate::limit::Limit; +use cel_interpreter::{ExecutionError, Value}; +pub use errors::{EvaluationError, ParseError}; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +pub(super) mod errors { + use cel_interpreter::ExecutionError; + use std::error::Error; + use std::fmt::{Display, Formatter}; + + #[derive(Debug, PartialEq)] + pub enum EvaluationError { + UnexpectedValueType(String), + ExecutionError(ExecutionError), + } + + impl Display for EvaluationError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + EvaluationError::UnexpectedValueType(value) => { + write!(f, "unexpected value of type {}", value) + } + EvaluationError::ExecutionError(error) => error.fmt(f), + } + } + } + + impl Error for EvaluationError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + EvaluationError::UnexpectedValueType(_) => None, + EvaluationError::ExecutionError(err) => Some(err), + } + } + } + + #[derive(Debug)] + pub struct ParseError { + input: String, + source: Box, + } + + impl ParseError { + pub fn from(source: cel_parser::ParseError, input: String) -> Self { + Self { + input, + source: Box::new(source), + } + } + } + + impl Display for ParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "couldn't parse {}: {}", self.input, self.source) + } + } + + impl Error for ParseError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(self.source.as_ref()) + } + } + + impl From for EvaluationError { + fn from(err: ExecutionError) -> Self { + EvaluationError::ExecutionError(err) + } + } +} + +pub struct Context<'a> { + variables: HashSet, + ctx: cel_interpreter::Context<'a>, +} + +impl<'a> Context<'a> { + pub(crate) fn new(root: String, values: HashMap) -> Self { + let mut ctx = cel_interpreter::Context::default(); + let mut variables = HashSet::new(); + + if root.is_empty() { + for (binding, value) in values { + ctx.add_variable_from_value(binding.clone(), value.clone()); + variables.insert(binding); + } + } else { + let map = cel_interpreter::objects::Map::from(values.clone()); + ctx.add_variable_from_value(root, Value::Map(map)); + } + + Self { variables, ctx } + } + + pub fn list_binding(&mut self, name: String, value: Vec>) { + let v = value + .iter() + .map(|values| { + let map = cel_interpreter::objects::Map::from(values.clone()); + Value::Map(map) + }) + .collect::>(); + self.variables.insert(name.clone()); + self.ctx + .add_variable_from_value(name, Value::List(v.into())); + } + + pub(crate) fn for_limit<'b>(&'b self, limit: &Limit) -> Self + where + 'b: 'a, + { + let mut inner = self.ctx.new_inner_scope(); + let limit_data = cel_interpreter::objects::Map::from(HashMap::from([ + ( + "name", + limit + .name + .as_ref() + .map(|n| Value::String(Arc::new(n.to_string()))) + .unwrap_or(Value::Null), + ), + ( + "id", + limit + .id + .as_ref() + .map(|n| Value::String(Arc::new(n.to_string()))) + .unwrap_or(Value::Null), + ), + ])); + inner.add_variable_from_value("limit", Value::Map(limit_data)); + Self { + variables: self.variables.clone(), + ctx: inner, + } + } + + pub(crate) fn has_variables(&self, names: &[&str]) -> bool { + names.iter().all(|name| self.variables.contains(*name)) + } +} + +impl Default for Context<'_> { + fn default() -> Self { + Self::new(String::default(), HashMap::default()) + } +} + +impl From> for Context<'_> { + fn from(value: HashMap) -> Self { + Self::new(String::default(), value) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct Expression { + source: String, + expression: cel_parser::Expression, +} + +impl Expression { + pub fn parse(source: T) -> Result { + let source = source.to_string(); + match cel_parser::parse(&source) { + Ok(expression) => Ok(Self { source, expression }), + Err(err) => Err(ParseError::from(err, source)), + } + } + + pub fn eval(&self, ctx: &Context) -> Result, EvaluationError> { + let result = self.resolve(ctx); + match result { + Ok(value) => match value { + Value::Int(i) => Ok(i.to_string()), + Value::UInt(i) => Ok(i.to_string()), + Value::Float(f) => Ok(f.to_string()), + Value::String(s) => Ok(s.to_string()), + Value::Null => Ok("null".to_owned()), + Value::Bool(b) => Ok(b.to_string()), + val => Err(err_on_value(val)), + } + .map(Some), + Err(ExecutionError::NoSuchKey(_)) => Ok(None), + Err(err) => Err(err.into()), + } + } + + pub(super) fn resolve(&self, ctx: &Context) -> Result { + Value::resolve(&self.expression, &ctx.ctx) + } + + pub fn source(&self) -> &str { + self.source.as_str() + } + + pub fn variables(&self) -> Vec { + self.expression + .references() + .variables() + .into_iter() + .map(String::from) + .collect() + } +} + +fn err_on_value(val: Value) -> EvaluationError { + match val { + Value::List(list) => EvaluationError::UnexpectedValueType(format!("list: `{:?}`", *list)), + Value::Map(map) => EvaluationError::UnexpectedValueType(format!("map: `{:?}`", *map.map)), + Value::Function(ident, _) => { + EvaluationError::UnexpectedValueType(format!("function: `{}`", *ident)) + } + Value::Bytes(b) => EvaluationError::UnexpectedValueType(format!("function: `{:?}`", *b)), + Value::Duration(d) => EvaluationError::UnexpectedValueType(format!("duration: `{d}`")), + Value::Timestamp(ts) => EvaluationError::UnexpectedValueType(format!("timestamp: `{ts}`")), + Value::Int(i) => EvaluationError::UnexpectedValueType(format!("integer: `{i}`")), + Value::UInt(u) => EvaluationError::UnexpectedValueType(format!("unsigned integer: `{u}`")), + Value::Float(f) => EvaluationError::UnexpectedValueType(format!("float: `{f}`")), + Value::String(s) => EvaluationError::UnexpectedValueType(format!("string: `{s}`")), + Value::Bool(b) => EvaluationError::UnexpectedValueType(format!("bool: `{b}`")), + Value::Null => EvaluationError::UnexpectedValueType("null".to_owned()), + } +} + +impl TryFrom for Expression { + type Error = ParseError; + + fn try_from(value: String) -> Result { + Self::parse(value) + } +} + +impl TryFrom<&str> for Predicate { + type Error = ParseError; + + fn try_from(value: &str) -> Result { + Self::parse(value) + } +} + +impl From for String { + fn from(value: Expression) -> Self { + value.source + } +} + +impl PartialEq for Expression { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl Eq for Expression {} + +impl Hash for Expression { + fn hash(&self, state: &mut H) { + self.source.hash(state); + } +} + +impl PartialOrd for Expression { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Expression { + fn cmp(&self, other: &Self) -> Ordering { + self.source.cmp(&other.source) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct Predicate { + #[serde(skip_serializing, default)] + variables: HashSet, + expression: Expression, +} + +impl Predicate { + pub fn parse(source: T) -> Result { + Expression::parse(source).map(|e| Self { + variables: e + .expression + .references() + .variables() + .into_iter() + .map(String::from) + .collect(), + expression: e, + }) + } + + pub fn test(&self, ctx: &Context) -> Result { + if !self + .variables + .iter() + .filter(|binding| binding.as_str() != "limit") + .all(|v| ctx.variables.contains(v)) + { + return Ok(false); + } + + match self.expression.resolve(ctx) { + Ok(value) => match value { + Value::Bool(b) => Ok(b), + v => Err(err_on_value(v)), + }, + Err(ExecutionError::NoSuchKey(_)) => Ok(false), + Err(err) => Err(err.into()), + } + } +} + +impl Eq for Predicate {} + +impl PartialEq for Predicate { + fn eq(&self, other: &Self) -> bool { + self.expression.source == other.expression.source + } +} + +impl PartialOrd for Predicate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Predicate { + fn cmp(&self, other: &Self) -> Ordering { + self.expression.cmp(&other.expression) + } +} + +impl Hash for Predicate { + fn hash(&self, state: &mut H) { + self.expression.source.hash(state); + } +} + +impl TryFrom for Predicate { + type Error = ParseError; + + fn try_from(value: String) -> Result { + Self::parse(value) + } +} + +impl TryFrom<&str> for Expression { + type Error = ParseError; + + fn try_from(value: &str) -> Result { + Self::parse(value) + } +} + +impl From for String { + fn from(value: Predicate) -> Self { + value.expression.source + } +} + +#[cfg(test)] +mod tests { + use super::{Context, Expression, Predicate}; + use std::collections::{HashMap, HashSet}; + + #[test] + fn expression() { + let exp = Expression::parse("100").expect("failed to parse"); + assert_eq!(exp.eval(&ctx()), Ok(Some(String::from("100")))); + } + + #[test] + fn expression_serialization() { + let exp = Expression::parse("100").expect("failed to parse"); + let serialized = serde_json::to_string(&exp).expect("failed to serialize"); + let deserialized: Expression = + serde_json::from_str(&serialized).expect("failed to deserialize"); + assert_eq!(exp.eval(&ctx()), deserialized.eval(&ctx())); + } + + #[test] + fn unexpected_value_type_expression() { + let exp = Expression::parse("['100']").expect("failed to parse"); + assert_eq!( + exp.eval(&ctx()).map_err(|e| format!("{e}")), + Err("unexpected value of type list: `[String(\"100\")]`".to_string()) + ); + } + + #[test] + fn predicate() { + let pred = Predicate::parse("42 == uint('42')").expect("failed to parse"); + assert_eq!(pred.test(&ctx()), Ok(true)); + } + + #[test] + fn predicate_no_var() { + let pred = Predicate::parse("not_there == 42").expect("failed to parse"); + assert_eq!(pred.test(&ctx()), Ok(false)); + } + + #[test] + fn predicate_no_key() { + let pred = Predicate::parse("there.not == 42").expect("failed to parse"); + assert_eq!( + pred.test(&HashMap::from([("there".to_string(), String::default())]).into()), + Ok(false) + ); + } + + #[test] + fn unexpected_value_predicate() { + let pred = Predicate::parse("42").expect("failed to parse"); + assert_eq!( + pred.test(&ctx()).map_err(|e| format!("{e}")), + Err("unexpected value of type integer: `42`".to_string()) + ); + } + + #[test] + fn supports_list_bindings() { + let pred = Predicate::parse("root[0].key == '1' && root[1]['key'] == '2'") + .expect("failed to parse"); + let mut ctx = Context::default(); + ctx.list_binding( + "root".to_string(), + vec![ + HashMap::from([("key".to_string(), "1".to_string())]), + HashMap::from([("key".to_string(), "2".to_string())]), + ], + ); + assert_eq!(pred.test(&ctx).map_err(|e| format!("{e}")), Ok(true)); + } + + fn ctx<'a>() -> Context<'a> { + Context { + variables: HashSet::default(), + ctx: cel_interpreter::Context::default(), + } + } +} diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index 12aadce0..96496693 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -242,9 +242,18 @@ mod tests { #[test] fn opens_db_on_disk() { let namespace = "test_namespace"; - let limit = Limit::new(namespace, 1, 2, vec!["req.method == 'GET'"], vec!["app_id"]) - .expect("This must be a valid limit!"); - let counter = Counter::new(limit, HashMap::default()); + let limit = Limit::new( + namespace, + 1, + 2, + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); + let counter = Counter::new(limit, &ctx) + .unwrap() + .expect("must have a counter"); let tmp = TempDir::new().expect("We should have a dir!"); { diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 020918c5..77f0a3ff 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -9,7 +9,7 @@ use tokio::sync::mpsc::Sender; use tracing::debug; use crate::counter::Counter; -use crate::limit::Limit; +use crate::limit::{Context, Limit}; use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::CounterUpdate; use crate::storage::distributed::grpc::{Broker, CounterEntry}; @@ -47,7 +47,9 @@ impl CounterStorage for CrInMemoryStorage { let key = encode_limit_to_key(limit); limits.entry(key.clone()).or_insert(Arc::new(CounterEntry { key, - counter: Counter::new(limit.clone(), HashMap::default()), + counter: Counter::new(limit.clone(), &Context::default()) + .expect("counter creation can't fail! no vars to resolve!") + .expect("must have a counter"), value: CrCounterValue::new( self.identifier.clone(), limit.max_value(), @@ -333,6 +335,15 @@ fn encode_counter_to_key(counter: &Counter) -> Vec { } fn encode_limit_to_key(limit: &Limit) -> Vec { - let counter = Counter::new(limit.clone(), HashMap::default()); + // fixme this is broken! + let vars: HashMap = limit + .variables() + .into_iter() + .map(|k| (k, "".to_string())) + .collect(); + let ctx = vars.into(); + let counter = Counter::new(limit.clone(), &ctx) + .expect("counter creation can't fail! faked vars!") + .expect("must have a counter"); key_for_counter_v2(&counter) } diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index 3148b6df..28dc2119 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -1,5 +1,5 @@ use crate::counter::Counter; -use crate::limit::{Limit, Namespace}; +use crate::limit::{Context, Limit, Namespace}; use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::{Authorization, CounterStorage, StorageErr}; use moka::sync::Cache; @@ -211,7 +211,10 @@ impl InMemoryStorage { for (limit, counter) in self.simple_limits.read().unwrap().iter() { if limit.namespace() == namespace { res.insert( - Counter::new(limit.clone(), HashMap::default()), + // todo fixme + Counter::new(limit.clone(), &Context::default()) + .unwrap() + .unwrap(), counter.clone(), ); } @@ -252,18 +255,28 @@ mod tests { fn counters_for_multiple_limit_per_ns() { let storage = InMemoryStorage::default(); let namespace = "test_namespace"; - let limit_1 = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]) - .expect("This must be a valid limit!"); + let limit_1 = Limit::new( + namespace, + 1, + 1, + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let limit_2 = Limit::new( namespace, 1, 10, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); - let counter_1 = Counter::new(limit_1, HashMap::default()); - let counter_2 = Counter::new(limit_2, HashMap::default()); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); + let counter_1 = Counter::new(limit_1, &ctx) + .expect("counter creation failed!") + .expect("Should have a counter"); + let counter_2 = Counter::new(limit_2, &ctx) + .expect("counter creation failed!") + .expect("Should have a counter"); storage.update_counter(&counter_1, 1).unwrap(); storage.update_counter(&counter_2, 1).unwrap(); diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 621eb202..10fe5a44 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -119,12 +119,11 @@ mod tests { "example.com", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); assert_eq!( - "namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req.method == \\\"GET\\\"\"],\"variables\":[\"app_id\"]}".as_bytes(), + "namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req_method == 'GET'\"],\"variables\":[\"app_id\"]}".as_bytes(), key_for_counters_of_limit(&limit)) } @@ -135,10 +134,9 @@ mod tests { "example.com", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); assert_eq!( "\u{2}\u{7}test_id".as_bytes(), key_for_counters_of_limit(&limit) @@ -148,9 +146,18 @@ mod tests { #[test] fn counter_key_and_counter_are_symmetric() { let namespace = "ns_counter:"; - let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]) - .expect("This must be a valid limit!"); - let counter = Counter::new(limit.clone(), HashMap::default()); + let limit = Limit::new( + namespace, + 1, + 1, + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); + let counter = Counter::new(limit.clone(), &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let raw = key_for_counter(&counter); assert_eq!(counter, partial_counter_from_counter_key(&raw)); } @@ -158,9 +165,18 @@ mod tests { #[test] fn counter_key_does_not_include_transient_state() { let namespace = "ns_counter:"; - let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]) - .expect("This must be a valid limit!"); - let counter = Counter::new(limit.clone(), HashMap::default()); + let limit = Limit::new( + namespace, + 1, + 1, + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); + let counter = Counter::new(limit.clone(), &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let mut other = counter.clone(); other.set_remaining(123); other.set_expires_in(Duration::from_millis(456)); @@ -174,7 +190,7 @@ pub mod bin { use std::collections::HashMap; use crate::counter::Counter; - use crate::limit::Limit; + use crate::limit::{Limit, Predicate}; #[derive(PartialEq, Debug, Serialize, Deserialize)] struct IdCounterKey<'a> { @@ -272,9 +288,17 @@ pub mod bin { .into_iter() .map(|(var, value)| (var.to_string(), value.to_string())) .collect(); - let limit = - Limit::new(ns, u64::default(), seconds, conditions, map.keys()).unwrap(); - Counter::new(limit, map) + let limit = Limit::new( + ns, + u64::default(), + seconds, + conditions + .into_iter() + .map(|p| p.try_into().expect("condition corrupted!")), + map.keys() + .map(|var| var.as_str().try_into().expect("variable corrupted!")), + ); + Counter::resolved_vars(limit, map).expect("counter creation failed!") } 2u8 => { let IdCounterKey { id, variables } = postcard::from_bytes(key).unwrap(); @@ -284,16 +308,16 @@ pub mod bin { .collect(); // we are not able to rebuild the full limit since we only have the id and variables. - let limit = Limit::with_id::<&str, &str, &str>( + let limit = Limit::with_id::<&str, &str>( id, "", u64::default(), 0, vec![], - map.keys(), - ) - .unwrap(); - Counter::new(limit, map) + map.keys() + .map(|var| var.as_str().try_into().expect("variable corrupted!")), + ); + Counter::resolved_vars(limit, map).expect("counter creation failed!") } _ => panic!("Unknown version: {}", version), } @@ -321,8 +345,18 @@ pub mod bin { .into_iter() .map(|(var, value)| (var.to_string(), value.to_string())) .collect(); - let limit = Limit::new(ns, u64::default(), seconds, conditions, map.keys()); - Counter::new(limit.unwrap(), map) + let limit = Limit::new( + ns, + u64::default(), + seconds, + conditions + .into_iter() + .map(|p| p.try_into().expect("condition corrupted!")) + .collect::>(), + map.keys() + .map(|p| p.as_str().try_into().expect("variable corrupted!")), + ); + Counter::resolved_vars(limit, map).unwrap() } #[cfg(test)] @@ -342,15 +376,21 @@ pub mod bin { namespace, 1, 2, - vec!["foo == 'bar'"], - vec!["app_id", "role", "wat"], - ) - .expect("This must be a valid limit!"); + vec!["foo == 'bar'".try_into().expect("failed parsing!")], + vec![ + "app_id".try_into().expect("failed parsing!"), + "role".try_into().expect("failed parsing!"), + "wat".try_into().expect("failed parsing!"), + ], + ); let mut vars = HashMap::default(); vars.insert("role".to_string(), "admin".to_string()); vars.insert("app_id".to_string(), "123".to_string()); vars.insert("wat".to_string(), "dunno".to_string()); - let counter = Counter::new(limit.clone(), vars); + let ctx = vars.into(); + let counter = Counter::new(limit.clone(), &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let raw = key_for_counter(&counter); let key_back: CounterKey = @@ -362,11 +402,19 @@ pub mod bin { #[test] fn counter_key_and_counter_are_symmetric() { let namespace = "ns_counter:"; - let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]) - .expect("This must be a valid limit!"); + let limit = Limit::new( + namespace, + 1, + 1, + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let mut variables = HashMap::default(); variables.insert("app_id".to_string(), "123".to_string()); - let counter = Counter::new(limit.clone(), variables); + let ctx = variables.into(); + let counter = Counter::new(limit.clone(), &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let raw = key_for_counter(&counter); assert_eq!(counter, partial_counter_from_counter_key(&raw)); } @@ -374,9 +422,18 @@ pub mod bin { #[test] fn counter_key_starts_with_namespace_prefix() { let namespace = "ns_counter:"; - let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]) - .expect("This must be a valid limit!"); - let counter = Counter::new(limit, HashMap::default()); + let limit = Limit::new( + namespace, + 1, + 1, + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); + let counter = Counter::new(limit, &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let serialized_counter = key_for_counter(&counter); let prefix = prefix_for_namespace(namespace); @@ -386,37 +443,46 @@ pub mod bin { #[test] fn counters_with_id() { let namespace = "ns_counter:"; - let limit_without_id = - Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]) - .expect("This must be a valid limit!"); + let limit_without_id = Limit::new( + namespace, + 1, + 1, + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let limit_with_id = Limit::with_id( "id200", namespace, 1, 1, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); - - let counter_with_id = Counter::new(limit_with_id, HashMap::default()); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); + + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); + let counter_with_id = Counter::new(limit_with_id, &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let serialized_with_id_counter = key_for_counter(&counter_with_id); - let counter_without_id = Counter::new(limit_without_id, HashMap::default()); + let counter_without_id = Counter::new(limit_without_id, &ctx) + .expect("counter creation failed!") + .expect("must have a counter"); let serialized_without_id_counter = key_for_counter(&counter_without_id); // the original key_for_counter continues to encode kinda big - assert_eq!(serialized_without_id_counter.len(), 35); - assert_eq!(serialized_with_id_counter.len(), 35); + assert_eq!(serialized_without_id_counter.len(), 46); + assert_eq!(serialized_with_id_counter.len(), 46); // serialized_counter_v2 will only encode the id.... so it will be smaller for // counters with an id. let serialized_counter_with_id_v2 = key_for_counter_v2(&counter_with_id); - assert_eq!(serialized_counter_with_id_v2.clone().len(), 8); + assert_eq!(serialized_counter_with_id_v2.clone().len(), 19); // but continues to be large for counters without an id. let serialized_counter_without_id_v2 = key_for_counter_v2(&counter_without_id); - assert_eq!(serialized_counter_without_id_v2.clone().len(), 36); + assert_eq!(serialized_counter_without_id_v2.clone().len(), 47); } } } diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 52e87136..49c83e45 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -569,7 +569,8 @@ mod tests { .consume(1, |items| { assert_eq!(items.len(), 1); assert!( - SystemTime::now().duration_since(start).unwrap() < Duration::from_millis(5) + SystemTime::now().duration_since(start).unwrap() + < Duration::from_millis(10) ); async { Ok::<(), ()>(()) } }) @@ -680,11 +681,12 @@ mod tests { "test_namespace", max_val, 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), - values, + vec!["req_method == 'POST'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), + &values.into(), ) + .expect("failed creating counter") + .expect("Should have a counter") } } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 8f1b58dc..eda05161 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -439,17 +439,20 @@ mod tests { const LOCAL_INCREMENTS: u64 = 2; let mut counters_and_deltas = HashMap::new(); + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); let counter = Counter::new( Limit::new( "test_namespace", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), - Default::default(), - ); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), + &ctx, + ) + .expect("counter creation failed!") + .expect("must have a counter"); let arc = Arc::new(CachedCounterValue::from_authority( &counter, @@ -502,17 +505,20 @@ mod tests { #[tokio::test] async fn flush_batcher_and_update_counters_test() { + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); let counter = Counter::new( Limit::new( "test_namespace", 10, 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), - Default::default(), - ); + vec!["req_method == 'POST'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), + &ctx, + ) + .expect("counter creation failed!") + .expect("must have a counter"); let mock_response = Value::Array(vec![ Value::Int(8), @@ -562,17 +568,20 @@ mod tests { #[tokio::test] async fn flush_batcher_reverts_on_err() { + let map = HashMap::from([("app_id".to_string(), "foo".to_string())]); + let ctx = map.into(); let counter = Counter::new( Limit::new( "test_namespace", 10, 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), - Default::default(), - ); + vec!["req_method == 'POST'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), + &ctx, + ) + .expect("counter creation failed!") + .expect("must have a counter"); let error: RedisError = io::Error::new(io::ErrorKind::TimedOut, "That was long!").into(); assert!(error.is_timeout()); diff --git a/limitador/tests/helpers/tests_limiter.rs b/limitador/tests/helpers/tests_limiter.rs index 2bae0c3e..89e631b1 100644 --- a/limitador/tests/helpers/tests_limiter.rs +++ b/limitador/tests/helpers/tests_limiter.rs @@ -1,8 +1,8 @@ use limitador::counter::Counter; use limitador::errors::LimitadorError; -use limitador::limit::{Limit, Namespace}; +use limitador::limit::{Context, Limit, Namespace}; use limitador::{AsyncRateLimiter, CheckResult, RateLimiter}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; // This exposes a struct that wraps both implementations of the rate limiter, // the blocking and the async one. This allows us to avoid duplications in the @@ -70,17 +70,15 @@ impl TestsLimiter { pub async fn is_rate_limited( &self, namespace: &str, - values: &HashMap, + ctx: &Context<'_>, delta: u64, ) -> Result { match &self.limiter_impl { LimiterImpl::Blocking(limiter) => { - limiter.is_rate_limited(&namespace.into(), values, delta) + limiter.is_rate_limited(&namespace.into(), ctx, delta) } LimiterImpl::Async(limiter) => { - limiter - .is_rate_limited(&namespace.into(), values, delta) - .await + limiter.is_rate_limited(&namespace.into(), ctx, delta).await } } } @@ -88,17 +86,15 @@ impl TestsLimiter { pub async fn update_counters( &self, namespace: &str, - values: &HashMap, + ctx: &Context<'_>, delta: u64, ) -> Result<(), LimitadorError> { match &self.limiter_impl { LimiterImpl::Blocking(limiter) => { - limiter.update_counters(&namespace.into(), values, delta) + limiter.update_counters(&namespace.into(), ctx, delta) } LimiterImpl::Async(limiter) => { - limiter - .update_counters(&namespace.into(), values, delta) - .await + limiter.update_counters(&namespace.into(), ctx, delta).await } } } @@ -106,20 +102,17 @@ impl TestsLimiter { pub async fn check_rate_limited_and_update( &self, namespace: &str, - values: &HashMap, + ctx: &Context<'_>, delta: u64, load_counters: bool, ) -> Result { match &self.limiter_impl { - LimiterImpl::Blocking(limiter) => limiter.check_rate_limited_and_update( - &namespace.into(), - values, - delta, - load_counters, - ), + LimiterImpl::Blocking(limiter) => { + limiter.check_rate_limited_and_update(&namespace.into(), ctx, delta, load_counters) + } LimiterImpl::Async(limiter) => { limiter - .check_rate_limited_and_update(&namespace.into(), values, delta, load_counters) + .check_rate_limited_and_update(&namespace.into(), ctx, delta, load_counters) .await } } diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index be458d22..a19eb756 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -220,18 +220,16 @@ mod test { "first_namespace", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), Limit::new( "second_namespace", 20, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), ]; for limit in limits { @@ -254,19 +252,17 @@ mod test { "first_namespace", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let lim2 = Limit::new( "second_namespace", 20, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); for limit in [&lim1, &lim2] { rate_limiter.add_limit(limit).await; @@ -288,10 +284,9 @@ mod test { "test_namespace", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; @@ -309,10 +304,9 @@ mod test { "test_namespace", 10, 60, - vec!["req.method == 'GET'"], - Vec::::new(), - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + Vec::default(), + ); rate_limiter.add_limit(&limit).await; @@ -332,19 +326,17 @@ mod test { namespace, 10, 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'POST'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let limit_2 = Limit::new( namespace, 5, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit_1).await; rate_limiter.add_limit(&limit_2).await; @@ -361,10 +353,9 @@ mod test { "test_namespace", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; @@ -379,18 +370,17 @@ mod test { namespace, 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &values.into(), 1) .await .unwrap(); @@ -415,18 +405,16 @@ mod test { namespace, 10, 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), + vec!["req_method == 'POST'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), Limit::new( namespace, 5, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), ]; for limit in limits.iter() { @@ -445,16 +433,22 @@ mod test { let namespace2 = "test_namespace_2"; rate_limiter - .add_limit( - &Limit::new(namespace1, 10, 60, vec!["x == '10'"], vec!["z"]) - .expect("This must be a valid limit!"), - ) + .add_limit(&Limit::new( + namespace1, + 10, + 60, + vec!["x == '10'".try_into().expect("failed parsing!")], + vec!["z".try_into().expect("failed parsing!")], + )) .await; rate_limiter - .add_limit( - &Limit::new(namespace2, 5, 60, vec!["x == '10'"], vec!["z"]) - .expect("This must be a valid limit!"), - ) + .add_limit(&Limit::new( + namespace2, + 5, + 60, + vec!["x == '10'".try_into().expect("failed parsing!")], + vec!["z".try_into().expect("failed parsing!")], + )) .await; rate_limiter.delete_limits(namespace1).await.unwrap(); @@ -469,18 +463,17 @@ mod test { namespace, 5, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &values.into(), 1) .await .unwrap(); @@ -504,32 +497,32 @@ mod test { namespace, max_hits, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); for i in 0..max_hits { assert!( !rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); } assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -542,32 +535,32 @@ mod test { namespace, max_hits, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); for i in 0..max_hits { assert!( !rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); } assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -580,18 +573,16 @@ mod test { namespace, max_hits, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), Limit::new( namespace, max_hits + 1, 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"), + vec!["req_method == 'POST'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ), ]; for limit in limits { @@ -599,34 +590,36 @@ mod test { } let mut get_values: HashMap = HashMap::new(); - get_values.insert("req.method".to_string(), "GET".to_string()); + get_values.insert("req_method".to_string(), "GET".to_string()); get_values.insert("app_id".to_string(), "test_app_id".to_string()); + let get_ctx = get_values.into(); let mut post_values: HashMap = HashMap::new(); - post_values.insert("req.method".to_string(), "POST".to_string()); + post_values.insert("req_method".to_string(), "POST".to_string()); post_values.insert("app_id".to_string(), "test_app_id".to_string()); + let post_ctx = post_values.into(); for i in 0..max_hits { assert!( !rate_limiter - .is_rate_limited(namespace, &get_values, 1) + .is_rate_limited(namespace, &get_ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); assert!( !rate_limiter - .is_rate_limited(namespace, &post_values, 1) + .is_rate_limited(namespace, &post_ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .check_rate_limited_and_update(namespace, &get_values, 1, false) + .check_rate_limited_and_update(namespace, &get_ctx, 1, false) .await .unwrap(); rate_limiter - .check_rate_limited_and_update(namespace, &post_values, 1, false) + .check_rate_limited_and_update(namespace, &post_ctx, 1, false) .await .unwrap(); } @@ -635,11 +628,11 @@ mod test { tokio::time::sleep(Duration::from_millis(40)).await; assert!(rate_limiter - .is_rate_limited(namespace, &get_values, 1) + .is_rate_limited(namespace, &get_ctx, 1) .await .unwrap()); assert!(!rate_limiter - .is_rate_limited(namespace, &post_values, 1) + .is_rate_limited(namespace, &post_ctx, 1) .await .unwrap()); } @@ -650,31 +643,31 @@ mod test { namespace, 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); // Report 5 hits twice. The limit is 10, so the first limited call should be // the third one. for _ in 0..2 { assert!(!rate_limiter - .is_rate_limited(namespace, &values, 5) + .is_rate_limited(namespace, &ctx, 5) .await .unwrap()); rate_limiter - .update_counters(namespace, &values, 5) + .update_counters(namespace, &ctx, 5) .await .unwrap(); } assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -686,19 +679,19 @@ mod test { namespace, max, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); assert!(rate_limiter - .is_rate_limited(namespace, &values, max + 1) + .is_rate_limited(namespace, &ctx, max + 1) .await .unwrap()) } @@ -710,36 +703,38 @@ mod test { namespace, max_hits, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); for i in 0..max_hits { // Add an extra value that does not apply to the limit on each // iteration. It should not affect. + let mut values = values.clone(); values.insert("does_not_apply".to_string(), i.to_string()); + let ctx = values.into(); assert!( !rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); } + let ctx = values.into(); assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -748,10 +743,11 @@ mod test { rate_limiter: &mut TestsLimiter, ) { let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); + let ctx = values.into(); assert!(!rate_limiter - .is_rate_limited("test_namespace", &values, 1) + .is_rate_limited("test_namespace", &ctx, 1) .await .unwrap()); } @@ -765,20 +761,20 @@ mod test { namespace, 0, // So reporting 1 more would not be allowed 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; // Notice that does not match because the method is "POST". let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "POST".to_string()); + values.insert("req_method".to_string(), "POST".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); assert!(!rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -790,18 +786,18 @@ mod test { namespace, 0, // So reporting 1 more would not be allowed 60, - Vec::::new(), // unconditional - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + Vec::default(), // unconditional + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); assert!(rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap()); } @@ -814,21 +810,21 @@ mod test { namespace, max_hits, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); for _ in 0..max_hits { assert!( !rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, false) + .check_rate_limited_and_update(namespace, &ctx, 1, false) .await .unwrap() .limited @@ -837,7 +833,7 @@ mod test { assert!( rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, false) + .check_rate_limited_and_update(namespace, &ctx, 1, false) .await .unwrap() .limited @@ -852,20 +848,20 @@ mod test { namespace, max_hits, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); for hit in 0..max_hits { let result = rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, true) + .check_rate_limited_and_update(namespace, &ctx, 1, true) .await .unwrap(); assert!(!result.limited); @@ -880,7 +876,7 @@ mod test { } let result = rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, true) + .check_rate_limited_and_update(namespace, &ctx, 1, true) .await .unwrap(); assert!(result.limited); @@ -903,21 +899,21 @@ mod test { namespace, 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); values.insert("app_id".to_string(), "test_app_id".to_string()); // Does not match the limit defined - values.insert("req.method".to_string(), "POST".to_string()); + values.insert("req_method".to_string(), "POST".to_string()); + let ctx = values.into(); assert!( !rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, false) + .check_rate_limited_and_update(namespace, &ctx, 1, false) .await .unwrap() .limited @@ -933,19 +929,19 @@ mod test { namespace, 0, // So reporting 1 more would not be allowed 60, - Vec::::new(), // unconditional - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + Vec::default(), // unconditional + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values: HashMap = HashMap::new(); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); assert!( rate_limiter - .check_rate_limited_and_update(namespace, &values, 1, false) + .check_rate_limited_and_update(namespace, &ctx, 1, false) .await .unwrap() .limited @@ -962,24 +958,27 @@ mod test { namespace, max_hits, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); + let ctx = values.into(); rate_limiter - .update_counters(namespace, &values, hits_app_1) + .update_counters(namespace, &ctx, hits_app_1) .await .unwrap(); + let mut values = HashMap::new(); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "2".to_string()); + let ctx = values.into(); rate_limiter - .update_counters(namespace, &values, hits_app_2) + .update_counters(namespace, &ctx, hits_app_2) .await .unwrap(); @@ -1020,10 +1019,9 @@ mod test { "test_namespace", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; @@ -1042,18 +1040,18 @@ mod test { namespace, 10, limit_time, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); + let ctx = values.into(); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); @@ -1068,19 +1066,17 @@ mod test { "first_namespace", 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let second_limit = Limit::new( "second_namespace", 20, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter .configure_with(vec![first_limit.clone(), second_limit.clone()]) @@ -1109,18 +1105,18 @@ mod test { namespace, max_value, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit).await; let mut values = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "1".to_string()); + let ctx = values.into(); rate_limiter - .update_counters(namespace, &values, hits_to_report) + .update_counters(namespace, &ctx, hits_to_report) .await .unwrap(); @@ -1149,19 +1145,17 @@ mod test { namespace, 10, 1, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let limit_to_be_deleted = Limit::new( namespace, 20, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); for limit in [&limit_to_be_kept, &limit_to_be_deleted].iter() { rate_limiter.add_limit(limit).await; @@ -1185,19 +1179,17 @@ mod test { namespace, 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let limit_update = Limit::new( namespace, 20, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); rate_limiter.add_limit(&limit_orig).await; @@ -1219,28 +1211,25 @@ mod test { namespace, 10, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let limit_2 = Limit::new( namespace, 20, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); let mut limit_3 = Limit::new( namespace, 20, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); limit_3.set_name("Name is irrelevant too".to_owned()); assert!(rate_limiter.add_limit(&limit_1).await); @@ -1267,31 +1256,31 @@ mod test { namespace, max_hits, 60, - vec!["req.method == 'GET'"], - vec!["app_id"], - ) - .expect("This must be a valid limit!"); + vec!["req_method == 'GET'".try_into().expect("failed parsing!")], + vec!["app_id".try_into().expect("failed parsing!")], + ); for rate_limiter in rate_limiters.iter() { rate_limiter.add_limit(&limit).await; } let mut values: HashMap = HashMap::new(); - values.insert("req.method".to_string(), "GET".to_string()); + values.insert("req_method".to_string(), "GET".to_string()); values.insert("app_id".to_string(), "test_app_id".to_string()); + let ctx = values.into(); for i in 0..max_hits { // Alternate between the two rate limiters let rate_limiter = rate_limiters.get((i % 2) as usize).unwrap(); assert!( !rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap(), "Must not be limited after {i}" ); rate_limiter - .update_counters(namespace, &values, 1) + .update_counters(namespace, &ctx, 1) .await .unwrap(); } @@ -1303,7 +1292,7 @@ mod test { || async { let rate_limiter = rate_limiters.first().unwrap(); rate_limiter - .is_rate_limited(namespace, &values, 1) + .is_rate_limited(namespace, &ctx, 1) .await .unwrap() }