From cb4d18bcfc1050b346838d03737c222f7e366a5a Mon Sep 17 00:00:00 2001 From: James Sturtevant Date: Wed, 1 Mar 2023 14:27:48 -0800 Subject: [PATCH 1/4] Add integration test for sync example Adds asserts to the example project so it will fail if something isn't work between the client and server. This also adds an integration test that builds and runs the sync example so the changes to the project can be validated e2e. Signed-off-by: James Sturtevant --- example/client.rs | 77 +++++++++++++++++++++++++++++++++------------- tests/sync-test.rs | 61 ++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 22 deletions(-) create mode 100644 tests/sync-test.rs diff --git a/example/client.rs b/example/client.rs index bd2fb898..04b96722 100644 --- a/example/client.rs +++ b/example/client.rs @@ -15,12 +15,17 @@ mod protocols; mod utils; +use log::LevelFilter; use protocols::sync::{agent, agent_ttrpc, health, health_ttrpc}; use std::thread; use ttrpc::context::{self, Context}; +use ttrpc::error::Error; +use ttrpc::proto::Code; use ttrpc::Client; fn main() { + simple_logging::log_to_stderr(LevelFilter::Trace); + let c = Client::connect(utils::SOCK_ADDR).unwrap(); let hc = health_ttrpc::HealthClient::new(c.clone()); let ac = agent_ttrpc::AgentServiceClient::new(c); @@ -33,69 +38,97 @@ fn main() { let t = thread::spawn(move || { let req = health::CheckRequest::new(); println!( - "OS Thread {:?} - {} started: {:?}", + "OS Thread {:?} - health.check() started: {:?}", std::thread::current().id(), - "health.check()", now.elapsed(), ); + + let rsp = thc.check(default_ctx(), &req); + match rsp.as_ref() { + Err(Error::RpcStatus(s)) => { + assert_eq!(Code::NOT_FOUND, s.code()); + assert_eq!("Just for fun".to_string(), s.message()) + } + Err(e) => { + panic!("not expecting an error from the example server: {:?}", e) + } + Ok(x) => { + panic!("not expecting a OK response from the example server: {:?}", x) + } + } println!( - "OS Thread {:?} - {} -> {:?} ended: {:?}", + "OS Thread {:?} - health.check() -> {:?} ended: {:?}", std::thread::current().id(), - "health.check()", - thc.check(default_ctx(), &req), + rsp, now.elapsed(), ); }); let t2 = thread::spawn(move || { println!( - "OS Thread {:?} - {} started: {:?}", + "OS Thread {:?} - agent.list_interfaces() started: {:?}", std::thread::current().id(), - "agent.list_interfaces()", now.elapsed(), ); let show = match tac.list_interfaces(default_ctx(), &agent::ListInterfacesRequest::new()) { - Err(e) => format!("{:?}", e), - Ok(s) => format!("{:?}", s), + Err(e) => { + panic!("not expecting an error from the example server: {:?}", e) + } + Ok(s) => { + assert_eq!("first".to_string(), s.Interfaces[0].name); + assert_eq!("second".to_string(), s.Interfaces[1].name); + format!("{s:?}") + } }; println!( - "OS Thread {:?} - {} -> {} ended: {:?}", + "OS Thread {:?} - agent.list_interfaces() -> {} ended: {:?}", std::thread::current().id(), - "agent.list_interfaces()", show, now.elapsed(), ); }); println!( - "Main OS Thread - {} started: {:?}", - "agent.online_cpu_mem()", + "Main OS Thread - agent.online_cpu_mem() started: {:?}", now.elapsed() ); let show = match ac.online_cpu_mem(default_ctx(), &agent::OnlineCPUMemRequest::new()) { - Err(e) => format!("{:?}", e), - Ok(s) => format!("{:?}", s), + Err(Error::RpcStatus(s)) => { + assert_eq!(Code::NOT_FOUND, s.code()); + assert_eq!( + "/grpc.AgentService/OnlineCPUMem is not supported".to_string(), + s.message() + ); + format!("{s:?}") + } + Err(e) => { + panic!("not expecting an error from the example server: {:?}", e) + } + Ok(s) => { + panic!("not expecting a OK response from the example server: {:?}", s) + } }; println!( - "Main OS Thread - {} -> {} ended: {:?}", - "agent.online_cpu_mem()", + "Main OS Thread - agent.online_cpu_mem() -> {} ended: {:?}", show, now.elapsed() ); println!("\nsleep 2 seconds ...\n"); thread::sleep(std::time::Duration::from_secs(2)); + + let version = hc.version(default_ctx(), &health::CheckRequest::new()); + assert_eq!("mock.0.1", version.as_ref().unwrap().agent_version.as_str()); + assert_eq!("0.0.1", version.as_ref().unwrap().grpc_version.as_str()); println!( - "Main OS Thread - {} started: {:?}", - "health.version()", + "Main OS Thread - health.version() started: {:?}", now.elapsed() ); println!( - "Main OS Thread - {} -> {:?} ended: {:?}", - "health.version()", - hc.version(default_ctx(), &health::CheckRequest::new()), + "Main OS Thread - health.version() -> {:?} ended: {:?}", + version, now.elapsed() ); diff --git a/tests/sync-test.rs b/tests/sync-test.rs new file mode 100644 index 00000000..a5a92149 --- /dev/null +++ b/tests/sync-test.rs @@ -0,0 +1,61 @@ +use std::{ + io::{BufRead, BufReader}, + process::Command, + time::Duration, +}; + +#[test] +fn run_sync_example() -> Result<(), Box> { + // start the server and give it a moment to start. + let mut server = run_example("server").spawn().unwrap(); + std::thread::sleep(Duration::from_secs(2)); + + let mut client = run_example("client").spawn().unwrap(); + let mut client_succeeded = false; + let start = std::time::Instant::now(); + let timeout = Duration::from_secs(600); + loop { + if start.elapsed() > timeout { + println!("Running the client timed out. output:"); + client.kill().unwrap_or_else(|e| { + println!("This may occur on Windows if the process has exited: {e}"); + }); + let output = client.stdout.unwrap(); + BufReader::new(output).lines().for_each(|line| { + println!("{}", line.unwrap()); + }); + break; + } + + match client.try_wait() { + Ok(Some(status)) => { + client_succeeded = status.success(); + break; + } + Ok(None) => { + // still running + continue; + } + Err(e) => { + println!("Error: {e}"); + break; + } + } + } + + // be sure to clean up the server, the client should have run to completion + server.kill()?; + assert!(client_succeeded); + Ok(()) +} + +fn run_example(example: &str) -> Command { + let mut cmd = Command::new("cargo"); + cmd.arg("run") + .arg("--example") + .arg(example) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .current_dir("example"); + cmd +} From e735bb35de17e633d1096ee5d22cd85ca81e4125 Mon Sep 17 00:00:00 2001 From: James Sturtevant Date: Tue, 24 Jan 2023 00:56:13 -0800 Subject: [PATCH 2/4] Refactor to support other OSes This moves unix specific calls out of the main server and client functions. It does this by introducing several new types: PipeListener, PipeConnection, and ClientConnection. These types are contain the unix specific functionality to communitcate with Unix Domain sockets and are hidden behind a conditional compilation flag. Signed-off-by: James Sturtevant --- src/sync/channel.rs | 34 ++-- src/sync/client.rs | 113 +++++-------- src/sync/mod.rs | 1 + src/sync/server.rs | 194 ++++++++--------------- src/sync/sys/mod.rs | 5 + src/sync/sys/unix/mod.rs | 2 + src/sync/sys/unix/net.rs | 334 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 459 insertions(+), 224 deletions(-) create mode 100644 src/sync/sys/mod.rs create mode 100644 src/sync/sys/unix/mod.rs create mode 100644 src/sync/sys/unix/net.rs diff --git a/src/sync/channel.rs b/src/sync/channel.rs index a23b5915..e9e341ca 100644 --- a/src/sync/channel.rs +++ b/src/sync/channel.rs @@ -12,18 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use nix::sys::socket::*; -use std::os::unix::io::RawFd; use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; +use crate::sync::sys::{PipeConnection}; use crate::proto::{Code, MessageHeader, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX}; -fn retryable(e: nix::Error) -> bool { - use ::nix::Error; - e == Error::EINTR || e == Error::EAGAIN -} - -fn read_count(fd: RawFd, count: usize) -> Result> { +fn read_count (fd: &PipeConnection, count: usize) -> Result> { let mut v: Vec = vec![0; count]; let mut len = 0; @@ -32,7 +26,7 @@ fn read_count(fd: RawFd, count: usize) -> Result> { } loop { - match recv(fd, &mut v[len..], MsgFlags::empty()) { + match fd.read(&mut v[len..]) { Ok(l) => { len += l; // when socket peer closed, it would return 0. @@ -40,11 +34,6 @@ fn read_count(fd: RawFd, count: usize) -> Result> { break; } } - - Err(e) if retryable(e) => { - // Should retry - } - Err(e) => { return Err(Error::Socket(e.to_string())); } @@ -54,7 +43,7 @@ fn read_count(fd: RawFd, count: usize) -> Result> { Ok(v[0..len].to_vec()) } -fn write_count(fd: RawFd, buf: &[u8], count: usize) -> Result { +fn write_count(fd: &PipeConnection, buf: &[u8], count: usize) -> Result { let mut len = 0; if count == 0 { @@ -62,18 +51,13 @@ fn write_count(fd: RawFd, buf: &[u8], count: usize) -> Result { } loop { - match send(fd, &buf[len..], MsgFlags::empty()) { + match fd.write(&buf[len..]){ Ok(l) => { len += l; if len == count { break; } } - - Err(e) if retryable(e) => { - // Should retry - } - Err(e) => { return Err(Error::Socket(e.to_string())); } @@ -83,7 +67,7 @@ fn write_count(fd: RawFd, buf: &[u8], count: usize) -> Result { Ok(len) } -fn read_message_header(fd: RawFd) -> Result { +fn read_message_header(fd: &PipeConnection) -> Result { let buf = read_count(fd, MESSAGE_HEADER_LENGTH)?; let size = buf.len(); if size != MESSAGE_HEADER_LENGTH { @@ -98,7 +82,7 @@ fn read_message_header(fd: RawFd) -> Result { Ok(mh) } -pub fn read_message(fd: RawFd) -> Result<(MessageHeader, Vec)> { +pub fn read_message(fd: &PipeConnection) -> Result<(MessageHeader, Vec)> { let mh = read_message_header(fd)?; trace!("Got Message header {:?}", mh); @@ -125,7 +109,7 @@ pub fn read_message(fd: RawFd) -> Result<(MessageHeader, Vec)> { Ok((mh, buf)) } -fn write_message_header(fd: RawFd, mh: MessageHeader) -> Result<()> { +fn write_message_header(fd: &PipeConnection, mh: MessageHeader) -> Result<()> { let buf: Vec = mh.into(); let size = write_count(fd, &buf, MESSAGE_HEADER_LENGTH)?; @@ -139,7 +123,7 @@ fn write_message_header(fd: RawFd, mh: MessageHeader) -> Result<()> { Ok(()) } -pub fn write_message(fd: RawFd, mh: MessageHeader, buf: Vec) -> Result<()> { +pub fn write_message(fd: &PipeConnection, mh: MessageHeader, buf: Vec) -> Result<()> { write_message_header(fd, mh)?; let size = write_count(fd, &buf, buf.len())?; diff --git a/src/sync/client.rs b/src/sync/client.rs index 94872651..564b78e4 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -14,18 +14,15 @@ //! Sync client of ttrpc. -use nix::sys::socket::*; -use nix::unistd::close; + use std::collections::HashMap; use std::os::unix::io::RawFd; use std::sync::mpsc; use std::sync::{Arc, Mutex}; -use std::{io, thread}; +use std::{thread}; -#[cfg(target_os = "macos")] -use crate::common::set_fd_close_exec; -use crate::common::{client_connect, SOCK_CLOEXEC}; use crate::error::{Error, Result}; +use crate::sync::sys::{ClientConnection}; use crate::proto::{Code, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE}; use crate::sync::channel::{read_message, write_message}; use std::time::Duration; @@ -36,38 +33,36 @@ type Receiver = mpsc::Receiver<(Vec, mpsc::SyncSender>>)>; /// A ttrpc Client (sync). #[derive(Clone)] pub struct Client { - _fd: RawFd, + _fd: Arc, sender_tx: Sender, - _client_close: Arc, } impl Client { pub fn connect(sockaddr: &str) -> Result { - let fd = unsafe { client_connect(sockaddr)? }; - Ok(Self::new(fd)) + let conn = ClientConnection::client_connect(sockaddr)?; + + Ok(Self::new_client(conn)) } /// Initialize a new [`Client`] from raw file descriptor. pub fn new(fd: RawFd) -> Client { - let (sender_tx, rx): (Sender, Receiver) = mpsc::channel(); + let conn = ClientConnection::new(fd); - let (recver_fd, close_fd) = - socketpair(AddressFamily::Unix, SockType::Stream, None, SOCK_CLOEXEC).unwrap(); + Self::new_client(conn) + } - // MacOS doesn't support descriptor creation with SOCK_CLOEXEC automically, - // so there is a chance of leak if fork + exec happens in between of these calls. - #[cfg(target_os = "macos")] - { - set_fd_close_exec(recver_fd).unwrap(); - set_fd_close_exec(close_fd).unwrap(); - } + fn new_client(pipe_client: ClientConnection) -> Client { + let client = Arc::new(pipe_client); + - let client_close = Arc::new(ClientClose { fd, close_fd }); + let (sender_tx, rx): (Sender, Receiver) = mpsc::channel(); + let recver_map_orig = Arc::new(Mutex::new(HashMap::new())); //Sender let recver_map = recver_map_orig.clone(); + let sender_client = client.clone(); thread::spawn(move || { let mut stream_id: u32 = 1; for (buf, recver_tx) in rx.iter() { @@ -80,7 +75,8 @@ impl Client { } let mut mh = MessageHeader::new_request(0, buf.len() as u32); mh.set_stream_id(current_stream_id); - if let Err(e) = write_message(fd, mh, buf) { + let c = sender_client.get_pipe_connection(); + if let Err(e) = write_message(&c, mh, buf) { //Remove current_stream_id and recver_tx to recver_map { let mut map = recver_map.lock().unwrap(); @@ -95,53 +91,28 @@ impl Client { }); //Recver + let reciever_client = client.clone(); thread::spawn(move || { - let mut pollers = vec![ - libc::pollfd { - fd: recver_fd, - events: libc::POLLIN, - revents: 0, - }, - libc::pollfd { - fd, - events: libc::POLLIN, - revents: 0, - }, - ]; + loop { - let returned = unsafe { - let pollers: &mut [libc::pollfd] = &mut pollers; - libc::poll( - pollers as *mut _ as *mut libc::pollfd, - pollers.len() as _, - -1, - ) - }; - - if returned == -1 { - let err = io::Error::last_os_error(); - if err.raw_os_error() == Some(libc::EINTR) { + + match reciever_client.ready() { + Ok(None) => { continue; } - - error!("fatal error in process reaper:{}", err); - break; - } else if returned < 1 { - continue; - } - - if pollers[0].revents != 0 { - break; - } - - if pollers[pollers.len() - 1].revents == 0 { - continue; + Ok(_) => {} + Err(e) => { + error!("pipeConnection ready error {:?}", e); + break; + } } - let mh; let buf; - match read_message(fd) { + + let pipe_connection = reciever_client.get_pipe_connection(); + + match read_message(&pipe_connection) { Ok((x, y)) => { mh = x; buf = y; @@ -190,10 +161,9 @@ impl Client { map.remove(&mh.stream_id); } - let _ = close(recver_fd).map_err(|e| { + let _ = reciever_client.close_receiver().map_err(|e| { warn!( - "failed to close recver_fd: {} with error: {:?}", - recver_fd, e + "failed to close with error: {:?}", e ) }); @@ -201,9 +171,8 @@ impl Client { }); Client { - _fd: fd, + _fd: client, sender_tx, - _client_close: client_close, } } pub fn request(&self, req: Request) -> Result { @@ -239,15 +208,9 @@ impl Client { } } -struct ClientClose { - fd: RawFd, - close_fd: RawFd, -} - -impl Drop for ClientClose { +impl Drop for ClientConnection { fn drop(&mut self) { - close(self.close_fd).unwrap(); - close(self.fd).unwrap(); - trace!("All client is droped"); + self.close().unwrap(); + trace!("All client is dropped"); } } diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 2830460b..53d0680c 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -8,6 +8,7 @@ mod channel; mod client; mod server; +mod sys; #[macro_use] mod utils; diff --git a/src/sync/server.rs b/src/sync/server.rs index 0a0e6e3f..6faf99d6 100644 --- a/src/sync/server.rs +++ b/src/sync/server.rs @@ -13,27 +13,26 @@ // limitations under the License. //! Sync server of ttrpc. +//! + +#[cfg(target_os = "linux")] +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; -use nix::sys::socket::{self, *}; -use nix::unistd::*; use protobuf::{CodedInputStream, Message}; use std::collections::HashMap; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::mpsc::{channel, sync_channel, Receiver, Sender, SyncSender}; use std::sync::{Arc, Mutex}; use std::thread::JoinHandle; -use std::{io, thread}; +use std::{thread}; use super::utils::response_to_channel; -use crate::common; -#[cfg(not(any(target_os = "linux", target_os = "android")))] -use crate::common::set_fd_close_exec; use crate::context; use crate::error::{get_status, Error, Result}; use crate::proto::{Code, MessageHeader, Request, Response, MESSAGE_TYPE_REQUEST}; use crate::sync::channel::{read_message, write_message}; use crate::{MethodHandler, TtrpcContext}; +use crate::sync::sys::{PipeListener, PipeConnection}; // poll_queue will create WAIT_THREAD_COUNT_DEFAULT threads in begin. // If wait thread count < WAIT_THREAD_COUNT_MIN, create number to WAIT_THREAD_COUNT_DEFAULT. @@ -47,10 +46,9 @@ type MessageReceiver = Receiver<(MessageHeader, Vec)>; /// A ttrpc Server (sync). pub struct Server { - listeners: Vec, - monitor_fd: (RawFd, RawFd), + listeners: Vec>, listener_quit_flag: Arc, - connections: Arc>>, + connections: Arc>>, methods: Arc>>, handler: Option>, reaper: Option<(Sender, JoinHandle<()>)>, @@ -59,22 +57,30 @@ pub struct Server { thread_count_max: usize, } -struct Connection { - fd: RawFd, +struct Connection + { + fd: Arc, quit: Arc, handler: Option>, } -impl Connection { - fn close(&self) { +impl Connection + { + fn close (&self) { + self.fd.close().unwrap_or(()); + } + + fn shutdown(&self) { self.quit.store(true, Ordering::SeqCst); + // in case the connection had closed - socket::shutdown(self.fd, Shutdown::Read).unwrap_or(()); + self.fd.shutdown().unwrap_or(()); } } -struct ThreadS<'a> { - fd: RawFd, +struct ThreadS<'a> +{ + fd: &'a Arc, fdlock: &'a Arc>, wtc: &'a Arc, quit: &'a Arc, @@ -88,7 +94,7 @@ struct ThreadS<'a> { #[allow(clippy::too_many_arguments)] fn start_method_handler_thread( - fd: RawFd, + fd: Arc, fdlock: Arc>, wtc: Arc, quit: Arc, @@ -116,7 +122,7 @@ fn start_method_handler_thread( .unwrap_or_else(|err| trace!("Failed to send {:?}", err)); break; } - result = read_message(fd); + result = read_message(&fd); } if quit.load(Ordering::SeqCst) { @@ -207,7 +213,7 @@ fn start_method_handler_thread( continue; }; let ctx = TtrpcContext { - fd, + fd: fd.id(), mh, res_tx: res_tx.clone(), metadata: context::from_pb(&req.metadata), @@ -228,13 +234,14 @@ fn start_method_handler_thread( }); } -fn start_method_handler_threads(num: usize, ts: &ThreadS) { +fn start_method_handler_threads(num: usize, ts: &ThreadS) + { for _ in 0..num { if ts.quit.load(Ordering::SeqCst) { break; } start_method_handler_thread( - ts.fd, + ts.fd.clone(), ts.fdlock.clone(), ts.wtc.clone(), ts.quit.clone(), @@ -247,7 +254,8 @@ fn start_method_handler_threads(num: usize, ts: &ThreadS) { } } -fn check_method_handler_threads(ts: &ThreadS) { +fn check_method_handler_threads(ts: &ThreadS) + { let c = ts.wtc.load(Ordering::SeqCst); if c < ts.min { start_method_handler_threads(ts.default - c, ts); @@ -258,7 +266,6 @@ impl Default for Server { fn default() -> Self { Server { listeners: Vec::with_capacity(1), - monitor_fd: (-1, -1), listener_quit_flag: Arc::new(AtomicBool::new(false)), connections: Arc::new(Mutex::new(HashMap::new())), methods: Arc::new(HashMap::new()), @@ -283,15 +290,22 @@ impl Server { )); } - let (fd, _) = common::do_bind(sockaddr)?; - common::do_listen(fd)?; + let listener = PipeListener::new(sockaddr)?; - self.listeners.push(fd); + self.listeners.push(Arc::new(listener)); Ok(self) } pub fn add_listener(mut self, fd: RawFd) -> Result { - self.listeners.push(fd); + if !self.listeners.is_empty() { + return Err(Error::Others( + "ttrpc-rust just support 1 sockaddr now".to_string(), + )); + } + + let listener = PipeListener::new_from_fd(fd)?; + + self.listeners.push(Arc::new(listener)); Ok(self) } @@ -329,27 +343,14 @@ impl Server { self.listener_quit_flag.store(false, Ordering::SeqCst); - #[cfg(any(target_os = "linux", target_os = "android"))] - let fds = pipe2(nix::fcntl::OFlag::O_CLOEXEC)?; - - #[cfg(not(any(target_os = "linux", target_os = "android")))] - let fds = { - let (rfd, wfd) = pipe()?; - set_fd_close_exec(rfd)?; - set_fd_close_exec(wfd)?; - (rfd, wfd) - }; - - self.monitor_fd = fds; - - let listener = self.listeners[0]; + + let listener = self.listeners[0].clone(); let methods = self.methods.clone(); let default = self.thread_count_default; let min = self.thread_count_min; let max = self.thread_count_max; let listener_quit_flag = self.listener_quit_flag.clone(); - let monitor_fd = self.monitor_fd.0; let reaper_tx = match self.reaper.take() { None => { @@ -366,7 +367,7 @@ impl Server { .map(|mut cn| { cn.handler.take().map(|handler| { handler.join().unwrap(); - close(fd).unwrap(); + cn.close() }) }); } @@ -386,86 +387,29 @@ impl Server { let handler = thread::Builder::new() .name("listener_loop".into()) .spawn(move || { - let mut pollers = vec![ - libc::pollfd { - fd: monitor_fd, - events: libc::POLLIN, - revents: 0, - }, - libc::pollfd { - fd: listener, - events: libc::POLLIN, - revents: 0, - }, - ]; - - loop { - if listener_quit_flag.load(Ordering::SeqCst) { - info!("listener shutdown for quit flag"); - break; - } - - let returned = unsafe { - let pollers: &mut [libc::pollfd] = &mut pollers; - libc::poll( - pollers as *mut _ as *mut libc::pollfd, - pollers.len() as _, - -1, - ) - }; + + let listener = listener; - if returned == -1 { - let err = io::Error::last_os_error(); - if err.raw_os_error() == Some(libc::EINTR) { + loop { + let pipe_connection = match listener.accept(&listener_quit_flag) { + Ok(None) => { continue; } - - error!("fatal error in listener_loop:{:?}", err); - break; - } else if returned < 1 { - continue; - } - - if pollers[0].revents != 0 || pollers[pollers.len() - 1].revents == 0 { - continue; - } - - if listener_quit_flag.load(Ordering::SeqCst) { - info!("listener shutdown for quit flag"); - break; - } - - #[cfg(any(target_os = "linux", target_os = "android"))] - let fd = match accept4(listener, SockFlag::SOCK_CLOEXEC) { - Ok(fd) => fd, - Err(e) => { - error!("failed to accept error {:?}", e); - break; - } - }; - - // Non Linux platforms do not support accept4 with SOCK_CLOEXEC flag, so instead - // use accept and call fcntl separately to set SOCK_CLOEXEC. - // Because of this there is chance of the descriptor leak if fork + exec happens in between. - #[cfg(not(any(target_os = "linux", target_os = "android")))] - let fd = match accept(listener) { - Ok(fd) => { - if let Err(err) = set_fd_close_exec(fd) { - error!("fcntl failed after accept: {:?}", err); - break; - }; - fd + Ok(Some(conn)) => { + Arc::new(conn) } Err(e) => { - error!("failed to accept error {:?}", e); + error!("listener accept got {:?}", e); break; } }; + let methods = methods.clone(); let quit = Arc::new(AtomicBool::new(false)); let child_quit = quit.clone(); let reaper_tx_child = reaper_tx.clone(); + let pipe_connection_child = pipe_connection.clone(); let handler = thread::Builder::new() .name("client_handler".into()) @@ -473,11 +417,12 @@ impl Server { debug!("Got new client"); // Start response thread let quit_res = child_quit.clone(); + let pipe = pipe_connection_child.clone(); let (res_tx, res_rx): (MessageSender, MessageReceiver) = channel(); let handler = thread::spawn(move || { for r in res_rx.iter() { trace!("response thread get {:?}", r); - if let Err(e) = write_message(fd, r.0, r.1) { + if let Err(e) = write_message(&pipe, r.0, r.1) { error!("write_message got {:?}", e); quit_res.store(true, Ordering::SeqCst); break; @@ -487,10 +432,11 @@ impl Server { trace!("response thread quit"); }); + let pipe = pipe_connection_child.clone(); let (control_tx, control_rx): (SyncSender<()>, Receiver<()>) = sync_channel(0); let ts = ThreadS { - fd, + fd: &pipe, fdlock: &Arc::new(Mutex::new(())), wtc: &Arc::new(AtomicUsize::new(0)), methods: &methods, @@ -517,17 +463,18 @@ impl Server { handler.join().unwrap_or(()); // client_handler should not close fd before exit // , which prevent fd reuse issue. - reaper_tx_child.send(fd).unwrap(); + reaper_tx_child.send(pipe.id()).unwrap(); debug!("client thread quit"); }) .unwrap(); let mut cns = connections.lock().unwrap(); + cns.insert( - fd, + pipe_connection.id(), Connection { - fd, + fd: pipe_connection, handler: Some(handler), quit: quit.clone(), }, @@ -563,12 +510,9 @@ impl Server { pub fn stop_listen(mut self) -> Self { self.listener_quit_flag.store(true, Ordering::SeqCst); - close(self.monitor_fd.1).unwrap_or_else(|e| { - warn!( - "failed to close notify fd: {} with error: {}", - self.monitor_fd.1, e - ) - }); + + self.listeners[0].close().unwrap(); + info!("close monitor"); if let Some(handler) = self.handler.take() { handler.join().unwrap(); @@ -582,7 +526,7 @@ impl Server { let connections = self.connections.lock().unwrap(); for (_fd, c) in connections.iter() { - c.close(); + c.shutdown(); } // release connections's lock, since the following handler.join() // would wait on the other thread's exit in which would take the lock. @@ -601,14 +545,16 @@ impl Server { } } +#[cfg(target_os = "linux")] impl FromRawFd for Server { unsafe fn from_raw_fd(fd: RawFd) -> Self { Self::default().add_listener(fd).unwrap() } } +#[cfg(target_os = "linux")] impl AsRawFd for Server { fn as_raw_fd(&self) -> RawFd { - self.listeners[0] + self.listeners[0].as_raw_fd() } } diff --git a/src/sync/sys/mod.rs b/src/sync/sys/mod.rs new file mode 100644 index 00000000..f0e91790 --- /dev/null +++ b/src/sync/sys/mod.rs @@ -0,0 +1,5 @@ +#[cfg(unix)] +mod unix; +#[cfg(unix)] +pub use crate::sync::sys::unix::{PipeConnection, PipeListener, ClientConnection}; + diff --git a/src/sync/sys/unix/mod.rs b/src/sync/sys/unix/mod.rs new file mode 100644 index 00000000..bc36d736 --- /dev/null +++ b/src/sync/sys/unix/mod.rs @@ -0,0 +1,2 @@ +mod net; +pub use net::{PipeConnection, PipeListener, ClientConnection}; diff --git a/src/sync/sys/unix/net.rs b/src/sync/sys/unix/net.rs new file mode 100644 index 00000000..eb29dfb6 --- /dev/null +++ b/src/sync/sys/unix/net.rs @@ -0,0 +1,334 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +use crate::error::Result; +use nix::sys::socket::*; +use std::io::{self}; +use std::os::unix::io::RawFd; +use std::os::unix::prelude::AsRawFd; +use nix::Error; + +use nix::unistd::*; +use std::sync::{Arc}; +use std::sync::atomic::{AtomicBool, Ordering}; +use crate::common::{self, client_connect, SOCK_CLOEXEC}; +#[cfg(target_os = "macos")] +use crate::common::set_fd_close_exec; +use nix::sys::socket::{self}; + +pub struct PipeListener { + fd: RawFd, + monitor_fd: (RawFd, RawFd), +} + +impl AsRawFd for PipeListener { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl PipeListener { + pub(crate) fn new(sockaddr: &str) -> Result { + let (fd, _) = common::do_bind(sockaddr)?; + common::do_listen(fd)?; + + let fds = PipeListener::new_monitor_fd()?; + + Ok(PipeListener { + fd, + monitor_fd: fds, + }) + } + + pub(crate) fn new_from_fd(fd: RawFd) -> Result { + let fds = PipeListener::new_monitor_fd()?; + + Ok(PipeListener { + fd, + monitor_fd: fds, + }) + } + + fn new_monitor_fd() -> Result<(i32, i32)> { + #[cfg(target_os = "linux")] + let fds = pipe2(nix::fcntl::OFlag::O_CLOEXEC)?; + + + #[cfg(target_os = "macos")] + let fds = { + let (rfd, wfd) = pipe()?; + set_fd_close_exec(rfd)?; + set_fd_close_exec(wfd)?; + (rfd, wfd) + }; + + Ok(fds) + } + + // accept returns: + // - Ok(Some(PipeConnection)) if a new connection is established + // - Ok(None) if spurious wake up with no new connection + // - Err(io::Error) if there is an error and listener loop should be shutdown + pub(crate) fn accept( &self, quit_flag: &Arc) -> std::result::Result, io::Error> { + if quit_flag.load(Ordering::SeqCst) { + return Err(io::Error::new(io::ErrorKind::Other, "listener shutdown for quit flag")); + } + + let mut pollers = vec![ + libc::pollfd { + fd: self.monitor_fd.0, + events: libc::POLLIN, + revents: 0, + }, + libc::pollfd { + fd: self.fd, + events: libc::POLLIN, + revents: 0, + }, + ]; + + let returned = unsafe { + let pollers: &mut [libc::pollfd] = &mut pollers; + libc::poll( + pollers as *mut _ as *mut libc::pollfd, + pollers.len() as _, + -1, + ) + }; + + if returned == -1 { + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(libc::EINTR) { + return Err(err); + } + + error!("fatal error in listener_loop:{:?}", err); + return Err(err); + } else if returned < 1 { + return Ok(None) + } + + if pollers[0].revents != 0 || pollers[pollers.len() - 1].revents == 0 { + return Ok(None); + } + + if quit_flag.load(Ordering::SeqCst) { + return Err(io::Error::new(io::ErrorKind::Other, "listener shutdown for quit flag")); + } + + #[cfg(target_os = "linux")] + let fd = match accept4(self.fd, SockFlag::SOCK_CLOEXEC) { + Ok(fd) => fd, + Err(e) => { + error!("failed to accept error {:?}", e); + return Err(std::io::Error::from_raw_os_error(e as i32)); + } + }; + + // Non Linux platforms do not support accept4 with SOCK_CLOEXEC flag, so instead + // use accept and call fcntl separately to set SOCK_CLOEXEC. + // Because of this there is chance of the descriptor leak if fork + exec happens in between. + #[cfg(target_os = "macos")] + let fd = match accept(self.fd) { + Ok(fd) => { + if let Err(err) = set_fd_close_exec(fd) { + error!("fcntl failed after accept: {:?}", err); + return Err(io::Error::new(io::ErrorKind::Other, format!("{err:?}"))); + }; + fd + } + Err(e) => { + error!("failed to accept error {:?}", e); + return Err(std::io::Error::from_raw_os_error(e as i32)); + } + }; + + + Ok(Some(PipeConnection { fd })) + } + + pub fn close(&self) -> Result<()> { + close(self.monitor_fd.1).unwrap_or_else(|e| { + warn!( + "failed to close notify fd: {} with error: {}", + self.monitor_fd.1, e + ) + }); + Ok(()) + } +} + + +pub struct PipeConnection { + fd: RawFd, +} + +impl PipeConnection { + pub(crate) fn new(fd: RawFd) -> PipeConnection { + PipeConnection { fd } + } + + pub(crate) fn id(&self) -> i32 { + self.fd + } + + pub fn read(&self, buf: &mut [u8]) -> Result { + loop { + match recv(self.fd, buf, MsgFlags::empty()) { + Ok(l) => return Ok(l), + Err(e) if retryable(e) => { + // Should retry + continue; + } + Err(e) => { + return Err(crate::Error::Nix(e)); + } + } + } + } + + pub fn write(&self, buf: &[u8]) -> Result { + loop { + match send(self.fd, buf, MsgFlags::empty()) { + Ok(l) => return Ok(l), + Err(e) if retryable(e) => { + // Should retry + continue; + } + Err(e) => { + return Err(crate::Error::Nix(e)); + } + } + } + } + + pub fn close(&self) -> Result<()> { + match close(self.fd) { + Ok(_) => Ok(()), + Err(e) => Err(crate::Error::Nix(e)) + } + } + + pub fn shutdown(&self) -> Result<()> { + match socket::shutdown(self.fd, Shutdown::Read) { + Ok(_) => Ok(()), + Err(e) => Err(crate::Error::Nix(e)) + } + } +} + +pub struct ClientConnection { + fd: RawFd, + socket_pair: (RawFd, RawFd), +} + +impl ClientConnection { + pub fn client_connect(sockaddr: &str)-> Result { + let fd = unsafe { client_connect(sockaddr)? }; + Ok(ClientConnection::new(fd)) + } + + pub(crate) fn new(fd: RawFd) -> ClientConnection { + let (recver_fd, close_fd) = + socketpair(AddressFamily::Unix, SockType::Stream, None, SOCK_CLOEXEC).unwrap(); + + // MacOS doesn't support descriptor creation with SOCK_CLOEXEC automically, + // so there is a chance of leak if fork + exec happens in between of these calls. + #[cfg(target_os = "macos")] + { + set_fd_close_exec(recver_fd).unwrap(); + set_fd_close_exec(close_fd).unwrap(); + } + + + ClientConnection { + fd, + socket_pair: (recver_fd, close_fd) + } + } + + pub fn ready(&self) -> std::result::Result, io::Error> { + let mut pollers = vec![ + libc::pollfd { + fd: self.socket_pair.0, + events: libc::POLLIN, + revents: 0, + }, + libc::pollfd { + fd: self.fd, + events: libc::POLLIN, + revents: 0, + }, + ]; + + let returned = unsafe { + let pollers: &mut [libc::pollfd] = &mut pollers; + libc::poll( + pollers as *mut _ as *mut libc::pollfd, + pollers.len() as _, + -1, + ) + }; + + if returned == -1 { + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(libc::EINTR) { + return Ok(None) + } + + error!("fatal error in process reaper:{}", err); + return Err(err); + } else if returned < 1 { + return Ok(None) + } + + if pollers[0].revents != 0 { + return Err(io::Error::new(io::ErrorKind::Other, "pipe closed")); + } + + if pollers[pollers.len() - 1].revents == 0 { + return Ok(None) + } + + Ok(Some(())) + } + + pub fn get_pipe_connection(&self) -> PipeConnection { + PipeConnection::new(self.fd) + } + + pub fn close_receiver(&self) -> Result<()> { + match close(self.socket_pair.0) { + Ok(_) => Ok(()), + Err(e) => Err(crate::Error::Nix(e)) + } + } + + pub fn close(&self) -> Result<()> { + match close(self.socket_pair.1) { + Ok(_) => {}, + Err(e) => return Err(crate::Error::Nix(e)) + }; + + match close(self.fd) { + Ok(_) => Ok(()), + Err(e) => Err(crate::Error::Nix(e)) + } + } +} + +fn retryable(e: nix::Error) -> bool { + e == Error::EINTR || e == Error::EAGAIN +} From 6f6b9f1291f3b28d2be6f2020a85ad4a67db10ee Mon Sep 17 00:00:00 2001 From: James Sturtevant Date: Wed, 1 Mar 2023 14:53:40 -0800 Subject: [PATCH 3/4] Add Windows Implementation for sync server and client Adds the windows functionality for PipeListener, PipeConnection, and ClientConnection and does the few other changes required to build and run the example projects. This includes adding feature support to the examples so they wouldn't build the async projects (as the unix specific code hasn't been removed yet). Namedpipes are used as Containerd is one of the main use cases for this project on Windows and containerd only supports namedpipes. Signed-off-by: James Sturtevant --- .github/workflows/bvt.yml | 4 +- Cargo.toml | 7 +- Makefile | 7 +- example/async-client.rs | 9 +- example/async-server.rs | 13 +- example/async-stream-client.rs | 15 +- example/async-stream-server.rs | 11 ++ example/build.rs | 4 +- example/protocols/mod.rs | 5 +- example/utils.rs | 20 ++- src/asynchronous/client.rs | 2 +- src/asynchronous/stream.rs | 2 +- src/common.rs | 23 +-- src/error.rs | 5 + src/lib.rs | 3 + src/macros.rs | 26 +++ src/sync/channel.rs | 30 ++-- src/sync/client.rs | 68 ++++--- src/sync/server.rs | 58 +++--- src/sync/sys/mod.rs | 4 + src/sync/sys/windows/mod.rs | 2 + src/sync/sys/windows/net.rs | 313 +++++++++++++++++++++++++++++++++ src/sync/utils.rs | 3 + 23 files changed, 521 insertions(+), 113 deletions(-) create mode 100644 src/macros.rs create mode 100644 src/sync/sys/windows/mod.rs create mode 100644 src/sync/sys/windows/net.rs diff --git a/.github/workflows/bvt.yml b/.github/workflows/bvt.yml index 1b5eff31..e9917550 100644 --- a/.github/workflows/bvt.yml +++ b/.github/workflows/bvt.yml @@ -6,7 +6,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, windows-latest] steps: - name: Checkout uses: actions/checkout@v3 @@ -22,7 +22,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, windows-latest] steps: - name: Checkout uses: actions/checkout@v3 diff --git a/Cargo.toml b/Cargo.toml index 2ce7f5f1..897e4d89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,17 +17,22 @@ nix = "0.23.0" log = "0.4" byteorder = "1.3.2" thiserror = "1.0" - async-trait = { version = "0.1.31", optional = true } tokio = { version = "1", features = ["rt", "sync", "io-util", "macros", "time"], optional = true } futures = { version = "0.3", optional = true } +[target.'cfg(windows)'.dependencies] +windows-sys = {version = "0.45", features = [ "Win32_Foundation", "Win32_Storage_FileSystem", "Win32_System_IO", "Win32_System_Pipes", "Win32_Security", "Win32_System_Threading"]} + [target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies] tokio-vsock = { version = "0.3.1", optional = true } [build-dependencies] protobuf-codegen = "3.1.0" +[dev-dependencies] +assert_cmd = "2.0.7" + [features] default = ["sync"] async = ["async-trait", "tokio", "futures", "tokio-vsock"] diff --git a/Makefile b/Makefile index 703c4abf..b0b3082e 100644 --- a/Makefile +++ b/Makefile @@ -23,8 +23,13 @@ build: debug .PHONY: test test: +ifeq ($OS,Windows_NT) + # async isn't enabled for windows, don't test that feature + cargo test --verbose +else cargo test --all-features --verbose - +endif + .PHONY: check check: cargo fmt --all -- --check diff --git a/example/async-client.rs b/example/async-client.rs index b4d752fc..4c4b7e9f 100644 --- a/example/async-client.rs +++ b/example/async-client.rs @@ -5,11 +5,18 @@ mod protocols; mod utils; - +#[cfg(unix)] use protocols::r#async::{agent, agent_ttrpc, health, health_ttrpc}; use ttrpc::context::{self, Context}; +#[cfg(unix)] use ttrpc::r#async::Client; +#[cfg(windows)] +fn main() { + println!("This example only works on Unix-like OSes"); +} + +#[cfg(unix)] #[tokio::main(flavor = "current_thread")] async fn main() { let c = Client::connect(utils::SOCK_ADDR).unwrap(); diff --git a/example/async-server.rs b/example/async-server.rs index 81bf048b..1ee2fb79 100644 --- a/example/async-server.rs +++ b/example/async-server.rs @@ -13,17 +13,22 @@ use std::sync::Arc; use log::LevelFilter; +#[cfg(unix)] use protocols::r#async::{agent, agent_ttrpc, health, health_ttrpc, types}; +#[cfg(unix)] use ttrpc::asynchronous::Server; use ttrpc::error::{Error, Result}; use ttrpc::proto::{Code, Status}; +#[cfg(unix)] use async_trait::async_trait; +#[cfg(unix)] use tokio::signal::unix::{signal, SignalKind}; use tokio::time::sleep; struct HealthService; +#[cfg(unix)] #[async_trait] impl health_ttrpc::Health for HealthService { async fn check( @@ -58,7 +63,7 @@ impl health_ttrpc::Health for HealthService { } struct AgentService; - +#[cfg(unix)] #[async_trait] impl agent_ttrpc::AgentService for AgentService { async fn list_interfaces( @@ -82,6 +87,12 @@ impl agent_ttrpc::AgentService for AgentService { } } +#[cfg(windows)] +fn main() { + println!("This example only works on Unix-like OSes"); +} + +#[cfg(unix)] #[tokio::main(flavor = "current_thread")] async fn main() { simple_logging::log_to_stderr(LevelFilter::Trace); diff --git a/example/async-stream-client.rs b/example/async-stream-client.rs index ba953596..4ba2b0c3 100644 --- a/example/async-stream-client.rs +++ b/example/async-stream-client.rs @@ -5,11 +5,18 @@ mod protocols; mod utils; - +#[cfg(unix)] use protocols::r#async::{empty, streaming, streaming_ttrpc}; use ttrpc::context::{self, Context}; +#[cfg(unix)] use ttrpc::r#async::Client; +#[cfg(windows)] +fn main() { + println!("This example only works on Unix-like OSes"); +} + +#[cfg(unix)] #[tokio::main(flavor = "current_thread")] async fn main() { simple_logging::log_to_stderr(log::LevelFilter::Info); @@ -48,6 +55,7 @@ fn default_ctx() -> Context { ctx } +#[cfg(unix)] async fn echo_request(cli: streaming_ttrpc::StreamingClient) { let echo1 = streaming::EchoPayload { seq: 1, @@ -59,6 +67,7 @@ async fn echo_request(cli: streaming_ttrpc::StreamingClient) { assert_eq!(resp.seq, echo1.seq + 1); } +#[cfg(unix)] async fn echo_stream(cli: streaming_ttrpc::StreamingClient) { let mut stream = cli.echo_stream(default_ctx()).await.unwrap(); @@ -81,6 +90,7 @@ async fn echo_stream(cli: streaming_ttrpc::StreamingClient) { assert!(matches!(ret, Err(ttrpc::Error::Eof))); } +#[cfg(unix)] async fn sum_stream(cli: streaming_ttrpc::StreamingClient) { let mut stream = cli.sum_stream(default_ctx()).await.unwrap(); @@ -108,6 +118,7 @@ async fn sum_stream(cli: streaming_ttrpc::StreamingClient) { assert_eq!(ssum.num, sum.num); } +#[cfg(unix)] async fn divide_stream(cli: streaming_ttrpc::StreamingClient) { let expected = streaming::Sum { sum: 392, @@ -127,6 +138,7 @@ async fn divide_stream(cli: streaming_ttrpc::StreamingClient) { assert_eq!(actual.num, expected.num); } +#[cfg(unix)] async fn echo_null(cli: streaming_ttrpc::StreamingClient) { let mut stream = cli.echo_null(default_ctx()).await.unwrap(); @@ -142,6 +154,7 @@ async fn echo_null(cli: streaming_ttrpc::StreamingClient) { assert_eq!(res, empty::Empty::new()); } +#[cfg(unix)] async fn echo_null_stream(cli: streaming_ttrpc::StreamingClient) { let stream = cli.echo_null_stream(default_ctx()).await.unwrap(); diff --git a/example/async-stream-server.rs b/example/async-stream-server.rs index d828e2d3..0404cfbc 100644 --- a/example/async-stream-server.rs +++ b/example/async-stream-server.rs @@ -10,15 +10,20 @@ use std::sync::Arc; use log::{info, LevelFilter}; +#[cfg(unix)] use protocols::r#async::{empty, streaming, streaming_ttrpc}; +#[cfg(unix)] use ttrpc::asynchronous::Server; +#[cfg(unix)] use async_trait::async_trait; +#[cfg(unix)] use tokio::signal::unix::{signal, SignalKind}; use tokio::time::sleep; struct StreamingService; +#[cfg(unix)] #[async_trait] impl streaming_ttrpc::Streaming for StreamingService { async fn echo( @@ -131,6 +136,12 @@ impl streaming_ttrpc::Streaming for StreamingService { } } +#[cfg(windows)] +fn main() { + println!("This example only works on Unix-like OSes"); +} + +#[cfg(unix)] #[tokio::main(flavor = "current_thread")] async fn main() { simple_logging::log_to_stderr(LevelFilter::Info); diff --git a/example/build.rs b/example/build.rs index cd9d5e50..a90274b4 100644 --- a/example/build.rs +++ b/example/build.rs @@ -45,7 +45,7 @@ fn main() { async_all: true, ..Default::default() }) - .rust_protobuf_customize(protobuf_customized.clone()) + .rust_protobuf_customize(protobuf_customized) .run() .expect("Gen async code failed."); @@ -75,7 +75,7 @@ fn replace_text_in_file(file_name: &str, from: &str, to: &str) -> Result<(), std let new_contents = contents.replace(from, to); - let mut dst = File::create(&file_name)?; + let mut dst = File::create(file_name)?; dst.write(new_contents.as_bytes())?; Ok(()) diff --git a/example/protocols/mod.rs b/example/protocols/mod.rs index b81f3d7d..d3d3a275 100644 --- a/example/protocols/mod.rs +++ b/example/protocols/mod.rs @@ -2,7 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 // - +#[cfg(unix)] pub mod asynchronous; -pub mod sync; +#[cfg(unix)] pub use asynchronous as r#async; +pub mod sync; diff --git a/example/utils.rs b/example/utils.rs index 24c6393f..816c430a 100644 --- a/example/utils.rs +++ b/example/utils.rs @@ -1,18 +1,28 @@ #![allow(dead_code)] -use std::fs; use std::io::Result; -use std::path::Path; -pub const SOCK_ADDR: &str = "unix:///tmp/ttrpc-test"; +#[cfg(unix)] +pub const SOCK_ADDR: &str = r"unix:///tmp/ttrpc-test"; +#[cfg(windows)] +pub const SOCK_ADDR: &str = r"\\.\pipe\ttrpc-test"; + +#[cfg(unix)] pub fn remove_if_sock_exist(sock_addr: &str) -> Result<()> { let path = sock_addr .strip_prefix("unix://") .expect("socket address is not expected"); - if Path::new(path).exists() { - fs::remove_file(&path)?; + if std::path::Path::new(path).exists() { + std::fs::remove_file(path)?; } Ok(()) } + +#[cfg(windows)] +pub fn remove_if_sock_exist(_sock_addr: &str) -> Result<()> { + //todo force close file handle? + + Ok(()) +} diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index 8ba9a890..4476779c 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -294,7 +294,7 @@ impl ReaderDelegate for ClientReader { }; resp_tx .send(Err(Error::Others(format!( - "Recver got malformed packet {msg:?}" + "Receiver got malformed packet {msg:?}" )))) .await .unwrap_or_else(|_e| error!("The request has returned")); diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index 1b72be29..7259c50a 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -311,7 +311,7 @@ where async fn _recv(rx: &mut ResultReceiver) -> Result { rx.recv().await.unwrap_or_else(|| { Err(Error::Others( - "Receive packet from recver error".to_string(), + "Receive packet from Receiver error".to_string(), )) }) } diff --git a/src/common.rs b/src/common.rs index 9c2ec40f..a3783852 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,9 +1,10 @@ +#![cfg(not(windows))] // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 // -//! Common functions and macros. +//! Common functions. use crate::error::{Error, Result}; #[cfg(any( @@ -173,26 +174,6 @@ pub(crate) unsafe fn client_connect(sockaddr: &str) -> Result { Ok(fd) } -macro_rules! cfg_sync { - ($($item:item)*) => { - $( - #[cfg(feature = "sync")] - #[cfg_attr(docsrs, doc(cfg(feature = "sync")))] - $item - )* - } -} - -macro_rules! cfg_async { - ($($item:item)*) => { - $( - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - $item - )* - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/error.rs b/src/error.rs index 4f7e2e3a..5c66e190 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,9 +27,14 @@ pub enum Error { #[error("rpc status: {0:?}")] RpcStatus(Status), + #[cfg(unix)] #[error("Nix error: {0}")] Nix(#[from] nix::Error), + #[cfg(windows)] + #[error("Windows error: {0}")] + Windows(i32), + #[error("ttrpc err: local stream closed")] LocalClosed, diff --git a/src/lib.rs b/src/lib.rs index 4f913d44..cd4872a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,6 +49,9 @@ pub mod error; #[macro_use] mod common; +#[macro_use] +mod macros; + pub mod context; pub mod proto; diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 00000000..7695a157 --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,26 @@ +// Copyright (c) 2020 Ant Financial +// +// SPDX-License-Identifier: Apache-2.0 +// + +//! macro functions. + +macro_rules! cfg_sync { + ($($item:item)*) => { + $( + #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "sync")))] + $item + )* + } +} + +macro_rules! cfg_async { + ($($item:item)*) => { + $( + #[cfg(all(feature = "async", target_family="unix"))] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + $item + )* + } +} diff --git a/src/sync/channel.rs b/src/sync/channel.rs index e9e341ca..18070b89 100644 --- a/src/sync/channel.rs +++ b/src/sync/channel.rs @@ -14,10 +14,10 @@ use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; -use crate::sync::sys::{PipeConnection}; +use crate::sync::sys::PipeConnection; use crate::proto::{Code, MessageHeader, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX}; -fn read_count (fd: &PipeConnection, count: usize) -> Result> { +fn read_count(conn: &PipeConnection, count: usize) -> Result> { let mut v: Vec = vec![0; count]; let mut len = 0; @@ -26,7 +26,7 @@ fn read_count (fd: &PipeConnection, count: usize) -> Result> { } loop { - match fd.read(&mut v[len..]) { + match conn.read(&mut v[len..]) { Ok(l) => { len += l; // when socket peer closed, it would return 0. @@ -43,7 +43,7 @@ fn read_count (fd: &PipeConnection, count: usize) -> Result> { Ok(v[0..len].to_vec()) } -fn write_count(fd: &PipeConnection, buf: &[u8], count: usize) -> Result { +fn write_count(conn: &PipeConnection, buf: &[u8], count: usize) -> Result { let mut len = 0; if count == 0 { @@ -51,7 +51,7 @@ fn write_count(fd: &PipeConnection, buf: &[u8], count: usize) -> Result { } loop { - match fd.write(&buf[len..]){ + match conn.write(&buf[len..]){ Ok(l) => { len += l; if len == count { @@ -67,8 +67,8 @@ fn write_count(fd: &PipeConnection, buf: &[u8], count: usize) -> Result { Ok(len) } -fn read_message_header(fd: &PipeConnection) -> Result { - let buf = read_count(fd, MESSAGE_HEADER_LENGTH)?; +fn read_message_header(conn: &PipeConnection) -> Result { + let buf = read_count(conn, MESSAGE_HEADER_LENGTH)?; let size = buf.len(); if size != MESSAGE_HEADER_LENGTH { return Err(sock_error_msg( @@ -82,8 +82,8 @@ fn read_message_header(fd: &PipeConnection) -> Result { Ok(mh) } -pub fn read_message(fd: &PipeConnection) -> Result<(MessageHeader, Vec)> { - let mh = read_message_header(fd)?; +pub fn read_message(conn: &PipeConnection) -> Result<(MessageHeader, Vec)> { + let mh = read_message_header(conn)?; trace!("Got Message header {:?}", mh); if mh.length > MESSAGE_LENGTH_MAX as u32 { @@ -96,7 +96,7 @@ pub fn read_message(fd: &PipeConnection) -> Result<(MessageHeader, Vec)> { )); } - let buf = read_count(fd, mh.length as usize)?; + let buf = read_count(conn, mh.length as usize)?; let size = buf.len(); if size != mh.length as usize { return Err(sock_error_msg( @@ -109,10 +109,10 @@ pub fn read_message(fd: &PipeConnection) -> Result<(MessageHeader, Vec)> { Ok((mh, buf)) } -fn write_message_header(fd: &PipeConnection, mh: MessageHeader) -> Result<()> { +fn write_message_header(conn: &PipeConnection, mh: MessageHeader) -> Result<()> { let buf: Vec = mh.into(); - let size = write_count(fd, &buf, MESSAGE_HEADER_LENGTH)?; + let size = write_count(conn, &buf, MESSAGE_HEADER_LENGTH)?; if size != MESSAGE_HEADER_LENGTH { return Err(sock_error_msg( size, @@ -123,10 +123,10 @@ fn write_message_header(fd: &PipeConnection, mh: MessageHeader) -> Result<()> { Ok(()) } -pub fn write_message(fd: &PipeConnection, mh: MessageHeader, buf: Vec) -> Result<()> { - write_message_header(fd, mh)?; +pub fn write_message(conn: &PipeConnection, mh: MessageHeader, buf: Vec) -> Result<()> { + write_message_header(conn, mh)?; - let size = write_count(fd, &buf, buf.len())?; + let size = write_count(conn, &buf, buf.len())?; if size != buf.len() { return Err(sock_error_msg( size, diff --git a/src/sync/client.rs b/src/sync/client.rs index 564b78e4..9cc117d8 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -14,18 +14,22 @@ //! Sync client of ttrpc. +#[cfg(unix)] +use std::os::unix::io::RawFd; use std::collections::HashMap; -use std::os::unix::io::RawFd; use std::sync::mpsc; use std::sync::{Arc, Mutex}; use std::{thread}; +use std::time::Duration; use crate::error::{Error, Result}; use crate::sync::sys::{ClientConnection}; use crate::proto::{Code, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE}; use crate::sync::channel::{read_message, write_message}; -use std::time::Duration; + +#[cfg(windows)] +use super::sys::PipeConnection; type Sender = mpsc::Sender<(Vec, mpsc::SyncSender>>)>; type Receiver = mpsc::Receiver<(Vec, mpsc::SyncSender>>)>; @@ -33,7 +37,7 @@ type Receiver = mpsc::Receiver<(Vec, mpsc::SyncSender>>)>; /// A ttrpc Client (sync). #[derive(Clone)] pub struct Client { - _fd: Arc, + _connection: Arc, sender_tx: Sender, } @@ -44,6 +48,7 @@ impl Client { Ok(Self::new_client(conn)) } + #[cfg(unix)] /// Initialize a new [`Client`] from raw file descriptor. pub fn new(fd: RawFd) -> Client { let conn = ClientConnection::new(fd); @@ -54,15 +59,15 @@ impl Client { fn new_client(pipe_client: ClientConnection) -> Client { let client = Arc::new(pipe_client); - let (sender_tx, rx): (Sender, Receiver) = mpsc::channel(); - - let recver_map_orig = Arc::new(Mutex::new(HashMap::new())); + + let receiver_map = recver_map_orig.clone(); + let connection = Arc::new(client.get_pipe_connection()); + let sender_client = connection.clone(); + //Sender - let recver_map = recver_map_orig.clone(); - let sender_client = client.clone(); thread::spawn(move || { let mut stream_id: u32 = 1; for (buf, recver_tx) in rx.iter() { @@ -70,16 +75,16 @@ impl Client { stream_id += 2; //Put current_stream_id and recver_tx to recver_map { - let mut map = recver_map.lock().unwrap(); + let mut map = receiver_map.lock().unwrap(); map.insert(current_stream_id, recver_tx.clone()); } let mut mh = MessageHeader::new_request(0, buf.len() as u32); mh.set_stream_id(current_stream_id); - let c = sender_client.get_pipe_connection(); - if let Err(e) = write_message(&c, mh, buf) { + + if let Err(e) = write_message(&sender_client, mh, buf) { //Remove current_stream_id and recver_tx to recver_map { - let mut map = recver_map.lock().unwrap(); + let mut map = receiver_map.lock().unwrap(); map.remove(¤t_stream_id); } recver_tx @@ -91,13 +96,11 @@ impl Client { }); //Recver - let reciever_client = client.clone(); + let receiver_connection = connection; + let receiver_client = client.clone(); thread::spawn(move || { - - loop { - - match reciever_client.ready() { + match receiver_client.ready() { Ok(None) => { continue; } @@ -107,12 +110,10 @@ impl Client { break; } } + let mh; let buf; - - let pipe_connection = reciever_client.get_pipe_connection(); - - match read_message(&pipe_connection) { + match read_message(&receiver_connection) { Ok((x, y)) => { mh = x; buf = y; @@ -141,14 +142,14 @@ impl Client { let recver_tx = match map.get(&mh.stream_id) { Some(tx) => tx, None => { - debug!("Recver got unknown packet {:?} {:?}", mh, buf); + debug!("Receiver got unknown packet {:?} {:?}", mh, buf); continue; } }; if mh.type_ != MESSAGE_TYPE_RESPONSE { recver_tx .send(Err(Error::Others(format!( - "Recver got malformed packet {mh:?} {buf:?}" + "Receiver got malformed packet {mh:?} {buf:?}" )))) .unwrap_or_else(|_e| error!("The request has returned")); continue; @@ -161,17 +162,17 @@ impl Client { map.remove(&mh.stream_id); } - let _ = reciever_client.close_receiver().map_err(|e| { + let _ = receiver_client.close_receiver().map_err(|e| { warn!( "failed to close with error: {:?}", e ) }); - trace!("Recver quit"); + trace!("Receiver quit"); }); Client { - _fd: client, + _connection: client, sender_tx, } } @@ -186,12 +187,12 @@ impl Client { let result = if req.timeout_nano == 0 { rx.recv() - .map_err(err_to_others_err!(e, "Receive packet from recver error: "))? + .map_err(err_to_others_err!(e, "Receive packet from Receiver error: "))? } else { rx.recv_timeout(Duration::from_nanos(req.timeout_nano as u64)) .map_err(err_to_others_err!( e, - "Receive packet from recver timeout: " + "Receive packet from Receiver timeout: " ))? }; @@ -211,6 +212,15 @@ impl Client { impl Drop for ClientConnection { fn drop(&mut self) { self.close().unwrap(); - trace!("All client is dropped"); + trace!("Client is dropped"); + } +} + +// close everything up from the pipe connection on Windows +#[cfg(windows)] +impl Drop for PipeConnection { + fn drop(&mut self) { + self.close().unwrap(); + trace!("pipe connection is dropped"); } } diff --git a/src/sync/server.rs b/src/sync/server.rs index 6faf99d6..a8e5883d 100644 --- a/src/sync/server.rs +++ b/src/sync/server.rs @@ -15,7 +15,7 @@ //! Sync server of ttrpc. //! -#[cfg(target_os = "linux")] +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use protobuf::{CodedInputStream, Message}; @@ -57,30 +57,27 @@ pub struct Server { thread_count_max: usize, } -struct Connection - { - fd: Arc, +struct Connection { + connection: Arc, quit: Arc, handler: Option>, } -impl Connection - { +impl Connection { fn close (&self) { - self.fd.close().unwrap_or(()); + self.connection.close().unwrap_or(()); } fn shutdown(&self) { self.quit.store(true, Ordering::SeqCst); // in case the connection had closed - self.fd.shutdown().unwrap_or(()); + self.connection.shutdown().unwrap_or(()); } } -struct ThreadS<'a> -{ - fd: &'a Arc, +struct ThreadS<'a> { + connection: &'a Arc, fdlock: &'a Arc>, wtc: &'a Arc, quit: &'a Arc, @@ -94,7 +91,7 @@ struct ThreadS<'a> #[allow(clippy::too_many_arguments)] fn start_method_handler_thread( - fd: Arc, + connection: Arc, fdlock: Arc>, wtc: Arc, quit: Arc, @@ -122,7 +119,7 @@ fn start_method_handler_thread( .unwrap_or_else(|err| trace!("Failed to send {:?}", err)); break; } - result = read_message(&fd); + result = read_message(&connection); } if quit.load(Ordering::SeqCst) { @@ -213,7 +210,7 @@ fn start_method_handler_thread( continue; }; let ctx = TtrpcContext { - fd: fd.id(), + fd: connection.id(), mh, res_tx: res_tx.clone(), metadata: context::from_pb(&req.metadata), @@ -234,14 +231,13 @@ fn start_method_handler_thread( }); } -fn start_method_handler_threads(num: usize, ts: &ThreadS) - { +fn start_method_handler_threads(num: usize, ts: &ThreadS) { for _ in 0..num { if ts.quit.load(Ordering::SeqCst) { break; } start_method_handler_thread( - ts.fd.clone(), + ts.connection.clone(), ts.fdlock.clone(), ts.wtc.clone(), ts.quit.clone(), @@ -254,8 +250,7 @@ fn start_method_handler_threads(num: usize, ts: &ThreadS) } } -fn check_method_handler_threads(ts: &ThreadS) - { +fn check_method_handler_threads(ts: &ThreadS) { let c = ts.wtc.load(Ordering::SeqCst); if c < ts.min { start_method_handler_threads(ts.default - c, ts); @@ -296,6 +291,7 @@ impl Server { Ok(self) } + #[cfg(unix)] pub fn add_listener(mut self, fd: RawFd) -> Result { if !self.listeners.is_empty() { return Err(Error::Others( @@ -387,10 +383,8 @@ impl Server { let handler = thread::Builder::new() .name("listener_loop".into()) .spawn(move || { - - let listener = listener; - loop { + trace!("listening..."); let pipe_connection = match listener.accept(&listener_quit_flag) { Ok(None) => { continue; @@ -403,7 +397,6 @@ impl Server { break; } }; - let methods = methods.clone(); let quit = Arc::new(AtomicBool::new(false)); @@ -436,7 +429,7 @@ impl Server { let (control_tx, control_rx): (SyncSender<()>, Receiver<()>) = sync_channel(0); let ts = ThreadS { - fd: &pipe, + connection: &pipe, fdlock: &Arc::new(Mutex::new(())), wtc: &Arc::new(AtomicUsize::new(0)), methods: &methods, @@ -471,10 +464,11 @@ impl Server { let mut cns = connections.lock().unwrap(); + let id = pipe_connection.id(); cns.insert( - pipe_connection.id(), + id, Connection { - fd: pipe_connection, + connection: pipe_connection, handler: Some(handler), quit: quit.clone(), }, @@ -500,7 +494,7 @@ impl Server { } if self.thread_count_default <= self.thread_count_min { return Err(Error::Others( - "thread_count_default should biger than thread_count_min".to_string(), + "thread_count_default should bigger than thread_count_min".to_string(), )); } self.start_listen()?; @@ -511,7 +505,11 @@ impl Server { pub fn stop_listen(mut self) -> Self { self.listener_quit_flag.store(true, Ordering::SeqCst); - self.listeners[0].close().unwrap(); + self.listeners[0].close().unwrap_or_else(|e| { + warn!( + "failed to close connection with error: {}", e + ) + }); info!("close monitor"); if let Some(handler) = self.handler.take() { @@ -545,14 +543,14 @@ impl Server { } } -#[cfg(target_os = "linux")] +#[cfg(unix)] impl FromRawFd for Server { unsafe fn from_raw_fd(fd: RawFd) -> Self { Self::default().add_listener(fd).unwrap() } } -#[cfg(target_os = "linux")] +#[cfg(unix)] impl AsRawFd for Server { fn as_raw_fd(&self) -> RawFd { self.listeners[0].as_raw_fd() diff --git a/src/sync/sys/mod.rs b/src/sync/sys/mod.rs index f0e91790..89546c43 100644 --- a/src/sync/sys/mod.rs +++ b/src/sync/sys/mod.rs @@ -3,3 +3,7 @@ mod unix; #[cfg(unix)] pub use crate::sync::sys::unix::{PipeConnection, PipeListener, ClientConnection}; +#[cfg(windows)] +mod windows; +#[cfg(windows)] +pub use crate::sync::sys::windows::{PipeConnection, PipeListener, ClientConnection}; diff --git a/src/sync/sys/windows/mod.rs b/src/sync/sys/windows/mod.rs new file mode 100644 index 00000000..bc36d736 --- /dev/null +++ b/src/sync/sys/windows/mod.rs @@ -0,0 +1,2 @@ +mod net; +pub use net::{PipeConnection, PipeListener, ClientConnection}; diff --git a/src/sync/sys/windows/net.rs b/src/sync/sys/windows/net.rs new file mode 100644 index 00000000..4ec9d02c --- /dev/null +++ b/src/sync/sys/windows/net.rs @@ -0,0 +1,313 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +use crate::error::Result; +use crate::error::Error; +use std::cell::UnsafeCell; +use std::ffi::OsStr; +use std::fs::OpenOptions; +use std::os::windows::ffi::OsStrExt; +use std::os::windows::fs::OpenOptionsExt; +use std::os::windows::io::{IntoRawHandle}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc}; +use std::{io}; + +use windows_sys::Win32::Foundation::{ CloseHandle, ERROR_IO_PENDING, ERROR_PIPE_CONNECTED, INVALID_HANDLE_VALUE }; +use windows_sys::Win32::Storage::FileSystem::{ ReadFile, WriteFile, FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, PIPE_ACCESS_DUPLEX }; +use windows_sys::Win32::System::IO::{ GetOverlappedResult, OVERLAPPED }; +use windows_sys::Win32::System::Pipes::{ CreateNamedPipeW, ConnectNamedPipe,DisconnectNamedPipe, PIPE_WAIT, PIPE_UNLIMITED_INSTANCES, PIPE_REJECT_REMOTE_CLIENTS }; +use windows_sys::Win32::System::Threading::CreateEventW; + +const PIPE_BUFFER_SIZE: u32 = 65536; +const WAIT_FOR_EVENT: i32 = 1; + +pub struct PipeListener { + first_instance: AtomicBool, + address: String, +} + +#[repr(C)] +struct Overlapped { + inner: UnsafeCell, +} + +impl Overlapped { + fn new_with_event(event: isize) -> Overlapped { + let mut ol = Overlapped { + inner: UnsafeCell::new(unsafe { std::mem::zeroed() }), + }; + ol.inner.get_mut().hEvent = event; + ol + } + + fn new() -> Overlapped { + Overlapped { + inner: UnsafeCell::new(unsafe { std::mem::zeroed() }), + } + } + + fn as_mut_ptr(&self) -> *mut OVERLAPPED { + self.inner.get() + } +} + +impl PipeListener { + pub(crate) fn new(sockaddr: &str) -> Result { + Ok(PipeListener { + first_instance: AtomicBool::new(true), + address: sockaddr.to_string(), + }) + } + + // accept returns: + // - Ok(Some(PipeConnection)) if a new connection is established + // - Err(io::Error) if there is an error and listener loop should be shutdown + pub(crate) fn accept(&self, quit_flag: &Arc) -> std::result::Result, io::Error> { + if quit_flag.load(Ordering::SeqCst) { + return Err(io::Error::new( + io::ErrorKind::Other, + "listener shutdown for quit flag", + )); + } + + // Create a new pipe instance for every new client + let np = self.new_instance().unwrap(); + let ol = Overlapped::new(); + + trace!("listening for connection"); + let result = unsafe { ConnectNamedPipe(np, ol.as_mut_ptr())}; + if result != 0 { + return Err(io::Error::last_os_error()); + } + + match io::Error::last_os_error() { + e if e.raw_os_error() == Some(ERROR_IO_PENDING as i32) => { + let mut bytes_transfered = 0; + let res = unsafe {GetOverlappedResult(np, ol.as_mut_ptr(), &mut bytes_transfered, WAIT_FOR_EVENT) }; + match res { + 0 => { + return Err(io::Error::last_os_error()); + } + _ => { + Ok(Some(PipeConnection::new(np))) + } + } + } + e if e.raw_os_error() == Some(ERROR_PIPE_CONNECTED as i32) => { + Ok(Some(PipeConnection::new(np))) + } + e => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("failed to connect pipe: {:?}", e), + )); + } + } + } + + fn new_instance(&self) -> io::Result { + let name = OsStr::new(&self.address.as_str()) + .encode_wide() + .chain(Some(0)) // add NULL termination + .collect::>(); + + let mut open_mode = PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED ; + + if self.first_instance.load(Ordering::SeqCst) { + open_mode |= FILE_FLAG_FIRST_PIPE_INSTANCE; + self.first_instance.swap(false, Ordering::SeqCst); + } + + // null for security attributes means the handle cannot be inherited and write access is restricted to system + // https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipe-security-and-access-rights + match unsafe { CreateNamedPipeW(name.as_ptr(), open_mode, PIPE_WAIT | PIPE_REJECT_REMOTE_CLIENTS, PIPE_UNLIMITED_INSTANCES, PIPE_BUFFER_SIZE, PIPE_BUFFER_SIZE, 0, std::ptr::null_mut())} { + INVALID_HANDLE_VALUE => { + return Err(io::Error::last_os_error()) + } + h => { + return Ok(h) + }, + }; + } + + pub fn close(&self) -> Result<()> { + Ok(()) + } +} + +pub struct PipeConnection { + named_pipe: isize, + read_event: isize, + write_event: isize, +} + +// PipeConnection on Windows is used by both the Server and Client to read and write to the named pipe +// The named pipe is created with the overlapped flag enable the simultaneous read and write operations. +// This is required since a read and write be issued at the same time on a given named pipe instance. +// +// An event is created for the read and write operations. When the read or write is issued +// it either returns immediately or the thread is suspended until the event is signaled when +// the overlapped (async) operation completes and the event is triggered allow the thread to continue. +// +// Due to the implementation of the sync Server and client there is always only one read and one write +// operation in flight at a time so we can reuse the same event. +// +// For more information on overlapped and events: https://learn.microsoft.com/en-us/windows/win32/api/ioapiset/nf-ioapiset-getoverlappedresult#remarks +// "It is safer to use an event object because of the confusion that can occur when multiple simultaneous overlapped operations are performed on the same file, named pipe, or communications device." +// "In this situation, there is no way to know which operation caused the object's state to be signaled." +impl PipeConnection { + pub(crate) fn new(h: isize) -> PipeConnection { + trace!("creating events for thread {:?} on pipe instance {}", std::thread::current().id(), h as i32); + let read_event = unsafe { CreateEventW(std::ptr::null_mut(), 0, 1, std::ptr::null_mut()) }; + let write_event = unsafe { CreateEventW(std::ptr::null_mut(), 0, 1, std::ptr::null_mut()) }; + PipeConnection { + named_pipe: h, + read_event: read_event, + write_event: write_event, + } + } + + pub(crate) fn id(&self) -> i32 { + self.named_pipe as i32 + } + + pub fn read(&self, buf: &mut [u8]) -> Result { + trace!("starting read for thread {:?} on pipe instance {}", std::thread::current().id(), self.named_pipe as i32); + let ol = Overlapped::new_with_event(self.read_event); + + let len = std::cmp::min(buf.len(), u32::MAX as usize) as u32; + let mut bytes_read= 0; + let result = unsafe { ReadFile(self.named_pipe, buf.as_mut_ptr() as *mut _, len, &mut bytes_read,ol.as_mut_ptr()) }; + if result > 0 && bytes_read > 0 { + // Got result no need to wait for pending read to complete + return Ok(bytes_read as usize) + } + + // wait for pending operation to complete (thread will be suspended until event is signaled) + match io::Error::last_os_error() { + ref e if e.raw_os_error() == Some(ERROR_IO_PENDING as i32) => { + let mut bytes_transfered = 0; + let res = unsafe {GetOverlappedResult(self.named_pipe, ol.as_mut_ptr(), &mut bytes_transfered, WAIT_FOR_EVENT) }; + match res { + 0 => { + return Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())) + } + _ => { + return Ok(bytes_transfered as usize) + } + } + } + ref e => { + return Err(Error::Others(format!("failed to read from pipe: {:?}", e))) + } + } + } + + pub fn write(&self, buf: &[u8]) -> Result { + trace!("starting write for thread {:?} on pipe instance {}", std::thread::current().id(), self.named_pipe as i32); + let ol = Overlapped::new_with_event(self.write_event); + let mut bytes_written = 0; + let len = std::cmp::min(buf.len(), u32::MAX as usize) as u32; + let result = unsafe { WriteFile(self.named_pipe, buf.as_ptr() as *const _,len, &mut bytes_written, ol.as_mut_ptr())}; + if result > 0 && bytes_written > 0 { + // No need to wait for pending write to complete + return Ok(bytes_written as usize) + } + + // wait for pending operation to complete (thread will be suspended until event is signaled) + match io::Error::last_os_error() { + ref e if e.raw_os_error() == Some(ERROR_IO_PENDING as i32) => { + let mut bytes_transfered = 0; + let res = unsafe {GetOverlappedResult(self.named_pipe, ol.as_mut_ptr(), &mut bytes_transfered, WAIT_FOR_EVENT) }; + match res { + 0 => { + return Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())) + } + _ => { + return Ok(bytes_transfered as usize) + } + } + } + ref e => { + return Err(Error::Others(format!("failed to write to pipe: {:?}", e))) + } + } + } + + pub fn close(&self) -> Result<()> { + close_handle(self.named_pipe)?; + close_handle(self.read_event)?; + close_handle(self.write_event) + } + + pub fn shutdown(&self) -> Result<()> { + let result = unsafe { DisconnectNamedPipe(self.named_pipe) }; + match result { + 0 => Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())), + _ => Ok(()), + } + } +} + +pub struct ClientConnection { + address: String +} + +fn close_handle(handle: isize) -> Result<()> { + let result = unsafe { CloseHandle(handle) }; + match result { + 0 => Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())), + _ => Ok(()), + } +} + +impl ClientConnection { + pub fn client_connect(sockaddr: &str) -> Result { + Ok(ClientConnection::new(sockaddr)) + } + + pub(crate) fn new(sockaddr: &str) -> ClientConnection { + ClientConnection { + address: sockaddr.to_string() + } + } + + pub fn ready(&self) -> std::result::Result, io::Error> { + // Windows is a "completion" based system so "readiness" isn't really applicable + Ok(Some(())) + } + + pub fn get_pipe_connection(&self) -> PipeConnection { + let mut opts = OpenOptions::new(); + opts.read(true) + .write(true) + .custom_flags(FILE_FLAG_OVERLAPPED); + let file = opts.open(self.address.as_str()); + + PipeConnection::new(file.unwrap().into_raw_handle() as isize) + } + + pub fn close_receiver(&self) -> Result<()> { + // close the pipe from the pipe connection + Ok(()) + } + + pub fn close(&self) -> Result<()> { + // close the pipe from the pipe connection + Ok(()) + } +} diff --git a/src/sync/utils.rs b/src/sync/utils.rs index 80069ae6..3ad6a845 100644 --- a/src/sync/utils.rs +++ b/src/sync/utils.rs @@ -97,7 +97,10 @@ macro_rules! client_request { /// The context of ttrpc (sync). #[derive(Debug)] pub struct TtrpcContext { + #[cfg(unix)] pub fd: std::os::unix::io::RawFd, + #[cfg(windows)] + pub fd: i32, pub mh: MessageHeader, pub res_tx: std::sync::mpsc::Sender<(MessageHeader, Vec)>, pub metadata: HashMap>, From 4975099450cc3f11a05cb50ae4f70561c3595d58 Mon Sep 17 00:00:00 2001 From: James Sturtevant Date: Thu, 2 Mar 2023 15:39:23 -0800 Subject: [PATCH 4/4] Fix stuck thread on server shutdown The connect namedpipe thread would be in a suspended state when shutdown on the server is called. Setting the event to a signalled state to wake the thread up so everything can shut down properly. Signed-off-by: James Sturtevant --- src/sync/client.rs | 23 +++++++++---- src/sync/sys/unix/net.rs | 4 +-- src/sync/sys/windows/net.rs | 67 +++++++++++++++++++++++++------------ 3 files changed, 64 insertions(+), 30 deletions(-) diff --git a/src/sync/client.rs b/src/sync/client.rs index 9cc117d8..86ceb777 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -45,7 +45,7 @@ impl Client { pub fn connect(sockaddr: &str) -> Result { let conn = ClientConnection::client_connect(sockaddr)?; - Ok(Self::new_client(conn)) + Self::new_client(conn) } #[cfg(unix)] @@ -53,10 +53,15 @@ impl Client { pub fn new(fd: RawFd) -> Client { let conn = ClientConnection::new(fd); - Self::new_client(conn) + // TODO: upgrade the API of Client::new and remove this panic for the major version release + Self::new_client(conn).unwrap_or_else(|e| { + panic!( + "client was not successfully initialized: {}", e + ) + }) } - fn new_client(pipe_client: ClientConnection) -> Client { + fn new_client(pipe_client: ClientConnection) -> Result { let client = Arc::new(pipe_client); let (sender_tx, rx): (Sender, Receiver) = mpsc::channel(); @@ -64,7 +69,7 @@ impl Client { let receiver_map = recver_map_orig.clone(); - let connection = Arc::new(client.get_pipe_connection()); + let connection = Arc::new(client.get_pipe_connection()?); let sender_client = connection.clone(); //Sender @@ -171,10 +176,10 @@ impl Client { trace!("Receiver quit"); }); - Client { + Ok(Client { _connection: client, sender_tx, - } + }) } pub fn request(&self, req: Request) -> Result { let buf = req.encode().map_err(err_to_others_err!(e, ""))?; @@ -220,7 +225,11 @@ impl Drop for ClientConnection { #[cfg(windows)] impl Drop for PipeConnection { fn drop(&mut self) { - self.close().unwrap(); + self.close().unwrap_or_else(|e| { + trace!( + "connection may already be closed: {}", e + ) + }); trace!("pipe connection is dropped"); } } diff --git a/src/sync/sys/unix/net.rs b/src/sync/sys/unix/net.rs index eb29dfb6..3fdf47b8 100644 --- a/src/sync/sys/unix/net.rs +++ b/src/sync/sys/unix/net.rs @@ -305,8 +305,8 @@ impl ClientConnection { Ok(Some(())) } - pub fn get_pipe_connection(&self) -> PipeConnection { - PipeConnection::new(self.fd) + pub fn get_pipe_connection(&self) -> Result { + Ok(PipeConnection::new(self.fd)) } pub fn close_receiver(&self) -> Result<()> { diff --git a/src/sync/sys/windows/net.rs b/src/sync/sys/windows/net.rs index 4ec9d02c..1a5adbca 100644 --- a/src/sync/sys/windows/net.rs +++ b/src/sync/sys/windows/net.rs @@ -30,7 +30,7 @@ use windows_sys::Win32::Foundation::{ CloseHandle, ERROR_IO_PENDING, ERROR_PIPE_ use windows_sys::Win32::Storage::FileSystem::{ ReadFile, WriteFile, FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, PIPE_ACCESS_DUPLEX }; use windows_sys::Win32::System::IO::{ GetOverlappedResult, OVERLAPPED }; use windows_sys::Win32::System::Pipes::{ CreateNamedPipeW, ConnectNamedPipe,DisconnectNamedPipe, PIPE_WAIT, PIPE_UNLIMITED_INSTANCES, PIPE_REJECT_REMOTE_CLIENTS }; -use windows_sys::Win32::System::Threading::CreateEventW; +use windows_sys::Win32::System::Threading::{CreateEventW, SetEvent}; const PIPE_BUFFER_SIZE: u32 = 65536; const WAIT_FOR_EVENT: i32 = 1; @@ -38,6 +38,7 @@ const WAIT_FOR_EVENT: i32 = 1; pub struct PipeListener { first_instance: AtomicBool, address: String, + connection_event: isize, } #[repr(C)] @@ -54,12 +55,6 @@ impl Overlapped { ol } - fn new() -> Overlapped { - Overlapped { - inner: UnsafeCell::new(unsafe { std::mem::zeroed() }), - } - } - fn as_mut_ptr(&self) -> *mut OVERLAPPED { self.inner.get() } @@ -67,9 +62,11 @@ impl Overlapped { impl PipeListener { pub(crate) fn new(sockaddr: &str) -> Result { + let connection_event = create_event()?; Ok(PipeListener { first_instance: AtomicBool::new(true), address: sockaddr.to_string(), + connection_event }) } @@ -85,11 +82,21 @@ impl PipeListener { } // Create a new pipe instance for every new client - let np = self.new_instance().unwrap(); - let ol = Overlapped::new(); + let instance = self.new_instance()?; + let np = match PipeConnection::new(instance) { + Ok(np) => np, + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("failed to create new pipe instance: {:?}", e), + )); + } + }; + + let ol = Overlapped::new_with_event(self.connection_event); trace!("listening for connection"); - let result = unsafe { ConnectNamedPipe(np, ol.as_mut_ptr())}; + let result = unsafe { ConnectNamedPipe(np.named_pipe, ol.as_mut_ptr())}; if result != 0 { return Err(io::Error::last_os_error()); } @@ -97,18 +104,18 @@ impl PipeListener { match io::Error::last_os_error() { e if e.raw_os_error() == Some(ERROR_IO_PENDING as i32) => { let mut bytes_transfered = 0; - let res = unsafe {GetOverlappedResult(np, ol.as_mut_ptr(), &mut bytes_transfered, WAIT_FOR_EVENT) }; + let res = unsafe {GetOverlappedResult(np.named_pipe, ol.as_mut_ptr(), &mut bytes_transfered, WAIT_FOR_EVENT) }; match res { 0 => { return Err(io::Error::last_os_error()); } _ => { - Ok(Some(PipeConnection::new(np))) + Ok(Some(np)) } } } e if e.raw_os_error() == Some(ERROR_PIPE_CONNECTED as i32) => { - Ok(Some(PipeConnection::new(np))) + Ok(Some(np)) } e => { return Err(io::Error::new( @@ -145,7 +152,9 @@ impl PipeListener { } pub fn close(&self) -> Result<()> { - Ok(()) + // release the ConnectNamedPipe thread by signaling the event and clean up event handle + set_event(self.connection_event)?; + close_handle(self.connection_event) } } @@ -170,15 +179,15 @@ pub struct PipeConnection { // "It is safer to use an event object because of the confusion that can occur when multiple simultaneous overlapped operations are performed on the same file, named pipe, or communications device." // "In this situation, there is no way to know which operation caused the object's state to be signaled." impl PipeConnection { - pub(crate) fn new(h: isize) -> PipeConnection { + pub(crate) fn new(h: isize) -> Result { trace!("creating events for thread {:?} on pipe instance {}", std::thread::current().id(), h as i32); - let read_event = unsafe { CreateEventW(std::ptr::null_mut(), 0, 1, std::ptr::null_mut()) }; - let write_event = unsafe { CreateEventW(std::ptr::null_mut(), 0, 1, std::ptr::null_mut()) }; - PipeConnection { + let read_event = create_event()?; + let write_event = create_event()?; + Ok(PipeConnection { named_pipe: h, read_event: read_event, write_event: write_event, - } + }) } pub(crate) fn id(&self) -> i32 { @@ -275,6 +284,22 @@ fn close_handle(handle: isize) -> Result<()> { } } +fn create_event() -> Result { + let result = unsafe { CreateEventW(std::ptr::null_mut(), 0, 1, std::ptr::null_mut()) }; + match result { + 0 => Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())), + _ => Ok(result), + } +} + +fn set_event(event: isize) -> Result<()> { + let result = unsafe { SetEvent(event) }; + match result { + 0 => Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())), + _ => Ok(()), + } +} + impl ClientConnection { pub fn client_connect(sockaddr: &str) -> Result { Ok(ClientConnection::new(sockaddr)) @@ -291,14 +316,14 @@ impl ClientConnection { Ok(Some(())) } - pub fn get_pipe_connection(&self) -> PipeConnection { + pub fn get_pipe_connection(&self) -> Result { let mut opts = OpenOptions::new(); opts.read(true) .write(true) .custom_flags(FILE_FLAG_OVERLAPPED); let file = opts.open(self.address.as_str()); - PipeConnection::new(file.unwrap().into_raw_handle() as isize) + return PipeConnection::new(file.unwrap().into_raw_handle() as isize) } pub fn close_receiver(&self) -> Result<()> {