From a12587f435ba67f0544424c59867f2fa4659d10b Mon Sep 17 00:00:00 2001 From: taskooh Date: Thu, 3 Oct 2024 21:32:13 +0900 Subject: [PATCH] WIP --- mpc-net/src/lib.rs | 25 +++--- mpc-net/src/multi.rs | 199 ++++++++++++++++++------------------------- 2 files changed, 93 insertions(+), 131 deletions(-) diff --git a/mpc-net/src/lib.rs b/mpc-net/src/lib.rs index 74cd1e8..d46899e 100644 --- a/mpc-net/src/lib.rs +++ b/mpc-net/src/lib.rs @@ -23,16 +23,15 @@ pub enum MultiplexedStreamID { Two = 2, } -pub trait MpcNet { +pub trait MpcNet: Send + Sync { /// Am I the first party? - #[inline] - fn is_leader() -> bool { - Self::party_id() == 0 + fn is_leader(&self) -> bool { + self.party_id() == 0 } /// How many parties are there? - fn n_parties() -> usize; + fn n_parties(&self) -> usize; /// What is my party number (0 to n-1)? - fn party_id() -> usize; + fn party_id(&self) -> usize; /// Initialize the network layer from a file. /// The file should contain one HOST:PORT setting per line, corresponding to the addresses of /// the parties in increasing order. @@ -40,7 +39,7 @@ pub trait MpcNet { /// Parties are zero-indexed. fn init_from_file(path: &str, party_id: usize); /// Is the network layer initalized? - fn is_init() -> bool; + fn is_init(&self) -> bool; /// Uninitialize the network layer, closing all connections. fn deinit(); /// Set statistics to zero. @@ -48,12 +47,12 @@ pub trait MpcNet { /// Get statistics. fn stats() -> Stats; /// All parties send bytes to each other. - fn broadcast_bytes(bytes: &[u8]) -> Vec>; + fn broadcast_bytes(&self, bytes: &[u8]) -> Vec>; /// All parties send bytes to the king. - fn worker_send_or_leader_receive(bytes: &[u8]) -> Option>>; + fn worker_send_or_leader_receive(&self, bytes: &[u8]) -> Option>>; /// All parties recv bytes from the king. /// Provide bytes iff you're the king! - fn worker_receive_or_leader_send(bytes: Option>>) -> Vec; + fn worker_receive_or_leader_send(&self, bytes: Option>>) -> Vec; /// Everyone sends bytes to the king, who recieves those bytes, runs a computation on them, and /// redistributes the resulting bytes. @@ -61,9 +60,9 @@ pub trait MpcNet { /// The king's computation is given by a function, `f` /// proceeds. #[inline] - fn leader_compute(bytes: &[u8], f: impl Fn(Vec>) -> Vec>) -> Vec { - let king_response = Self::worker_send_or_leader_receive(bytes).map(f); - Self::worker_receive_or_leader_send(king_response) + fn leader_compute(&self, bytes: &[u8], f: impl Fn(Vec>) -> Vec>) -> Vec { + let king_response = self.worker_send_or_leader_receive(bytes).map(f); + self.worker_receive_or_leader_send(king_response) } fn uninit(); diff --git a/mpc-net/src/multi.rs b/mpc-net/src/multi.rs index c7fdd89..801ffad 100644 --- a/mpc-net/src/multi.rs +++ b/mpc-net/src/multi.rs @@ -1,41 +1,42 @@ -use std::{ - fs::File, - io::{BufRead, BufReader, Read, Write}, - net::{SocketAddr, TcpListener, TcpStream}, - sync::Mutex, -}; - use ark_std::{end_timer, perf_trace::TimerInfo, start_timer}; -use lazy_static::lazy_static; -use log::debug; -use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +use async_smux::MuxStream; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::net::SocketAddr; +use std::sync::atomic::AtomicUsize; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::Mutex as TokioMutex; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; -use crate::MpcNet; +use crate::{MPCNetError, MpcNet}; -lazy_static! { - static ref CONNECTIONS: Mutex = Mutex::new(Connections::default()); -} +pub type WrappedMuxStream = Framed, LengthDelimitedCodec>; -/// Macro for locking the FieldChannel singleton in the current scope. -macro_rules! get_ch { - () => { - CONNECTIONS.lock().expect("Poisoned FieldChannel") - }; +struct Peer { + id: usize, + listen_addr: SocketAddr, + streams: Option>>>, } -#[derive(Debug)] -struct Peer { - id: usize, - addr: SocketAddr, - stream: Option, +impl Debug for Peer { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut f = f.debug_struct("Peer"); + f.field("id", &self.id); + f.field("listen_addr", &self.listen_addr); + f.field("streams", &self.streams.is_some()); + f.finish() + } } -impl Default for Peer { - fn default() -> Self { +impl Clone for Peer { + fn clone(&self) -> Self { Self { - id: 0, - addr: "127.0.0.1:8000".parse().unwrap(), - stream: None, + id: self.id, + listen_addr: self.listen_addr, + streams: None, } } } @@ -50,15 +51,26 @@ pub struct Stats { } #[derive(Default, Debug)] -struct Connections { - id: usize, - peers: Vec, - stats: Stats, +struct MPCNetConnection { + pub id: usize, + pub listener: Option, + pub peers: HashMap>, + pub n_parties: usize, + pub upload: AtomicUsize, + pub download: AtomicUsize, } -impl Connections { +impl MPCNetConnection { /// Given a path and the `id` of oneself, initialize the structure - fn init_from_path(&mut self, path: &str, id: usize) { + fn init_from_path(path: &str, id: usize) -> Self { + let mut this = MPCNetConnection { + id: 0, + listener: None, + peers: Default::default(), + n_parties: 0, + upload: AtomicUsize::new(0), + download: AtomicUsize::new(0), + }; let f = BufReader::new(File::open(path).expect("host configuration path")); let mut peer_id = 0; for line in f.lines() { @@ -70,85 +82,36 @@ impl Connections { .unwrap_or_else(|e| panic!("bad socket address: {}:\n{}", trimmed, e)); let peer = Peer { id: peer_id, - addr, - stream: None, + listen_addr: addr, + streams: None, }; - self.peers.push(peer); + this.peers.insert(peer_id, peer); peer_id += 1; } } - assert!(id < self.peers.len()); - self.id = id; + assert!(id < this.peers.len()); + this.id = id; + this.n_parties = this.peers.len(); + this } - fn connect_to_all(&mut self) { - let timer = start_timer!(|| "Connecting"); - let n = self.peers.len(); - for from_id in 0..n { - for to_id in (from_id + 1)..n { - debug!("{} to {}", from_id, to_id); - if self.id == from_id { - let to_addr = self.peers[to_id].addr; - debug!("Contacting {}", to_id); - let stream = loop { - let mut ms_waited = 0; - match TcpStream::connect(to_addr) { - Ok(s) => break s, - Err(e) => match e.kind() { - std::io::ErrorKind::ConnectionRefused - | std::io::ErrorKind::ConnectionReset => { - ms_waited += 10; - std::thread::sleep(std::time::Duration::from_millis(10)); - if ms_waited % 3_000 == 0 { - debug!("Still waiting"); - } else if ms_waited > 30_000 { - panic!("Could not find peer in 30s"); - } - } - _ => { - panic!("Error during FieldChannel::new: {}", e); - } - }, - } - }; - stream.set_nodelay(true).unwrap(); - self.peers[to_id].stream = Some(stream); - } else if self.id == to_id { - debug!("Awaiting {}", from_id); - let listener = TcpListener::bind(self.peers[self.id].addr).unwrap(); - let (stream, _addr) = listener.accept().unwrap(); - stream.set_nodelay(true).unwrap(); - self.peers[from_id].stream = Some(stream); - } - } - // Sender for next round waits for note from this sender to prevent race on receipt. - if from_id + 1 < n { - if self.id == from_id { - self.peers[self.id + 1] - .stream - .as_mut() - .unwrap() - .write_all(&[0u8]) - .unwrap(); - } else if self.id == from_id + 1 { - self.peers[self.id - 1] - .stream - .as_mut() - .unwrap() - .read_exact(&mut [0u8]) - .unwrap(); - } - } - } - // Do a round with the king, to be sure everyone is ready - let from_all = self.send_to_king(&[self.id as u8]); - self.recv_from_king(from_all); - for id in 0..n { - if id != self.id { - assert!(self.peers[id].stream.is_some()); - } - } - end_timer!(timer); + + pub async fn listen(&mut self) -> Result<(), MPCNetError> { + let listen_addr = self.peers.get(&self.id).unwrap().listen_addr; + let listener = TcpListener::bind(listen_addr).await.unwrap(); + self.listener = Some(listener); + Ok(()) } + + async fn connect_to_all(&mut self) { + let n_minus_1 = self.n_parties - 1; + let self_id = self.id; + let peer_addrs = self + .peers + .iter() + .map(|(_, p)| p.listen_addr) + .collect::>(); + } + fn am_king(&self) -> bool { self.id == 0 } @@ -219,7 +182,7 @@ impl Connections { } else { self.stats.bytes_sent += m; self.peers[0] - .stream + .streams .as_mut() .unwrap() .write_all(bytes_out) @@ -251,7 +214,7 @@ impl Connections { end_timer!(timer); bytes_out[own_id].clone() } else { - let stream = self.peers[0].stream.as_mut().unwrap(); + let stream = self.peers[0].streams.as_mut().unwrap(); let mut bytes_size = [0u8; 8]; stream.read_exact(&mut bytes_size).unwrap(); let m = u64::from_le_bytes(bytes_size) as usize; @@ -263,7 +226,7 @@ impl Connections { } fn uninit(&mut self) { for p in &mut self.peers { - p.stream = None; + p.streams = None; } } } @@ -272,24 +235,24 @@ pub struct MpcMultiNet; impl MpcNet for MpcMultiNet { #[inline] - fn party_id() -> usize { + fn party_id(&self) -> usize { get_ch!().id } #[inline] - fn n_parties() -> usize { + fn n_parties(&self) -> usize { get_ch!().peers.len() } #[inline] fn init_from_file(path: &str, party_id: usize) { let mut ch = get_ch!(); - ch.init_from_path(path, party_id); + MPCNetConnection::init_from_path(path, party_id); ch.connect_to_all(); } #[inline] - fn is_init() -> bool { + fn is_init(&self) -> bool { get_ch!() .peers .first() @@ -313,17 +276,17 @@ impl MpcNet for MpcMultiNet { } #[inline] - fn broadcast_bytes(bytes: &[u8]) -> Vec> { + fn broadcast_bytes(&self, bytes: &[u8]) -> Vec> { get_ch!().broadcast(bytes) } #[inline] - fn worker_send_or_leader_receive(bytes: &[u8]) -> Option>> { + fn worker_send_or_leader_receive(&self, bytes: &[u8]) -> Option>> { get_ch!().send_to_king(bytes) } #[inline] - fn worker_receive_or_leader_send(bytes: Option>>) -> Vec { + fn worker_receive_or_leader_send(&self, bytes: Option>>) -> Vec { get_ch!().recv_from_king(bytes) }