Skip to content

Commit

Permalink
test: add test for connect request in udp::handler
Browse files Browse the repository at this point in the history
  • Loading branch information
josecelano committed Sep 20, 2022
1 parent 028e40b commit ba6b26d
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 24 deletions.
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use log::info;
use torrust_tracker::tracker::statistics::StatsTracker;
use torrust_tracker::tracker::tracker::TorrentTracker;
use torrust_tracker::{logging, setup, static_time, Configuration};

Expand All @@ -19,8 +20,11 @@ async fn main() {
}
};

// Initialize stats tracker
let stats_tracker = StatsTracker::new_running_instance();

// Initialize Torrust tracker
let tracker = match TorrentTracker::new(config.clone()) {
let tracker = match TorrentTracker::new(config.clone(), Box::new(stats_tracker)) {
Ok(tracker) => Arc::new(tracker),
Err(error) => {
panic!("{}", error)
Expand Down
54 changes: 40 additions & 14 deletions src/tracker/statistics.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use async_trait::async_trait;
use std::sync::Arc;

use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::Sender;
use tokio::sync::{mpsc, RwLock, RwLockReadGuard};

const CHANNEL_BUFFER_SIZE: usize = 65_535;

#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum TrackerStatisticsEvent {
Tcp4Announce,
Tcp4Scrape,
Expand Down Expand Up @@ -61,25 +61,19 @@ pub struct StatsTracker {
}

impl StatsTracker {
pub fn new_running_instance() -> Self {
let mut stats_tracker = Self::new();
stats_tracker.run_worker();
stats_tracker
}

pub fn new() -> Self {
Self {
channel_sender: None,
stats: Arc::new(RwLock::new(TrackerStatistics::new())),
}
}

pub async fn get_stats(&self) -> RwLockReadGuard<'_, TrackerStatistics> {
self.stats.read().await
}

pub async fn send_event(&self, event: TrackerStatisticsEvent) -> Option<Result<(), SendError<TrackerStatisticsEvent>>> {
if let Some(tx) = &self.channel_sender {
Some(tx.send(event).await)
} else {
None
}
}

pub fn run_worker(&mut self) {
let (tx, mut rx) = mpsc::channel::<TrackerStatisticsEvent>(CHANNEL_BUFFER_SIZE);

Expand Down Expand Up @@ -134,3 +128,35 @@ impl StatsTracker {
});
}
}

#[async_trait]
pub trait TrackerStatisticsEventSender: Sync + Send {
async fn send_event(&self, event: TrackerStatisticsEvent) -> Option<Result<(), SendError<TrackerStatisticsEvent>>>;
}

#[async_trait]
impl TrackerStatisticsEventSender for StatsTracker {
async fn send_event(&self, event: TrackerStatisticsEvent) -> Option<Result<(), SendError<TrackerStatisticsEvent>>> {
if let Some(tx) = &self.channel_sender {
Some(tx.send(event).await)
} else {
None
}
}
}

#[async_trait]
pub trait TrackerStatisticsRepository: Sync + Send {
async fn get_stats(&self) -> RwLockReadGuard<'_, TrackerStatistics>;
}

#[async_trait]
impl TrackerStatisticsRepository for StatsTracker {
async fn get_stats(&self) -> RwLockReadGuard<'_, TrackerStatistics> {
self.stats.read().await
}
}

pub trait TrackerStatsService: TrackerStatisticsEventSender + TrackerStatisticsRepository {}

impl TrackerStatsService for StatsTracker {}
12 changes: 3 additions & 9 deletions src/tracker/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::databases::database::Database;
use crate::mode::TrackerMode;
use crate::peer::TorrentPeer;
use crate::protocol::common::InfoHash;
use crate::statistics::{StatsTracker, TrackerStatistics, TrackerStatisticsEvent};
use crate::statistics::{TrackerStatistics, TrackerStatisticsEvent, TrackerStatsService};
use crate::tracker::key;
use crate::tracker::key::AuthKey;
use crate::tracker::torrent::{TorrentEntry, TorrentError, TorrentStats};
Expand All @@ -24,19 +24,13 @@ pub struct TorrentTracker {
keys: RwLock<std::collections::HashMap<String, AuthKey>>,
whitelist: RwLock<std::collections::HashSet<InfoHash>>,
torrents: RwLock<std::collections::BTreeMap<InfoHash, TorrentEntry>>,
stats_tracker: StatsTracker,
stats_tracker: Box<dyn TrackerStatsService>,
database: Box<dyn Database>,
}

