Skip to content

Commit

Permalink
Finish adapting mpz-ole to new IO model.
Browse files Browse the repository at this point in the history
  • Loading branch information
th4s committed Mar 7, 2024
1 parent ca1e22d commit 2de0b5a
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 126 deletions.
38 changes: 4 additions & 34 deletions ole/mpz-ole/src/ideal/ole.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use mpz_core::ProtocolMessage;
use mpz_share_conversion_core::Field;
use rand::thread_rng;
use std::marker::PhantomData;
use utils_aio::{sink::IoSink, stream::IoStream};

/// Returns an ideal OLE pair.
pub fn ideal_ole_pair<F: Field>() -> (IdealOLEProvider<F>, IdealOLEEvaluator<F>) {
Expand Down Expand Up @@ -48,15 +47,7 @@ impl<F: Field> ProtocolMessage for IdealOLEEvaluator<F> {

#[async_trait]
impl<F: Field> OLEeProvide<F> for IdealOLEProvider<F> {
async fn provide<
Si: IoSink<Self::Msg> + Send + Unpin,
St: IoStream<Self::Msg> + Send + Unpin,
>(
&mut self,
_sink: &mut Si,
_stream: &mut St,
factors: Vec<F>,
) -> Result<Vec<F>, OLEError> {
async fn provide(&mut self, factors: Vec<F>) -> Result<Vec<F>, OLEError> {
let mut rng = thread_rng();
let offsets: Vec<F> = (0..factors.len()).map(|_| F::rand(&mut rng)).collect();

Expand All @@ -70,15 +61,7 @@ impl<F: Field> OLEeProvide<F> for IdealOLEProvider<F> {

#[async_trait]
impl<F: Field> OLEeEvaluate<F> for IdealOLEEvaluator<F> {
async fn evaluate<
Si: IoSink<Self::Msg> + Send + Unpin,
St: IoStream<Self::Msg> + Send + Unpin,
>(
&mut self,
_sink: &mut Si,
_stream: &mut St,
input: Vec<F>,
) -> Result<Vec<F>, OLEError> {
async fn evaluate(&mut self, input: Vec<F>) -> Result<Vec<F>, OLEError> {
let (factors, offsets) = self
.channel
.next()
Expand All @@ -99,11 +82,9 @@ impl<F: Field> OLEeEvaluate<F> for IdealOLEEvaluator<F> {
#[cfg(test)]
mod tests {
use crate::{ideal::ole::ideal_ole_pair, OLEeEvaluate, OLEeProvide};
use futures::StreamExt;
use mpz_core::{prg::Prg, Block};
use mpz_share_conversion_core::fields::{p256::P256, UniformRand};
use rand::SeedableRng;
use utils_aio::duplex::MemoryDuplex;

#[tokio::test]
async fn test_ideal_ole() {
Expand All @@ -113,21 +94,10 @@ mod tests {
let inputs: Vec<P256> = (0..count).map(|_| P256::rand(&mut rng)).collect();
let factors: Vec<P256> = (0..count).map(|_| P256::rand(&mut rng)).collect();

let (send_channel, recv_channel) = MemoryDuplex::<()>::new();

let (mut provider_sink, mut provider_stream) = send_channel.split();
let (mut evaluator_sink, mut evaluator_stream) = recv_channel.split();

let (mut provider, mut evaluator) = ideal_ole_pair::<P256>();

let offsets = provider
.provide(&mut provider_sink, &mut provider_stream, factors.clone())
.await
.unwrap();
let outputs = evaluator
.evaluate(&mut evaluator_sink, &mut evaluator_stream, inputs.clone())
.await
.unwrap();
let offsets = provider.provide(factors.clone()).await.unwrap();
let outputs = evaluator.evaluate(inputs.clone()).await.unwrap();

inputs
.iter()
Expand Down
38 changes: 4 additions & 34 deletions ole/mpz-ole/src/ideal/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use mpz_core::ProtocolMessage;
use mpz_share_conversion_core::Field;
use rand::thread_rng;
use std::marker::PhantomData;
use utils_aio::{sink::IoSink, stream::IoStream};

/// Returns an ideal ROLE pair.
pub fn ideal_role_pair<F: Field>() -> (IdealROLEProvider<F>, IdealROLEEvaluator<F>) {
Expand Down Expand Up @@ -48,15 +47,7 @@ impl<F: Field> ProtocolMessage for IdealROLEEvaluator<F> {

#[async_trait]
impl<F: Field> RandomOLEeProvide<F> for IdealROLEProvider<F> {
async fn provide_random<
Si: IoSink<Self::Msg> + Send + Unpin,
St: IoStream<Self::Msg> + Send + Unpin,
>(
&mut self,
_sink: &mut Si,
_stream: &mut St,
count: usize,
) -> Result<(Vec<F>, Vec<F>), OLEError> {
async fn provide_random(&mut self, count: usize) -> Result<(Vec<F>, Vec<F>), OLEError> {
let mut rng = thread_rng();

let ak: Vec<F> = (0..count).map(|_| F::rand(&mut rng)).collect();
Expand All @@ -72,15 +63,7 @@ impl<F: Field> RandomOLEeProvide<F> for IdealROLEProvider<F> {

#[async_trait]
impl<F: Field> RandomOLEeEvaluate<F> for IdealROLEEvaluator<F> {
async fn evaluate_random<
Si: IoSink<Self::Msg> + Send + Unpin,
St: IoStream<Self::Msg> + Send + Unpin,
>(
&mut self,
_sink: &mut Si,
_stream: &mut St,
count: usize,
) -> Result<(Vec<F>, Vec<F>), OLEError> {
async fn evaluate_random(&mut self, count: usize) -> Result<(Vec<F>, Vec<F>), OLEError> {
let bk: Vec<F> = {
let mut rng = thread_rng();
(0..count).map(|_| F::rand(&mut rng)).collect()
Expand All @@ -106,30 +89,17 @@ impl<F: Field> RandomOLEeEvaluate<F> for IdealROLEEvaluator<F> {
#[cfg(test)]
mod tests {
use crate::{ideal::role::ideal_role_pair, RandomOLEeEvaluate, RandomOLEeProvide};
use futures::StreamExt;
use mpz_share_conversion_core::fields::p256::P256;
use utils_aio::duplex::MemoryDuplex;

#[tokio::test]
async fn test_ideal_role() {
let count = 12;

let (send_channel, recv_channel) = MemoryDuplex::<()>::new();

let (mut provider_sink, mut provider_stream) = send_channel.split();
let (mut evaluator_sink, mut evaluator_stream) = recv_channel.split();

let (mut provider, mut evaluator) = ideal_role_pair::<P256>();

let (ak, xk) = provider
.provide_random(&mut provider_sink, &mut provider_stream, count)
.await
.unwrap();
let (ak, xk) = provider.provide_random(count).await.unwrap();

let (bk, yk) = evaluator
.evaluate_random(&mut evaluator_sink, &mut evaluator_stream, count)
.await
.unwrap();
let (bk, yk) = evaluator.evaluate_random(count).await.unwrap();

ak.iter()
.zip(bk)
Expand Down
24 changes: 14 additions & 10 deletions ole/mpz-ole/src/ole/role/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,46 @@ use async_trait::async_trait;
use futures::SinkExt;
use mpz_ole_core::ole::role::OLEeEvaluator as OLEeCoreEvaluator;
use mpz_share_conversion_core::Field;
use utils_aio::{
sink::IoSink,
stream::{ExpectStreamExt, IoStream},
};
use utils_aio::{duplex::Duplex, stream::ExpectStreamExt};

/// An evaluator for OLE with errors.
pub struct OLEeEvaluator<const N: usize, T: RandomOLEeEvaluate<F>, F: Field> {
pub struct OLEeEvaluator<const N: usize, T, F, IO> {
channel: IO,
role_evaluator: T,
ole_core: OLEeCoreEvaluator<F>,
}

impl<const N: usize, T: RandomOLEeEvaluate<F>, F: Field> OLEeEvaluator<N, T, F> {
impl<const N: usize, T, F: Field, IO> OLEeEvaluator<N, T, F, IO> {
/// Creates a new [`OLEeEvaluator`].
pub fn new(role_evaluator: T) -> Self {
pub fn new(channel: IO, role_evaluator: T) -> Self {
// Check that the right N is used depending on the needed bit size of the field.
let _: () = Check::<N, F>::IS_BITSIZE_CORRECT;

Self {
channel,
role_evaluator,
ole_core: OLEeCoreEvaluator::default(),
}
}
}

#[async_trait]
impl<const N: usize, T, F: Field> OLEeEvaluate<F> for OLEeEvaluator<N, T, F>
impl<const N: usize, T, F: Field, IO> OLEeEvaluate<F> for OLEeEvaluator<N, T, F, IO>
where
IO: Duplex<OLEeMessage<F>>,
T: RandomOLEeEvaluate<F> + Send,
{
async fn evaluate(&mut self, inputs: Vec<F>) -> Result<Vec<F>, OLEError> {
let (bk_dash, yk_dash) = self.role_evaluator.evaluate_random(inputs.len()).await?;

let vk: Vec<F> = self.ole_core.create_mask(&bk_dash, &inputs)?;

let uk: Vec<F> = stream.expect_next().await?.try_into_provider_derand()?;
sink.send(OLEeMessage::EvaluatorDerand(vk)).await?;
let uk: Vec<F> = self
.channel
.expect_next()
.await?
.try_into_provider_derand()?;
self.channel.send(OLEeMessage::EvaluatorDerand(vk)).await?;

let yk: Vec<F> = self.ole_core.generate_output(&inputs, &yk_dash, &uk)?;

Expand Down
13 changes: 5 additions & 8 deletions ole/mpz-ole/src/ole/role/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ pub use provider::OLEeProvider;
mod tests {
use super::{OLEeEvaluator, OLEeProvider};
use crate::{ideal::role::ideal_role_pair, OLEeEvaluate, OLEeProvide};
use futures::StreamExt;
use mpz_core::{prg::Prg, Block};
use mpz_share_conversion_core::fields::{p256::P256, UniformRand};
use rand::SeedableRng;
Expand All @@ -23,20 +22,18 @@ mod tests {

let (sender_channel, receiver_channel) = MemoryDuplex::new();

let (mut provider_sink, mut provider_stream) = sender_channel.split();
let (mut evaluator_sink, mut evaluator_stream) = receiver_channel.split();

let (role_provider, role_evaluator) = ideal_role_pair::<P256>();

let mut ole_provider = OLEeProvider::<32, _, P256>::new(role_provider);
let mut ole_evaluator = OLEeEvaluator::<32, _, P256>::new(role_evaluator);
let mut ole_provider = OLEeProvider::<32, _, P256, _>::new(sender_channel, role_provider);
let mut ole_evaluator =
OLEeEvaluator::<32, _, P256, _>::new(receiver_channel, role_evaluator);

let ak: Vec<P256> = (0..count).map(|_| P256::rand(&mut rng)).collect();
let bk: Vec<P256> = (0..count).map(|_| P256::rand(&mut rng)).collect();

let (provider_res, evaluator_res) = tokio::join!(
ole_provider.provide(&mut provider_sink, &mut provider_stream, ak.clone()),
ole_evaluator.evaluate(&mut evaluator_sink, &mut evaluator_stream, bk.clone())
ole_provider.provide(ak.clone()),
ole_evaluator.evaluate(bk.clone())
);

let xk = provider_res.unwrap();
Expand Down
24 changes: 14 additions & 10 deletions ole/mpz-ole/src/ole/role/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,47 @@ use async_trait::async_trait;
use futures::SinkExt;
use mpz_ole_core::ole::role::OLEeProvider as OLEeCoreProvider;
use mpz_share_conversion_core::Field;
use utils_aio::{
sink::IoSink,
stream::{ExpectStreamExt, IoStream},
};
use utils_aio::{duplex::Duplex, stream::ExpectStreamExt};

/// A provider for OLE with errors.
pub struct OLEeProvider<const N: usize, T: RandomOLEeProvide<F>, F: Field> {
pub struct OLEeProvider<const N: usize, T, F, IO> {
channel: IO,
role_provider: T,
ole_core: OLEeCoreProvider<F>,
}

impl<const N: usize, T: RandomOLEeProvide<F>, F: Field> OLEeProvider<N, T, F> {
impl<const N: usize, T, F: Field, IO> OLEeProvider<N, T, F, IO> {
/// Creates a new [`OLEeProvider`].
pub fn new(role_provider: T) -> Self {
pub fn new(channel: IO, role_provider: T) -> Self {
// Check that the right N is used depending on the needed bit size of the field.
let _: () = Check::<N, F>::IS_BITSIZE_CORRECT;

Self {
channel,
role_provider,
ole_core: OLEeCoreProvider::default(),
}
}
}

#[async_trait]
impl<const N: usize, T, F: Field> OLEeProvide<F> for OLEeProvider<N, T, F>
impl<const N: usize, T, F: Field, IO> OLEeProvide<F> for OLEeProvider<N, T, F, IO>
where
T: RandomOLEeProvide<F> + Send,
IO: Duplex<OLEeMessage<F>>,
Self: Send,
{
async fn provide(&mut self, factors: Vec<F>) -> Result<Vec<F>, OLEError> {
let (ak_dash, xk_dash) = self.role_provider.provide_random(factors.len()).await?;

let uk: Vec<F> = self.ole_core.create_mask(&ak_dash, &factors)?;

sink.send(OLEeMessage::ProviderDerand(uk)).await?;
let vk: Vec<F> = stream.expect_next().await?.try_into_evaluator_derand()?;
self.channel.send(OLEeMessage::ProviderDerand(uk)).await?;
let vk: Vec<F> = self
.channel
.expect_next()
.await?
.try_into_evaluator_derand()?;

let x_k: Vec<F> = self
.ole_core
Expand Down
26 changes: 15 additions & 11 deletions ole/mpz-ole/src/role/rot/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,34 @@ use futures::SinkExt;
use mpz_ole_core::role::ot::ROLEeEvaluator as ROLEeCoreEvaluator;
use mpz_ot::RandomOTReceiver;
use mpz_share_conversion_core::Field;
use utils_aio::{
sink::IoSink,
stream::{ExpectStreamExt, IoStream},
};
use utils_aio::{duplex::Duplex, stream::ExpectStreamExt};

/// An evaluator for ROLE with errors.
pub struct ROLEeEvaluator<const N: usize, T: RandomOTReceiver<bool, [u8; N]>, F: Field> {
pub struct ROLEeEvaluator<const N: usize, T, F, IO> {
channel: IO,
rot_receiver: T,
role_core: ROLEeCoreEvaluator<N, F>,
}

impl<const N: usize, T: RandomOTReceiver<bool, [u8; N]>, F: Field> ROLEeEvaluator<N, T, F> {
impl<const N: usize, T, F: Field, IO> ROLEeEvaluator<N, T, F, IO> {
/// Create a new [`ROLEeEvaluator`].
pub fn new(rot_receiver: T) -> Self {
pub fn new(channel: IO, rot_receiver: T) -> Self {
// Check that the right N is used depending on the needed bit size of the field.
let _: () = Check::<N, F>::IS_BITSIZE_CORRECT;

Self {
channel,
rot_receiver,
role_core: ROLEeCoreEvaluator::default(),
}
}
}

#[async_trait]
impl<const N: usize, T, F: Field> RandomOLEeEvaluate<F> for ROLEeEvaluator<N, T, F>
impl<const N: usize, T, F: Field, IO> RandomOLEeEvaluate<F> for ROLEeEvaluator<N, T, F, IO>
where
T: RandomOTReceiver<bool, [u8; N]> + Send,
IO: Duplex<ROLEeMessage<F>>,
Self: Send,
{
async fn evaluate_random(&mut self, count: usize) -> Result<(Vec<F>, Vec<F>), OLEError> {
Expand All @@ -40,12 +40,16 @@ where
.receive_random(count * F::BIT_SIZE as usize)
.await?;

let (ui, ek): (Vec<F>, Vec<F>) =
stream.expect_next().await?.try_into_random_provider_msg()?;
let (ui, ek): (Vec<F>, Vec<F>) = self
.channel
.expect_next()
.await?
.try_into_random_provider_msg()?;

let dk: Vec<F> = self.role_core.sample_d(count);

sink.send(ROLEeMessage::RandomEvaluatorMsg(dk.clone()))
self.channel
.send(ROLEeMessage::RandomEvaluatorMsg(dk.clone()))
.await?;

let (bk, yk) = self.role_core.generate_output(&fi, &tfi, &ui, &dk, &ek)?;
Expand Down
13 changes: 5 additions & 8 deletions ole/mpz-ole/src/role/rot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ pub use provider::ROLEeProvider;
mod tests {
use super::{ROLEeEvaluator, ROLEeProvider};
use crate::{RandomOLEeEvaluate, RandomOLEeProvide};
use futures::StreamExt;
use mpz_ot::ideal::ideal_random_ot_pair;
use mpz_share_conversion_core::fields::p256::P256;
use utils_aio::duplex::MemoryDuplex;
Expand All @@ -20,17 +19,15 @@ mod tests {
let count = 12;
let (sender_channel, receiver_channel) = MemoryDuplex::new();

let (mut provider_sink, mut provider_stream) = sender_channel.split();
let (mut evaluator_sink, mut evaluator_stream) = receiver_channel.split();

let (rot_sender, rot_receiver) = ideal_random_ot_pair::<[u8; 32]>([0; 32]);

let mut role_provider = ROLEeProvider::<32, _, P256>::new(rot_sender);
let mut role_evaluator = ROLEeEvaluator::<32, _, P256>::new(rot_receiver);
let mut role_provider = ROLEeProvider::<32, _, P256, _>::new(sender_channel, rot_sender);
let mut role_evaluator =
ROLEeEvaluator::<32, _, P256, _>::new(receiver_channel, rot_receiver);

let (provider_res, evaluator_res) = tokio::join!(
role_provider.provide_random(&mut provider_sink, &mut provider_stream, count),
role_evaluator.evaluate_random(&mut evaluator_sink, &mut evaluator_stream, count)
role_provider.provide_random(count),
role_evaluator.evaluate_random(count)
);

let (ak, xk) = provider_res.unwrap();
Expand Down
Loading

0 comments on commit 2de0b5a

Please sign in to comment.