Skip to content

Commit

Permalink
Unified server and local association manager implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
zonyitoo committed May 23, 2020
1 parent b305219 commit 3b47fa6
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 307 deletions.
318 changes: 302 additions & 16 deletions src/relay/udprelay/association.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::{
io::{self, Cursor, Read},
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
time::Duration,
};

use async_trait::async_trait;
Expand All @@ -25,11 +26,12 @@ use tokio::{

use crate::{
config::{Config, ServerAddr, ServerConfig},
context::Context,
context::{Context, SharedContext},
relay::{
loadbalancing::server::{ServerData, SharedServerStatistic},
socks5::Address,
sys::create_udp_socket_with_context,
sys::{create_udp_socket, create_udp_socket_with_context},
utils::try_timeout,
},
};

Expand Down Expand Up @@ -180,7 +182,7 @@ impl ProxyAssociation {
Ok(ProxyAssociation { tx, watchers })
}

pub async fn send(&mut self, target: Address, payload: Vec<u8>) {
async fn send(&mut self, target: Address, payload: Vec<u8>) {
if let Err(..) = self.tx.send((target, payload)).await {
// SHOULDn't HAPPEN
unreachable!("UDP association local -> remote queue closed unexpectly");
Expand Down Expand Up @@ -427,24 +429,36 @@ impl ProxyAssociation {
}
}

#[derive(Clone)]
pub struct ProxyAssociationManager<K> {
map: Arc<Mutex<LruCache<K, ProxyAssociation>>>,
struct AssociationManagerInner<K, A> {
map: Arc<Mutex<LruCache<K, A>>>,
watcher: AbortHandle,
}

impl<K> Drop for ProxyAssociationManager<K> {
impl<K, A> Drop for AssociationManagerInner<K, A> {
fn drop(&mut self) {
self.watcher.abort()
}
}

impl<K> ProxyAssociationManager<K>
pub struct AssociationManager<K, A> {
inner: Arc<AssociationManagerInner<K, A>>,
}

impl<K, A> Clone for AssociationManager<K, A> {
fn clone(&self) -> Self {
AssociationManager {
inner: self.inner.clone(),
}
}
}

impl<K, A> AssociationManager<K, A>
where
K: Ord + Clone + Send + 'static,
A: Send + 'static,
{
/// Create a new ProxyAssociationManager based on Config
pub fn new(config: &Config) -> ProxyAssociationManager<K> {
/// Create a new AssociationManager based on Config
pub fn new(config: &Config) -> AssociationManager<K, A> {
let timeout = config.udp_timeout.unwrap_or(DEFAULT_TIMEOUT);

// TODO: Set default capacity by getrlimit #262
Expand Down Expand Up @@ -474,34 +488,306 @@ where

tokio::spawn(release_task);

ProxyAssociationManager { map, watcher }
AssociationManager {
inner: Arc::new(AssociationManagerInner { map, watcher }),
}
}

/// Try to reset ProxyAssociation's last used time by key
///
/// Return true if ProxyAssociation is still exist
pub async fn keep_alive(&self, key: &K) -> bool {
let mut assoc = self.map.lock().await;
let mut assoc = self.inner.map.lock().await;
assoc.get(key).is_some()
}
}

impl<K> AssociationManager<K, ProxyAssociation>
where
K: Ord + Clone + Send + 'static,
{
/// Send a packet to target address
///
/// Create a new association by `create_fut` if association doesn't exist
pub async fn send_packet<F>(&self, key: K, target: Address, pkt: Vec<u8>, create_fut: F) -> io::Result<()>
pub async fn send_packet<CFut>(&self, key: K, target: Address, payload: Vec<u8>, create_fut: CFut) -> io::Result<()>
where
F: Future<Output = io::Result<ProxyAssociation>>,
CFut: Future<Output = io::Result<ProxyAssociation>>,
{
let mut assoc_map = self.map.lock().await;
let mut assoc_map = self.inner.map.lock().await;
let assoc = match assoc_map.entry(key) {
Entry::Occupied(oc) => oc.into_mut(),
Entry::Vacant(vc) => vc.insert(create_fut.await?),
};

// FIXME: Lock is still kept for a mutable reference
// Send to local -> remote task
assoc.send(target, pkt).await;
assoc.send(target, payload).await;

Ok(())
}
}

/// Association manager for local
pub type ProxyAssociationManager<K> = AssociationManager<K, ProxyAssociation>;

// Represent a UDP association in server
pub struct ServerAssociation {
// local -> remote Queue
// Drops tx, will close local -> remote task
tx: mpsc::Sender<Vec<u8>>,

// local <- remote task life watcher
watcher: AbortHandle,
}

impl Drop for ServerAssociation {
fn drop(&mut self) {
self.watcher.abort();
}
}

impl ServerAssociation {
/// Create an association with addr
pub async fn associate(
context: SharedContext,
svr_idx: usize,
src_addr: SocketAddr,
mut response_tx: mpsc::Sender<(SocketAddr, BytesMut)>,
) -> io::Result<ServerAssociation> {
// Create a socket for receiving packets
let local_addr = match context.config().local_addr {
None => {
// Let system allocate an address for us
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
}
Some(ref addr) => {
// Uses configured local address
addr.bind_addr(&context).await?
}
};
let remote_udp = create_udp_socket(&local_addr).await?;

let local_addr = remote_udp.local_addr().expect("could not determine port bound to");
debug!("created UDP Association for {} from {}", src_addr, local_addr);

// Create a channel for sending packets to remote
// FIXME: Channel size 1024?
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(1024);

// Splits socket into sender and receiver
let (mut receiver, mut sender) = remote_udp.split();

let timeout = context.config().udp_timeout.unwrap_or(DEFAULT_TIMEOUT);

// local -> remote
{
let context = context.clone();
tokio::spawn(async move {
let svr_cfg = context.server_config(svr_idx);

while let Some(pkt) = rx.recv().await {
// pkt is already a raw packet, so just send it
if let Err(err) =
ServerAssociation::relay_l2r(&context, src_addr, &mut sender, &pkt[..], timeout, svr_cfg).await
{
error!("failed to relay packet, {} -> ..., error: {}", src_addr, err);

// FIXME: Ignore? Or how to deal with it?
}
}

debug!("UDP ASSOCIATE {} -> .. finished", src_addr);
});
}

let (r2l_task, close_flag) = future::abortable(async move {
let svr_cfg = context.server_config(svr_idx);

loop {
// Read and send back to source
match ServerAssociation::relay_r2l(&context, src_addr, &mut receiver, &mut response_tx, svr_cfg).await {
Ok(..) => {}
Err(err) => {
error!("failed to receive packet, {} <- .., error: {}", src_addr, err);

// FIXME: Don't break, or if you can find a way to drop the ServerAssociation
// break;
}
}
}
});

// local <- remote
tokio::spawn(async move {
let _ = r2l_task.await;

debug!("UDP ASSOCIATE {} <- .. finished", src_addr);
});

Ok(ServerAssociation {
tx,
watcher: close_flag,
})
}

/// Relay packets from local to remote
async fn relay_l2r(
context: &Context,
src: SocketAddr,
remote_udp: &mut SendHalf,
pkt: &[u8],
timeout: Duration,
svr_cfg: &ServerConfig,
) -> io::Result<()> {
// First of all, decrypt payload CLIENT -> SERVER
let decrypted_pkt = match decrypt_payload(context, svr_cfg.method(), svr_cfg.key(), pkt) {
Ok(Some(pkt)) => pkt,
Ok(None) => {
error!("failed to decrypt pkt in UDP relay, packet too short");
let err = io::Error::new(io::ErrorKind::InvalidData, "packet too short");
return Err(err);
}
Err(err) => {
error!("failed to decrypt pkt in UDP relay: {}", err);
let err = io::Error::new(io::ErrorKind::InvalidData, "decrypt failed");
return Err(err);
}
};

// CLIENT -> SERVER protocol: ADDRESS + PAYLOAD
let mut cur = Cursor::new(decrypted_pkt);

let addr = Address::read_from(&mut cur).await?;

debug!("UDP ASSOCIATE {} <-> {} establishing", src, addr);

if context.check_outbound_blocked(&addr) {
warn!("outbound {} is blocked by ACL rules", addr);
return Ok(());
}

// Take out internal buffer for optimizing one byte copy
let header_len = cur.position() as usize;
let decrypted_pkt = cur.into_inner();
let body = &decrypted_pkt[header_len..];

let send_len = match addr {
Address::SocketAddress(ref remote_addr) => {
debug!(
"UDP ASSOCIATE {} -> {} ({}), payload length {} bytes",
src,
addr,
remote_addr,
body.len()
);
try_timeout(remote_udp.send_to(body, remote_addr), Some(timeout)).await?
}
Address::DomainNameAddress(ref dname, port) => lookup_outbound_then!(context, dname, port, |remote_addr| {
match try_timeout(remote_udp.send_to(body, &remote_addr), Some(timeout)).await {
Ok(l) => {
debug!(
"UDP ASSOCIATE {} -> {} ({}), payload length {} bytes",
src,
addr,
remote_addr,
body.len()
);
Ok(l)
}
Err(err) => {
error!(
"UDP ASSOCIATE {} -> {} ({}), payload length {} bytes",
src,
addr,
remote_addr,
body.len()
);
Err(err)
}
}
})
.map(|(_, l)| l)?,
};

assert_eq!(body.len(), send_len);

Ok(())
}

/// Relay packets from remote to local
async fn relay_r2l(
context: &Context,
src_addr: SocketAddr,
remote_udp: &mut RecvHalf,
response_tx: &mut mpsc::Sender<(SocketAddr, BytesMut)>,
svr_cfg: &ServerConfig,
) -> io::Result<()> {
// Waiting for response from server SERVER -> CLIENT
// Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded.
let mut remote_buf = vec![0u8; MAXIMUM_UDP_PAYLOAD_SIZE];
let (remote_recv_len, remote_addr) = remote_udp.recv_from(&mut remote_buf).await?;

debug!(
"UDP ASSOCIATE {} <- {}, payload length {} bytes",
src_addr, remote_addr, remote_recv_len
);

// FIXME: The Address should be the Address that client sent
let addr = Address::SocketAddress(remote_addr);

// CLIENT <- SERVER protocol: ADDRESS + PAYLOAD
let mut send_buf = Vec::new();
addr.write_to_buf(&mut send_buf);
send_buf.extend_from_slice(&remote_buf[..remote_recv_len]);

let mut encrypt_buf = BytesMut::new();
encrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &send_buf, &mut encrypt_buf)?;

// Send back to src_addr
if let Err(err) = response_tx.send((src_addr, encrypt_buf)).await {
error!("failed to send packet into response channel, error: {}", err);

// FIXME: What to do? Ignore?
}

Ok(())
}

// Send packet to remote
//
// Return `Err` if receiver have been closed
async fn send(&mut self, pkt: Vec<u8>) {
if let Err(..) = self.tx.send(pkt).await {
// SHOULDn't HAPPEN
unreachable!("UDP Association local -> remote Queue closed unexpectly");
}
}
}

impl<K> AssociationManager<K, ServerAssociation>
where
K: Ord + Clone + Send + 'static,
{
/// Send a packet to target address
///
/// Create a new association by `create_fut` if association doesn't exist
pub async fn send_packet<CFut>(&self, key: K, payload: Vec<u8>, create_fut: CFut) -> io::Result<()>
where
CFut: Future<Output = io::Result<ServerAssociation>>,
{
let mut assoc_map = self.inner.map.lock().await;
let assoc = match assoc_map.entry(key) {
Entry::Occupied(oc) => oc.into_mut(),
Entry::Vacant(vc) => vc.insert(create_fut.await?),
};

// FIXME: Lock is still kept for a mutable reference
// Send to local -> remote task
assoc.send(payload).await;

Ok(())
}
}

/// Association manager for server
pub type ServerAssociationManager<K> = AssociationManager<K, ServerAssociation>;
Loading

0 comments on commit 3b47fa6

Please sign in to comment.