Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add domain unix socket supports #594

Merged
merged 11 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ jobs:
if: runner.os == 'Linux'
timeout-minutes: 30
run: cargo run --example cmake-dataflow
- name: "Unix Domain Socket example"
if: runner.os == 'Linux'
run: cargo run --example rust-dataflow -- dataflow_socket.yml

# python examples
- uses: actions/setup-python@v2
Expand Down
15 changes: 15 additions & 0 deletions apis/rust/node/src/daemon_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@ use dora_core::{
};
use eyre::{bail, eyre, Context};
use shared_memory_server::{ShmemClient, ShmemConf};
#[cfg(unix)]
use std::os::unix::net::UnixStream;
use std::{
net::{SocketAddr, TcpStream},
time::Duration,
};

mod tcp;
#[cfg(unix)]
mod unix_domain;

pub enum DaemonChannel {
Shmem(ShmemClient<Timestamped<DaemonRequest>, DaemonReply>),
Tcp(TcpStream),
#[cfg(unix)]
UnixDomain(UnixStream),
}

impl DaemonChannel {
Expand All @@ -38,6 +44,13 @@ impl DaemonChannel {
Ok(channel)
}

#[cfg(unix)]
#[tracing::instrument(level = "trace")]
pub fn new_unix_socket(path: &std::path::PathBuf) -> eyre::Result<Self> {
let stream = UnixStream::connect(path).wrap_err("failed to open Unix socket")?;
Ok(DaemonChannel::UnixDomain(stream))
}

pub fn register(
&mut self,
dataflow_id: DataflowId,
Expand Down Expand Up @@ -69,6 +82,8 @@ impl DaemonChannel {
match self {
DaemonChannel::Shmem(client) => client.request(request),
DaemonChannel::Tcp(stream) => tcp::request(stream, request),
#[cfg(unix)]
DaemonChannel::UnixDomain(stream) => unix_domain::request(stream, request),
}
}
}
84 changes: 84 additions & 0 deletions apis/rust/node/src/daemon_connection/unix_domain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use dora_core::daemon_messages::{DaemonReply, DaemonRequest, Timestamped};
use eyre::{eyre, Context};
use std::{
io::{Read, Write},
os::unix::net::UnixStream,
};

enum Serializer {
Bincode,
SerdeJson,
}
pub fn request(
connection: &mut UnixStream,
request: &Timestamped<DaemonRequest>,
) -> eyre::Result<DaemonReply> {
send_message(connection, request)?;
if request.inner.expects_tcp_bincode_reply() {
receive_reply(connection, Serializer::Bincode)
.and_then(|reply| reply.ok_or_else(|| eyre!("server disconnected unexpectedly")))
// Use serde json for message with variable length
} else if request.inner.expects_tcp_json_reply() {
receive_reply(connection, Serializer::SerdeJson)
.and_then(|reply| reply.ok_or_else(|| eyre!("server disconnected unexpectedly")))
} else {
Ok(DaemonReply::Empty)
}
}

fn send_message(
connection: &mut UnixStream,
message: &Timestamped<DaemonRequest>,
) -> eyre::Result<()> {
let serialized = bincode::serialize(&message).wrap_err("failed to serialize DaemonRequest")?;
stream_send(connection, &serialized).wrap_err("failed to send DaemonRequest")?;
Ok(())
}

fn receive_reply(
connection: &mut UnixStream,
serializer: Serializer,
) -> eyre::Result<Option<DaemonReply>> {
let raw = match stream_receive(connection) {
Ok(raw) => raw,
Err(err) => match err.kind() {
std::io::ErrorKind::UnexpectedEof | std::io::ErrorKind::ConnectionAborted => {
return Ok(None)
}
other => {
return Err(err).with_context(|| {
format!(
"unexpected I/O error (kind {other:?}) while trying to receive DaemonReply"
)
})
}
},
};
match serializer {
Serializer::Bincode => bincode::deserialize(&raw)
.wrap_err("failed to deserialize DaemonReply")
.map(Some),
Serializer::SerdeJson => serde_json::from_slice(&raw)
.wrap_err("failed to deserialize DaemonReply")
.map(Some),
}
}

fn stream_send(connection: &mut (impl Write + Unpin), message: &[u8]) -> std::io::Result<()> {
let len_raw = (message.len() as u64).to_le_bytes();
connection.write_all(&len_raw)?;
connection.write_all(message)?;
connection.flush()?;
Ok(())
}

fn stream_receive(connection: &mut (impl Read + Unpin)) -> std::io::Result<Vec<u8>> {
let reply_len = {
let mut raw = [0; 8];
connection.read_exact(&mut raw)?;
u64::from_le_bytes(raw) as usize
};
let mut reply = vec![0; reply_len];
connection.read_exact(&mut reply)?;
Ok(reply)
}
12 changes: 12 additions & 0 deletions apis/rust/node/src/event_stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ impl EventStream {
)?,
DaemonCommunication::Tcp { socket_addr } => DaemonChannel::new_tcp(*socket_addr)
.wrap_err_with(|| format!("failed to connect event stream for node `{node_id}`"))?,
#[cfg(unix)]
DaemonCommunication::UnixDomain { socket_file } => {
DaemonChannel::new_unix_socket(socket_file).wrap_err_with(|| {
format!("failed to connect event stream for node `{node_id}`")
})?
}
};

