Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mpz-common): async sync primitives #152

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ tokio = "1.23"
tokio-util = "0.7"
scoped-futures = "0.1.3"
pollster = "0.3"
pin-project-lite = "0.2"

# serialization
ark-serialize = "0.4"
Expand Down
4 changes: 3 additions & 1 deletion crates/mpz-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"

[features]
default = ["sync"]
sync = []
sync = ["tokio/sync"]
test-utils = ["uid-mux/test-utils"]
ideal = []
rayon = ["dep:rayon"]
Expand All @@ -16,6 +16,7 @@ mpz-core.workspace = true

futures.workspace = true
async-trait.workspace = true
pin-project-lite.workspace = true
scoped-futures.workspace = true
thiserror.workspace = true
serio.workspace = true
Expand All @@ -24,6 +25,7 @@ serde = { workspace = true, features = ["derive"] }
pollster.workspace = true
rayon = { workspace = true, optional = true }
cfg-if.workspace = true
tokio = { workspace = true, optional = true, default-features = false }

[dev-dependencies]
tokio = { workspace = true, features = [
Expand Down
109 changes: 109 additions & 0 deletions crates/mpz-common/src/sync/async_mutex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//! Synchronized async mutex.

use pollster::FutureExt;
use tokio::sync::{Mutex as TokioMutex, MutexGuard};

use crate::{
context::Context,
sync::{AsyncSyncer, MutexError},
};

/// A mutex which synchronizes exclusive access to a resource across logical threads.
///
/// There are two configurations for a mutex, either as a leader or as a follower.
///
/// **Leader**
///
/// A leader mutex is the authority on the order in which threads can acquire a lock. When a
/// thread acquires a lock, it broadcasts a message to all follower mutexes, which then enforce
/// that this order is preserved.
///
/// **Follower**
///
/// A follower mutex waits for messages from the leader mutex to inform it of the order in which
/// threads can acquire a lock.
#[derive(Debug)]
pub struct AsyncMutex<T> {
inner: TokioMutex<T>,
syncer: AsyncSyncer,
}

impl<T> AsyncMutex<T> {
/// Creates a new leader mutex.
///
/// # Arguments
///
/// * `value` - The value protected by the mutex.
pub fn new_leader(value: T) -> Self {
Self {
inner: TokioMutex::new(value),
syncer: AsyncSyncer::new_leader(),
}
}

/// Creates a new follower mutex.
///
/// # Arguments
///
/// * `value` - The value protected by the mutex.
pub fn new_follower(value: T) -> Self {
Self {
inner: TokioMutex::new(value),
syncer: AsyncSyncer::new_follower(),
}
}

/// Returns a lock on the mutex.
pub async fn lock<Ctx: Context>(&self, ctx: &mut Ctx) -> Result<MutexGuard<'_, T>, MutexError> {
self.syncer
.sync(ctx.io_mut(), self.inner.lock())
.await
.map_err(MutexError::from)
}

/// Returns an unsynchronized blocking lock on the mutex.
///
/// # Warning
///
/// Do not use this method unless you are certain that the way you're mutating the state does
/// not require synchronization. Also, don't hold this lock across await points it will cause
/// deadlocks.
pub fn blocking_lock_unsync(&self) -> MutexGuard<'_, T> {
self.inner.lock().block_on()
}

/// Returns the inner value, consuming the mutex.
pub fn into_inner(self) -> T {
self.inner.into_inner()
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;

#[test]
fn test_async_mutex() {
let leader_mutex = Arc::new(AsyncMutex::new_leader(()));
let follower_mutex = Arc::new(AsyncMutex::new_follower(()));

let (mut ctx_a, mut ctx_b) = crate::executor::test_st_executor(8);

futures::executor::block_on(async {
futures::join!(
async {
drop(leader_mutex.lock(&mut ctx_a).await.unwrap());
drop(leader_mutex.lock(&mut ctx_a).await.unwrap());
drop(leader_mutex.lock(&mut ctx_a).await.unwrap());
},
async {
drop(follower_mutex.lock(&mut ctx_b).await.unwrap());
drop(follower_mutex.lock(&mut ctx_b).await.unwrap());
drop(follower_mutex.lock(&mut ctx_b).await.unwrap());
},
);
});
}
}
239 changes: 239 additions & 0 deletions crates/mpz-common/src/sync/async_syncer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
use std::{
collections::HashMap,
pin::Pin,
sync::{Arc, Mutex as StdMutex},
task::{ready, Context as StdContext, Poll, Waker},
};

use futures::{future::poll_fn, Future};
use serio::{stream::IoStreamExt, IoDuplex};
use tokio::sync::Mutex;

use crate::sync::{SyncError, Ticket};

/// An async version of [`Syncer`](crate::sync::Syncer).
#[derive(Debug, Clone)]
pub struct AsyncSyncer(SyncerInner);

impl AsyncSyncer {
/// Creates a new leader.
pub fn new_leader() -> Self {
Self(SyncerInner::Leader(Leader::default()))
}

/// Creates a new follower.
pub fn new_follower() -> Self {
Self(SyncerInner::Follower(Follower::default()))
}

/// Synchronizes the order of execution across logical threads.
///
/// # Arguments
///
/// * `io` - The I/O channel of the logical thread.
/// * `fut` - The future to await.
pub async fn sync<Io: IoDuplex + Unpin, Fut>(
&self,
io: &mut Io,
fut: Fut,
) -> Result<Fut::Output, SyncError>
where
Fut: Future,
{
match &self.0 {
SyncerInner::Leader(leader) => leader.sync(io, fut).await,
SyncerInner::Follower(follower) => follower.sync(io, fut).await,
}
}
}

#[derive(Debug, Clone)]
enum SyncerInner {
Leader(Leader),
Follower(Follower),
}

#[derive(Debug, Default, Clone)]
struct Leader {
tick: Arc<Mutex<Ticket>>,
}

impl Leader {
async fn sync<Io: IoDuplex + Unpin, Fut>(
&self,
io: &mut Io,
fut: Fut,
) -> Result<Fut::Output, SyncError>
where
Fut: Future,
{
let mut io = Pin::new(io);
poll_fn(|cx| io.as_mut().poll_ready(cx)).await?;
let (output, tick) = {
let mut tick_lock = self.tick.lock().await;
let output = fut.await;
let tick = tick_lock.increment_in_place();
(output, tick)
};
io.start_send(tick)?;
Ok(output)
}
}

#[derive(Debug, Default, Clone)]
struct Follower {
queue: Arc<StdMutex<Queue>>,
}

impl Follower {
async fn sync<Io: IoDuplex + Unpin, Fut>(
&self,
io: &mut Io,
fut: Fut,
) -> Result<Fut::Output, SyncError>
where
Fut: Future,
{
let tick = io.expect_next().await?;
Ok(Wait::new(&self.queue, tick, fut).await)
}
}

#[derive(Debug, Default)]
struct Queue {
// The current ticket.
tick: Ticket,
// Tasks waiting for their ticket to be accepted.
waiting: HashMap<Ticket, Waker>,
}

impl Queue {
// Wakes up the next waiting task.
fn wake_next(&mut self) {
if let Some(waker) = self.waiting.remove(&self.tick) {
waker.wake();
}
}
}

pin_project_lite::pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct Wait<'a, Fut> {
queue: &'a StdMutex<Queue>,
tick: Ticket,
#[pin]
fut: Fut,
}
}