impl TorrentTracker {
pub fn new(config: Arc<Configuration>) -> Result<TorrentTracker, r2d2::Error> {
pub fn new(config: Arc<Configuration>, stats_tracker: Box<dyn TrackerStatsService>) -> Result<TorrentTracker, r2d2::Error> {
let database = database::connect_database(&config.db_driver, &config.db_path)?;
let mut stats_tracker = StatsTracker::new();

// starts a thread for updating tracker stats
if config.tracker_usage_statistics {
stats_tracker.run_worker();
}

Ok(TorrentTracker {
config: config.clone(),
Expand Down
150 changes: 150 additions & 0 deletions src/udp/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,153 @@ fn handle_error(e: ServerError, transaction_id: TransactionId) -> Response {
message: message.into(),
})
}

#[cfg(test)]
mod tests {
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
};

use tokio::sync::{mpsc::error::SendError, RwLock, RwLockReadGuard};

use crate::{
protocol::utils::get_connection_id,
statistics::{
StatsTracker, TrackerStatistics, TrackerStatisticsEvent, TrackerStatisticsEventSender, TrackerStatisticsRepository,
TrackerStatsService,
},
tracker::tracker::TorrentTracker,
udp::handle_connect,
Configuration,
};
use aquatic_udp_protocol::{ConnectRequest, ConnectResponse, Response, TransactionId};
use async_trait::async_trait;

fn default_tracker_config() -> Arc<Configuration> {
Arc::new(Configuration::default())
}

fn initialized_tracker() -> Arc<TorrentTracker> {
Arc::new(TorrentTracker::new(default_tracker_config(), Box::new(StatsTracker::new_running_instance())).unwrap())
}

fn sample_remote_addr() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
}

fn sample_connect_request() -> ConnectRequest {
ConnectRequest {
transaction_id: TransactionId(0i32),
}
}

fn sample_ipv4_socket_address() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
}

fn sample_ipv6_socket_address() -> SocketAddr {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080)
}

#[tokio::test]
async fn a_connect_response_should_contain_the_same_transaction_id_as_the_connect_request() {
let request = ConnectRequest {
transaction_id: TransactionId(0i32),
};

let response = handle_connect(sample_remote_addr(), &request, initialized_tracker())
.await
.unwrap();

assert_eq!(
response,
Response::Connect(ConnectResponse {
connection_id: get_connection_id(&sample_remote_addr()),
transaction_id: request.transaction_id
})
);
}

#[tokio::test]
async fn a_connect_response_should_contain_a_new_connection_id() {
let request = ConnectRequest {
transaction_id: TransactionId(0i32),
};

let response = handle_connect(sample_remote_addr(), &request, initialized_tracker())
.await
.unwrap();

assert_eq!(
response,
Response::Connect(ConnectResponse {
connection_id: get_connection_id(&sample_remote_addr()),
transaction_id: request.transaction_id
})
);
}

struct TrackerStatsServiceMock {
stats: Arc<RwLock<TrackerStatistics>>,
expected_event: Option<TrackerStatisticsEvent>,
}

impl TrackerStatsServiceMock {
fn new() -> Self {
Self {
stats: Arc::new(RwLock::new(TrackerStatistics::new())),
expected_event: None,
}
}

fn should_throw_event(&mut self, expected_event: TrackerStatisticsEvent) {
self.expected_event = Some(expected_event);
}
}

#[async_trait]
impl TrackerStatisticsEventSender for TrackerStatsServiceMock {
async fn send_event(&self, _event: TrackerStatisticsEvent) -> Option<Result<(), SendError<TrackerStatisticsEvent>>> {
if self.expected_event.is_some() {
assert_eq!(_event, *self.expected_event.as_ref().unwrap());
}
None
}
}

#[async_trait]
impl TrackerStatisticsRepository for TrackerStatsServiceMock {
async fn get_stats(&self) -> RwLockReadGuard<'_, TrackerStatistics> {
self.stats.read().await
}
}

impl TrackerStatsService for TrackerStatsServiceMock {}

#[tokio::test]
async fn it_should_send_the_upd4_connect_event_when_a_client_tries_to_connect_using_a_ip4_socket_address() {
let mut tracker_stats_service = Box::new(TrackerStatsServiceMock::new());

let client_socket_address = sample_ipv4_socket_address();
tracker_stats_service.should_throw_event(TrackerStatisticsEvent::Udp4Connect);

let torrent_tracker = Arc::new(TorrentTracker::new(default_tracker_config(), tracker_stats_service).unwrap());
handle_connect(client_socket_address, &sample_connect_request(), torrent_tracker)
.await
.unwrap();
}

#[tokio::test]
async fn it_should_send_the_upd6_connect_event_when_a_client_tries_to_connect_using_a_ip6_socket_address() {
let mut tracker_stats_service = Box::new(TrackerStatsServiceMock::new());

let client_socket_address = sample_ipv6_socket_address();
tracker_stats_service.should_throw_event(TrackerStatisticsEvent::Udp6Connect);

let torrent_tracker = Arc::new(TorrentTracker::new(default_tracker_config(), tracker_stats_service).unwrap());
handle_connect(client_socket_address, &sample_connect_request(), torrent_tracker)
.await
.unwrap();
}
}

0 comments on commit ba6b26d

Please sign in to comment.