Skip to content

Commit

Permalink
Add domain unix socket supports (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
XxChang authored Jul 19, 2024
1 parent f722f1a commit bccb1ae
Show file tree
Hide file tree
Showing 19 changed files with 321 additions and 24 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ jobs:
if: runner.os == 'Linux'
timeout-minutes: 30
run: cargo run --example cmake-dataflow
- name: "Unix Domain Socket example"
if: runner.os == 'Linux'
run: cargo run --example rust-dataflow -- dataflow_socket.yml

# python examples
- uses: actions/setup-python@v2
Expand Down
15 changes: 15 additions & 0 deletions apis/rust/node/src/daemon_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@ use dora_core::{
};
use eyre::{bail, eyre, Context};
use shared_memory_server::{ShmemClient, ShmemConf};
#[cfg(unix)]
use std::os::unix::net::UnixStream;
use std::{
net::{SocketAddr, TcpStream},
time::Duration,
};

mod tcp;
#[cfg(unix)]
mod unix_domain;

pub enum DaemonChannel {
Shmem(ShmemClient<Timestamped<DaemonRequest>, DaemonReply>),
Tcp(TcpStream),
#[cfg(unix)]
UnixDomain(UnixStream),
}

impl DaemonChannel {
Expand All @@ -38,6 +44,13 @@ impl DaemonChannel {
Ok(channel)
}

#[cfg(unix)]
#[tracing::instrument(level = "trace")]
pub fn new_unix_socket(path: &std::path::PathBuf) -> eyre::Result<Self> {
let stream = UnixStream::connect(path).wrap_err("failed to open Unix socket")?;
Ok(DaemonChannel::UnixDomain(stream))
}

pub fn register(
&mut self,
dataflow_id: DataflowId,
Expand Down Expand Up @@ -69,6 +82,8 @@ impl DaemonChannel {
match self {
DaemonChannel::Shmem(client) => client.request(request),
DaemonChannel::Tcp(stream) => tcp::request(stream, request),
#[cfg(unix)]
DaemonChannel::UnixDomain(stream) => unix_domain::request(stream, request),
}
}
}
84 changes: 84 additions & 0 deletions apis/rust/node/src/daemon_connection/unix_domain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use dora_core::daemon_messages::{DaemonReply, DaemonRequest, Timestamped};
use eyre::{eyre, Context};
use std::{
io::{Read, Write},
os::unix::net::UnixStream,
};

enum Serializer {
Bincode,
SerdeJson,
}
pub fn request(
connection: &mut UnixStream,
request: &Timestamped<DaemonRequest>,
) -> eyre::Result<DaemonReply> {
send_message(connection, request)?;
if request.inner.expects_tcp_bincode_reply() {
receive_reply(connection, Serializer::Bincode)
.and_then(|reply| reply.ok_or_else(|| eyre!("server disconnected unexpectedly")))
// Use serde json for message with variable length
} else if request.inner.expects_tcp_json_reply() {
receive_reply(connection, Serializer::SerdeJson)
.and_then(|reply| reply.ok_or_else(|| eyre!("server disconnected unexpectedly")))
} else {
Ok(DaemonReply::Empty)
}
}

fn send_message(
connection: &mut UnixStream,
message: &Timestamped<DaemonRequest>,
) -> eyre::Result<()> {
let serialized = bincode::serialize(&message).wrap_err("failed to serialize DaemonRequest")?;
stream_send(connection, &serialized).wrap_err("failed to send DaemonRequest")?;
Ok(())
}

fn receive_reply(
connection: &mut UnixStream,
serializer: Serializer,
) -> eyre::Result<Option<DaemonReply>> {
let raw = match stream_receive(connection) {
Ok(raw) => raw,
Err(err) => match err.kind() {
std::io::ErrorKind::UnexpectedEof | std::io::ErrorKind::ConnectionAborted => {
return Ok(None)
}
other => {
return Err(err).with_context(|| {
format!(
"unexpected I/O error (kind {other:?}) while trying to receive DaemonReply"
)
})
}
},
};
match serializer {
Serializer::Bincode => bincode::deserialize(&raw)
.wrap_err("failed to deserialize DaemonReply")
.map(Some),
Serializer::SerdeJson => serde_json::from_slice(&raw)
.wrap_err("failed to deserialize DaemonReply")
.map(Some),
}
}

fn stream_send(connection: &mut (impl Write + Unpin), message: &[u8]) -> std::io::Result<()> {
let len_raw = (message.len() as u64).to_le_bytes();
connection.write_all(&len_raw)?;
connection.write_all(message)?;
connection.flush()?;
Ok(())
}

fn stream_receive(connection: &mut (impl Read + Unpin)) -> std::io::Result<Vec<u8>> {
let reply_len = {
let mut raw = [0; 8];
connection.read_exact(&mut raw)?;
u64::from_le_bytes(raw) as usize
};
let mut reply = vec![0; reply_len];
connection.read_exact(&mut reply)?;
Ok(reply)
}
12 changes: 12 additions & 0 deletions apis/rust/node/src/event_stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ impl EventStream {
)?,
DaemonCommunication::Tcp { socket_addr } => DaemonChannel::new_tcp(*socket_addr)
.wrap_err_with(|| format!("failed to connect event stream for node `{node_id}`"))?,
#[cfg(unix)]
DaemonCommunication::UnixDomain { socket_file } => {
DaemonChannel::new_unix_socket(socket_file).wrap_err_with(|| {
format!("failed to connect event stream for node `{node_id}`")
})?
}
};

let close_channel = match daemon_communication {
Expand All @@ -63,6 +69,12 @@ impl EventStream {
.wrap_err_with(|| {
format!("failed to connect event close channel for node `{node_id}`")
})?,
#[cfg(unix)]
DaemonCommunication::UnixDomain { socket_file } => {
DaemonChannel::new_unix_socket(socket_file).wrap_err_with(|| {
format!("failed to connect event close channel for node `{node_id}`")
})?
}
};

Self::init_on_channel(dataflow_id, node_id, channel, close_channel, clock)
Expand Down
5 changes: 5 additions & 0 deletions apis/rust/node/src/node/control_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ impl ControlChannel {
.wrap_err("failed to create shmem control channel")?,
DaemonCommunication::Tcp { socket_addr } => DaemonChannel::new_tcp(*socket_addr)
.wrap_err("failed to connect control channel")?,
#[cfg(unix)]
DaemonCommunication::UnixDomain { socket_file } => {
DaemonChannel::new_unix_socket(socket_file)
.wrap_err("failed to connect control channel")?
}
};

Self::init_on_channel(dataflow_id, node_id, channel, clock)
Expand Down
6 changes: 6 additions & 0 deletions apis/rust/node/src/node/drop_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ impl DropStream {
}
DaemonCommunication::Tcp { socket_addr } => DaemonChannel::new_tcp(*socket_addr)
.wrap_err_with(|| format!("failed to connect drop stream for node `{node_id}`"))?,
#[cfg(unix)]
DaemonCommunication::UnixDomain { socket_file } => {
DaemonChannel::new_unix_socket(socket_file).wrap_err_with(|| {
format!("failed to connect drop stream for node `{node_id}`")
})?
}
};

Self::init_on_channel(dataflow_id, node_id, channel, hlc)
Expand Down
10 changes: 5 additions & 5 deletions binaries/daemon/src/coordinator.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
tcp_utils::{tcp_receive, tcp_send},
socket_stream_utils::{socket_stream_receive, socket_stream_send},
DaemonCoordinatorEvent,
};
use dora_core::{
Expand Down Expand Up @@ -41,10 +41,10 @@ pub async fn register(
},
timestamp: clock.new_timestamp(),
})?;
tcp_send(&mut stream, &register)
socket_stream_send(&mut stream, &register)
.await
.wrap_err("failed to send register request to dora-coordinator")?;
let reply_raw = tcp_receive(&mut stream)
let reply_raw = socket_stream_receive(&mut stream)
.await
.wrap_err("failed to register reply from dora-coordinator")?;
let result: Timestamped<RegisterResult> = serde_json::from_slice(&reply_raw)
Expand All @@ -59,7 +59,7 @@ pub async fn register(
let (tx, rx) = mpsc::channel(1);
tokio::spawn(async move {
loop {
let event = match tcp_receive(&mut stream).await {
let event = match socket_stream_receive(&mut stream).await {
Ok(raw) => match serde_json::from_slice(&raw) {
Ok(event) => event,
Err(err) => {
Expand Down Expand Up @@ -109,7 +109,7 @@ pub async fn register(
continue;
}
};
if let Err(err) = tcp_send(&mut stream, &serialized).await {
if let Err(err) = socket_stream_send(&mut stream, &serialized).await {
tracing::warn!("failed to send reply to coordinator: {err}");
continue;
};
Expand Down
6 changes: 3 additions & 3 deletions binaries/daemon/src/inter_daemon.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::tcp_utils::{tcp_receive, tcp_send};
use crate::socket_stream_utils::{socket_stream_receive, socket_stream_send};
use dora_core::daemon_messages::{InterDaemonEvent, Timestamped};
use eyre::{Context, ContextCompat};
use std::{collections::BTreeMap, io::ErrorKind, net::SocketAddr};
Expand Down Expand Up @@ -52,7 +52,7 @@ pub async fn send_inter_daemon_event(
.connect()
.await
.wrap_err_with(|| format!("failed to connect to machine `{target_machine}`"))?;
tcp_send(connection, &message)
socket_stream_send(connection, &message)
.await
.wrap_err_with(|| format!("failed to send event to machine `{target_machine}`"))?;
}
Expand Down Expand Up @@ -131,7 +131,7 @@ async fn handle_connection_loop(
async fn receive_message(
connection: &mut TcpStream,
) -> eyre::Result<Option<Timestamped<InterDaemonEvent>>> {
let raw = match tcp_receive(connection).await {
let raw = match socket_stream_receive(connection).await {
Ok(raw) => raw,
Err(err) => match err.kind() {
ErrorKind::UnexpectedEof
Expand Down
10 changes: 5 additions & 5 deletions binaries/daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use inter_daemon::InterDaemonConnection;
use local_listener::DynamicNodeEventWrapper;
use pending::PendingNodes;
use shared_memory_server::ShmemConf;
use socket_stream_utils::socket_stream_send;
use std::sync::Arc;
use std::time::Instant;
use std::{
Expand All @@ -39,7 +40,6 @@ use std::{
time::Duration,
};
use sysinfo::Pid;
use tcp_utils::tcp_send;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
Expand All @@ -56,8 +56,8 @@ mod local_listener;
mod log;
mod node_communication;
mod pending;
mod socket_stream_utils;
mod spawn;
mod tcp_utils;

#[cfg(feature = "telemetry")]
use dora_tracing::telemetry::serialize_context;
Expand Down Expand Up @@ -314,7 +314,7 @@ impl Daemon {
},
timestamp: self.clock.new_timestamp(),
})?;
tcp_send(connection, &msg)
socket_stream_send(connection, &msg)
.await
.wrap_err("failed to send watchdog message to dora-coordinator")?;

Expand Down Expand Up @@ -345,7 +345,7 @@ impl Daemon {
},
timestamp: self.clock.new_timestamp(),
})?;
tcp_send(connection, &msg)
socket_stream_send(connection, &msg)
.await
.wrap_err("failed to send watchdog message to dora-coordinator")?;

Expand Down Expand Up @@ -1103,7 +1103,7 @@ impl Daemon {
},
timestamp: self.clock.new_timestamp(),
})?;
tcp_send(connection, &msg)
socket_stream_send(connection, &msg)
.await
.wrap_err("failed to report dataflow finish to dora-coordinator")?;
}
Expand Down
6 changes: 3 additions & 3 deletions binaries/daemon/src/local_listener.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::tcp_utils::{tcp_receive, tcp_send};
use crate::socket_stream_utils::{socket_stream_receive, socket_stream_send};
use dora_core::daemon_messages::{DaemonReply, DaemonRequest, DynamicNodeEvent, Timestamped};
use eyre::Context;
use std::{io::ErrorKind, net::SocketAddr};
Expand Down Expand Up @@ -99,7 +99,7 @@ async fn handle_connection_loop(
continue;
}
};
if let Err(err) = tcp_send(&mut connection, &serialized).await {
if let Err(err) = socket_stream_send(&mut connection, &serialized).await {
tracing::warn!("failed to send reply: {err}");
continue;
};
Expand All @@ -120,7 +120,7 @@ async fn handle_connection_loop(
async fn receive_message(
connection: &mut TcpStream,
) -> eyre::Result<Option<Timestamped<DaemonRequest>>> {
let raw = match tcp_receive(connection).await {
let raw = match socket_stream_receive(connection).await {
Ok(raw) => raw,
Err(err) => match err.kind() {
ErrorKind::UnexpectedEof
Expand Down
34 changes: 34 additions & 0 deletions binaries/daemon/src/node_communication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use std::{
sync::Arc,
task::Poll,
};
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::{
net::TcpListener,
sync::{
Expand All @@ -28,6 +30,8 @@ use tokio::{
// TODO unify and avoid duplication;
pub mod shmem;
pub mod tcp;
#[cfg(unix)]
pub mod unix_domain;

pub async fn spawn_listener_loop(
dataflow_id: &DataflowId,
Expand Down Expand Up @@ -138,6 +142,36 @@ pub async fn spawn_listener_loop(
daemon_events_close_region_id,
})
}
#[cfg(unix)]
LocalCommunicationConfig::UnixDomain => {
use std::path::Path;
let tmpfile_dir = Path::new("/tmp");
let tmpfile_dir = tmpfile_dir.join(dataflow_id.to_string());
if !tmpfile_dir.exists() {
std::fs::create_dir_all(&tmpfile_dir).context("could not create tmp dir")?;
}
let socket_file = tmpfile_dir.join(format!("{}.sock", node_id));
let socket = match UnixListener::bind(&socket_file) {
Ok(socket) => socket,
Err(err) => {
return Err(eyre::Report::new(err)
.wrap_err("failed to create local Unix domain socket"))
}
};

let event_loop_node_id = format!("{dataflow_id}/{node_id}");
let daemon_tx = daemon_tx.clone();
tokio::spawn(async move {
unix_domain::listener_loop(socket, daemon_tx, queue_sizes, clock).await;
tracing::debug!("event listener loop finished for `{event_loop_node_id}`");
});

Ok(DaemonCommunication::UnixDomain { socket_file })
}
#[cfg(not(unix))]
LocalCommunicationConfig::UnixDomain => {
eyre::bail!("Communication via UNIX domain sockets is only supported on UNIX systems")
}
}
}

Expand Down
Loading

0 comments on commit bccb1ae

Please sign in to comment.