Skip to content

Commit

Permalink
feat(mpz-common): mt executor
Browse files Browse the repository at this point in the history
  • Loading branch information
sinui0 committed May 15, 2024
1 parent 94a8c3a commit ed58db0
Show file tree
Hide file tree
Showing 11 changed files with 341 additions and 56 deletions.
9 changes: 6 additions & 3 deletions crates/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ mpz-share-conversion-core = { path = "mpz-share-conversion-core" }
clmul = { path = "clmul" }
matrix-transpose = { path = "matrix-transpose" }

tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8f2fc9e" }
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8f2fc9e" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "1d27bd7" }
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "1d27bd7" }

# rand
rand_chacha = "0.3"
Expand Down Expand Up @@ -83,7 +83,10 @@ prost-build = "0.9"
bytes = "1"
yamux = "0.10"
bytemuck = { version = "1.13", features = ["derive"] }
serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8f2fc9e" }
serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "1d27bd7" }

# io
uid-mux = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "1d27bd7" }

# testing
prost = "0.9"
Expand Down
13 changes: 12 additions & 1 deletion crates/mpz-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
[features]
default = ["sync"]
sync = []
test-utils = []
test-utils = ["uid-mux/test-utils"]
ideal = []

[dependencies]
Expand All @@ -17,4 +17,15 @@ async-trait.workspace = true
scoped-futures.workspace = true
thiserror.workspace = true
serio.workspace = true
uid-mux.workspace = true
serde = { workspace = true, features = ["derive"] }

[dev-dependencies]
tokio = { workspace = true, features = [
"io-util",
"macros",
"rt-multi-thread",
] }
tokio-util = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["test-utils"] }
tracing-subscriber = { workspace = true, features = ["fmt"] }
80 changes: 56 additions & 24 deletions crates/mpz-common/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,50 @@
use core::fmt;

use async_trait::async_trait;

use scoped_futures::ScopedBoxFuture;
use serio::{IoSink, IoStream};

use crate::ThreadId;

/// An error for types that implement [`Context`].
#[derive(Debug, thiserror::Error)]
#[error("context error: {kind}")]
pub struct ContextError {
kind: ErrorKind,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
}

impl ContextError {
pub(crate) fn new_with_source<E: Into<Box<dyn std::error::Error + Send + Sync>>>(
kind: ErrorKind,
source: E,
) -> Self {
Self {
kind,
source: Some(source.into()),
}
}
}

#[derive(Debug)]
pub(crate) enum ErrorKind {
Mux,
}

impl fmt::Display for ErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ErrorKind::Mux => write!(f, "multiplexer error"),
}
}
}