let close_channel = match daemon_communication {
Expand All @@ -63,6 +69,12 @@ impl EventStream {
.wrap_err_with(|| {
format!("failed to connect event close channel for node `{node_id}`")
})?,
#[cfg(unix)]
DaemonCommunication::UnixDomain { socket_file } => {
DaemonChannel::new_unix_socket(socket_file).wrap_err_with(|| {
format!("failed to connect event close channel for node `{node_id}`")
})?
}
};

Self::init_on_channel(dataflow_id, node_id, channel, close_channel, clock)
Expand Down
5 changes: 5 additions & 0 deletions apis/rust/node/src/node/control_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ impl ControlChannel {
.wrap_err("failed to create shmem control channel")?,
DaemonCommunication::Tcp { socket_addr } => DaemonChannel::new_tcp(*socket_addr)
.wrap_err("failed to connect control channel")?,
#[cfg(unix)]
DaemonCommunication::UnixDomain { socket_file } => {
DaemonChannel::new_unix_socket(socket_file)
.wrap_err("failed to connect control channel")?
}
};

Self::init_on_channel(dataflow_id, node_id, channel, clock)
Expand Down
6 changes: 6 additions & 0 deletions apis/rust/node/src/node/drop_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ impl DropStream {
}
DaemonCommunication::Tcp { socket_addr } => DaemonChannel::new_tcp(*socket_addr)
.wrap_err_with(|| format!("failed to connect drop stream for node `{node_id}`"))?,
#[cfg(unix)]
DaemonCommunication::UnixDomain { socket_file } => {
DaemonChannel::new_unix_socket(socket_file).wrap_err_with(|| {
format!("failed to connect drop stream for node `{node_id}`")
})?
}
};