impl<'a, Fut> Wait<'a, Fut> {
fn new(queue: &'a StdMutex<Queue>, tick: Ticket, fut: Fut) -> Self {
Self { queue, tick, fut }
}
}

impl<'a, Fut> Future for Wait<'a, Fut>
where
Fut: Future,
{
type Output = Fut::Output;

fn poll(self: Pin<&mut Self>, cx: &mut StdContext<'_>) -> Poll<Self::Output> {
let mut queue_lock = self.queue.lock().unwrap();
if queue_lock.tick == self.tick {
let this = self.project();
let output = ready!(this.fut.poll(cx));
queue_lock.tick.increment_in_place();
queue_lock.wake_next();
Poll::Ready(output)
} else {
queue_lock.waiting.insert(self.tick, cx.waker().clone());
Poll::Pending
}
}
}

#[cfg(test)]
mod tests {
use futures::{executor::block_on, poll};
use serio::channel::duplex;

use super::*;

#[test]
fn test_syncer() {
let (mut io_0a, mut io_0b) = duplex(1);
let (mut io_1a, mut io_1b) = duplex(1);
let (mut io_2a, mut io_2b) = duplex(1);

let syncer_a = AsyncSyncer::new_leader();
let syncer_b = AsyncSyncer::new_follower();

let log_a = Arc::new(Mutex::new(Vec::new()));
let log_b = Arc::new(Mutex::new(Vec::new()));

block_on(async {
syncer_a
.sync(&mut io_0a, async {
let mut log = log_a.lock().await;
log.push(0);
})
.await
.unwrap();
syncer_a
.sync(&mut io_1a, async {
let mut log = log_a.lock().await;
log.push(1);
})
.await
.unwrap();
syncer_a
.sync(&mut io_2a, async {
let mut log = log_a.lock().await;
log.push(2);
})
.await
.unwrap();
});

let mut fut_a = Box::pin(syncer_b.sync(&mut io_2b, async {
let mut log = log_b.lock().await;
log.push(2);
}));

let mut fut_b = Box::pin(syncer_b.sync(&mut io_0b, async {
let mut log = log_b.lock().await;
log.push(0);
}));

let mut fut_c = Box::pin(syncer_b.sync(&mut io_1b, async {
let mut log = log_b.lock().await;
log.push(1);
}));

block_on(async move {
// Poll out of order.
assert!(poll!(&mut fut_a).is_pending());
assert!(poll!(&mut fut_c).is_pending());
assert!(poll!(&mut fut_b).is_ready());
assert!(poll!(&mut fut_c).is_ready());
assert!(poll!(&mut fut_a).is_ready());
});

let log_a = Arc::into_inner(log_a).unwrap().into_inner();
let log_b = Arc::into_inner(log_b).unwrap().into_inner();

assert_eq!(log_a, log_b);
}

#[test]
fn test_syncer_is_send() {
let (mut io, _) = duplex(1);
let syncer = AsyncSyncer::new_leader();

fn is_send<T: Send>(_: T) {}

is_send(syncer.sync(&mut io, async {}));
}
}
Loading