Skip to content

Commit

Permalink
use hash of ip and route
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerboa-app committed Feb 25, 2024
1 parent d944dd2 commit 8da8964
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions src/web/throttle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,49 @@ use std::collections::HashMap;
use std::net::{SocketAddr, Ipv4Addr, IpAddr};
use std::time::{Instant, Duration};
use std::sync::Arc;
use openssl::sha::{self, sha512};
use tokio::sync::Mutex;

use axum::
{
http::{Request, StatusCode},
http::{self, StatusCode},
response::Response,
extract::{State, ConnectInfo},
middleware::Next
};

pub struct Requests
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct Request
{
hash: [u8; 64]
}

impl Request
{
pub fn new(ip: Ipv4Addr, uri: &str) -> Request
{
Request { hash: sha512(&[uri.as_bytes(), &ip.octets()].concat()) }
}
}

pub struct RequestData
{
count: u32,
last_request_time: Instant,
timeout: bool
}

impl Requests
impl RequestData
{
pub fn clone(&self) -> Requests
pub fn clone(&self) -> RequestData
{
Requests { count: self.count.clone(), last_request_time: self.last_request_time.clone(), timeout: false }
RequestData { count: self.count.clone(), last_request_time: self.last_request_time.clone(), timeout: false }
}
}

pub struct IpThrottler
{
requests_from: HashMap<Ipv4Addr, Requests>,
requests_from: HashMap<Request, RequestData>,
max_requests_per_second: f64,
timeout_millis: u128,
clear_period: Duration,
Expand Down Expand Up @@ -59,7 +74,7 @@ impl IpThrottler
}
}

pub fn is_limited(&mut self, addr: SocketAddr) -> bool
pub fn is_limited(&mut self, addr: SocketAddr, uri: &str) -> bool
{
let ip = addr.ip();
let ipv4: Ipv4Addr;
Expand All @@ -69,15 +84,19 @@ impl IpThrottler
IpAddr::V4(ip4) => {ipv4 = ip4}
IpAddr::V6(_ip6) => {return true}
}

let request = Request::new(ipv4, uri);

println!("{:?}", request);

let requests = if self.requests_from.contains_key(&ipv4)
let requests = if self.requests_from.contains_key(&request)
{
self.requests_from[&ipv4].clone()
self.requests_from[&request].clone()
}
else
{
self.requests_from.insert(ipv4, Requests {count: 0 as u32, last_request_time: Instant::now(), timeout: false});
self.requests_from[&ipv4].clone()
self.requests_from.insert(request.clone(), RequestData {count: 0 as u32, last_request_time: Instant::now(), timeout: false});
self.requests_from[&request].clone()
};

let time = requests.last_request_time.elapsed().as_millis();
Expand All @@ -87,23 +106,23 @@ impl IpThrottler
{
if time < self.timeout_millis
{
*self.requests_from.get_mut(&ipv4).unwrap() = Requests {count: requests.count, last_request_time: requests.last_request_time, timeout: true};
*self.requests_from.get_mut(&request).unwrap() = RequestData {count: requests.count, last_request_time: requests.last_request_time, timeout: true};
return true
}
else
{
*self.requests_from.get_mut(&ipv4).unwrap() = Requests {count: 0, last_request_time: Instant::now(), timeout: false};
*self.requests_from.get_mut(&request).unwrap() = RequestData {count: 0, last_request_time: Instant::now(), timeout: false};
return false
}
}

if time < 1000
{
*self.requests_from.get_mut(&ipv4).unwrap() = Requests {count: requests.count+1, last_request_time: requests.last_request_time, timeout: false};
*self.requests_from.get_mut(&request).unwrap() = RequestData {count: requests.count+1, last_request_time: requests.last_request_time, timeout: false};
}
else
{
*self.requests_from.get_mut(&ipv4).unwrap() = Requests {count: 0, last_request_time: Instant::now(), timeout: false};
*self.requests_from.get_mut(&request).unwrap() = RequestData {count: 0, last_request_time: Instant::now(), timeout: false};
}
return false
}
Expand All @@ -113,15 +132,15 @@ pub async fn handle_throttle<B>
(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<Arc<Mutex<IpThrottler>>>,
request: Request<B>,
request: http::Request<B>,
next: Next<B>
) -> Result<Response, StatusCode>
{
let serve_start = Instant::now();
{
let mut throttler = state.lock().await;
throttler.check_clear();
if throttler.is_limited(addr)
if throttler.is_limited(addr, &request.uri().to_string())
{
crate::debug(format!("Denying: {} @/{}", addr, request.uri().to_string()), None);
crate::debug(format!("Serve time: {} s", serve_start.elapsed().as_secs_f64()), Some("PERFORMANCE".to_string()));
Expand Down

0 comments on commit 8da8964

Please sign in to comment.