Skip to content

Commit

Permalink
Add rate limiter config and integration test
Browse files Browse the repository at this point in the history
* Make period optional, default to 1sec.
* Reject configs with period < 100ms since we can't
  really guarantee correctness at lower resolutions.

Work on #5
  • Loading branch information
iffyio committed Aug 11, 2020
1 parent 8b488a4 commit 856e1d5
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ edition = "2018"

[dependencies]
clap = "2.33.0"
humantime-serde = "1.0.0"
prometheus = { version = "0.9", default-features = false }
serde = { version = "1.0.104", features = ["derive"] }
serde_yaml = "0.8.11"
Expand Down
2 changes: 2 additions & 0 deletions src/extensions/filter_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub trait Filter: Send + Sync {
pub enum Error {
NotFound(String),
FieldInvalid { field: String, reason: String },
DeserializeFailed(String),
}

impl fmt::Display for Error {
Expand All @@ -78,6 +79,7 @@ impl fmt::Display for Error {
Error::FieldInvalid { field, reason } => {
write!(f, "field {} is invalid: {}", field, reason)
}
Error::DeserializeFailed(reason) => write!(f, "{}", reason),
}
}
}
Expand Down
62 changes: 45 additions & 17 deletions src/extensions/filters/local_rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,71 @@
*/

use crate::config::EndPoint;
use crate::extensions::Filter;
use crate::extensions::{Error, Filter, FilterFactory};
use serde::{Deserialize, Serialize};
use serde_yaml::Value;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot::{channel, Sender};
use tokio::time::{self, Instant};

#[allow(dead_code)]
/// Config represents a RateLimitFilter's configuration.
#[derive(Serialize, Deserialize, Debug)]
struct Config {
/// max_packets is the maximum number of packets allowed
/// to be forwarded by the rate limiter in a given duration.
max_packets: usize,
/// period is the duration during which max_packets applies.
period: Duration,
/// If none is provided, it defaults to 1 second.
#[serde(with = "humantime_serde")]
period: Option<Duration>,
}

#[allow(dead_code)]
/// RateLimitFilter is a filter that implements packet rate limiting
/// based on the token-bucket algorithm.
/// Creates instances of RateLimitFilter.
pub struct RateLimitFilterFactory;

impl RateLimitFilterFactory {
pub fn new() -> Self {
RateLimitFilterFactory {}
}
}

