Skip to content

Commit

Permalink
feat(ext/net): Refactor TCP socket listeners for future clustering mo…
Browse files Browse the repository at this point in the history
…de (#23037)

Changes:

- Implements a TCP socket listener that will allow for round-robin
load-balancing in-process.
 - Cleans up the raw networking code to make it easier to work with.
  • Loading branch information
mmastrac authored and satyarohith committed Apr 11, 2024
1 parent 95e78a9 commit a91bf05
Show file tree
Hide file tree
Showing 9 changed files with 539 additions and 353 deletions.
6 changes: 5 additions & 1 deletion ext/http/request_properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ impl HttpPropertyExtractor for DefaultHttpPropertyExtractor {
async fn accept_connection_from_listener(
listener: &NetworkStreamListener,
) -> Result<NetworkStream, AnyError> {
listener.accept().await.map_err(Into::into)
listener
.accept()
.await
.map_err(Into::into)
.map(|(stm, _)| stm)
}

fn listen_properties_from_listener(
Expand Down
67 changes: 36 additions & 31 deletions ext/net/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ pub mod ops_tls;
pub mod ops_unix;
pub mod raw;
pub mod resolve_addr;
mod tcp;

use deno_core::error::AnyError;
use deno_core::op2;
use deno_core::OpState;
use deno_tls::rustls::RootCertStore;
use deno_tls::RootCertStoreProvider;
Expand Down Expand Up @@ -93,21 +93,13 @@ deno_core::extension!(deno_net,
ops_tls::op_net_accept_tls,
ops_tls::op_tls_handshake,

#[cfg(unix)] ops_unix::op_net_accept_unix,
#[cfg(unix)] ops_unix::op_net_connect_unix<P>,
#[cfg(unix)] ops_unix::op_net_listen_unix<P>,
#[cfg(unix)] ops_unix::op_net_listen_unixpacket<P>,
#[cfg(unix)] ops_unix::op_node_unstable_net_listen_unixpacket<P>,
#[cfg(unix)] ops_unix::op_net_recv_unixpacket,
#[cfg(unix)] ops_unix::op_net_send_unixpacket<P>,

#[cfg(not(unix))] op_net_accept_unix,
#[cfg(not(unix))] op_net_connect_unix,
#[cfg(not(unix))] op_net_listen_unix,
#[cfg(not(unix))] op_net_listen_unixpacket,
#[cfg(not(unix))] op_node_unstable_net_listen_unixpacket,
#[cfg(not(unix))] op_net_recv_unixpacket,
#[cfg(not(unix))] op_net_send_unixpacket,
ops_unix::op_net_accept_unix,
ops_unix::op_net_connect_unix<P>,
ops_unix::op_net_listen_unix<P>,
ops_unix::op_net_listen_unixpacket<P>,
ops_unix::op_node_unstable_net_listen_unixpacket<P>,
ops_unix::op_net_recv_unixpacket,
ops_unix::op_net_send_unixpacket<P>,
],
esm = [ "01_net.js", "02_tls.js" ],
options = {
Expand All @@ -124,19 +116,32 @@ deno_core::extension!(deno_net,
},
);

macro_rules! stub_op {
($name:ident) => {
#[op2(fast)]
fn $name() {
panic!("Unsupported on non-unix platforms")
}
};
}
/// Stub ops for non-unix platforms.
#[cfg(not(unix))]
mod ops_unix {
use crate::NetPermissions;
use deno_core::op2;

stub_op!(op_net_accept_unix);
stub_op!(op_net_connect_unix);
stub_op!(op_net_listen_unix);
stub_op!(op_net_listen_unixpacket);
stub_op!(op_node_unstable_net_listen_unixpacket);
stub_op!(op_net_recv_unixpacket);
stub_op!(op_net_send_unixpacket);
macro_rules! stub_op {
($name:ident) => {
#[op2(fast)]
pub fn $name() {
panic!("Unsupported on non-unix platforms")
}
};
($name:ident<P>) => {
#[op2(fast)]
pub fn $name<P: NetPermissions>() {
panic!("Unsupported on non-unix platforms")
}
};
}

stub_op!(op_net_accept_unix);
stub_op!(op_net_connect_unix<P>);
stub_op!(op_net_listen_unix<P>);
stub_op!(op_net_listen_unixpacket<P>);
stub_op!(op_node_unstable_net_listen_unixpacket<P>);
stub_op!(op_net_recv_unixpacket);
stub_op!(op_net_send_unixpacket<P>);
}
49 changes: 9 additions & 40 deletions ext/net/ops.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.

use crate::io::TcpStreamResource;
use crate::raw::NetworkListenerResource;
use crate::resolve_addr::resolve_addr;
use crate::resolve_addr::resolve_addr_sync;
use crate::tcp::TcpListener;
use crate::NetPermissions;
use deno_core::error::bad_resource;
use deno_core::error::custom_error;
Expand Down Expand Up @@ -33,7 +35,6 @@ use std::net::Ipv6Addr;
use std::net::SocketAddr;
use std::rc::Rc;
use std::str::FromStr;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::net::UdpSocket;
use trust_dns_proto::rr::rdata::caa::Value;
Expand Down Expand Up @@ -85,7 +86,7 @@ pub async fn op_net_accept_tcp(
let resource = state
.borrow()
.resource_table
.get::<TcpListenerResource>(rid)
.get::<NetworkListenerResource<TcpListener>>(rid)
.map_err(|_| bad_resource("Listener has been closed"))?;
let listener = RcRef::map(&resource, |r| &r.listener)
.try_borrow_mut()
Expand Down Expand Up @@ -320,21 +321,6 @@ where
Ok((rid, IpAddr::from(local_addr), IpAddr::from(remote_addr)))
}

pub struct TcpListenerResource {
pub listener: AsyncRefCell<TcpListener>,
pub cancel: CancelHandle,
}

impl Resource for TcpListenerResource {
fn name(&self) -> Cow<str> {
"tcpListener".into()
}

fn close(self: Rc<Self>) {
self.cancel.cancel();
}
}

struct UdpSocketResource {
socket: AsyncRefCell<UdpSocket>,
cancel: CancelHandle,
Expand Down Expand Up @@ -369,29 +355,10 @@ where
let addr = resolve_addr_sync(&addr.hostname, addr.port)?
.next()
.ok_or_else(|| generic_error("No resolved address found"))?;
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, None)?;
#[cfg(not(windows))]
socket.set_reuse_address(true)?;
if reuse_port {
#[cfg(any(target_os = "android", target_os = "linux"))]
socket.set_reuse_port(true)?;
}
let socket_addr = socket2::SockAddr::from(addr);
socket.bind(&socket_addr)?;
socket.listen(128)?;
socket.set_nonblocking(true)?;
let std_listener: std::net::TcpListener = socket.into();
let listener = TcpListener::from_std(std_listener)?;

let listener = TcpListener::bind_direct(addr, reuse_port)?;
let local_addr = listener.local_addr()?;
let listener_resource = TcpListenerResource {
listener: AsyncRefCell::new(listener),
cancel: Default::default(),
};
let listener_resource = NetworkListenerResource::new(listener);
let rid = state.resource_table.add(listener_resource);

Ok((rid, IpAddr::from(local_addr)))
Expand Down Expand Up @@ -781,6 +748,7 @@ mod tests {
use socket2::SockRef;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::net::ToSocketAddrs;
use std::path::Path;
use std::sync::Arc;
use std::sync::Mutex;
Expand Down Expand Up @@ -1030,7 +998,8 @@ mod tests {
) {
let sockets = Arc::new(Mutex::new(vec![]));
let clone_addr = addr.clone();
let listener = TcpListener::bind(addr).await.unwrap();
let addr = addr.to_socket_addrs().unwrap().next().unwrap();
let listener = TcpListener::bind_direct(addr, false).unwrap();
let accept_fut = listener.accept().boxed_local();
let store_fut = async move {
let socket = accept_fut.await.unwrap();
Expand Down
91 changes: 34 additions & 57 deletions ext/net/ops_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
use crate::io::TcpStreamResource;
use crate::ops::IpAddr;
use crate::ops::TlsHandshakeInfo;
use crate::raw::NetworkListenerResource;
use crate::resolve_addr::resolve_addr;
use crate::resolve_addr::resolve_addr_sync;
use crate::tcp::TcpListener;
use crate::DefaultTlsOptions;
use crate::NetPermissions;
use crate::UnsafelyIgnoreCertificateErrors;
Expand Down Expand Up @@ -36,9 +38,6 @@ use deno_tls::TlsKeys;
use rustls_tokio_stream::TlsStreamRead;
use rustls_tokio_stream::TlsStreamWrite;
use serde::Deserialize;
use socket2::Domain;
use socket2::Socket;
use socket2::Type;
use std::borrow::Cow;
use std::cell::RefCell;
use std::convert::From;
Expand All @@ -47,20 +46,37 @@ use std::fs::File;
use std::io::BufReader;
use std::io::ErrorKind;
use std::io::Read;
use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::net::TcpStream;

pub use rustls_tokio_stream::TlsStream;

pub(crate) const TLS_BUFFER_SIZE: Option<NonZeroUsize> =
NonZeroUsize::new(65536);

pub struct TlsListener {
pub(crate) tcp_listener: TcpListener,
pub(crate) tls_config: Arc<ServerConfig>,
}

impl TlsListener {
pub async fn accept(&self) -> std::io::Result<(TlsStream, SocketAddr)> {
let (tcp, addr) = self.tcp_listener.accept().await?;
let tls =
TlsStream::new_server_side(tcp, self.tls_config.clone(), TLS_BUFFER_SIZE);
Ok((tls, addr))
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.tcp_listener.local_addr()
}
}

#[derive(Debug)]
pub struct TlsStreamResource {
rd: AsyncRefCell<TlsStreamRead>,
Expand Down Expand Up @@ -399,22 +415,6 @@ fn load_private_keys_from_file(
load_private_keys(&key_bytes)
}

pub struct TlsListenerResource {
pub(crate) tcp_listener: AsyncRefCell<TcpListener>,
pub(crate) tls_config: Arc<ServerConfig>,
cancel_handle: CancelHandle,
}

impl Resource for TlsListenerResource {
fn name(&self) -> Cow<str> {
"tlsListener".into()
}

fn close(self: Rc<Self>) {
self.cancel_handle.cancel();
}
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListenTlsArgs {
Expand Down Expand Up @@ -470,31 +470,14 @@ where
let bind_addr = resolve_addr_sync(&addr.hostname, addr.port)?
.next()
.ok_or_else(|| generic_error("No resolved address found"))?;
let domain = if bind_addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, None)?;
#[cfg(not(windows))]
socket.set_reuse_address(true)?;
if args.reuse_port {
#[cfg(any(target_os = "android", target_os = "linux"))]
socket.set_reuse_port(true)?;
}
let socket_addr = socket2::SockAddr::from(bind_addr);
socket.bind(&socket_addr)?;
socket.listen(128)?;
socket.set_nonblocking(true)?;
let std_listener: std::net::TcpListener = socket.into();
let tcp_listener = TcpListener::from_std(std_listener)?;

let tcp_listener = TcpListener::bind_direct(bind_addr, args.reuse_port)?;
let local_addr = tcp_listener.local_addr()?;

let tls_listener_resource = TlsListenerResource {
tcp_listener: AsyncRefCell::new(tcp_listener),
tls_config: Arc::new(tls_config),
cancel_handle: Default::default(),
};
let tls_listener_resource = NetworkListenerResource::new(TlsListener {
tcp_listener,
tls_config: tls_config.into(),
});

let rid = state.resource_table.add(tls_listener_resource);

Expand All @@ -510,16 +493,16 @@ pub async fn op_net_accept_tls(
let resource = state
.borrow()
.resource_table
.get::<TlsListenerResource>(rid)
.get::<NetworkListenerResource<TlsListener>>(rid)
.map_err(|_| bad_resource("Listener has been closed"))?;

let cancel_handle = RcRef::map(&resource, |r| &r.cancel_handle);
let tcp_listener = RcRef::map(&resource, |r| &r.tcp_listener)
let cancel_handle = RcRef::map(&resource, |r| &r.cancel);
let listener = RcRef::map(&resource, |r| &r.listener)
.try_borrow_mut()
.ok_or_else(|| custom_error("Busy", "Another accept task is ongoing"))?;

let (tcp_stream, remote_addr) =
match tcp_listener.accept().try_or_cancel(&cancel_handle).await {
let (tls_stream, remote_addr) =
match listener.accept().try_or_cancel(&cancel_handle).await {
Ok(tuple) => tuple,
Err(err) if err.kind() == ErrorKind::Interrupted => {
// FIXME(bartlomieju): compatibility with current JS implementation.
Expand All @@ -528,14 +511,7 @@ pub async fn op_net_accept_tls(
Err(err) => return Err(err.into()),
};

let local_addr = tcp_stream.local_addr()?;

let tls_stream = TlsStream::new_server_side(
tcp_stream,
resource.tls_config.clone(),
TLS_BUFFER_SIZE,
);

let local_addr = tls_stream.local_addr()?;
let rid = {
let mut state_ = state.borrow_mut();
state_
Expand All @@ -555,6 +531,7 @@ pub async fn op_tls_handshake(
let resource = state
.borrow()
.resource_table
.get::<TlsStreamResource>(rid)?;
.get::<TlsStreamResource>(rid)
.map_err(|_| bad_resource("Listener has been closed"))?;
resource.handshake().await
}
Loading

0 comments on commit a91bf05

Please sign in to comment.