/// A thread context.
#[async_trait]
pub trait Context: Send {
/// The type of I/O channel used by the thread.
pub trait Context: Send + Sync {
/// I/O channel used by the thread.
type Io: IoSink + IoStream + Send + Unpin + 'static;

/// Returns the thread ID.
Expand All @@ -21,7 +57,7 @@ pub trait Context: Send {
///
/// Implementations may not be able to fork, in which case the closures are executed
/// sequentially.
async fn join<'a, A, B, RA, RB>(&'a mut self, a: A, b: B) -> (RA, RB)
async fn join<'a, A, B, RA, RB>(&'a mut self, a: A, b: B) -> Result<(RA, RB), ContextError>
where
A: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, RA> + Send + 'a,
B: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, RB> + Send + 'a,
Expand All @@ -36,7 +72,11 @@ pub trait Context: Send {
///
/// Implementations may not be able to fork, in which case the closures are executed
/// sequentially.
async fn try_join<'a, A, B, RA, RB, E>(&'a mut self, a: A, b: B) -> Result<(RA, RB), E>
async fn try_join<'a, A, B, RA, RB, E>(
&'a mut self,
a: A,
b: B,
) -> Result<Result<(RA, RB), E>, ContextError>
where
A: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, Result<RA, E>> + Send + 'a,
B: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, Result<RB, E>> + Send + 'a,
Expand All @@ -50,17 +90,12 @@ pub trait Context: Send {
/// This macro calls `Context::join` under the hood.
#[macro_export]
macro_rules! join {
($ctx:ident, $task_0:expr, $task_1:expr) => {
async {
use $crate::{scoped_futures::ScopedFutureExt, Context};
$ctx.join(
|$ctx| async { $task_0.await }.scope_boxed(),
|$ctx| async { $task_1.await }.scope_boxed(),
)
($ctx:ident, $task_0:expr, $task_1:expr) => {{
#[allow(unused_imports)]
use $crate::{scoped_futures::ScopedFutureExt, Context};
$ctx.join(|$ctx| $task_0.scope_boxed(), |$ctx| $task_1.scope_boxed())
.await
}
.await
};
}};
}

/// A convenience macro for forking a context and joining two tasks concurrently, returning an error
Expand All @@ -69,17 +104,12 @@ macro_rules! join {
/// This macro calls `Context::try_join` under the hood.
#[macro_export]
macro_rules! try_join {
($ctx:ident, $task_0:expr, $task_1:expr) => {
async {
use $crate::{scoped_futures::ScopedFutureExt, Context};
$ctx.try_join(
|$ctx| async { $task_0.await }.scope_boxed(),
|$ctx| async { $task_1.await }.scope_boxed(),
)
($ctx:ident, $task_0:expr, $task_1:expr) => {{
#[allow(unused_imports)]
use $crate::{scoped_futures::ScopedFutureExt, Context};
$ctx.try_join(|$ctx| $task_0.scope_boxed(), |$ctx| $task_1.scope_boxed())
.await
}
.await
};
}};
}

#[cfg(test)]
Expand All @@ -94,6 +124,7 @@ mod tests {
join!(ctx, async { println!("{:?}", ctx.id()) }, async {
println!("{:?}", ctx.id())
})
.unwrap()
});
}

Expand All @@ -107,6 +138,7 @@ mod tests {
async { Ok::<_, ()>(println!("{:?}", ctx.id())) },
async { Ok::<_, ()>(println!("{:?}", ctx.id())) }
)
.unwrap()
.unwrap();
});
}
Expand Down
57 changes: 56 additions & 1 deletion crates/mpz-common/src/executor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
//! Executors.
mod mt;
mod st;

pub use mt::MTExecutor;
pub use st::STExecutor;

#[cfg(any(test, feature = "test-utils"))]
mod test_utils {
use serio::channel::{duplex, MemoryDuplex};
use std::future::IntoFuture;

use futures::{Future, FutureExt, TryFutureExt};
use serio::{
channel::{duplex, MemoryDuplex},
codec::Bincode,
};
use uid_mux::{
test_utils::test_yamux_pair_framed,
yamux::{ConnectionError, YamuxCtrl},
FramedMux, FramedUidMux,
};

use crate::ThreadId;

use super::*;

Expand All @@ -18,6 +33,46 @@ mod test_utils {

(STExecutor::new(io_0), STExecutor::new(io_1))
}

/// Test multi-threaded executor.
pub type TestMTExecutor = MTExecutor<
FramedMux<YamuxCtrl, Bincode>,
<FramedMux<YamuxCtrl, Bincode> as FramedUidMux<ThreadId>>::Framed,
>;

/// Creates a pair of multi-threaded executors with yamux I/O channels.
pub fn test_mt_executor(
io_buffer: usize,
) -> (
(TestMTExecutor, TestMTExecutor),
impl Future<Output = Result<(), ConnectionError>>,
) {
let ((mux_0, fut_0), (mux_1, fut_1)) = test_yamux_pair_framed(io_buffer, Bincode);

let mut fut_io =
futures::future::try_join(fut_0.into_future(), fut_1.into_future()).map_ok(|_| ());

let (ctx_0, ctx_1) = futures::executor::block_on(async {
let fut_exec =
futures::future::try_join(MTExecutor::new(mux_0), MTExecutor::new(mux_1));
futures::select! {
ctx = fut_exec.fuse() => ctx.unwrap(),
_ = (&mut fut_io).fuse() => panic!(),
}
});

((ctx_0, ctx_1), fut_io)
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_create_mt_executor() {
_ = test_mt_executor(1024);
}
}
}

#[cfg(any(test, feature = "test-utils"))]
Expand Down
Loading

0 comments on commit ed58db0

Please sign in to comment.