/// A filter that implements rate limiting on packets based on
/// the token-bucket algorithm.
/// Packets that violate the rate limit are dropped.
/// It only applies rate limiting on packets that are destined for the
/// proxy's endpoints. All other packets flow through the filter untouched.
pub struct RateLimitFilter {
/// max_tokens is the number of tokens in a full bucket.
max_tokens: usize,
struct RateLimitFilter {
/// available_tokens is how many tokens are left in the bucket any
/// any given moment.
available_tokens: Arc<AtomicUsize>,
/// shutdown_tx signals the spawned token refill future to exit.
shutdown_tx: Option<Sender<()>>,
}

impl FilterFactory for RateLimitFilterFactory {
fn name(&self) -> String {
"quilkin.extensions.filters.local_rate_limit.alphav1.LocalRateLimit".into()
}

fn create_from_config(&self, config_value: &Value) -> Result<Box<dyn Filter>, Error> {
let config: Config = serde_yaml::to_string(config_value)
.and_then(|raw_config| serde_yaml::from_str(raw_config.as_str()))
.map_err(|err| Error::DeserializeFailed(err.to_string()))?;

match config.period {
Some(period) if period.lt(&Duration::from_millis(100)) => Err(Error::FieldInvalid {
field: "period".into(),
reason: format!("value must be at least 100ms"),
}),
_ => Ok(Box::new(RateLimitFilter::new(config))),
}
}
}

impl RateLimitFilter {
#[allow(dead_code)]
/// new returns a new RateLimitFilter. It spawns a future in the background
/// that periodically refills the rate limiter's tokens.
fn new(config: Config) -> Self {
Expand All @@ -59,7 +88,7 @@ impl RateLimitFilter {
let tokens = Arc::new(AtomicUsize::new(config.max_packets));

let max_tokens = config.max_packets;
let period = config.period;
let period = config.period.unwrap_or(Duration::from_secs(1));
let available_tokens = tokens.clone();
let _ = tokio::spawn(async move {
let mut interval = time::interval_at(Instant::now() + period, period);
Expand All @@ -85,7 +114,6 @@ impl RateLimitFilter {
});

RateLimitFilter {
max_tokens: config.max_packets,
available_tokens: tokens,
shutdown_tx: Some(shutdown_tx),
}
Expand Down Expand Up @@ -168,7 +196,7 @@ mod tests {
// Test that we always start with the max number of tokens available.
let r = RateLimitFilter::new(Config {
max_packets: 3,
period: Duration::from_millis(100),
period: Some(Duration::from_millis(100)),
});

assert_eq!(r.acquire_token(), Some(()));
Expand All @@ -181,7 +209,7 @@ mod tests {
async fn token_exhaustion_and_refill() {
let r = RateLimitFilter::new(Config {
max_packets: 2,
period: Duration::from_millis(100),
period: Some(Duration::from_millis(100)),
});

// Exhaust tokens
Expand All @@ -204,7 +232,7 @@ mod tests {

let r = RateLimitFilter::new(Config {
max_packets: 3,
period: Duration::from_millis(100),
period: Some(Duration::from_millis(100)),
});

// Use up some of the tokens.
Expand All @@ -224,7 +252,7 @@ mod tests {
async fn filter_with_no_available_tokens() {
let r = RateLimitFilter::new(Config {
max_packets: 0,
period: Duration::from_millis(100),
period: Some(Duration::from_millis(100)),
});

// Check that other routes are not affected.
Expand Down Expand Up @@ -260,7 +288,7 @@ mod tests {
async fn filter_with_available_tokens() {
let r = RateLimitFilter::new(Config {
max_packets: 1,
period: Duration::from_millis(100),
period: Some(Duration::from_millis(100)),
});

assert_eq!(
Expand Down
1 change: 1 addition & 0 deletions src/extensions/filters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

pub use debug_filter::{DebugFilter, DebugFilterFactory};
pub use local_rate_limit::RateLimitFilterFactory;

mod debug_filter;
mod local_rate_limit;
1 change: 1 addition & 0 deletions src/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ mod filter_chain;
pub fn default_registry(base: &Logger) -> FilterRegistry {
let mut fr = FilterRegistry::new();
fr.insert(filters::DebugFilterFactory::new(base));
fr.insert(filters::RateLimitFilterFactory::new());
fr
}
1 change: 0 additions & 1 deletion src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ impl Server {
"origin" => packet.dest(),
"contents" => String::from_utf8(packet.contents().clone()).unwrap(),
);

if let Some(data) =
chain.local_send_filter(packet.dest(), packet.contents().to_vec())
{
Expand Down
8 changes: 5 additions & 3 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ pub async fn echo_server() -> SocketAddr {
let mut socket = ephemeral_socket().await;
let addr = socket.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = vec![0; 1024];
let (size, sender) = socket.recv_from(&mut buf).await.unwrap();
socket.send_to(&buf[..size], sender).await.unwrap();
loop {
let mut buf = vec![0; 1024];
let (size, sender) = socket.recv_from(&mut buf).await.unwrap();
socket.send_to(&buf[..size], sender).await.unwrap();
}
});
addr
}
Expand Down
77 changes: 77 additions & 0 deletions tests/local_rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2020 Google LLC All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

extern crate quilkin;

#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

use quilkin::config::{Config, ConnectionConfig, EndPoint, Filter, Local};
use quilkin::extensions::filters::RateLimitFilterFactory;
use quilkin::extensions::{default_registry, FilterFactory};
use quilkin::test_utils::{echo_server, logger, recv_multiple_packets, run_proxy};

#[tokio::test]
async fn local_rate_limit_filter() {
let base_logger = logger();

let yaml = "
max_packets: 2
period: 1s
";
let echo = echo_server().await;

let server_port = 12346;
let server_config = Config {
local: Local { port: server_port },
filters: vec![Filter {
name: RateLimitFilterFactory::new().name(),
config: serde_yaml::from_str(yaml).unwrap(),
}],
connections: ConnectionConfig::Server {
endpoints: vec![EndPoint {
name: "server".to_string(),
address: echo,
connection_ids: vec![],
}],
},
};

let close_server = run_proxy(&base_logger, default_registry(&base_logger), server_config);

let (mut recv_chan, mut send) = recv_multiple_packets(&base_logger).await;

let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port);

for _ in 0..3 {
send.send_to("hello".as_bytes(), &server_addr)
.await
.unwrap();
}

for _ in 0..2 {
assert_eq!(recv_chan.recv().await.unwrap(), "hello");
}

// Allow enough time to have received any response.
tokio::time::delay_for(std::time::Duration::from_millis(100)).await;
// Check that we do not get any response.
assert!(recv_chan.try_recv().is_err());

close_server();
}
}

0 comments on commit 856e1d5

Please sign in to comment.