diff --git a/src/lib.rs b/src/lib.rs index 0458bc7..954a912 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ pub struct RuntimeOptions pub static mut OPTIONS: RuntimeOptions = RuntimeOptions { debug: false, debug_timestamp: false }; -pub fn debug(msg: String, context: Option) +pub fn debug(msg: String, context: Option<&str>) { unsafe { if OPTIONS.debug == false { return } } diff --git a/src/server/api/stats.rs b/src/server/api/stats.rs index a5b401c..3912bd1 100644 --- a/src/server/api/stats.rs +++ b/src/server/api/stats.rs @@ -85,14 +85,14 @@ impl ApiRequest for StatsDigest Ok(p) => p, Err(e) => { - crate::debug(format!("{} deserialising POST payload",e), Some("Stats Digest".to_string())); + crate::debug(format!("{} deserialising POST payload",e), Some("Stats Digest")); return StatusCode::BAD_REQUEST } } } Err(e) => { - crate::debug(format!("{} deserialising POST payload",e), Some("Stats Digest".to_string())); + crate::debug(format!("{} deserialising POST payload",e), Some("Stats Digest")); return StatusCode::BAD_REQUEST } }; diff --git a/src/server/https.rs b/src/server/https.rs index 822fa9c..b996d08 100644 --- a/src/server/https.rs +++ b/src/server/https.rs @@ -108,7 +108,7 @@ impl Server StatsSaveTask::new ( stats.clone(), - schedule_from_option(config.stats.digest_schedule.clone()) + schedule_from_option(config.stats.save_schedule.clone()) ) ) ); diff --git a/src/server/stats/hits.rs b/src/server/stats/hits.rs index 2e2353b..ee451aa 100644 --- a/src/server/stats/hits.rs +++ b/src/server/stats/hits.rs @@ -123,7 +123,7 @@ pub async fn process_hit "\nTotal stats time: {} s (Passthrough)\nCompute stats time: {} s (Passthrough)", start_time.elapsed().as_secs_f64(), compute_start_time.elapsed().as_secs_f64() - ), Some("PERFORMANCE".to_string())); + ), Some("PERFORMANCE")); return } @@ -141,7 +141,7 @@ pub async fn process_hit } }; - crate::debug(format!("{:?}", hit), Some("Statistics".to_string())); + crate::debug(format!("{:?}", hit), Some("Statistics")); stats.hits.insert(hash, hit); @@ -150,7 +150,7 @@ pub async fn process_hit "\nTotal stats time: {} s\nCompute stats time: {} s", start_time.elapsed().as_secs_f64(), compute_start_time.elapsed().as_secs_f64() - ), Some("PERFORMANCE".to_string())); + ), Some("PERFORMANCE")); } /// Gathers [Hit]s both from disk and those cached in [HitStats] diff --git a/src/server/stats/mod.rs b/src/server/stats/mod.rs index 4084885..5be60ed 100644 --- a/src/server/stats/mod.rs +++ b/src/server/stats/mod.rs @@ -1,4 +1,4 @@ -use std::{fs::create_dir, sync::Arc}; +use std::{collections::HashMap, fs::create_dir, sync::Arc}; use axum::async_trait; use chrono::{DateTime, Utc}; @@ -13,7 +13,8 @@ pub mod hits; pub mod digest; pub mod file; -/// A task to periodically save HitStats to disk +/// A task to periodically save HitStats to disk, clearing +/// the HitStats memory. /// See [crate::task::Task] and [crate::task::TaskPool] pub struct StatsSaveTask { @@ -48,7 +49,7 @@ impl Task for StatsSaveTask { let config = Config::load_or_default(CONFIG_PATH); { - let stats = self.state.lock().await; + let mut stats = self.state.lock().await; if !std::path::Path::new(&config.stats.path).exists() { @@ -62,6 +63,7 @@ impl Task for StatsSaveTask let mut file = StatsFile::new(); file.load(&stats); file.write_bytes(); + stats.hits = HashMap::new(); } self.schedule = schedule_from_option(config.stats.save_schedule.clone()); diff --git a/src/server/throttle.rs b/src/server/throttle.rs index 8a65d75..d73dbfb 100644 --- a/src/server/throttle.rs +++ b/src/server/throttle.rs @@ -19,14 +19,21 @@ pub struct Request hash: [u8; 64] } +/// sha512 an ip and uri impl Request { pub fn new(ip: Ipv4Addr, uri: &str) -> Request { Request { hash: sha512(&[uri.as_bytes(), &ip.octets()].concat()) } } + + pub fn hash(&self) -> [u8; 64] + { + return self.hash + } } +/// Represent a unique [Request] (ip+uri hash) repeated count times pub struct RequestData { count: u32, @@ -42,6 +49,7 @@ impl RequestData } } +/// Detect repeated [Request]s and reflect if block for [IpThrottler::timeout_millis] pub struct IpThrottler { requests_from: HashMap, @@ -58,22 +66,26 @@ impl IpThrottler IpThrottler { requests_from: HashMap::new(), - max_requests_per_second: max_requests_per_second, - timeout_millis: timeout_millis, + max_requests_per_second, + timeout_millis, clear_period: Duration::from_secs(clear_period_seconds), last_clear: Instant::now() } } + /// Free hashmap (= HashMap::new()) if [IpThrottler::clear_period] has elapsed pub fn check_clear(&mut self) { if self.last_clear.elapsed() > self.clear_period { - self.requests_from.clear(); + self.requests_from = HashMap::new(); self.last_clear = Instant::now(); } } + /// Record hit counts for unique [Request]s over a time window of + /// [IpThrottler::clear_period]s. If more than [IpThrottler::max_requests_per_second] + /// the [Request] is marked as in [RequestData::timeout] for [IpThrottler::timeout_millis]ms. pub fn is_limited(&mut self, addr: SocketAddr, uri: &str) -> bool { let ip = addr.ip(); @@ -86,17 +98,15 @@ impl IpThrottler } let request = Request::new(ipv4, uri); - - println!("{:?}", request); let requests = if self.requests_from.contains_key(&request) { - self.requests_from[&request].clone() + &self.requests_from[&request] } else { self.requests_from.insert(request.clone(), RequestData {count: 0 as u32, last_request_time: Instant::now(), timeout: false}); - self.requests_from[&request].clone() + &self.requests_from[&request] }; let time = requests.last_request_time.elapsed().as_millis(); @@ -128,6 +138,8 @@ impl IpThrottler } } +/// Reflects any [Request]s in timeout (see [IpThrottler::is_limited]) as +/// [StatusCode::TOO_MANY_REQUESTS]. pub async fn handle_throttle ( ConnectInfo(addr): ConnectInfo, @@ -142,15 +154,15 @@ pub async fn handle_throttle throttler.check_clear(); if throttler.is_limited(addr, &request.uri().to_string()) { - crate::debug(format!("Denying: {} @/{}", addr, request.uri().to_string()), None); - crate::debug(format!("Serve time: {} s", serve_start.elapsed().as_secs_f64()), Some("PERFORMANCE".to_string())); + crate::debug(format!("Denying: {} @/{}", addr, request.uri().to_string()), Some("THROTTLE")); + crate::debug(format!("Serve time: {} s", serve_start.elapsed().as_secs_f64()), Some("PERFORMANCE")); Err(StatusCode::TOO_MANY_REQUESTS) } else { crate::debug(format!("Allowing: {} @/{}", addr, request.uri().to_string()), None); let response = next.run(request).await; - crate::debug(format!("Serve time: {} s", serve_start.elapsed().as_secs_f64()), Some("PERFORMANCE".to_string())); + crate::debug(format!("Serve time: {} s", serve_start.elapsed().as_secs_f64()), Some("PERFORMANCE")); Ok(response) } } diff --git a/tests/test_throttler.rs b/tests/test_throttler.rs new file mode 100644 index 0000000..a627a0a --- /dev/null +++ b/tests/test_throttler.rs @@ -0,0 +1,60 @@ +mod common; + +#[cfg(test)] +mod test_throttle +{ + use std::net::{Ipv4Addr, SocketAddr}; + + use busser::server::throttle::{IpThrottler, Request}; + use openssl::sha::sha512; + + #[test] + pub fn test_request() + { + let r1 = Request::new + ( + Ipv4Addr::new(127, 0, 0, 0), + "/index.html" + ); + + let r2 = Request::new + ( + Ipv4Addr::new(127, 0, 0, 0), + "/page.html" + ); + + let r3 = Request::new + ( + Ipv4Addr::new(127, 1, 1, 1), + "/index.html" + ); + + assert_ne!(r1, r2); + assert_ne!(r1, r3); + assert_ne!(r2, r3); + + assert_eq!(r1.hash(), sha512(&["/index.html".as_bytes(), &Ipv4Addr::new(127, 0, 0, 0).octets()].concat())); + assert_eq!(r2.hash(), sha512(&["/page.html".as_bytes(), &Ipv4Addr::new(127, 0, 0, 0).octets()].concat())); + assert_eq!(r3.hash(), sha512(&["/index.html".as_bytes(), &Ipv4Addr::new(127, 1, 1, 1).octets()].concat())); + } + + #[test] + pub fn test_throttler() + { + let mut throttle = IpThrottler::new(1e-9, 5000, 3600); + let ip = Ipv4Addr::new(127, 0, 0, 0); + let path = "/index.html"; + assert_eq!(throttle.is_limited(SocketAddr::new(std::net::IpAddr::V4(ip), 80), path), false); + assert_eq!(throttle.is_limited(SocketAddr::new(std::net::IpAddr::V4(ip), 80), path), true); + throttle.check_clear(); + assert_eq!(throttle.is_limited(SocketAddr::new(std::net::IpAddr::V4(ip), 80), path), true); + + let mut throttle = IpThrottler::new(1e-9, 5000, 0); + let ip = Ipv4Addr::new(127, 0, 0, 0); + let path = "/index.html"; + assert_eq!(throttle.is_limited(SocketAddr::new(std::net::IpAddr::V4(ip), 80), path), false); + assert_eq!(throttle.is_limited(SocketAddr::new(std::net::IpAddr::V4(ip), 80), path), true); + throttle.check_clear(); + assert_eq!(throttle.is_limited(SocketAddr::new(std::net::IpAddr::V4(ip), 80), path), false); + } +} \ No newline at end of file