diff --git a/ext/http/request_properties.rs b/ext/http/request_properties.rs index 02ef1387145b8b..1422c7417d2774 100644 --- a/ext/http/request_properties.rs +++ b/ext/http/request_properties.rs @@ -119,7 +119,11 @@ impl HttpPropertyExtractor for DefaultHttpPropertyExtractor { async fn accept_connection_from_listener( listener: &NetworkStreamListener, ) -> Result { - listener.accept().await.map_err(Into::into) + listener + .accept() + .await + .map_err(Into::into) + .map(|(stm, _)| stm) } fn listen_properties_from_listener( diff --git a/ext/net/lib.rs b/ext/net/lib.rs index d6e1d9dc237227..d137aa315a47a4 100644 --- a/ext/net/lib.rs +++ b/ext/net/lib.rs @@ -7,9 +7,9 @@ pub mod ops_tls; pub mod ops_unix; pub mod raw; pub mod resolve_addr; +mod tcp; use deno_core::error::AnyError; -use deno_core::op2; use deno_core::OpState; use deno_tls::rustls::RootCertStore; use deno_tls::RootCertStoreProvider; @@ -93,21 +93,13 @@ deno_core::extension!(deno_net, ops_tls::op_net_accept_tls, ops_tls::op_tls_handshake, - #[cfg(unix)] ops_unix::op_net_accept_unix, - #[cfg(unix)] ops_unix::op_net_connect_unix

, - #[cfg(unix)] ops_unix::op_net_listen_unix

, - #[cfg(unix)] ops_unix::op_net_listen_unixpacket

, - #[cfg(unix)] ops_unix::op_node_unstable_net_listen_unixpacket

, - #[cfg(unix)] ops_unix::op_net_recv_unixpacket, - #[cfg(unix)] ops_unix::op_net_send_unixpacket

, - - #[cfg(not(unix))] op_net_accept_unix, - #[cfg(not(unix))] op_net_connect_unix, - #[cfg(not(unix))] op_net_listen_unix, - #[cfg(not(unix))] op_net_listen_unixpacket, - #[cfg(not(unix))] op_node_unstable_net_listen_unixpacket, - #[cfg(not(unix))] op_net_recv_unixpacket, - #[cfg(not(unix))] op_net_send_unixpacket, + ops_unix::op_net_accept_unix, + ops_unix::op_net_connect_unix

, + ops_unix::op_net_listen_unix

, + ops_unix::op_net_listen_unixpacket

, + ops_unix::op_node_unstable_net_listen_unixpacket

, + ops_unix::op_net_recv_unixpacket, + ops_unix::op_net_send_unixpacket

, ], esm = [ "01_net.js", "02_tls.js" ], options = { @@ -124,19 +116,32 @@ deno_core::extension!(deno_net, }, ); -macro_rules! stub_op { - ($name:ident) => { - #[op2(fast)] - fn $name() { - panic!("Unsupported on non-unix platforms") - } - }; -} +/// Stub ops for non-unix platforms. +#[cfg(not(unix))] +mod ops_unix { + use crate::NetPermissions; + use deno_core::op2; -stub_op!(op_net_accept_unix); -stub_op!(op_net_connect_unix); -stub_op!(op_net_listen_unix); -stub_op!(op_net_listen_unixpacket); -stub_op!(op_node_unstable_net_listen_unixpacket); -stub_op!(op_net_recv_unixpacket); -stub_op!(op_net_send_unixpacket); + macro_rules! stub_op { + ($name:ident) => { + #[op2(fast)] + pub fn $name() { + panic!("Unsupported on non-unix platforms") + } + }; + ($name:ident

) => { + #[op2(fast)] + pub fn $name() { + panic!("Unsupported on non-unix platforms") + } + }; + } + + stub_op!(op_net_accept_unix); + stub_op!(op_net_connect_unix

); + stub_op!(op_net_listen_unix

); + stub_op!(op_net_listen_unixpacket

); + stub_op!(op_node_unstable_net_listen_unixpacket

); + stub_op!(op_net_recv_unixpacket); + stub_op!(op_net_send_unixpacket

); +} diff --git a/ext/net/ops.rs b/ext/net/ops.rs index 4b24529355175c..a25b6c310f044f 100644 --- a/ext/net/ops.rs +++ b/ext/net/ops.rs @@ -1,8 +1,10 @@ // Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. use crate::io::TcpStreamResource; +use crate::raw::NetworkListenerResource; use crate::resolve_addr::resolve_addr; use crate::resolve_addr::resolve_addr_sync; +use crate::tcp::TcpListener; use crate::NetPermissions; use deno_core::error::bad_resource; use deno_core::error::custom_error; @@ -33,7 +35,6 @@ use std::net::Ipv6Addr; use std::net::SocketAddr; use std::rc::Rc; use std::str::FromStr; -use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio::net::UdpSocket; use trust_dns_proto::rr::rdata::caa::Value; @@ -85,7 +86,7 @@ pub async fn op_net_accept_tcp( let resource = state .borrow() .resource_table - .get::(rid) + .get::>(rid) .map_err(|_| bad_resource("Listener has been closed"))?; let listener = RcRef::map(&resource, |r| &r.listener) .try_borrow_mut() @@ -320,21 +321,6 @@ where Ok((rid, IpAddr::from(local_addr), IpAddr::from(remote_addr))) } -pub struct TcpListenerResource { - pub listener: AsyncRefCell, - pub cancel: CancelHandle, -} - -impl Resource for TcpListenerResource { - fn name(&self) -> Cow { - "tcpListener".into() - } - - fn close(self: Rc) { - self.cancel.cancel(); - } -} - struct UdpSocketResource { socket: AsyncRefCell, cancel: CancelHandle, @@ -369,29 +355,10 @@ where let addr = resolve_addr_sync(&addr.hostname, addr.port)? .next() .ok_or_else(|| generic_error("No resolved address found"))?; - let domain = if addr.is_ipv4() { - Domain::IPV4 - } else { - Domain::IPV6 - }; - let socket = Socket::new(domain, Type::STREAM, None)?; - #[cfg(not(windows))] - socket.set_reuse_address(true)?; - if reuse_port { - #[cfg(any(target_os = "android", target_os = "linux"))] - socket.set_reuse_port(true)?; - } - let socket_addr = socket2::SockAddr::from(addr); - socket.bind(&socket_addr)?; - socket.listen(128)?; - socket.set_nonblocking(true)?; - let std_listener: std::net::TcpListener = socket.into(); - let listener = TcpListener::from_std(std_listener)?; + + let listener = TcpListener::bind_direct(addr, reuse_port)?; let local_addr = listener.local_addr()?; - let listener_resource = TcpListenerResource { - listener: AsyncRefCell::new(listener), - cancel: Default::default(), - }; + let listener_resource = NetworkListenerResource::new(listener); let rid = state.resource_table.add(listener_resource); Ok((rid, IpAddr::from(local_addr))) @@ -781,6 +748,7 @@ mod tests { use socket2::SockRef; use std::net::Ipv4Addr; use std::net::Ipv6Addr; + use std::net::ToSocketAddrs; use std::path::Path; use std::sync::Arc; use std::sync::Mutex; @@ -1030,7 +998,8 @@ mod tests { ) { let sockets = Arc::new(Mutex::new(vec![])); let clone_addr = addr.clone(); - let listener = TcpListener::bind(addr).await.unwrap(); + let addr = addr.to_socket_addrs().unwrap().next().unwrap(); + let listener = TcpListener::bind_direct(addr, false).unwrap(); let accept_fut = listener.accept().boxed_local(); let store_fut = async move { let socket = accept_fut.await.unwrap(); diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs index 874f795f273681..c0ac315865af4e 100644 --- a/ext/net/ops_tls.rs +++ b/ext/net/ops_tls.rs @@ -3,8 +3,10 @@ use crate::io::TcpStreamResource; use crate::ops::IpAddr; use crate::ops::TlsHandshakeInfo; +use crate::raw::NetworkListenerResource; use crate::resolve_addr::resolve_addr; use crate::resolve_addr::resolve_addr_sync; +use crate::tcp::TcpListener; use crate::DefaultTlsOptions; use crate::NetPermissions; use crate::UnsafelyIgnoreCertificateErrors; @@ -36,9 +38,6 @@ use deno_tls::TlsKeys; use rustls_tokio_stream::TlsStreamRead; use rustls_tokio_stream::TlsStreamWrite; use serde::Deserialize; -use socket2::Domain; -use socket2::Socket; -use socket2::Type; use std::borrow::Cow; use std::cell::RefCell; use std::convert::From; @@ -47,13 +46,13 @@ use std::fs::File; use std::io::BufReader; use std::io::ErrorKind; use std::io::Read; +use std::net::SocketAddr; use std::num::NonZeroUsize; use std::path::Path; use std::rc::Rc; use std::sync::Arc; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; -use tokio::net::TcpListener; use tokio::net::TcpStream; pub use rustls_tokio_stream::TlsStream; @@ -61,6 +60,23 @@ pub use rustls_tokio_stream::TlsStream; pub(crate) const TLS_BUFFER_SIZE: Option = NonZeroUsize::new(65536); +pub struct TlsListener { + pub(crate) tcp_listener: TcpListener, + pub(crate) tls_config: Arc, +} + +impl TlsListener { + pub async fn accept(&self) -> std::io::Result<(TlsStream, SocketAddr)> { + let (tcp, addr) = self.tcp_listener.accept().await?; + let tls = + TlsStream::new_server_side(tcp, self.tls_config.clone(), TLS_BUFFER_SIZE); + Ok((tls, addr)) + } + pub fn local_addr(&self) -> std::io::Result { + self.tcp_listener.local_addr() + } +} + #[derive(Debug)] pub struct TlsStreamResource { rd: AsyncRefCell, @@ -399,22 +415,6 @@ fn load_private_keys_from_file( load_private_keys(&key_bytes) } -pub struct TlsListenerResource { - pub(crate) tcp_listener: AsyncRefCell, - pub(crate) tls_config: Arc, - cancel_handle: CancelHandle, -} - -impl Resource for TlsListenerResource { - fn name(&self) -> Cow { - "tlsListener".into() - } - - fn close(self: Rc) { - self.cancel_handle.cancel(); - } -} - #[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListenTlsArgs { @@ -470,31 +470,14 @@ where let bind_addr = resolve_addr_sync(&addr.hostname, addr.port)? .next() .ok_or_else(|| generic_error("No resolved address found"))?; - let domain = if bind_addr.is_ipv4() { - Domain::IPV4 - } else { - Domain::IPV6 - }; - let socket = Socket::new(domain, Type::STREAM, None)?; - #[cfg(not(windows))] - socket.set_reuse_address(true)?; - if args.reuse_port { - #[cfg(any(target_os = "android", target_os = "linux"))] - socket.set_reuse_port(true)?; - } - let socket_addr = socket2::SockAddr::from(bind_addr); - socket.bind(&socket_addr)?; - socket.listen(128)?; - socket.set_nonblocking(true)?; - let std_listener: std::net::TcpListener = socket.into(); - let tcp_listener = TcpListener::from_std(std_listener)?; + + let tcp_listener = TcpListener::bind_direct(bind_addr, args.reuse_port)?; let local_addr = tcp_listener.local_addr()?; - let tls_listener_resource = TlsListenerResource { - tcp_listener: AsyncRefCell::new(tcp_listener), - tls_config: Arc::new(tls_config), - cancel_handle: Default::default(), - }; + let tls_listener_resource = NetworkListenerResource::new(TlsListener { + tcp_listener, + tls_config: tls_config.into(), + }); let rid = state.resource_table.add(tls_listener_resource); @@ -510,16 +493,16 @@ pub async fn op_net_accept_tls( let resource = state .borrow() .resource_table - .get::(rid) + .get::>(rid) .map_err(|_| bad_resource("Listener has been closed"))?; - let cancel_handle = RcRef::map(&resource, |r| &r.cancel_handle); - let tcp_listener = RcRef::map(&resource, |r| &r.tcp_listener) + let cancel_handle = RcRef::map(&resource, |r| &r.cancel); + let listener = RcRef::map(&resource, |r| &r.listener) .try_borrow_mut() .ok_or_else(|| custom_error("Busy", "Another accept task is ongoing"))?; - let (tcp_stream, remote_addr) = - match tcp_listener.accept().try_or_cancel(&cancel_handle).await { + let (tls_stream, remote_addr) = + match listener.accept().try_or_cancel(&cancel_handle).await { Ok(tuple) => tuple, Err(err) if err.kind() == ErrorKind::Interrupted => { // FIXME(bartlomieju): compatibility with current JS implementation. @@ -528,14 +511,7 @@ pub async fn op_net_accept_tls( Err(err) => return Err(err.into()), }; - let local_addr = tcp_stream.local_addr()?; - - let tls_stream = TlsStream::new_server_side( - tcp_stream, - resource.tls_config.clone(), - TLS_BUFFER_SIZE, - ); - + let local_addr = tls_stream.local_addr()?; let rid = { let mut state_ = state.borrow_mut(); state_ @@ -555,6 +531,7 @@ pub async fn op_tls_handshake( let resource = state .borrow() .resource_table - .get::(rid)?; + .get::(rid) + .map_err(|_| bad_resource("Listener has been closed"))?; resource.handshake().await } diff --git a/ext/net/ops_unix.rs b/ext/net/ops_unix.rs index be3e9d153de2fd..7d2f6af3cb90c4 100644 --- a/ext/net/ops_unix.rs +++ b/ext/net/ops_unix.rs @@ -1,6 +1,7 @@ // Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. use crate::io::UnixStreamResource; +use crate::raw::NetworkListenerResource; use crate::NetPermissions; use deno_core::error::bad_resource; use deno_core::error::custom_error; @@ -32,21 +33,6 @@ pub fn into_string(s: std::ffi::OsString) -> Result { }) } -pub(crate) struct UnixListenerResource { - pub listener: AsyncRefCell, - cancel: CancelHandle, -} - -impl Resource for UnixListenerResource { - fn name(&self) -> Cow { - "unixListener".into() - } - - fn close(self: Rc) { - self.cancel.cancel(); - } -} - pub struct UnixDatagramResource { pub socket: AsyncRefCell, pub cancel: CancelHandle, @@ -81,7 +67,7 @@ pub async fn op_net_accept_unix( let resource = state .borrow() .resource_table - .get::(rid) + .get::>(rid) .map_err(|_| bad_resource("Listener has been closed"))?; let listener = RcRef::map(&resource, |r| &r.listener) .try_borrow_mut() @@ -206,10 +192,7 @@ where let listener = UnixListener::bind(address_path)?; let local_addr = listener.local_addr()?; let pathname = local_addr.as_pathname().map(pathstring).transpose()?; - let listener_resource = UnixListenerResource { - listener: AsyncRefCell::new(listener), - cancel: Default::default(), - }; + let listener_resource = NetworkListenerResource::new(listener); let rid = state.resource_table.add(listener_resource); Ok((rid, pathname)) } diff --git a/ext/net/raw.rs b/ext/net/raw.rs index c583da3bd91920..f2de760652aa18 100644 --- a/ext/net/raw.rs +++ b/ext/net/raw.rs @@ -1,176 +1,305 @@ // Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. use crate::io::TcpStreamResource; -#[cfg(unix)] -use crate::io::UnixStreamResource; -use crate::ops::TcpListenerResource; -use crate::ops_tls::TlsListenerResource; use crate::ops_tls::TlsStreamResource; -use crate::ops_tls::TLS_BUFFER_SIZE; -#[cfg(unix)] -use crate::ops_unix::UnixListenerResource; use deno_core::error::bad_resource; use deno_core::error::bad_resource_id; use deno_core::error::AnyError; +use deno_core::AsyncRefCell; +use deno_core::CancelHandle; +use deno_core::Resource; use deno_core::ResourceId; use deno_core::ResourceTable; -use deno_tls::rustls::ServerConfig; -use pin_project::pin_project; -use rustls_tokio_stream::TlsStream; +use std::borrow::Cow; use std::rc::Rc; -use std::sync::Arc; -use tokio::net::TcpStream; -#[cfg(unix)] -use tokio::net::UnixStream; -/// A raw stream of one of the types handled by this extension. -#[pin_project(project = NetworkStreamProject)] -pub enum NetworkStream { - Tcp(#[pin] TcpStream), - Tls(#[pin] TlsStream), - #[cfg(unix)] - Unix(#[pin] UnixStream), +pub trait NetworkStreamTrait: Into { + type Resource; + const RESOURCE_NAME: &'static str; + fn local_address(&self) -> Result; + fn peer_address(&self) -> Result; } -impl From for NetworkStream { - fn from(value: TcpStream) -> Self { - NetworkStream::Tcp(value) - } +#[allow(async_fn_in_trait)] +pub trait NetworkStreamListenerTrait: + Into + Send + Sync +{ + type Stream: NetworkStreamTrait + 'static; + type Addr: Into + 'static; + /// Additional data, if needed + type ResourceData: Default; + const RESOURCE_NAME: &'static str; + async fn accept(&self) -> std::io::Result<(Self::Stream, Self::Addr)>; + fn listen_address(&self) -> Result; } -impl From for NetworkStream { - fn from(value: TlsStream) -> Self { - NetworkStream::Tls(value) - } +/// A strongly-typed network listener resource for something that +/// implements `NetworkListenerTrait`. +pub struct NetworkListenerResource { + pub listener: AsyncRefCell, + /// Associated data for this resource. Not required. + #[allow(unused)] + pub data: T::ResourceData, + pub cancel: CancelHandle, } -#[cfg(unix)] -impl From for NetworkStream { - fn from(value: UnixStream) -> Self { - NetworkStream::Unix(value) +impl Resource + for NetworkListenerResource +{ + fn name(&self) -> Cow { + T::RESOURCE_NAME.into() } -} -/// A raw stream of one of the types handled by this extension. -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum NetworkStreamType { - Tcp, - Tls, - #[cfg(unix)] - Unix, + fn close(self: Rc) { + self.cancel.cancel(); + } } -impl NetworkStream { - pub fn local_address(&self) -> Result { - match self { - Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)), - Self::Tls(tls) => Ok(NetworkStreamAddress::Ip(tls.local_addr()?)), - #[cfg(unix)] - Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.local_addr()?)), +impl NetworkListenerResource { + pub fn new(t: T) -> Self { + Self { + listener: AsyncRefCell::new(t), + data: Default::default(), + cancel: Default::default(), } } - pub fn peer_address(&self) -> Result { - match self { - Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.peer_addr()?)), - Self::Tls(tls) => Ok(NetworkStreamAddress::Ip(tls.peer_addr()?)), - #[cfg(unix)] - Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.peer_addr()?)), + /// Returns a [`NetworkStreamListener`] from this resource if it is not in use elsewhere. + fn take( + resource_table: &mut ResourceTable, + listener_rid: ResourceId, + ) -> Result, AnyError> { + if let Ok(resource_rc) = resource_table.take::(listener_rid) { + let resource = Rc::try_unwrap(resource_rc) + .map_err(|_| bad_resource("Listener is currently in use"))?; + return Ok(Some(resource.listener.into_inner().into())); } + Ok(None) } +} - pub fn stream(&self) -> NetworkStreamType { - match self { - Self::Tcp(_) => NetworkStreamType::Tcp, - Self::Tls(_) => NetworkStreamType::Tls, - #[cfg(unix)] - Self::Unix(_) => NetworkStreamType::Unix, +/// Each of the network streams has the exact same pattern for listening, accepting, etc, so +/// we just codegen them all via macro to avoid repeating each one of these N times. +macro_rules! network_stream { + ( $([$i:ident, $il:ident, $stream:path, $listener:path, $addr:path, $stream_resource:ty]),* ) => { + /// A raw stream of one of the types handled by this extension. + #[pin_project::pin_project(project = NetworkStreamProject)] + pub enum NetworkStream { + $( $i (#[pin] $stream), )* } - } -} -impl tokio::io::AsyncRead for NetworkStream { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_read(cx, buf), - NetworkStreamProject::Tls(s) => s.poll_read(cx, buf), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_read(cx, buf), + /// A raw stream of one of the types handled by this extension. + #[derive(Copy, Clone, PartialEq, Eq)] + pub enum NetworkStreamType { + $( $i, )* } - } -} -impl tokio::io::AsyncWrite for NetworkStream { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_write(cx, buf), - NetworkStreamProject::Tls(s) => s.poll_write(cx, buf), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_write(cx, buf), + /// A raw stream listener of one of the types handled by this extension. + pub enum NetworkStreamListener { + $( $i( $listener ), )* } - } - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_flush(cx), - NetworkStreamProject::Tls(s) => s.poll_flush(cx), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_flush(cx), + $( + impl NetworkStreamListenerTrait for $listener { + type Stream = $stream; + type Addr = $addr; + type ResourceData = (); + const RESOURCE_NAME: &'static str = concat!(stringify!($il), "Listener"); + async fn accept(&self) -> std::io::Result<(Self::Stream, Self::Addr)> { + <$listener> :: accept(self).await + } + fn listen_address(&self) -> std::io::Result { + self.local_addr() + } + } + + impl From<$listener> for NetworkStreamListener { + fn from(value: $listener) -> Self { + Self::$i(value) + } + } + + impl NetworkStreamTrait for $stream { + type Resource = $stream_resource; + const RESOURCE_NAME: &'static str = concat!(stringify!($il), "Stream"); + fn local_address(&self) -> Result { + Ok(NetworkStreamAddress::from(self.local_addr()?)) + } + fn peer_address(&self) -> Result { + Ok(NetworkStreamAddress::from(self.peer_addr()?)) + } + } + + impl From<$stream> for NetworkStream { + fn from(value: $stream) -> Self { + Self::$i(value) + } + } + )* + + impl NetworkStream { + pub fn local_address(&self) -> Result { + match self { + $( Self::$i(stm) => Ok(NetworkStreamAddress::from(stm.local_addr()?)), )* + } + } + + pub fn peer_address(&self) -> Result { + match self { + $( Self::$i(stm) => Ok(NetworkStreamAddress::from(stm.peer_addr()?)), )* + } + } + + pub fn stream(&self) -> NetworkStreamType { + match self { + $( Self::$i(_) => NetworkStreamType::$i, )* + } + } } - } - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_shutdown(cx), - NetworkStreamProject::Tls(s) => s.poll_shutdown(cx), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_shutdown(cx), + impl tokio::io::AsyncRead for NetworkStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_read(cx, buf), )* + } + } } - } - fn is_write_vectored(&self) -> bool { - match self { - Self::Tcp(s) => s.is_write_vectored(), - Self::Tls(s) => s.is_write_vectored(), - #[cfg(unix)] - Self::Unix(s) => s.is_write_vectored(), + impl tokio::io::AsyncWrite for NetworkStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_write(cx, buf), )* + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_flush(cx), )* + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_shutdown(cx), )* + } + } + + fn is_write_vectored(&self) -> bool { + match self { + $( NetworkStream::$i(s) => s.is_write_vectored(), )* + } + } + + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_write_vectored(cx, bufs), )* + } + } } - } - fn poll_write_vectored( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> std::task::Poll> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_write_vectored(cx, bufs), - NetworkStreamProject::Tls(s) => s.poll_write_vectored(cx, bufs), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_write_vectored(cx, bufs), + impl NetworkStreamListener { + /// Accepts a connection on this listener. + pub async fn accept(&self) -> Result<(NetworkStream, NetworkStreamAddress), std::io::Error> { + Ok(match self { + $( + Self::$i(s) => { + let (stm, addr) = s.accept().await?; + (NetworkStream::$i(stm), addr.into()) + } + )* + }) + } + + pub fn listen_address(&self) -> Result { + match self { + $( Self::$i(s) => { Ok(NetworkStreamAddress::from(s.listen_address()?)) } )* + } + } + + pub fn stream(&self) -> NetworkStreamType { + match self { + $( Self::$i(_) => { NetworkStreamType::$i } )* + } + } + + /// Return a `NetworkStreamListener` if a resource exists for this `ResourceId` and it is currently + /// not locked. + pub fn take_resource(resource_table: &mut ResourceTable, listener_rid: ResourceId) -> Result { + $( + if let Some(resource) = NetworkListenerResource::<$listener>::take(resource_table, listener_rid)? { + return Ok(resource) + } + )* + Err(bad_resource_id()) + } } - } + }; } -/// A raw stream listener of one of the types handled by this extension. -pub enum NetworkStreamListener { - Tcp(tokio::net::TcpListener), - Tls(tokio::net::TcpListener, Arc), - #[cfg(unix)] - Unix(tokio::net::UnixListener), -} +#[cfg(unix)] +network_stream!( + [ + Tcp, + tcp, + tokio::net::TcpStream, + crate::tcp::TcpListener, + std::net::SocketAddr, + TcpStreamResource + ], + [ + Tls, + tls, + crate::ops_tls::TlsStream, + crate::ops_tls::TlsListener, + std::net::SocketAddr, + TlsStreamResource + ], + [ + Unix, + unix, + tokio::net::UnixStream, + tokio::net::UnixListener, + tokio::net::unix::SocketAddr, + crate::io::UnixStreamResource + ] +); + +#[cfg(not(unix))] +network_stream!( + [ + Tcp, + tcp, + tokio::net::TcpStream, + crate::tcp::TcpListener, + std::net::SocketAddr, + TcpStreamResource + ], + [ + Tls, + tls, + crate::ops_tls::TlsStream, + crate::ops_tls::TlsListener, + std::net::SocketAddr, + TlsStreamResource + ] +); pub enum NetworkStreamAddress { Ip(std::net::SocketAddr), @@ -178,46 +307,16 @@ pub enum NetworkStreamAddress { Unix(tokio::net::unix::SocketAddr), } -impl NetworkStreamListener { - /// Accepts a connection on this listener. - pub async fn accept(&self) -> Result { - Ok(match self { - Self::Tcp(tcp) => { - let (stream, _addr) = tcp.accept().await?; - NetworkStream::Tcp(stream) - } - Self::Tls(tcp, config) => { - let (stream, _addr) = tcp.accept().await?; - NetworkStream::Tls(TlsStream::new_server_side( - stream, - config.clone(), - TLS_BUFFER_SIZE, - )) - } - #[cfg(unix)] - Self::Unix(unix) => { - let (stream, _addr) = unix.accept().await?; - NetworkStream::Unix(stream) - } - }) - } - - pub fn listen_address(&self) -> Result { - match self { - Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)), - Self::Tls(tcp, _) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)), - #[cfg(unix)] - Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.local_addr()?)), - } +impl From for NetworkStreamAddress { + fn from(value: std::net::SocketAddr) -> Self { + NetworkStreamAddress::Ip(value) } +} - pub fn stream(&self) -> NetworkStreamType { - match self { - Self::Tcp(..) => NetworkStreamType::Tcp, - Self::Tls(..) => NetworkStreamType::Tls, - #[cfg(unix)] - Self::Unix(..) => NetworkStreamType::Unix, - } +#[cfg(unix)] +impl From for NetworkStreamAddress { + fn from(value: tokio::net::unix::SocketAddr) -> Self { + NetworkStreamAddress::Unix(value) } } @@ -252,7 +351,8 @@ pub fn take_network_stream_resource( } #[cfg(unix)] - if let Ok(resource_rc) = resource_table.take::(stream_rid) + if let Ok(resource_rc) = + resource_table.take::(stream_rid) { // This UNIX socket might be used somewhere else. let resource = Rc::try_unwrap(resource_rc) @@ -271,33 +371,5 @@ pub fn take_network_stream_listener_resource( resource_table: &mut ResourceTable, listener_rid: ResourceId, ) -> Result { - if let Ok(resource_rc) = - resource_table.take::(listener_rid) - { - let resource = Rc::try_unwrap(resource_rc) - .map_err(|_| bad_resource("TCP socket listener is currently in use"))?; - return Ok(NetworkStreamListener::Tcp(resource.listener.into_inner())); - } - - if let Ok(resource_rc) = - resource_table.take::(listener_rid) - { - let resource = Rc::try_unwrap(resource_rc) - .map_err(|_| bad_resource("TLS socket listener is currently in use"))?; - return Ok(NetworkStreamListener::Tls( - resource.tcp_listener.into_inner(), - resource.tls_config, - )); - } - - #[cfg(unix)] - if let Ok(resource_rc) = - resource_table.take::(listener_rid) - { - let resource = Rc::try_unwrap(resource_rc) - .map_err(|_| bad_resource("UNIX socket listener is currently in use"))?; - return Ok(NetworkStreamListener::Unix(resource.listener.into_inner())); - } - - Err(bad_resource_id()) + NetworkStreamListener::take_resource(resource_table, listener_rid) } diff --git a/ext/net/tcp.rs b/ext/net/tcp.rs new file mode 100644 index 00000000000000..58362024333b92 --- /dev/null +++ b/ext/net/tcp.rs @@ -0,0 +1,176 @@ +// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; + +use socket2::Domain; +use socket2::Protocol; +use socket2::Type; + +/// Our per-process `Connections`. We can use this to find an existant listener for +/// a given local address and clone its socket for us to listen on in our thread. +static CONNS: std::sync::OnceLock> = + std::sync::OnceLock::new(); + +/// Maintains a map of listening address to `TcpConnection`. +#[derive(Default)] +struct Connections { + tcp: HashMap>, +} + +/// Holds an open listener. We clone the underlying file descriptor (unix) or socket handle (Windows) +/// and then listen on our copy of it. +pub struct TcpConnection { + /// The pristine FD that we'll clone for each LB listener + #[cfg(unix)] + sock: std::os::fd::OwnedFd, + #[cfg(not(unix))] + sock: std::os::windows::io::OwnedSocket, + key: SocketAddr, +} + +impl TcpConnection { + /// Boot a load-balanced TCP connection + pub fn start(key: SocketAddr) -> std::io::Result { + let listener = bind_socket_and_listen(key, false)?; + let sock = listener.into(); + + Ok(Self { sock, key }) + } + + fn listener(&self) -> std::io::Result { + let listener = std::net::TcpListener::from(self.sock.try_clone()?); + let listener = tokio::net::TcpListener::from_std(listener)?; + Ok(listener) + } +} + +/// A TCP socket listener that optionally allows for round-robin load-balancing in-process. +pub struct TcpListener { + listener: Option, + conn: Option>, +} + +/// Does this platform implement `SO_REUSEPORT` in a load-balancing manner? +const REUSE_PORT_LOAD_BALANCES: bool = + cfg!(any(target_os = "android", target_os = "linux")); + +impl TcpListener { + /// Bind to a port. On Linux, or when we don't have `SO_REUSEPORT` set, we just bind the port directly. + /// On other platforms, we emulate `SO_REUSEPORT` by cloning the socket and having each clone race to + /// accept every connection. + /// + /// ## Why not `SO_REUSEPORT`? + /// + /// The `SO_REUSEPORT` socket option allows multiple sockets on the same host to bind to the same port. This is + /// particularly useful for load balancing or implementing high availability in server applications. + /// + /// On Linux, `SO_REUSEPORT` allows multiple sockets to bind to the same port, and the kernel will load + /// balance incoming connections among those sockets. Each socket can accept connections independently. + /// This is useful for scenarios where you want to distribute incoming connections among multiple processes + /// or threads. + /// + /// On macOS (which is based on BSD), the behaviour of `SO_REUSEPORT` is slightly different. When `SO_REUSEPORT` is set, + /// multiple sockets can still bind to the same port, but the kernel does not perform load balancing as it does on Linux. + /// Instead, it follows a "last bind wins" strategy. This means that the most recently bound socket will receive + /// incoming connections exclusively, while the previously bound sockets will not receive any connections. + /// This behaviour is less useful for load balancing compared to Linux, but it can still be valuable in certain scenarios. + pub fn bind( + socket_addr: SocketAddr, + reuse_port: bool, + ) -> std::io::Result { + if REUSE_PORT_LOAD_BALANCES && reuse_port { + Self::bind_load_balanced(socket_addr) + } else { + Self::bind_direct(socket_addr, reuse_port) + } + } + + /// Bind directly to the port, passing `reuse_port` directly to the socket. On platforms other + /// than Linux, `reuse_port` does not do any load balancing. + pub fn bind_direct( + socket_addr: SocketAddr, + reuse_port: bool, + ) -> std::io::Result { + // We ignore `reuse_port` on platforms other than Linux to match the existing behaviour. + let listener = bind_socket_and_listen(socket_addr, reuse_port)?; + Ok(Self { + listener: Some(tokio::net::TcpListener::from_std(listener)?), + conn: None, + }) + } + + /// Bind to the port in a load-balanced manner. + pub fn bind_load_balanced(socket_addr: SocketAddr) -> std::io::Result { + let tcp = &mut CONNS.get_or_init(Default::default).lock().unwrap().tcp; + if let Some(conn) = tcp.get(&socket_addr) { + let listener = Some(conn.listener()?); + return Ok(Self { + listener, + conn: Some(conn.clone()), + }); + } + let conn = Arc::new(TcpConnection::start(socket_addr)?); + let listener = Some(conn.listener()?); + tcp.insert(socket_addr, conn.clone()); + Ok(Self { + listener, + conn: Some(conn), + }) + } + + pub async fn accept( + &self, + ) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> { + let (tcp, addr) = self.listener.as_ref().unwrap().accept().await?; + Ok((tcp, addr)) + } + + pub fn local_addr(&self) -> std::io::Result { + self.listener.as_ref().unwrap().local_addr() + } +} + +impl Drop for TcpListener { + fn drop(&mut self) { + // If we're in load-balancing mode + if let Some(conn) = self.conn.take() { + let mut tcp = CONNS.get().unwrap().lock().unwrap(); + if Arc::strong_count(&conn) == 2 { + tcp.tcp.remove(&conn.key); + // Close the connection + debug_assert_eq!(Arc::strong_count(&conn), 1); + drop(conn); + } + } + } +} + +/// Bind a socket to an address and listen with the low-level options we need. +#[allow(unused_variables)] +fn bind_socket_and_listen( + socket_addr: SocketAddr, + reuse_port: bool, +) -> Result { + let socket = if socket_addr.is_ipv4() { + socket2::Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))? + } else { + socket2::Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))? + }; + #[cfg(not(windows))] + if REUSE_PORT_LOAD_BALANCES && reuse_port { + socket.set_reuse_port(true)?; + } + #[cfg(not(windows))] + // This is required for re-use of a port immediately after closing. There's a small + // security trade-off here but we err on the side of convenience. + // + // https://stackoverflow.com/questions/14388706/how-do-so-reuseaddr-and-so-reuseport-differ + // https://stackoverflow.com/questions/26772549/is-it-a-good-idea-to-reuse-port-using-option-so-reuseaddr-which-is-already-in-ti + socket.set_reuse_address(true)?; + socket.set_nonblocking(true)?; + socket.bind(&socket_addr.into())?; + socket.listen(128)?; + let listener = socket.into(); + Ok(listener) +} diff --git a/tests/unit/net_test.ts b/tests/unit/net_test.ts index eae1ae533d5b0c..dff3cc31fa97d6 100644 --- a/tests/unit/net_test.ts +++ b/tests/unit/net_test.ts @@ -1205,6 +1205,7 @@ Deno.test({ conn.close(); listener1Recv = true; p1 = undefined; + listener1.close(); }).catch(() => {}); } if (!p2) { @@ -1212,14 +1213,13 @@ Deno.test({ conn.close(); listener2Recv = true; p2 = undefined; + listener2.close(); }).catch(() => {}); } const conn = await Deno.connect({ port }); conn.close(); await Promise.race([p1, p2]); } - listener1.close(); - listener2.close(); }); Deno.test({ diff --git a/tests/unit/tls_test.ts b/tests/unit/tls_test.ts index 81d8de315093db..8f0a296c72ff88 100644 --- a/tests/unit/tls_test.ts +++ b/tests/unit/tls_test.ts @@ -1562,6 +1562,7 @@ Deno.test({ conn.close(); listener1Recv = true; p1 = undefined; + listener1.close(); }).catch(() => {}); } if (!p2) { @@ -1569,14 +1570,13 @@ Deno.test({ conn.close(); listener2Recv = true; p2 = undefined; + listener2.close(); }).catch(() => {}); } const conn = await Deno.connectTls({ hostname, port, caCerts }); conn.close(); await Promise.race([p1, p2]); } - listener1.close(); - listener2.close(); }); Deno.test({