Skip to content

Commit

Permalink
feat: make ban service generic for all trackers
Browse files Browse the repository at this point in the history
All UDP tracker will share the same service. In the future, the HTTP
trackers can also use it.

The service was not include inside the tracker (easy solution) becuase
the Tracker type is too big. It has became the app container. In fact,
we want to reduce it in the future by extracting the services outside of
the tracker: stats, whitelist, etc. Those services will be instantiate
independently in the future in the app bootstrap.
  • Loading branch information
josecelano committed Jan 7, 2025
1 parent 6f9b44c commit d9cfb38
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 37 deletions.
14 changes: 11 additions & 3 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
//! - Tracker REST API: the tracker API can be enabled/disabled.
use std::sync::Arc;

use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use torrust_tracker_configuration::Configuration;
use tracing::instrument;

use crate::bootstrap::jobs::{health_check_api, http_tracker, torrent_cleanup, tracker_apis, udp_tracker};
use crate::servers::registar::Registar;
use crate::servers::udp::server::banning::BanService;
use crate::{core, servers};

/// # Panics
Expand All @@ -37,8 +39,12 @@ use crate::{core, servers};
///
/// - Can't retrieve tracker keys from database.
/// - Can't load whitelist from database.
#[instrument(skip(config, tracker))]
pub async fn start(config: &Configuration, tracker: Arc<core::Tracker>) -> Vec<JoinHandle<()>> {
#[instrument(skip(config, tracker, ban_service))]
pub async fn start(
config: &Configuration,
tracker: Arc<core::Tracker>,
ban_service: Arc<RwLock<BanService>>,
) -> Vec<JoinHandle<()>> {
if config.http_api.is_none()
&& (config.udp_trackers.is_none() || config.udp_trackers.as_ref().map_or(true, std::vec::Vec::is_empty))
&& (config.http_trackers.is_none() || config.http_trackers.as_ref().map_or(true, std::vec::Vec::is_empty))
Expand Down Expand Up @@ -75,7 +81,9 @@ pub async fn start(config: &Configuration, tracker: Arc<core::Tracker>) -> Vec<J
udp_tracker_config.bind_address
);
} else {
jobs.push(udp_tracker::start_job(udp_tracker_config, tracker.clone(), registar.give_form()).await);
jobs.push(
udp_tracker::start_job(udp_tracker_config, tracker.clone(), ban_service.clone(), registar.give_form()).await,
);
}
}
} else {
Expand Down
9 changes: 7 additions & 2 deletions src/bootstrap/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//! 4. Initialize the domain tracker.
use std::sync::Arc;

use tokio::sync::RwLock;
use torrust_tracker_clock::static_time;
use torrust_tracker_configuration::validator::Validator;
use torrust_tracker_configuration::Configuration;
Expand All @@ -22,6 +23,8 @@ use super::config::initialize_configuration;
use crate::bootstrap;
use crate::core::services::tracker_factory;
use crate::core::Tracker;
use crate::servers::udp::server::banning::BanService;
use crate::servers::udp::server::launcher::MAX_CONNECTION_ID_ERRORS_PER_IP;
use crate::shared::crypto::ephemeral_instance_keys;
use crate::shared::crypto::keys::{self, Keeper as _};

Expand All @@ -32,7 +35,7 @@ use crate::shared::crypto::keys::{self, Keeper as _};
/// Setup can file if the configuration is invalid.
#[must_use]
#[instrument(skip())]
pub fn setup() -> (Configuration, Arc<Tracker>) {
pub fn setup() -> (Configuration, Arc<Tracker>, Arc<RwLock<BanService>>) {
#[cfg(not(test))]
check_seed();

Expand All @@ -44,9 +47,11 @@ pub fn setup() -> (Configuration, Arc<Tracker>) {

let tracker = initialize_with_configuration(&configuration);

let ban_service = Arc::new(RwLock::new(BanService::new(MAX_CONNECTION_ID_ERRORS_PER_IP)));

tracing::info!("Configuration:\n{}", configuration.clone().mask_secrets().to_json());

(configuration, tracker)
(configuration, tracker, ban_service)
}

/// checks if the seed is the instance seed in production.
Expand Down
13 changes: 10 additions & 3 deletions src/bootstrap/jobs/udp_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
//! > for the configuration options.
use std::sync::Arc;

use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use torrust_tracker_configuration::UdpTracker;
use tracing::instrument;

use crate::core;
use crate::servers::registar::ServiceRegistrationForm;
use crate::servers::udp::server::banning::BanService;
use crate::servers::udp::server::spawner::Spawner;
use crate::servers::udp::server::Server;
use crate::servers::udp::UDP_TRACKER_LOG_TARGET;
Expand All @@ -29,13 +31,18 @@ use crate::servers::udp::UDP_TRACKER_LOG_TARGET;
/// It will panic if the task did not finish successfully.
#[must_use]
#[allow(clippy::async_yields_async)]
#[instrument(skip(config, tracker, form))]
pub async fn start_job(config: &UdpTracker, tracker: Arc<core::Tracker>, form: ServiceRegistrationForm) -> JoinHandle<()> {
#[instrument(skip(config, tracker, ban_service, form))]
pub async fn start_job(
config: &UdpTracker,
tracker: Arc<core::Tracker>,
ban_service: Arc<RwLock<BanService>>,
form: ServiceRegistrationForm,
) -> JoinHandle<()> {
let bind_to = config.bind_address;
let cookie_lifetime = config.cookie_lifetime;

let server = Server::new(Spawner::new(bind_to))
.start(tracker, form, cookie_lifetime)
.start(tracker, ban_service, form, cookie_lifetime)
.await
.expect("it should be able to start the udp tracker");

Expand Down
4 changes: 2 additions & 2 deletions src/console/profiling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ pub async fn run() {
return;
};

let (config, tracker) = bootstrap::app::setup();
let (config, tracker, ban_service) = bootstrap::app::setup();

let jobs = app::start(&config, tracker).await;
let jobs = app::start(&config, tracker, ban_service).await;

// Run the tracker for a fixed duration
let run_duration = sleep(Duration::from_secs(duration_secs));
Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use torrust_tracker_lib::{app, bootstrap};

#[tokio::main]
async fn main() {
let (config, tracker) = bootstrap::app::setup();
let (config, tracker, ban_service) = bootstrap::app::setup();

let jobs = app::start(&config, tracker).await;
let jobs = app::start(&config, tracker, ban_service).await;

// handle the signals
tokio::select! {
Expand Down
11 changes: 3 additions & 8 deletions src/servers/udp/server/banning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,21 @@ use std::net::IpAddr;

use bloom::{CountingBloomFilter, ASMS};
use tokio::time::Instant;
use url::Url;

use crate::servers::udp::UDP_TRACKER_LOG_TARGET;

pub struct BanService {
max_connection_id_errors_per_ip: u32,
fuzzy_error_counter: CountingBloomFilter,
accurate_error_counter: HashMap<IpAddr, u32>,
local_addr: Url,
last_connection_id_errors_reset: Instant,
}

impl BanService {
#[must_use]
pub fn new(max_connection_id_errors_per_ip: u32, local_addr: Url) -> Self {
pub fn new(max_connection_id_errors_per_ip: u32) -> Self {
Self {
max_connection_id_errors_per_ip,
local_addr,
fuzzy_error_counter: CountingBloomFilter::with_rate(4, 0.01, 100),
accurate_error_counter: HashMap::new(),
last_connection_id_errors_reset: tokio::time::Instant::now(),
Expand Down Expand Up @@ -82,8 +79,7 @@ impl BanService {

self.last_connection_id_errors_reset = Instant::now();

let local_addr = self.local_addr.to_string();
tracing::info!(target: UDP_TRACKER_LOG_TARGET, local_addr, "Udp::run_udp_server::loop (connection id errors filter cleared)");
tracing::info!(target: UDP_TRACKER_LOG_TARGET, "Udp::run_udp_server::loop (connection id errors filter cleared)");
}
}

Expand All @@ -95,8 +91,7 @@ mod tests {

/// Sample service with one day ban duration.
fn ban_service(counter_limit: u32) -> BanService {
let udp_tracker_url = "udp://127.0.0.1".parse().unwrap();
BanService::new(counter_limit, udp_tracker_url)
BanService::new(counter_limit)
}

#[test]
Expand Down
21 changes: 11 additions & 10 deletions src/servers/udp/server/launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::servers::udp::UDP_TRACKER_LOG_TARGET;

/// The maximum number of connection id errors per ip. Clients will be banned if
/// they exceed this limit.
const MAX_CONNECTION_ID_ERRORS_PER_IP: u32 = 10;
pub const MAX_CONNECTION_ID_ERRORS_PER_IP: u32 = 10;
const IP_BANS_RESET_INTERVAL_IN_SECS: u64 = 3600;

/// A UDP server instance launcher.
Expand All @@ -40,9 +40,10 @@ impl Launcher {
/// It panics if unable to send address of socket.
/// It panics if the udp server is loaded when the tracker is private.
///
#[instrument(skip(tracker, bind_to, tx_start, rx_halt))]
#[instrument(skip(tracker, ban_service, bind_to, tx_start, rx_halt))]
pub async fn run_with_graceful_shutdown(
tracker: Arc<Tracker>,
ban_service: Arc<RwLock<BanService>>,
bind_to: SocketAddr,
cookie_lifetime: Duration,
tx_start: oneshot::Sender<Started>,
Expand Down Expand Up @@ -80,7 +81,7 @@ impl Launcher {
let local_addr = local_udp_url.clone();
tokio::task::spawn(async move {
tracing::debug!(target: UDP_TRACKER_LOG_TARGET, local_addr, "Udp::run_with_graceful_shutdown::task (listening...)");
let () = Self::run_udp_server_main(receiver, tracker.clone(), cookie_lifetime).await;
let () = Self::run_udp_server_main(receiver, tracker.clone(), ban_service.clone(), cookie_lifetime).await;
})
};

Expand Down Expand Up @@ -117,8 +118,13 @@ impl Launcher {
ServiceHealthCheckJob::new(binding, info, job)
}

#[instrument(skip(receiver, tracker))]
async fn run_udp_server_main(mut receiver: Receiver, tracker: Arc<Tracker>, cookie_lifetime: Duration) {
#[instrument(skip(receiver, tracker, ban_service))]
async fn run_udp_server_main(
mut receiver: Receiver,
tracker: Arc<Tracker>,
ban_service: Arc<RwLock<BanService>>,
cookie_lifetime: Duration,
) {
let active_requests = &mut ActiveRequests::default();

let addr = receiver.bound_socket_address();
Expand All @@ -127,11 +133,6 @@ impl Launcher {

let cookie_lifetime = cookie_lifetime.as_secs_f64();

let ban_service = Arc::new(RwLock::new(BanService::new(
MAX_CONNECTION_ID_ERRORS_PER_IP,
local_addr.parse().unwrap(),
)));

let ban_cleaner = ban_service.clone();

tokio::spawn(async move {
Expand Down
13 changes: 11 additions & 2 deletions src/servers/udp/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,23 @@ mod tests {
use std::sync::Arc;
use std::time::Duration;

use tokio::sync::RwLock;
use torrust_tracker_test_helpers::configuration::ephemeral_public;

use super::spawner::Spawner;
use super::Server;
use crate::bootstrap::app::initialize_with_configuration;
use crate::servers::registar::Registar;
use crate::servers::udp::server::banning::BanService;
use crate::servers::udp::server::launcher::MAX_CONNECTION_ID_ERRORS_PER_IP;

#[tokio::test]
async fn it_should_be_able_to_start_and_stop() {
let cfg = Arc::new(ephemeral_public());

let tracker = initialize_with_configuration(&cfg);
let ban_service = Arc::new(RwLock::new(BanService::new(MAX_CONNECTION_ID_ERRORS_PER_IP)));

let udp_trackers = cfg.udp_trackers.clone().expect("missing UDP trackers configuration");
let config = &udp_trackers[0];
let bind_to = config.bind_address;
Expand All @@ -77,7 +83,7 @@ mod tests {
let stopped = Server::new(Spawner::new(bind_to));

let started = stopped
.start(tracker, register.give_form(), config.cookie_lifetime)
.start(tracker, ban_service, register.give_form(), config.cookie_lifetime)
.await
.expect("it should start the server");

Expand All @@ -91,15 +97,18 @@ mod tests {
#[tokio::test]
async fn it_should_be_able_to_start_and_stop_with_wait() {
let cfg = Arc::new(ephemeral_public());

let tracker = initialize_with_configuration(&cfg);
let ban_service = Arc::new(RwLock::new(BanService::new(MAX_CONNECTION_ID_ERRORS_PER_IP)));

let config = &cfg.udp_trackers.as_ref().unwrap().first().unwrap();
let bind_to = config.bind_address;
let register = &Registar::default();

let stopped = Server::new(Spawner::new(bind_to));

let started = stopped
.start(tracker, register.give_form(), config.cookie_lifetime)
.start(tracker, ban_service, register.give_form(), config.cookie_lifetime)
.await
.expect("it should start the server");

Expand Down
6 changes: 4 additions & 2 deletions src/servers/udp/server/spawner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use std::time::Duration;

use derive_more::derive::Display;
use derive_more::Constructor;
use tokio::sync::oneshot;
use tokio::sync::{oneshot, RwLock};
use tokio::task::JoinHandle;

use super::banning::BanService;
use super::launcher::Launcher;
use crate::bootstrap::jobs::Started;
use crate::core::Tracker;
Expand All @@ -28,14 +29,15 @@ impl Spawner {
pub fn spawn_launcher(
&self,
tracker: Arc<Tracker>,
ban_service: Arc<RwLock<BanService>>,
cookie_lifetime: Duration,
tx_start: oneshot::Sender<Started>,
rx_halt: oneshot::Receiver<Halted>,
) -> JoinHandle<Spawner> {
let spawner = Self::new(self.bind_to);

tokio::spawn(async move {
Launcher::run_with_graceful_shutdown(tracker, spawner.bind_to, cookie_lifetime, tx_start, rx_halt).await;
Launcher::run_with_graceful_shutdown(tracker, ban_service, spawner.bind_to, cookie_lifetime, tx_start, rx_halt).await;
spawner
})
}
Expand Down
11 changes: 9 additions & 2 deletions src/servers/udp/server/states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ use std::time::Duration;

use derive_more::derive::Display;
use derive_more::Constructor;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tracing::{instrument, Level};

use super::banning::BanService;
use super::spawner::Spawner;
use super::{Server, UdpError};
use crate::bootstrap::jobs::Started;
Expand Down Expand Up @@ -62,10 +64,12 @@ impl Server<Stopped> {
///
/// It panics if unable to receive the bound socket address from service.
///
#[instrument(skip(self, tracker, form), err, ret(Display, level = Level::INFO))]
#[instrument(skip(self, tracker, ban_service, form), err, ret(Display, level = Level::INFO))]
pub async fn start(
self,
tracker: Arc<Tracker>,

ban_service: Arc<RwLock<BanService>>,
form: ServiceRegistrationForm,
cookie_lifetime: Duration,
) -> Result<Server<Running>, std::io::Error> {
Expand All @@ -75,7 +79,10 @@ impl Server<Stopped> {
assert!(!tx_halt.is_closed(), "Halt channel for UDP tracker should be open");

// May need to wrap in a task to about a tokio bug.
let task = self.state.spawner.spawn_launcher(tracker, cookie_lifetime, tx_start, rx_halt);
let task = self
.state
.spawner
.spawn_launcher(tracker, ban_service, cookie_lifetime, tx_start, rx_halt);

let local_addr = rx_start.await.expect("it should be able to start the service").address;

Expand Down
Loading

0 comments on commit d9cfb38

Please sign in to comment.