diff --git a/.github/workflows/bvt.yml b/.github/workflows/bvt.yml index 1b5eff31..e9917550 100644 --- a/.github/workflows/bvt.yml +++ b/.github/workflows/bvt.yml @@ -6,7 +6,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, windows-latest] steps: - name: Checkout uses: actions/checkout@v3 @@ -22,7 +22,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, windows-latest] steps: - name: Checkout uses: actions/checkout@v3 diff --git a/Cargo.toml b/Cargo.toml index 2ce7f5f1..897e4d89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,17 +17,22 @@ nix = "0.23.0" log = "0.4" byteorder = "1.3.2" thiserror = "1.0" - async-trait = { version = "0.1.31", optional = true } tokio = { version = "1", features = ["rt", "sync", "io-util", "macros", "time"], optional = true } futures = { version = "0.3", optional = true } +[target.'cfg(windows)'.dependencies] +windows-sys = {version = "0.45", features = [ "Win32_Foundation", "Win32_Storage_FileSystem", "Win32_System_IO", "Win32_System_Pipes", "Win32_Security", "Win32_System_Threading"]} + [target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies] tokio-vsock = { version = "0.3.1", optional = true } [build-dependencies] protobuf-codegen = "3.1.0" +[dev-dependencies] +assert_cmd = "2.0.7" + [features] default = ["sync"] async = ["async-trait", "tokio", "futures", "tokio-vsock"] diff --git a/Makefile b/Makefile index 703c4abf..b0b3082e 100644 --- a/Makefile +++ b/Makefile @@ -23,8 +23,13 @@ build: debug .PHONY: test test: +ifeq ($OS,Windows_NT) + # async isn't enabled for windows, don't test that feature + cargo test --verbose +else cargo test --all-features --verbose - +endif + .PHONY: check check: cargo fmt --all -- --check diff --git a/example/async-client.rs b/example/async-client.rs index b4d752fc..4c4b7e9f 100644 --- a/example/async-client.rs +++ b/example/async-client.rs @@ -5,11 +5,18 @@ mod protocols; mod utils; - +#[cfg(unix)] use protocols::r#async::{agent, agent_ttrpc, health, health_ttrpc}; use ttrpc::context::{self, Context}; +#[cfg(unix)] use ttrpc::r#async::Client; +#[cfg(windows)] +fn main() { + println!("This example only works on Unix-like OSes"); +} + +#[cfg(unix)] #[tokio::main(flavor = "current_thread")] async fn main() { let c = Client::connect(utils::SOCK_ADDR).unwrap(); diff --git a/example/async-server.rs b/example/async-server.rs index 81bf048b..1ee2fb79 100644 --- a/example/async-server.rs +++ b/example/async-server.rs @@ -13,17 +13,22 @@ use std::sync::Arc; use log::LevelFilter; +#[cfg(unix)] use protocols::r#async::{agent, agent_ttrpc, health, health_ttrpc, types}; +#[cfg(unix)] use ttrpc::asynchronous::Server; use ttrpc::error::{Error, Result}; use ttrpc::proto::{Code, Status}; +#[cfg(unix)] use async_trait::async_trait; +#[cfg(unix)] use tokio::signal::unix::{signal, SignalKind}; use tokio::time::sleep; struct HealthService; +#[cfg(unix)] #[async_trait] impl health_ttrpc::Health for HealthService { async fn check( @@ -58,7 +63,7 @@ impl health_ttrpc::Health for HealthService { } struct AgentService; - +#[cfg(unix)] #[async_trait] impl agent_ttrpc::AgentService for AgentService { async fn list_interfaces( @@ -82,6 +87,12 @@ impl agent_ttrpc::AgentService for AgentService { } } +#[cfg(windows)] +fn main() { + println!("This example only works on Unix-like OSes"); +} + +#[cfg(unix)] #[tokio::main(flavor = "current_thread")] async fn main() { simple_logging::log_to_stderr(LevelFilter::Trace); diff --git a/example/async-stream-client.rs b/example/async-stream-client.rs index ba953596..4ba2b0c3 100644 --- a/example/async-stream-client.rs +++ b/example/async-stream-client.rs @@ -5,11 +5,18 @@ mod protocols; mod utils; - +#[cfg(unix)] use protocols::r#async::{empty, streaming, streaming_ttrpc}; use ttrpc::context::{self, Context}; +#[cfg(unix)] use ttrpc::r#async::Client; +#[cfg(windows)] +fn main() { + println!("This example only works on Unix-like OSes"); +} + +#[cfg(unix)] #[tokio::main(flavor = "current_thread")] async fn main() { simple_logging::log_to_stderr(log::LevelFilter::Info); @@ -48,6 +55,7 @@ fn default_ctx() -> Context { ctx } +#[cfg(unix)] async fn echo_request(cli: streaming_ttrpc::StreamingClient) { let echo1 = streaming::EchoPayload { seq: 1, @@ -59,6 +67,7 @@ async fn echo_request(cli: streaming_ttrpc::StreamingClient) { assert_eq!(resp.seq, echo1.seq + 1); } +#[cfg(unix)] async fn echo_stream(cli: streaming_ttrpc::StreamingClient) { let mut stream = cli.echo_stream(default_ctx()).await.unwrap(); @@ -81,6 +90,7 @@ async fn echo_stream(cli: streaming_ttrpc::StreamingClient) { assert!(matches!(ret, Err(ttrpc::Error::Eof))); } +#[cfg(unix)] async fn sum_stream(cli: streaming_ttrpc::StreamingClient) { let mut stream = cli.sum_stream(default_ctx()).await.unwrap(); @@ -108,6 +118,7 @@ async fn sum_stream(cli: streaming_ttrpc::StreamingClient) { assert_eq!(ssum.num, sum.num); } +#[cfg(unix)] async fn divide_stream(cli: streaming_ttrpc::StreamingClient) { let expected = streaming::Sum { sum: 392, @@ -127,6 +138,7 @@ async fn divide_stream(cli: streaming_ttrpc::StreamingClient) { assert_eq!(actual.num, expected.num); } +#[cfg(unix)] async fn echo_null(cli: streaming_ttrpc::StreamingClient) { let mut stream = cli.echo_null(default_ctx()).await.unwrap(); @@ -142,6 +154,7 @@ async fn echo_null(cli: streaming_ttrpc::StreamingClient) { assert_eq!(res, empty::Empty::new()); } +#[cfg(unix)] async fn echo_null_stream(cli: streaming_ttrpc::StreamingClient) { let stream = cli.echo_null_stream(default_ctx()).await.unwrap(); diff --git a/example/async-stream-server.rs b/example/async-stream-server.rs index d828e2d3..0404cfbc 100644 --- a/example/async-stream-server.rs +++ b/example/async-stream-server.rs @@ -10,15 +10,20 @@ use std::sync::Arc; use log::{info, LevelFilter}; +#[cfg(unix)] use protocols::r#async::{empty, streaming, streaming_ttrpc}; +#[cfg(unix)] use ttrpc::asynchronous::Server; +#[cfg(unix)] use async_trait::async_trait; +#[cfg(unix)] use tokio::signal::unix::{signal, SignalKind}; use tokio::time::sleep; struct StreamingService; +#[cfg(unix)] #[async_trait] impl streaming_ttrpc::Streaming for StreamingService { async fn echo( @@ -131,6 +136,12 @@ impl streaming_ttrpc::Streaming for StreamingService { } } +#[cfg(windows)] +fn main() { + println!("This example only works on Unix-like OSes"); +} + +#[cfg(unix)] #[tokio::main(flavor = "current_thread")] async fn main() { simple_logging::log_to_stderr(LevelFilter::Info); diff --git a/example/build.rs b/example/build.rs index cd9d5e50..a90274b4 100644 --- a/example/build.rs +++ b/example/build.rs @@ -45,7 +45,7 @@ fn main() { async_all: true, ..Default::default() }) - .rust_protobuf_customize(protobuf_customized.clone()) + .rust_protobuf_customize(protobuf_customized) .run() .expect("Gen async code failed."); @@ -75,7 +75,7 @@ fn replace_text_in_file(file_name: &str, from: &str, to: &str) -> Result<(), std let new_contents = contents.replace(from, to); - let mut dst = File::create(&file_name)?; + let mut dst = File::create(file_name)?; dst.write(new_contents.as_bytes())?; Ok(()) diff --git a/example/client.rs b/example/client.rs index bd2fb898..04b96722 100644 --- a/example/client.rs +++ b/example/client.rs @@ -15,12 +15,17 @@ mod protocols; mod utils; +use log::LevelFilter; use protocols::sync::{agent, agent_ttrpc, health, health_ttrpc}; use std::thread; use ttrpc::context::{self, Context}; +use ttrpc::error::Error; +use ttrpc::proto::Code; use ttrpc::Client; fn main() { + simple_logging::log_to_stderr(LevelFilter::Trace); + let c = Client::connect(utils::SOCK_ADDR).unwrap(); let hc = health_ttrpc::HealthClient::new(c.clone()); let ac = agent_ttrpc::AgentServiceClient::new(c); @@ -33,69 +38,97 @@ fn main() { let t = thread::spawn(move || { let req = health::CheckRequest::new(); println!( - "OS Thread {:?} - {} started: {:?}", + "OS Thread {:?} - health.check() started: {:?}", std::thread::current().id(), - "health.check()", now.elapsed(), ); + + let rsp = thc.check(default_ctx(), &req); + match rsp.as_ref() { + Err(Error::RpcStatus(s)) => { + assert_eq!(Code::NOT_FOUND, s.code()); + assert_eq!("Just for fun".to_string(), s.message()) + } + Err(e) => { + panic!("not expecting an error from the example server: {:?}", e) + } + Ok(x) => { + panic!("not expecting a OK response from the example server: {:?}", x) + } + } println!( - "OS Thread {:?} - {} -> {:?} ended: {:?}", + "OS Thread {:?} - health.check() -> {:?} ended: {:?}", std::thread::current().id(), - "health.check()", - thc.check(default_ctx(), &req), + rsp, now.elapsed(), ); }); let t2 = thread::spawn(move || { println!( - "OS Thread {:?} - {} started: {:?}", + "OS Thread {:?} - agent.list_interfaces() started: {:?}", std::thread::current().id(), - "agent.list_interfaces()", now.elapsed(), ); let show = match tac.list_interfaces(default_ctx(), &agent::ListInterfacesRequest::new()) { - Err(e) => format!("{:?}", e), - Ok(s) => format!("{:?}", s), + Err(e) => { + panic!("not expecting an error from the example server: {:?}", e) + } + Ok(s) => { + assert_eq!("first".to_string(), s.Interfaces[0].name); + assert_eq!("second".to_string(), s.Interfaces[1].name); + format!("{s:?}") + } }; println!( - "OS Thread {:?} - {} -> {} ended: {:?}", + "OS Thread {:?} - agent.list_interfaces() -> {} ended: {:?}", std::thread::current().id(), - "agent.list_interfaces()", show, now.elapsed(), ); }); println!( - "Main OS Thread - {} started: {:?}", - "agent.online_cpu_mem()", + "Main OS Thread - agent.online_cpu_mem() started: {:?}", now.elapsed() ); let show = match ac.online_cpu_mem(default_ctx(), &agent::OnlineCPUMemRequest::new()) { - Err(e) => format!("{:?}", e), - Ok(s) => format!("{:?}", s), + Err(Error::RpcStatus(s)) => { + assert_eq!(Code::NOT_FOUND, s.code()); + assert_eq!( + "/grpc.AgentService/OnlineCPUMem is not supported".to_string(), + s.message() + ); + format!("{s:?}") + } + Err(e) => { + panic!("not expecting an error from the example server: {:?}", e) + } + Ok(s) => { + panic!("not expecting a OK response from the example server: {:?}", s) + } }; println!( - "Main OS Thread - {} -> {} ended: {:?}", - "agent.online_cpu_mem()", + "Main OS Thread - agent.online_cpu_mem() -> {} ended: {:?}", show, now.elapsed() ); println!("\nsleep 2 seconds ...\n"); thread::sleep(std::time::Duration::from_secs(2)); + + let version = hc.version(default_ctx(), &health::CheckRequest::new()); + assert_eq!("mock.0.1", version.as_ref().unwrap().agent_version.as_str()); + assert_eq!("0.0.1", version.as_ref().unwrap().grpc_version.as_str()); println!( - "Main OS Thread - {} started: {:?}", - "health.version()", + "Main OS Thread - health.version() started: {:?}", now.elapsed() ); println!( - "Main OS Thread - {} -> {:?} ended: {:?}", - "health.version()", - hc.version(default_ctx(), &health::CheckRequest::new()), + "Main OS Thread - health.version() -> {:?} ended: {:?}", + version, now.elapsed() ); diff --git a/example/protocols/mod.rs b/example/protocols/mod.rs index b81f3d7d..d3d3a275 100644 --- a/example/protocols/mod.rs +++ b/example/protocols/mod.rs @@ -2,7 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 // - +#[cfg(unix)] pub mod asynchronous; -pub mod sync; +#[cfg(unix)] pub use asynchronous as r#async; +pub mod sync; diff --git a/example/utils.rs b/example/utils.rs index 24c6393f..816c430a 100644 --- a/example/utils.rs +++ b/example/utils.rs @@ -1,18 +1,28 @@ #![allow(dead_code)] -use std::fs; use std::io::Result; -use std::path::Path; -pub const SOCK_ADDR: &str = "unix:///tmp/ttrpc-test"; +#[cfg(unix)] +pub const SOCK_ADDR: &str = r"unix:///tmp/ttrpc-test"; +#[cfg(windows)] +pub const SOCK_ADDR: &str = r"\\.\pipe\ttrpc-test"; + +#[cfg(unix)] pub fn remove_if_sock_exist(sock_addr: &str) -> Result<()> { let path = sock_addr .strip_prefix("unix://") .expect("socket address is not expected"); - if Path::new(path).exists() { - fs::remove_file(&path)?; + if std::path::Path::new(path).exists() { + std::fs::remove_file(path)?; } Ok(()) } + +#[cfg(windows)] +pub fn remove_if_sock_exist(_sock_addr: &str) -> Result<()> { + //todo force close file handle? + + Ok(()) +} diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index 8ba9a890..4476779c 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -294,7 +294,7 @@ impl ReaderDelegate for ClientReader { }; resp_tx .send(Err(Error::Others(format!( - "Recver got malformed packet {msg:?}" + "Receiver got malformed packet {msg:?}" )))) .await .unwrap_or_else(|_e| error!("The request has returned")); diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index 1b72be29..7259c50a 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -311,7 +311,7 @@ where async fn _recv(rx: &mut ResultReceiver) -> Result { rx.recv().await.unwrap_or_else(|| { Err(Error::Others( - "Receive packet from recver error".to_string(), + "Receive packet from Receiver error".to_string(), )) }) } diff --git a/src/common.rs b/src/common.rs index 9c2ec40f..a3783852 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,9 +1,10 @@ +#![cfg(not(windows))] // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 // -//! Common functions and macros. +//! Common functions. use crate::error::{Error, Result}; #[cfg(any( @@ -173,26 +174,6 @@ pub(crate) unsafe fn client_connect(sockaddr: &str) -> Result { Ok(fd) } -macro_rules! cfg_sync { - ($($item:item)*) => { - $( - #[cfg(feature = "sync")] - #[cfg_attr(docsrs, doc(cfg(feature = "sync")))] - $item - )* - } -} - -macro_rules! cfg_async { - ($($item:item)*) => { - $( - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - $item - )* - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/error.rs b/src/error.rs index 4f7e2e3a..5c66e190 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,9 +27,14 @@ pub enum Error { #[error("rpc status: {0:?}")] RpcStatus(Status), + #[cfg(unix)] #[error("Nix error: {0}")] Nix(#[from] nix::Error), + #[cfg(windows)] + #[error("Windows error: {0}")] + Windows(i32), + #[error("ttrpc err: local stream closed")] LocalClosed, diff --git a/src/lib.rs b/src/lib.rs index 4f913d44..cd4872a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,6 +49,9 @@ pub mod error; #[macro_use] mod common; +#[macro_use] +mod macros; + pub mod context; pub mod proto; diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 00000000..7695a157 --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,26 @@ +// Copyright (c) 2020 Ant Financial +// +// SPDX-License-Identifier: Apache-2.0 +// + +//! macro functions. + +macro_rules! cfg_sync { + ($($item:item)*) => { + $( + #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "sync")))] + $item + )* + } +} + +macro_rules! cfg_async { + ($($item:item)*) => { + $( + #[cfg(all(feature = "async", target_family="unix"))] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + $item + )* + } +} diff --git a/src/sync/channel.rs b/src/sync/channel.rs index a23b5915..18070b89 100644 --- a/src/sync/channel.rs +++ b/src/sync/channel.rs @@ -12,18 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use nix::sys::socket::*; -use std::os::unix::io::RawFd; use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; +use crate::sync::sys::PipeConnection; use crate::proto::{Code, MessageHeader, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX}; -fn retryable(e: nix::Error) -> bool { - use ::nix::Error; - e == Error::EINTR || e == Error::EAGAIN -} - -fn read_count(fd: RawFd, count: usize) -> Result> { +fn read_count(conn: &PipeConnection, count: usize) -> Result> { let mut v: Vec = vec![0; count]; let mut len = 0; @@ -32,7 +26,7 @@ fn read_count(fd: RawFd, count: usize) -> Result> { } loop { - match recv(fd, &mut v[len..], MsgFlags::empty()) { + match conn.read(&mut v[len..]) { Ok(l) => { len += l; // when socket peer closed, it would return 0. @@ -40,11 +34,6 @@ fn read_count(fd: RawFd, count: usize) -> Result> { break; } } - - Err(e) if retryable(e) => { - // Should retry - } - Err(e) => { return Err(Error::Socket(e.to_string())); } @@ -54,7 +43,7 @@ fn read_count(fd: RawFd, count: usize) -> Result> { Ok(v[0..len].to_vec()) } -fn write_count(fd: RawFd, buf: &[u8], count: usize) -> Result { +fn write_count(conn: &PipeConnection, buf: &[u8], count: usize) -> Result { let mut len = 0; if count == 0 { @@ -62,18 +51,13 @@ fn write_count(fd: RawFd, buf: &[u8], count: usize) -> Result { } loop { - match send(fd, &buf[len..], MsgFlags::empty()) { + match conn.write(&buf[len..]){ Ok(l) => { len += l; if len == count { break; } } - - Err(e) if retryable(e) => { - // Should retry - } - Err(e) => { return Err(Error::Socket(e.to_string())); } @@ -83,8 +67,8 @@ fn write_count(fd: RawFd, buf: &[u8], count: usize) -> Result { Ok(len) } -fn read_message_header(fd: RawFd) -> Result { - let buf = read_count(fd, MESSAGE_HEADER_LENGTH)?; +fn read_message_header(conn: &PipeConnection) -> Result { + let buf = read_count(conn, MESSAGE_HEADER_LENGTH)?; let size = buf.len(); if size != MESSAGE_HEADER_LENGTH { return Err(sock_error_msg( @@ -98,8 +82,8 @@ fn read_message_header(fd: RawFd) -> Result { Ok(mh) } -pub fn read_message(fd: RawFd) -> Result<(MessageHeader, Vec)> { - let mh = read_message_header(fd)?; +pub fn read_message(conn: &PipeConnection) -> Result<(MessageHeader, Vec)> { + let mh = read_message_header(conn)?; trace!("Got Message header {:?}", mh); if mh.length > MESSAGE_LENGTH_MAX as u32 { @@ -112,7 +96,7 @@ pub fn read_message(fd: RawFd) -> Result<(MessageHeader, Vec)> { )); } - let buf = read_count(fd, mh.length as usize)?; + let buf = read_count(conn, mh.length as usize)?; let size = buf.len(); if size != mh.length as usize { return Err(sock_error_msg( @@ -125,10 +109,10 @@ pub fn read_message(fd: RawFd) -> Result<(MessageHeader, Vec)> { Ok((mh, buf)) } -fn write_message_header(fd: RawFd, mh: MessageHeader) -> Result<()> { +fn write_message_header(conn: &PipeConnection, mh: MessageHeader) -> Result<()> { let buf: Vec = mh.into(); - let size = write_count(fd, &buf, MESSAGE_HEADER_LENGTH)?; + let size = write_count(conn, &buf, MESSAGE_HEADER_LENGTH)?; if size != MESSAGE_HEADER_LENGTH { return Err(sock_error_msg( size, @@ -139,10 +123,10 @@ fn write_message_header(fd: RawFd, mh: MessageHeader) -> Result<()> { Ok(()) } -pub fn write_message(fd: RawFd, mh: MessageHeader, buf: Vec) -> Result<()> { - write_message_header(fd, mh)?; +pub fn write_message(conn: &PipeConnection, mh: MessageHeader, buf: Vec) -> Result<()> { + write_message_header(conn, mh)?; - let size = write_count(fd, &buf, buf.len())?; + let size = write_count(conn, &buf, buf.len())?; if size != buf.len() { return Err(sock_error_msg( size, diff --git a/src/sync/client.rs b/src/sync/client.rs index 94872651..86ceb777 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -14,21 +14,22 @@ //! Sync client of ttrpc. -use nix::sys::socket::*; -use nix::unistd::close; -use std::collections::HashMap; +#[cfg(unix)] use std::os::unix::io::RawFd; + +use std::collections::HashMap; use std::sync::mpsc; use std::sync::{Arc, Mutex}; -use std::{io, thread}; +use std::{thread}; +use std::time::Duration; -#[cfg(target_os = "macos")] -use crate::common::set_fd_close_exec; -use crate::common::{client_connect, SOCK_CLOEXEC}; use crate::error::{Error, Result}; +use crate::sync::sys::{ClientConnection}; use crate::proto::{Code, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE}; use crate::sync::channel::{read_message, write_message}; -use std::time::Duration; + +#[cfg(windows)] +use super::sys::PipeConnection; type Sender = mpsc::Sender<(Vec, mpsc::SyncSender>>)>; type Receiver = mpsc::Receiver<(Vec, mpsc::SyncSender>>)>; @@ -36,38 +37,42 @@ type Receiver = mpsc::Receiver<(Vec, mpsc::SyncSender>>)>; /// A ttrpc Client (sync). #[derive(Clone)] pub struct Client { - _fd: RawFd, + _connection: Arc, sender_tx: Sender, - _client_close: Arc, } impl Client { pub fn connect(sockaddr: &str) -> Result { - let fd = unsafe { client_connect(sockaddr)? }; - Ok(Self::new(fd)) + let conn = ClientConnection::client_connect(sockaddr)?; + + Self::new_client(conn) } + #[cfg(unix)] /// Initialize a new [`Client`] from raw file descriptor. pub fn new(fd: RawFd) -> Client { - let (sender_tx, rx): (Sender, Receiver) = mpsc::channel(); - - let (recver_fd, close_fd) = - socketpair(AddressFamily::Unix, SockType::Stream, None, SOCK_CLOEXEC).unwrap(); - - // MacOS doesn't support descriptor creation with SOCK_CLOEXEC automically, - // so there is a chance of leak if fork + exec happens in between of these calls. - #[cfg(target_os = "macos")] - { - set_fd_close_exec(recver_fd).unwrap(); - set_fd_close_exec(close_fd).unwrap(); - } - - let client_close = Arc::new(ClientClose { fd, close_fd }); + let conn = ClientConnection::new(fd); + + // TODO: upgrade the API of Client::new and remove this panic for the major version release + Self::new_client(conn).unwrap_or_else(|e| { + panic!( + "client was not successfully initialized: {}", e + ) + }) + } + fn new_client(pipe_client: ClientConnection) -> Result { + let client = Arc::new(pipe_client); + + let (sender_tx, rx): (Sender, Receiver) = mpsc::channel(); let recver_map_orig = Arc::new(Mutex::new(HashMap::new())); + + let receiver_map = recver_map_orig.clone(); + let connection = Arc::new(client.get_pipe_connection()?); + let sender_client = connection.clone(); + //Sender - let recver_map = recver_map_orig.clone(); thread::spawn(move || { let mut stream_id: u32 = 1; for (buf, recver_tx) in rx.iter() { @@ -75,15 +80,16 @@ impl Client { stream_id += 2; //Put current_stream_id and recver_tx to recver_map { - let mut map = recver_map.lock().unwrap(); + let mut map = receiver_map.lock().unwrap(); map.insert(current_stream_id, recver_tx.clone()); } let mut mh = MessageHeader::new_request(0, buf.len() as u32); mh.set_stream_id(current_stream_id); - if let Err(e) = write_message(fd, mh, buf) { + + if let Err(e) = write_message(&sender_client, mh, buf) { //Remove current_stream_id and recver_tx to recver_map { - let mut map = recver_map.lock().unwrap(); + let mut map = receiver_map.lock().unwrap(); map.remove(¤t_stream_id); } recver_tx @@ -95,53 +101,24 @@ impl Client { }); //Recver + let receiver_connection = connection; + let receiver_client = client.clone(); thread::spawn(move || { - let mut pollers = vec![ - libc::pollfd { - fd: recver_fd, - events: libc::POLLIN, - revents: 0, - }, - libc::pollfd { - fd, - events: libc::POLLIN, - revents: 0, - }, - ]; - loop { - let returned = unsafe { - let pollers: &mut [libc::pollfd] = &mut pollers; - libc::poll( - pollers as *mut _ as *mut libc::pollfd, - pollers.len() as _, - -1, - ) - }; - - if returned == -1 { - let err = io::Error::last_os_error(); - if err.raw_os_error() == Some(libc::EINTR) { + match receiver_client.ready() { + Ok(None) => { continue; } - - error!("fatal error in process reaper:{}", err); - break; - } else if returned < 1 { - continue; - } - - if pollers[0].revents != 0 { - break; - } - - if pollers[pollers.len() - 1].revents == 0 { - continue; + Ok(_) => {} + Err(e) => { + error!("pipeConnection ready error {:?}", e); + break; + } } let mh; let buf; - match read_message(fd) { + match read_message(&receiver_connection) { Ok((x, y)) => { mh = x; buf = y; @@ -170,14 +147,14 @@ impl Client { let recver_tx = match map.get(&mh.stream_id) { Some(tx) => tx, None => { - debug!("Recver got unknown packet {:?} {:?}", mh, buf); + debug!("Receiver got unknown packet {:?} {:?}", mh, buf); continue; } }; if mh.type_ != MESSAGE_TYPE_RESPONSE { recver_tx .send(Err(Error::Others(format!( - "Recver got malformed packet {mh:?} {buf:?}" + "Receiver got malformed packet {mh:?} {buf:?}" )))) .unwrap_or_else(|_e| error!("The request has returned")); continue; @@ -190,21 +167,19 @@ impl Client { map.remove(&mh.stream_id); } - let _ = close(recver_fd).map_err(|e| { + let _ = receiver_client.close_receiver().map_err(|e| { warn!( - "failed to close recver_fd: {} with error: {:?}", - recver_fd, e + "failed to close with error: {:?}", e ) }); - trace!("Recver quit"); + trace!("Receiver quit"); }); - Client { - _fd: fd, + Ok(Client { + _connection: client, sender_tx, - _client_close: client_close, - } + }) } pub fn request(&self, req: Request) -> Result { let buf = req.encode().map_err(err_to_others_err!(e, ""))?; @@ -217,12 +192,12 @@ impl Client { let result = if req.timeout_nano == 0 { rx.recv() - .map_err(err_to_others_err!(e, "Receive packet from recver error: "))? + .map_err(err_to_others_err!(e, "Receive packet from Receiver error: "))? } else { rx.recv_timeout(Duration::from_nanos(req.timeout_nano as u64)) .map_err(err_to_others_err!( e, - "Receive packet from recver timeout: " + "Receive packet from Receiver timeout: " ))? }; @@ -239,15 +214,22 @@ impl Client { } } -struct ClientClose { - fd: RawFd, - close_fd: RawFd, +impl Drop for ClientConnection { + fn drop(&mut self) { + self.close().unwrap(); + trace!("Client is dropped"); + } } -impl Drop for ClientClose { +// close everything up from the pipe connection on Windows +#[cfg(windows)] +impl Drop for PipeConnection { fn drop(&mut self) { - close(self.close_fd).unwrap(); - close(self.fd).unwrap(); - trace!("All client is droped"); + self.close().unwrap_or_else(|e| { + trace!( + "connection may already be closed: {}", e + ) + }); + trace!("pipe connection is dropped"); } } diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 2830460b..53d0680c 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -8,6 +8,7 @@ mod channel; mod client; mod server; +mod sys; #[macro_use] mod utils; diff --git a/src/sync/server.rs b/src/sync/server.rs index 0a0e6e3f..a8e5883d 100644 --- a/src/sync/server.rs +++ b/src/sync/server.rs @@ -13,27 +13,26 @@ // limitations under the License. //! Sync server of ttrpc. +//! + +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; -use nix::sys::socket::{self, *}; -use nix::unistd::*; use protobuf::{CodedInputStream, Message}; use std::collections::HashMap; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::mpsc::{channel, sync_channel, Receiver, Sender, SyncSender}; use std::sync::{Arc, Mutex}; use std::thread::JoinHandle; -use std::{io, thread}; +use std::{thread}; use super::utils::response_to_channel; -use crate::common; -#[cfg(not(any(target_os = "linux", target_os = "android")))] -use crate::common::set_fd_close_exec; use crate::context; use crate::error::{get_status, Error, Result}; use crate::proto::{Code, MessageHeader, Request, Response, MESSAGE_TYPE_REQUEST}; use crate::sync::channel::{read_message, write_message}; use crate::{MethodHandler, TtrpcContext}; +use crate::sync::sys::{PipeListener, PipeConnection}; // poll_queue will create WAIT_THREAD_COUNT_DEFAULT threads in begin. // If wait thread count < WAIT_THREAD_COUNT_MIN, create number to WAIT_THREAD_COUNT_DEFAULT. @@ -47,10 +46,9 @@ type MessageReceiver = Receiver<(MessageHeader, Vec)>; /// A ttrpc Server (sync). pub struct Server { - listeners: Vec, - monitor_fd: (RawFd, RawFd), + listeners: Vec>, listener_quit_flag: Arc, - connections: Arc>>, + connections: Arc>>, methods: Arc>>, handler: Option>, reaper: Option<(Sender, JoinHandle<()>)>, @@ -60,21 +58,26 @@ pub struct Server { } struct Connection { - fd: RawFd, + connection: Arc, quit: Arc, handler: Option>, } impl Connection { - fn close(&self) { + fn close (&self) { + self.connection.close().unwrap_or(()); + } + + fn shutdown(&self) { self.quit.store(true, Ordering::SeqCst); + // in case the connection had closed - socket::shutdown(self.fd, Shutdown::Read).unwrap_or(()); + self.connection.shutdown().unwrap_or(()); } } struct ThreadS<'a> { - fd: RawFd, + connection: &'a Arc, fdlock: &'a Arc>, wtc: &'a Arc, quit: &'a Arc, @@ -88,7 +91,7 @@ struct ThreadS<'a> { #[allow(clippy::too_many_arguments)] fn start_method_handler_thread( - fd: RawFd, + connection: Arc, fdlock: Arc>, wtc: Arc, quit: Arc, @@ -116,7 +119,7 @@ fn start_method_handler_thread( .unwrap_or_else(|err| trace!("Failed to send {:?}", err)); break; } - result = read_message(fd); + result = read_message(&connection); } if quit.load(Ordering::SeqCst) { @@ -207,7 +210,7 @@ fn start_method_handler_thread( continue; }; let ctx = TtrpcContext { - fd, + fd: connection.id(), mh, res_tx: res_tx.clone(), metadata: context::from_pb(&req.metadata), @@ -234,7 +237,7 @@ fn start_method_handler_threads(num: usize, ts: &ThreadS) { break; } start_method_handler_thread( - ts.fd, + ts.connection.clone(), ts.fdlock.clone(), ts.wtc.clone(), ts.quit.clone(), @@ -258,7 +261,6 @@ impl Default for Server { fn default() -> Self { Server { listeners: Vec::with_capacity(1), - monitor_fd: (-1, -1), listener_quit_flag: Arc::new(AtomicBool::new(false)), connections: Arc::new(Mutex::new(HashMap::new())), methods: Arc::new(HashMap::new()), @@ -283,15 +285,23 @@ impl Server { )); } - let (fd, _) = common::do_bind(sockaddr)?; - common::do_listen(fd)?; + let listener = PipeListener::new(sockaddr)?; - self.listeners.push(fd); + self.listeners.push(Arc::new(listener)); Ok(self) } + #[cfg(unix)] pub fn add_listener(mut self, fd: RawFd) -> Result { - self.listeners.push(fd); + if !self.listeners.is_empty() { + return Err(Error::Others( + "ttrpc-rust just support 1 sockaddr now".to_string(), + )); + } + + let listener = PipeListener::new_from_fd(fd)?; + + self.listeners.push(Arc::new(listener)); Ok(self) } @@ -329,27 +339,14 @@ impl Server { self.listener_quit_flag.store(false, Ordering::SeqCst); - #[cfg(any(target_os = "linux", target_os = "android"))] - let fds = pipe2(nix::fcntl::OFlag::O_CLOEXEC)?; - - #[cfg(not(any(target_os = "linux", target_os = "android")))] - let fds = { - let (rfd, wfd) = pipe()?; - set_fd_close_exec(rfd)?; - set_fd_close_exec(wfd)?; - (rfd, wfd) - }; - - self.monitor_fd = fds; - - let listener = self.listeners[0]; + + let listener = self.listeners[0].clone(); let methods = self.methods.clone(); let default = self.thread_count_default; let min = self.thread_count_min; let max = self.thread_count_max; let listener_quit_flag = self.listener_quit_flag.clone(); - let monitor_fd = self.monitor_fd.0; let reaper_tx = match self.reaper.take() { None => { @@ -366,7 +363,7 @@ impl Server { .map(|mut cn| { cn.handler.take().map(|handler| { handler.join().unwrap(); - close(fd).unwrap(); + cn.close() }) }); } @@ -386,78 +383,17 @@ impl Server { let handler = thread::Builder::new() .name("listener_loop".into()) .spawn(move || { - let mut pollers = vec![ - libc::pollfd { - fd: monitor_fd, - events: libc::POLLIN, - revents: 0, - }, - libc::pollfd { - fd: listener, - events: libc::POLLIN, - revents: 0, - }, - ]; - - loop { - if listener_quit_flag.load(Ordering::SeqCst) { - info!("listener shutdown for quit flag"); - break; - } - - let returned = unsafe { - let pollers: &mut [libc::pollfd] = &mut pollers; - libc::poll( - pollers as *mut _ as *mut libc::pollfd, - pollers.len() as _, - -1, - ) - }; - - if returned == -1 { - let err = io::Error::last_os_error(); - if err.raw_os_error() == Some(libc::EINTR) { + loop { + trace!("listening..."); + let pipe_connection = match listener.accept(&listener_quit_flag) { + Ok(None) => { continue; } - - error!("fatal error in listener_loop:{:?}", err); - break; - } else if returned < 1 { - continue; - } - - if pollers[0].revents != 0 || pollers[pollers.len() - 1].revents == 0 { - continue; - } - - if listener_quit_flag.load(Ordering::SeqCst) { - info!("listener shutdown for quit flag"); - break; - } - - #[cfg(any(target_os = "linux", target_os = "android"))] - let fd = match accept4(listener, SockFlag::SOCK_CLOEXEC) { - Ok(fd) => fd, - Err(e) => { - error!("failed to accept error {:?}", e); - break; - } - }; - - // Non Linux platforms do not support accept4 with SOCK_CLOEXEC flag, so instead - // use accept and call fcntl separately to set SOCK_CLOEXEC. - // Because of this there is chance of the descriptor leak if fork + exec happens in between. - #[cfg(not(any(target_os = "linux", target_os = "android")))] - let fd = match accept(listener) { - Ok(fd) => { - if let Err(err) = set_fd_close_exec(fd) { - error!("fcntl failed after accept: {:?}", err); - break; - }; - fd + Ok(Some(conn)) => { + Arc::new(conn) } Err(e) => { - error!("failed to accept error {:?}", e); + error!("listener accept got {:?}", e); break; } }; @@ -466,6 +402,7 @@ impl Server { let quit = Arc::new(AtomicBool::new(false)); let child_quit = quit.clone(); let reaper_tx_child = reaper_tx.clone(); + let pipe_connection_child = pipe_connection.clone(); let handler = thread::Builder::new() .name("client_handler".into()) @@ -473,11 +410,12 @@ impl Server { debug!("Got new client"); // Start response thread let quit_res = child_quit.clone(); + let pipe = pipe_connection_child.clone(); let (res_tx, res_rx): (MessageSender, MessageReceiver) = channel(); let handler = thread::spawn(move || { for r in res_rx.iter() { trace!("response thread get {:?}", r); - if let Err(e) = write_message(fd, r.0, r.1) { + if let Err(e) = write_message(&pipe, r.0, r.1) { error!("write_message got {:?}", e); quit_res.store(true, Ordering::SeqCst); break; @@ -487,10 +425,11 @@ impl Server { trace!("response thread quit"); }); + let pipe = pipe_connection_child.clone(); let (control_tx, control_rx): (SyncSender<()>, Receiver<()>) = sync_channel(0); let ts = ThreadS { - fd, + connection: &pipe, fdlock: &Arc::new(Mutex::new(())), wtc: &Arc::new(AtomicUsize::new(0)), methods: &methods, @@ -517,17 +456,19 @@ impl Server { handler.join().unwrap_or(()); // client_handler should not close fd before exit // , which prevent fd reuse issue. - reaper_tx_child.send(fd).unwrap(); + reaper_tx_child.send(pipe.id()).unwrap(); debug!("client thread quit"); }) .unwrap(); let mut cns = connections.lock().unwrap(); + + let id = pipe_connection.id(); cns.insert( - fd, + id, Connection { - fd, + connection: pipe_connection, handler: Some(handler), quit: quit.clone(), }, @@ -553,7 +494,7 @@ impl Server { } if self.thread_count_default <= self.thread_count_min { return Err(Error::Others( - "thread_count_default should biger than thread_count_min".to_string(), + "thread_count_default should bigger than thread_count_min".to_string(), )); } self.start_listen()?; @@ -563,12 +504,13 @@ impl Server { pub fn stop_listen(mut self) -> Self { self.listener_quit_flag.store(true, Ordering::SeqCst); - close(self.monitor_fd.1).unwrap_or_else(|e| { + + self.listeners[0].close().unwrap_or_else(|e| { warn!( - "failed to close notify fd: {} with error: {}", - self.monitor_fd.1, e + "failed to close connection with error: {}", e ) }); + info!("close monitor"); if let Some(handler) = self.handler.take() { handler.join().unwrap(); @@ -582,7 +524,7 @@ impl Server { let connections = self.connections.lock().unwrap(); for (_fd, c) in connections.iter() { - c.close(); + c.shutdown(); } // release connections's lock, since the following handler.join() // would wait on the other thread's exit in which would take the lock. @@ -601,14 +543,16 @@ impl Server { } } +#[cfg(unix)] impl FromRawFd for Server { unsafe fn from_raw_fd(fd: RawFd) -> Self { Self::default().add_listener(fd).unwrap() } } +#[cfg(unix)] impl AsRawFd for Server { fn as_raw_fd(&self) -> RawFd { - self.listeners[0] + self.listeners[0].as_raw_fd() } } diff --git a/src/sync/sys/mod.rs b/src/sync/sys/mod.rs new file mode 100644 index 00000000..89546c43 --- /dev/null +++ b/src/sync/sys/mod.rs @@ -0,0 +1,9 @@ +#[cfg(unix)] +mod unix; +#[cfg(unix)] +pub use crate::sync::sys::unix::{PipeConnection, PipeListener, ClientConnection}; + +#[cfg(windows)] +mod windows; +#[cfg(windows)] +pub use crate::sync::sys::windows::{PipeConnection, PipeListener, ClientConnection}; diff --git a/src/sync/sys/unix/mod.rs b/src/sync/sys/unix/mod.rs new file mode 100644 index 00000000..bc36d736 --- /dev/null +++ b/src/sync/sys/unix/mod.rs @@ -0,0 +1,2 @@ +mod net; +pub use net::{PipeConnection, PipeListener, ClientConnection}; diff --git a/src/sync/sys/unix/net.rs b/src/sync/sys/unix/net.rs new file mode 100644 index 00000000..3fdf47b8 --- /dev/null +++ b/src/sync/sys/unix/net.rs @@ -0,0 +1,334 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +use crate::error::Result; +use nix::sys::socket::*; +use std::io::{self}; +use std::os::unix::io::RawFd; +use std::os::unix::prelude::AsRawFd; +use nix::Error; + +use nix::unistd::*; +use std::sync::{Arc}; +use std::sync::atomic::{AtomicBool, Ordering}; +use crate::common::{self, client_connect, SOCK_CLOEXEC}; +#[cfg(target_os = "macos")] +use crate::common::set_fd_close_exec; +use nix::sys::socket::{self}; + +pub struct PipeListener { + fd: RawFd, + monitor_fd: (RawFd, RawFd), +} + +impl AsRawFd for PipeListener { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl PipeListener { + pub(crate) fn new(sockaddr: &str) -> Result { + let (fd, _) = common::do_bind(sockaddr)?; + common::do_listen(fd)?; + + let fds = PipeListener::new_monitor_fd()?; + + Ok(PipeListener { + fd, + monitor_fd: fds, + }) + } + + pub(crate) fn new_from_fd(fd: RawFd) -> Result { + let fds = PipeListener::new_monitor_fd()?; + + Ok(PipeListener { + fd, + monitor_fd: fds, + }) + } + + fn new_monitor_fd() -> Result<(i32, i32)> { + #[cfg(target_os = "linux")] + let fds = pipe2(nix::fcntl::OFlag::O_CLOEXEC)?; + + + #[cfg(target_os = "macos")] + let fds = { + let (rfd, wfd) = pipe()?; + set_fd_close_exec(rfd)?; + set_fd_close_exec(wfd)?; + (rfd, wfd) + }; + + Ok(fds) + } + + // accept returns: + // - Ok(Some(PipeConnection)) if a new connection is established + // - Ok(None) if spurious wake up with no new connection + // - Err(io::Error) if there is an error and listener loop should be shutdown + pub(crate) fn accept( &self, quit_flag: &Arc) -> std::result::Result, io::Error> { + if quit_flag.load(Ordering::SeqCst) { + return Err(io::Error::new(io::ErrorKind::Other, "listener shutdown for quit flag")); + } + + let mut pollers = vec![ + libc::pollfd { + fd: self.monitor_fd.0, + events: libc::POLLIN, + revents: 0, + }, + libc::pollfd { + fd: self.fd, + events: libc::POLLIN, + revents: 0, + }, + ]; + + let returned = unsafe { + let pollers: &mut [libc::pollfd] = &mut pollers; + libc::poll( + pollers as *mut _ as *mut libc::pollfd, + pollers.len() as _, + -1, + ) + }; + + if returned == -1 { + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(libc::EINTR) { + return Err(err); + } + + error!("fatal error in listener_loop:{:?}", err); + return Err(err); + } else if returned < 1 { + return Ok(None) + } + + if pollers[0].revents != 0 || pollers[pollers.len() - 1].revents == 0 { + return Ok(None); + } + + if quit_flag.load(Ordering::SeqCst) { + return Err(io::Error::new(io::ErrorKind::Other, "listener shutdown for quit flag")); + } + + #[cfg(target_os = "linux")] + let fd = match accept4(self.fd, SockFlag::SOCK_CLOEXEC) { + Ok(fd) => fd, + Err(e) => { + error!("failed to accept error {:?}", e); + return Err(std::io::Error::from_raw_os_error(e as i32)); + } + }; + + // Non Linux platforms do not support accept4 with SOCK_CLOEXEC flag, so instead + // use accept and call fcntl separately to set SOCK_CLOEXEC. + // Because of this there is chance of the descriptor leak if fork + exec happens in between. + #[cfg(target_os = "macos")] + let fd = match accept(self.fd) { + Ok(fd) => { + if let Err(err) = set_fd_close_exec(fd) { + error!("fcntl failed after accept: {:?}", err); + return Err(io::Error::new(io::ErrorKind::Other, format!("{err:?}"))); + }; + fd + } + Err(e) => { + error!("failed to accept error {:?}", e); + return Err(std::io::Error::from_raw_os_error(e as i32)); + } + }; + + + Ok(Some(PipeConnection { fd })) + } + + pub fn close(&self) -> Result<()> { + close(self.monitor_fd.1).unwrap_or_else(|e| { + warn!( + "failed to close notify fd: {} with error: {}", + self.monitor_fd.1, e + ) + }); + Ok(()) + } +} + + +pub struct PipeConnection { + fd: RawFd, +} + +impl PipeConnection { + pub(crate) fn new(fd: RawFd) -> PipeConnection { + PipeConnection { fd } + } + + pub(crate) fn id(&self) -> i32 { + self.fd + } + + pub fn read(&self, buf: &mut [u8]) -> Result { + loop { + match recv(self.fd, buf, MsgFlags::empty()) { + Ok(l) => return Ok(l), + Err(e) if retryable(e) => { + // Should retry + continue; + } + Err(e) => { + return Err(crate::Error::Nix(e)); + } + } + } + } + + pub fn write(&self, buf: &[u8]) -> Result { + loop { + match send(self.fd, buf, MsgFlags::empty()) { + Ok(l) => return Ok(l), + Err(e) if retryable(e) => { + // Should retry + continue; + } + Err(e) => { + return Err(crate::Error::Nix(e)); + } + } + } + } + + pub fn close(&self) -> Result<()> { + match close(self.fd) { + Ok(_) => Ok(()), + Err(e) => Err(crate::Error::Nix(e)) + } + } + + pub fn shutdown(&self) -> Result<()> { + match socket::shutdown(self.fd, Shutdown::Read) { + Ok(_) => Ok(()), + Err(e) => Err(crate::Error::Nix(e)) + } + } +} + +pub struct ClientConnection { + fd: RawFd, + socket_pair: (RawFd, RawFd), +} + +impl ClientConnection { + pub fn client_connect(sockaddr: &str)-> Result { + let fd = unsafe { client_connect(sockaddr)? }; + Ok(ClientConnection::new(fd)) + } + + pub(crate) fn new(fd: RawFd) -> ClientConnection { + let (recver_fd, close_fd) = + socketpair(AddressFamily::Unix, SockType::Stream, None, SOCK_CLOEXEC).unwrap(); + + // MacOS doesn't support descriptor creation with SOCK_CLOEXEC automically, + // so there is a chance of leak if fork + exec happens in between of these calls. + #[cfg(target_os = "macos")] + { + set_fd_close_exec(recver_fd).unwrap(); + set_fd_close_exec(close_fd).unwrap(); + } + + + ClientConnection { + fd, + socket_pair: (recver_fd, close_fd) + } + } + + pub fn ready(&self) -> std::result::Result, io::Error> { + let mut pollers = vec![ + libc::pollfd { + fd: self.socket_pair.0, + events: libc::POLLIN, + revents: 0, + }, + libc::pollfd { + fd: self.fd, + events: libc::POLLIN, + revents: 0, + }, + ]; + + let returned = unsafe { + let pollers: &mut [libc::pollfd] = &mut pollers; + libc::poll( + pollers as *mut _ as *mut libc::pollfd, + pollers.len() as _, + -1, + ) + }; + + if returned == -1 { + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(libc::EINTR) { + return Ok(None) + } + + error!("fatal error in process reaper:{}", err); + return Err(err); + } else if returned < 1 { + return Ok(None) + } + + if pollers[0].revents != 0 { + return Err(io::Error::new(io::ErrorKind::Other, "pipe closed")); + } + + if pollers[pollers.len() - 1].revents == 0 { + return Ok(None) + } + + Ok(Some(())) + } + + pub fn get_pipe_connection(&self) -> Result { + Ok(PipeConnection::new(self.fd)) + } + + pub fn close_receiver(&self) -> Result<()> { + match close(self.socket_pair.0) { + Ok(_) => Ok(()), + Err(e) => Err(crate::Error::Nix(e)) + } + } + + pub fn close(&self) -> Result<()> { + match close(self.socket_pair.1) { + Ok(_) => {}, + Err(e) => return Err(crate::Error::Nix(e)) + }; + + match close(self.fd) { + Ok(_) => Ok(()), + Err(e) => Err(crate::Error::Nix(e)) + } + } +} + +fn retryable(e: nix::Error) -> bool { + e == Error::EINTR || e == Error::EAGAIN +} diff --git a/src/sync/sys/windows/mod.rs b/src/sync/sys/windows/mod.rs new file mode 100644 index 00000000..bc36d736 --- /dev/null +++ b/src/sync/sys/windows/mod.rs @@ -0,0 +1,2 @@ +mod net; +pub use net::{PipeConnection, PipeListener, ClientConnection}; diff --git a/src/sync/sys/windows/net.rs b/src/sync/sys/windows/net.rs new file mode 100644 index 00000000..1a5adbca --- /dev/null +++ b/src/sync/sys/windows/net.rs @@ -0,0 +1,338 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +use crate::error::Result; +use crate::error::Error; +use std::cell::UnsafeCell; +use std::ffi::OsStr; +use std::fs::OpenOptions; +use std::os::windows::ffi::OsStrExt; +use std::os::windows::fs::OpenOptionsExt; +use std::os::windows::io::{IntoRawHandle}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc}; +use std::{io}; + +use windows_sys::Win32::Foundation::{ CloseHandle, ERROR_IO_PENDING, ERROR_PIPE_CONNECTED, INVALID_HANDLE_VALUE }; +use windows_sys::Win32::Storage::FileSystem::{ ReadFile, WriteFile, FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, PIPE_ACCESS_DUPLEX }; +use windows_sys::Win32::System::IO::{ GetOverlappedResult, OVERLAPPED }; +use windows_sys::Win32::System::Pipes::{ CreateNamedPipeW, ConnectNamedPipe,DisconnectNamedPipe, PIPE_WAIT, PIPE_UNLIMITED_INSTANCES, PIPE_REJECT_REMOTE_CLIENTS }; +use windows_sys::Win32::System::Threading::{CreateEventW, SetEvent}; + +const PIPE_BUFFER_SIZE: u32 = 65536; +const WAIT_FOR_EVENT: i32 = 1; + +pub struct PipeListener { + first_instance: AtomicBool, + address: String, + connection_event: isize, +} + +#[repr(C)] +struct Overlapped { + inner: UnsafeCell, +} + +impl Overlapped { + fn new_with_event(event: isize) -> Overlapped { + let mut ol = Overlapped { + inner: UnsafeCell::new(unsafe { std::mem::zeroed() }), + }; + ol.inner.get_mut().hEvent = event; + ol + } + + fn as_mut_ptr(&self) -> *mut OVERLAPPED { + self.inner.get() + } +} + +impl PipeListener { + pub(crate) fn new(sockaddr: &str) -> Result { + let connection_event = create_event()?; + Ok(PipeListener { + first_instance: AtomicBool::new(true), + address: sockaddr.to_string(), + connection_event + }) + } + + // accept returns: + // - Ok(Some(PipeConnection)) if a new connection is established + // - Err(io::Error) if there is an error and listener loop should be shutdown + pub(crate) fn accept(&self, quit_flag: &Arc) -> std::result::Result, io::Error> { + if quit_flag.load(Ordering::SeqCst) { + return Err(io::Error::new( + io::ErrorKind::Other, + "listener shutdown for quit flag", + )); + } + + // Create a new pipe instance for every new client + let instance = self.new_instance()?; + let np = match PipeConnection::new(instance) { + Ok(np) => np, + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("failed to create new pipe instance: {:?}", e), + )); + } + }; + + let ol = Overlapped::new_with_event(self.connection_event); + + trace!("listening for connection"); + let result = unsafe { ConnectNamedPipe(np.named_pipe, ol.as_mut_ptr())}; + if result != 0 { + return Err(io::Error::last_os_error()); + } + + match io::Error::last_os_error() { + e if e.raw_os_error() == Some(ERROR_IO_PENDING as i32) => { + let mut bytes_transfered = 0; + let res = unsafe {GetOverlappedResult(np.named_pipe, ol.as_mut_ptr(), &mut bytes_transfered, WAIT_FOR_EVENT) }; + match res { + 0 => { + return Err(io::Error::last_os_error()); + } + _ => { + Ok(Some(np)) + } + } + } + e if e.raw_os_error() == Some(ERROR_PIPE_CONNECTED as i32) => { + Ok(Some(np)) + } + e => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("failed to connect pipe: {:?}", e), + )); + } + } + } + + fn new_instance(&self) -> io::Result { + let name = OsStr::new(&self.address.as_str()) + .encode_wide() + .chain(Some(0)) // add NULL termination + .collect::>(); + + let mut open_mode = PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED ; + + if self.first_instance.load(Ordering::SeqCst) { + open_mode |= FILE_FLAG_FIRST_PIPE_INSTANCE; + self.first_instance.swap(false, Ordering::SeqCst); + } + + // null for security attributes means the handle cannot be inherited and write access is restricted to system + // https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipe-security-and-access-rights + match unsafe { CreateNamedPipeW(name.as_ptr(), open_mode, PIPE_WAIT | PIPE_REJECT_REMOTE_CLIENTS, PIPE_UNLIMITED_INSTANCES, PIPE_BUFFER_SIZE, PIPE_BUFFER_SIZE, 0, std::ptr::null_mut())} { + INVALID_HANDLE_VALUE => { + return Err(io::Error::last_os_error()) + } + h => { + return Ok(h) + }, + }; + } + + pub fn close(&self) -> Result<()> { + // release the ConnectNamedPipe thread by signaling the event and clean up event handle + set_event(self.connection_event)?; + close_handle(self.connection_event) + } +} + +pub struct PipeConnection { + named_pipe: isize, + read_event: isize, + write_event: isize, +} + +// PipeConnection on Windows is used by both the Server and Client to read and write to the named pipe +// The named pipe is created with the overlapped flag enable the simultaneous read and write operations. +// This is required since a read and write be issued at the same time on a given named pipe instance. +// +// An event is created for the read and write operations. When the read or write is issued +// it either returns immediately or the thread is suspended until the event is signaled when +// the overlapped (async) operation completes and the event is triggered allow the thread to continue. +// +// Due to the implementation of the sync Server and client there is always only one read and one write +// operation in flight at a time so we can reuse the same event. +// +// For more information on overlapped and events: https://learn.microsoft.com/en-us/windows/win32/api/ioapiset/nf-ioapiset-getoverlappedresult#remarks +// "It is safer to use an event object because of the confusion that can occur when multiple simultaneous overlapped operations are performed on the same file, named pipe, or communications device." +// "In this situation, there is no way to know which operation caused the object's state to be signaled." +impl PipeConnection { + pub(crate) fn new(h: isize) -> Result { + trace!("creating events for thread {:?} on pipe instance {}", std::thread::current().id(), h as i32); + let read_event = create_event()?; + let write_event = create_event()?; + Ok(PipeConnection { + named_pipe: h, + read_event: read_event, + write_event: write_event, + }) + } + + pub(crate) fn id(&self) -> i32 { + self.named_pipe as i32 + } + + pub fn read(&self, buf: &mut [u8]) -> Result { + trace!("starting read for thread {:?} on pipe instance {}", std::thread::current().id(), self.named_pipe as i32); + let ol = Overlapped::new_with_event(self.read_event); + + let len = std::cmp::min(buf.len(), u32::MAX as usize) as u32; + let mut bytes_read= 0; + let result = unsafe { ReadFile(self.named_pipe, buf.as_mut_ptr() as *mut _, len, &mut bytes_read,ol.as_mut_ptr()) }; + if result > 0 && bytes_read > 0 { + // Got result no need to wait for pending read to complete + return Ok(bytes_read as usize) + } + + // wait for pending operation to complete (thread will be suspended until event is signaled) + match io::Error::last_os_error() { + ref e if e.raw_os_error() == Some(ERROR_IO_PENDING as i32) => { + let mut bytes_transfered = 0; + let res = unsafe {GetOverlappedResult(self.named_pipe, ol.as_mut_ptr(), &mut bytes_transfered, WAIT_FOR_EVENT) }; + match res { + 0 => { + return Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())) + } + _ => { + return Ok(bytes_transfered as usize) + } + } + } + ref e => { + return Err(Error::Others(format!("failed to read from pipe: {:?}", e))) + } + } + } + + pub fn write(&self, buf: &[u8]) -> Result { + trace!("starting write for thread {:?} on pipe instance {}", std::thread::current().id(), self.named_pipe as i32); + let ol = Overlapped::new_with_event(self.write_event); + let mut bytes_written = 0; + let len = std::cmp::min(buf.len(), u32::MAX as usize) as u32; + let result = unsafe { WriteFile(self.named_pipe, buf.as_ptr() as *const _,len, &mut bytes_written, ol.as_mut_ptr())}; + if result > 0 && bytes_written > 0 { + // No need to wait for pending write to complete + return Ok(bytes_written as usize) + } + + // wait for pending operation to complete (thread will be suspended until event is signaled) + match io::Error::last_os_error() { + ref e if e.raw_os_error() == Some(ERROR_IO_PENDING as i32) => { + let mut bytes_transfered = 0; + let res = unsafe {GetOverlappedResult(self.named_pipe, ol.as_mut_ptr(), &mut bytes_transfered, WAIT_FOR_EVENT) }; + match res { + 0 => { + return Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())) + } + _ => { + return Ok(bytes_transfered as usize) + } + } + } + ref e => { + return Err(Error::Others(format!("failed to write to pipe: {:?}", e))) + } + } + } + + pub fn close(&self) -> Result<()> { + close_handle(self.named_pipe)?; + close_handle(self.read_event)?; + close_handle(self.write_event) + } + + pub fn shutdown(&self) -> Result<()> { + let result = unsafe { DisconnectNamedPipe(self.named_pipe) }; + match result { + 0 => Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())), + _ => Ok(()), + } + } +} + +pub struct ClientConnection { + address: String +} + +fn close_handle(handle: isize) -> Result<()> { + let result = unsafe { CloseHandle(handle) }; + match result { + 0 => Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())), + _ => Ok(()), + } +} + +fn create_event() -> Result { + let result = unsafe { CreateEventW(std::ptr::null_mut(), 0, 1, std::ptr::null_mut()) }; + match result { + 0 => Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())), + _ => Ok(result), + } +} + +fn set_event(event: isize) -> Result<()> { + let result = unsafe { SetEvent(event) }; + match result { + 0 => Err(Error::Windows(io::Error::last_os_error().raw_os_error().unwrap())), + _ => Ok(()), + } +} + +impl ClientConnection { + pub fn client_connect(sockaddr: &str) -> Result { + Ok(ClientConnection::new(sockaddr)) + } + + pub(crate) fn new(sockaddr: &str) -> ClientConnection { + ClientConnection { + address: sockaddr.to_string() + } + } + + pub fn ready(&self) -> std::result::Result, io::Error> { + // Windows is a "completion" based system so "readiness" isn't really applicable + Ok(Some(())) + } + + pub fn get_pipe_connection(&self) -> Result { + let mut opts = OpenOptions::new(); + opts.read(true) + .write(true) + .custom_flags(FILE_FLAG_OVERLAPPED); + let file = opts.open(self.address.as_str()); + + return PipeConnection::new(file.unwrap().into_raw_handle() as isize) + } + + pub fn close_receiver(&self) -> Result<()> { + // close the pipe from the pipe connection + Ok(()) + } + + pub fn close(&self) -> Result<()> { + // close the pipe from the pipe connection + Ok(()) + } +} diff --git a/src/sync/utils.rs b/src/sync/utils.rs index 80069ae6..3ad6a845 100644 --- a/src/sync/utils.rs +++ b/src/sync/utils.rs @@ -97,7 +97,10 @@ macro_rules! client_request { /// The context of ttrpc (sync). #[derive(Debug)] pub struct TtrpcContext { + #[cfg(unix)] pub fd: std::os::unix::io::RawFd, + #[cfg(windows)] + pub fd: i32, pub mh: MessageHeader, pub res_tx: std::sync::mpsc::Sender<(MessageHeader, Vec)>, pub metadata: HashMap>, diff --git a/tests/sync-test.rs b/tests/sync-test.rs new file mode 100644 index 00000000..a5a92149 --- /dev/null +++ b/tests/sync-test.rs @@ -0,0 +1,61 @@ +use std::{ + io::{BufRead, BufReader}, + process::Command, + time::Duration, +}; + +#[test] +fn run_sync_example() -> Result<(), Box> { + // start the server and give it a moment to start. + let mut server = run_example("server").spawn().unwrap(); + std::thread::sleep(Duration::from_secs(2)); + + let mut client = run_example("client").spawn().unwrap(); + let mut client_succeeded = false; + let start = std::time::Instant::now(); + let timeout = Duration::from_secs(600); + loop { + if start.elapsed() > timeout { + println!("Running the client timed out. output:"); + client.kill().unwrap_or_else(|e| { + println!("This may occur on Windows if the process has exited: {e}"); + }); + let output = client.stdout.unwrap(); + BufReader::new(output).lines().for_each(|line| { + println!("{}", line.unwrap()); + }); + break; + } + + match client.try_wait() { + Ok(Some(status)) => { + client_succeeded = status.success(); + break; + } + Ok(None) => { + // still running + continue; + } + Err(e) => { + println!("Error: {e}"); + break; + } + } + } + + // be sure to clean up the server, the client should have run to completion + server.kill()?; + assert!(client_succeeded); + Ok(()) +} + +fn run_example(example: &str) -> Command { + let mut cmd = Command::new("cargo"); + cmd.arg("run") + .arg("--example") + .arg(example) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .current_dir("example"); + cmd +}