Self::init_on_channel(dataflow_id, node_id, channel, hlc)
Expand Down
10 changes: 5 additions & 5 deletions binaries/daemon/src/coordinator.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
tcp_utils::{tcp_receive, tcp_send},
socket_stream_utils::{socket_stream_receive, socket_stream_send},
DaemonCoordinatorEvent,
};
use dora_core::{
Expand Down Expand Up @@ -41,10 +41,10 @@ pub async fn register(
},
timestamp: clock.new_timestamp(),
})?;
tcp_send(&mut stream, &register)
socket_stream_send(&mut stream, &register)
.await
.wrap_err("failed to send register request to dora-coordinator")?;
let reply_raw = tcp_receive(&mut stream)
let reply_raw = socket_stream_receive(&mut stream)
.await
.wrap_err("failed to register reply from dora-coordinator")?;
let result: Timestamped<RegisterResult> = serde_json::from_slice(&reply_raw)
Expand All @@ -59,7 +59,7 @@ pub async fn register(
let (tx, rx) = mpsc::channel(1);
tokio::spawn(async move {
loop {
let event = match tcp_receive(&mut stream).await {
let event = match socket_stream_receive(&mut stream).await {
Ok(raw) => match serde_json::from_slice(&raw) {
Ok(event) => event,
Err(err) => {
Expand Down Expand Up @@ -109,7 +109,7 @@ pub async fn register(
continue;
}
};
if let Err(err) = tcp_send(&mut stream, &serialized).await {
if let Err(err) = socket_stream_send(&mut stream, &serialized).await {
tracing::warn!("failed to send reply to coordinator: {err}");
continue;
};
Expand Down
6 changes: 3 additions & 3 deletions binaries/daemon/src/inter_daemon.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::tcp_utils::{tcp_receive, tcp_send};
use crate::socket_stream_utils::{socket_stream_receive, socket_stream_send};
use dora_core::daemon_messages::{InterDaemonEvent, Timestamped};
use eyre::{Context, ContextCompat};
use std::{collections::BTreeMap, io::ErrorKind, net::SocketAddr};
Expand Down Expand Up @@ -52,7 +52,7 @@ pub async fn send_inter_daemon_event(
.connect()
.await
.wrap_err_with(|| format!("failed to connect to machine `{target_machine}`"))?;
tcp_send(connection, &message)
socket_stream_send(connection, &message)
.await
.wrap_err_with(|| format!("failed to send event to machine `{target_machine}`"))?;
}
Expand Down Expand Up @@ -131,7 +131,7 @@ async fn handle_connection_loop(
async fn receive_message(
connection: &mut TcpStream,
) -> eyre::Result<Option<Timestamped<InterDaemonEvent>>> {
let raw = match tcp_receive(connection).await {
let raw = match socket_stream_receive(connection).await {
Ok(raw) => raw,
Err(err) => match err.kind() {
ErrorKind::UnexpectedEof
Expand Down
10 changes: 5 additions & 5 deletions binaries/daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use inter_daemon::InterDaemonConnection;
use local_listener::DynamicNodeEventWrapper;
use pending::PendingNodes;
use shared_memory_server::ShmemConf;
use socket_stream_utils::socket_stream_send;
use std::sync::Arc;
use std::time::Instant;
use std::{
Expand All @@ -39,7 +40,6 @@ use std::{
time::Duration,
};
use sysinfo::Pid;
use tcp_utils::tcp_send;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
Expand All @@ -56,8 +56,8 @@ mod local_listener;
mod log;
mod node_communication;
mod pending;
mod socket_stream_utils;
mod spawn;
mod tcp_utils;

#[cfg(feature = "telemetry")]
use dora_tracing::telemetry::serialize_context;
Expand Down Expand Up @@ -314,7 +314,7 @@ impl Daemon {
},
timestamp: self.clock.new_timestamp(),
})?;
tcp_send(connection, &msg)
socket_stream_send(connection, &msg)
.await
.wrap_err("failed to send watchdog message to dora-coordinator")?;

Expand Down Expand Up @@ -345,7 +345,7 @@ impl Daemon {
},
timestamp: self.clock.new_timestamp(),
})?;
tcp_send(connection, &msg)
socket_stream_send(connection, &msg)
.await
.wrap_err("failed to send watchdog message to dora-coordinator")?;

Expand Down Expand Up @@ -1103,7 +1103,7 @@ impl Daemon {
},
timestamp: self.clock.new_timestamp(),
})?;
tcp_send(connection, &msg)
socket_stream_send(connection, &msg)
.await
.wrap_err("failed to report dataflow finish to dora-coordinator")?;
}
Expand Down
6 changes: 3 additions & 3 deletions binaries/daemon/src/local_listener.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::tcp_utils::{tcp_receive, tcp_send};
use crate::socket_stream_utils::{socket_stream_receive, socket_stream_send};
use dora_core::daemon_messages::{DaemonReply, DaemonRequest, DynamicNodeEvent, Timestamped};
use eyre::Context;
use std::{io::ErrorKind, net::SocketAddr};
Expand Down Expand Up @@ -99,7 +99,7 @@ async fn handle_connection_loop(
continue;
}
};
if let Err(err) = tcp_send(&mut connection, &serialized).await {
if let Err(err) = socket_stream_send(&mut connection, &serialized).await {
tracing::warn!("failed to send reply: {err}");
continue;
};
Expand All @@ -120,7 +120,7 @@ async fn handle_connection_loop(
async fn receive_message(
connection: &mut TcpStream,
) -> eyre::Result<Option<Timestamped<DaemonRequest>>> {
let raw = match tcp_receive(connection).await {
let raw = match socket_stream_receive(connection).await {
Ok(raw) => raw,
Err(err) => match err.kind() {
ErrorKind::UnexpectedEof
Expand Down
34 changes: 34 additions & 0 deletions binaries/daemon/src/node_communication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use std::{
sync::Arc,
task::Poll,
};
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::{
net::TcpListener,
sync::{
Expand All @@ -28,6 +30,8 @@ use tokio::{
// TODO unify and avoid duplication;
pub mod shmem;
pub mod tcp;
#[cfg(unix)]
pub mod unix_domain;

pub async fn spawn_listener_loop(
dataflow_id: &DataflowId,
Expand Down Expand Up @@ -138,6 +142,36 @@ pub async fn spawn_listener_loop(
daemon_events_close_region_id,
})
}
#[cfg(unix)]
LocalCommunicationConfig::UnixDomain => {
use std::path::Path;
let tmpfile_dir = Path::new("/tmp");
let tmpfile_dir = tmpfile_dir.join(dataflow_id.to_string());
if !tmpfile_dir.exists() {
std::fs::create_dir_all(&tmpfile_dir).context("could not create tmp dir")?;
}
let socket_file = tmpfile_dir.join(format!("{}.sock", node_id));
let socket = match UnixListener::bind(&socket_file) {
Ok(socket) => socket,
Err(err) => {
return Err(eyre::Report::new(err)
.wrap_err("failed to create local Unix domain socket"))
}
};

let event_loop_node_id = format!("{dataflow_id}/{node_id}");
let daemon_tx = daemon_tx.clone();
tokio::spawn(async move {
unix_domain::listener_loop(socket, daemon_tx, queue_sizes, clock).await;
tracing::debug!("event listener loop finished for `{event_loop_node_id}`");
});

Ok(DaemonCommunication::UnixDomain { socket_file })
}
#[cfg(not(unix))]
LocalCommunicationConfig::UnixDomain => {
eyre::bail!("Communication via UNIX domain sockets is only supported on UNIX systems")
}
}
}

Expand Down
Loading
Loading