Skip to content

Commit

Permalink
Improve code quality and documentation in mpz-ole
Browse files Browse the repository at this point in the history
  • Loading branch information
th4s committed Mar 7, 2024
1 parent 2c92b90 commit 0aa9250
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 41 deletions.
25 changes: 13 additions & 12 deletions ole/mpz-ole/src/ideal/ole.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rand::thread_rng;
use std::marker::PhantomData;
use utils_aio::{sink::IoSink, stream::IoStream};

/// Returns an ideal OLE pair
/// Returns an ideal OLE pair.
pub fn ideal_ole_pair<F: Field>() -> (IdealOLEProvider<F>, IdealOLEEvaluator<F>) {
let (sender, receiver) = mpsc::channel(10);

Expand All @@ -26,7 +26,7 @@ pub fn ideal_ole_pair<F: Field>() -> (IdealOLEProvider<F>, IdealOLEEvaluator<F>)
(provider, evaluator)
}

/// An ideal OLEProvider for field elements
/// An ideal OLE Provider.
pub struct IdealOLEProvider<F: Field> {
phantom: PhantomData<F>,
channel: mpsc::Sender<(Vec<F>, Vec<F>)>,
Expand All @@ -36,7 +36,7 @@ impl<F: Field> ProtocolMessage for IdealOLEProvider<F> {
type Msg = ();
}

/// An ideal OLEEvaluator for field elements
/// An ideal OLE Evaluator.
pub struct IdealOLEEvaluator<F: Field> {
phantom: PhantomData<F>,
channel: mpsc::Receiver<(Vec<F>, Vec<F>)>,
Expand All @@ -58,12 +58,13 @@ impl<F: Field> OLEeProvide<F> for IdealOLEProvider<F> {
factors: Vec<F>,
) -> Result<Vec<F>, OLEError> {
let mut rng = thread_rng();
let summands: Vec<F> = (0..factors.len()).map(|_| F::rand(&mut rng)).collect();
let offsets: Vec<F> = (0..factors.len()).map(|_| F::rand(&mut rng)).collect();

self.channel
.try_send((factors.clone(), summands.clone()))
.try_send((factors.clone(), offsets.clone()))
.expect("DummySender should be able to send");

Ok(summands)
Ok(offsets)
}
}

Expand All @@ -78,16 +79,16 @@ impl<F: Field> OLEeEvaluate<F> for IdealOLEEvaluator<F> {
_stream: &mut St,
input: Vec<F>,
) -> Result<Vec<F>, OLEError> {
let (factors, summands) = self
let (factors, offsets) = self
.channel
.next()
.await
.expect("DummySender should send a value");

let output: Vec<F> = input
.iter()
.zip(factors.iter().copied())
.zip(summands)
.zip(factors)
.zip(offsets)
.map(|((&a, b), x)| a * b + x)
.collect();

Expand All @@ -106,7 +107,7 @@ mod tests {

#[tokio::test]
async fn test_ideal_ole() {
let count = 16;
let count = 12;
let mut rng = Prg::from_seed(Block::ZERO);

let inputs: Vec<P256> = (0..count).map(|_| P256::rand(&mut rng)).collect();
Expand All @@ -119,7 +120,7 @@ mod tests {

let (mut provider, mut evaluator) = ideal_ole_pair::<P256>();

let summands = provider
let offsets = provider
.provide(&mut provider_sink, &mut provider_stream, factors.clone())
.await
.unwrap();
Expand All @@ -131,7 +132,7 @@ mod tests {
inputs
.iter()
.zip(factors)
.zip(summands)
.zip(offsets)
.zip(outputs)
.for_each(|(((&a, b), x), y)| assert_eq!(y, a * b + x));
}
Expand Down
13 changes: 7 additions & 6 deletions ole/mpz-ole/src/ideal/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rand::thread_rng;
use std::marker::PhantomData;
use utils_aio::{sink::IoSink, stream::IoStream};

/// Returns an ideal ROLE pair
/// Returns an ideal ROLE pair.
pub fn ideal_role_pair<F: Field>() -> (IdealROLEProvider<F>, IdealROLEEvaluator<F>) {
let (sender, receiver) = mpsc::channel(10);

Expand All @@ -26,7 +26,7 @@ pub fn ideal_role_pair<F: Field>() -> (IdealROLEProvider<F>, IdealROLEEvaluator<
(provider, evaluator)
}

/// An ideal ROLEProvider for field elements
/// An ideal ROLE Provider.
pub struct IdealROLEProvider<F: Field> {
phantom: PhantomData<F>,
channel: mpsc::Sender<(Vec<F>, Vec<F>)>,
Expand All @@ -36,7 +36,7 @@ impl<F: Field> ProtocolMessage for IdealROLEProvider<F> {
type Msg = ();
}

/// An ideal ROLEEvaluator for field elements
/// An ideal ROLE Evaluator.
pub struct IdealROLEEvaluator<F: Field> {
phantom: PhantomData<F>,
channel: mpsc::Receiver<(Vec<F>, Vec<F>)>,
Expand Down Expand Up @@ -94,9 +94,9 @@ impl<F: Field> RandomOLEeEvaluate<F> for IdealROLEEvaluator<F> {

let yk: Vec<F> = ak
.iter()
.zip(bk.iter().copied())
.zip(bk.iter())
.zip(xk)
.map(|((&a, b), x)| a * b + x)
.map(|((&a, &b), x)| a * b + x)
.collect();

Ok((bk, yk))
Expand All @@ -112,7 +112,7 @@ mod tests {

#[tokio::test]
async fn test_ideal_role() {
let count = 16;
let count = 12;

let (send_channel, recv_channel) = MemoryDuplex::<()>::new();

Expand All @@ -125,6 +125,7 @@ mod tests {
.provide_random(&mut provider_sink, &mut provider_stream, count)
.await
.unwrap();

let (bk, yk) = evaluator
.evaluate_random(&mut evaluator_sink, &mut evaluator_stream, count)
.await
Expand Down
16 changes: 8 additions & 8 deletions ole/mpz-ole/src/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use serde::{Deserialize, Serialize};
#[derive_err(Debug)]
/// A message type for ROLEe protocols.
pub enum ROLEeMessage<T, F: Field> {
/// Ciphertexts of the random OT protocol
/// Messages of the random OT protocol.
RandomOTMessage(T),
/// Random field elements sent by the provider
/// Random field elements sent by the provider.
///
/// These are u_i and e_k
/// These are u_i and e_k.
RandomProviderMsg(Vec<F>, Vec<F>),
/// Random field elements sent by the evaluator
/// Random field elements sent by the evaluator.
///
/// These are d_k
/// These are d_k.
RandomEvaluatorMsg(Vec<F>),
}

Expand All @@ -30,11 +30,11 @@ impl<T, F: Field> From<ROLEeMessageError<T, F>> for std::io::Error {
#[derive_err(Debug)]
/// A message type for OLEe protocols.
pub enum OLEeMessage<T, F: Field> {
/// Messages of the underlying ROLEe protocol
/// Messages of the underlying ROLEe protocol.
ROLEeMessage(T),
/// Field elements sent by the provider
/// Field elements sent by the provider.
ProviderDerand(Vec<F>),
/// Field elements sent by the evaluator
/// Field elements sent by the evaluator.
EvaluatorDerand(Vec<F>),
}

Expand Down
4 changes: 2 additions & 2 deletions ole/mpz-ole/src/ole/role/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ use utils_aio::{
stream::{ExpectStreamExt, IoStream},
};

/// An evaluator for OLEe.
/// An evaluator for OLE with errors.
pub struct OLEeEvaluator<const N: usize, T: RandomOLEeEvaluate<F>, F: Field> {
role_evaluator: T,
ole_core: OLEeCoreEvaluator<F>,
}

impl<const N: usize, T: RandomOLEeEvaluate<F>, F: Field> OLEeEvaluator<N, T, F> {
/// Create a new [`OLEeEvaluator`].
/// Creates a new [`OLEeEvaluator`].
pub fn new(role_evaluator: T) -> Self {
// Check that the right N is used depending on the needed bit size of the field.
let _: () = Check::<N, F>::IS_BITSIZE_CORRECT;
Expand Down
7 changes: 3 additions & 4 deletions ole/mpz-ole/src/ole/role/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
//! Provides implementations of OLE with errors (OLEe) based on ROLEe.
//! Provides an implementation of OLEe based on ROLEe.
mod evaluator;
mod provider;

pub use evaluator::OLEeEvaluator;
pub use provider::OLEeProvider;

use crate::msg::OLEeMessage;
use futures::{SinkExt, StreamExt};
use mpz_share_conversion_core::Field;
use utils_aio::{sink::IoSink, stream::IoStream};

use crate::msg::OLEeMessage;

/// Converts a sink of OLEe messages into a sink of ROLEe messsages.
fn into_role_sink<'a, Si: IoSink<OLEeMessage<T, F>> + Send + Unpin, T: Send + 'a, F: Field>(
sink: &'a mut Si,
Expand Down Expand Up @@ -43,7 +42,7 @@ mod tests {

#[tokio::test]
async fn test_ole() {
let count = 16;
let count = 12;
let mut rng = Prg::from_seed(Block::ZERO);

let (sender_channel, receiver_channel) = MemoryDuplex::new();
Expand Down
4 changes: 2 additions & 2 deletions ole/mpz-ole/src/ole/role/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ use utils_aio::{
stream::{ExpectStreamExt, IoStream},
};

/// A provider for various OLE constructions.
/// A provider for OLE with errors.
pub struct OLEeProvider<const N: usize, T: RandomOLEeProvide<F>, F: Field> {
role_provider: T,
ole_core: OLEeCoreProvider<F>,
}

impl<const N: usize, T: RandomOLEeProvide<F>, F: Field> OLEeProvider<N, T, F> {
/// Create a new [`OLEeProvider`].
/// Creates a new [`OLEeProvider`].
pub fn new(role_provider: T) -> Self {
// Check that the right N is used depending on the needed bit size of the field.
let _: () = Check::<N, F>::IS_BITSIZE_CORRECT;
Expand Down
2 changes: 1 addition & 1 deletion ole/mpz-ole/src/role/rot/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use utils_aio::{
stream::{ExpectStreamExt, IoStream},
};

/// An evaluator for ROLEe.
/// An evaluator for ROLE with errors.
pub struct ROLEeEvaluator<const N: usize, T: RandomOTReceiver<bool, [u8; N]>, F: Field> {
rot_receiver: T,
role_core: ROLEeCoreEvaluator<N, F>,
Expand Down
8 changes: 4 additions & 4 deletions ole/mpz-ole/src/role/rot/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! Provides implementations of ROLEe protocols based on random OT.
//! Provides an implementation of ROLEe based on random OT.
mod evaluator;
mod provider;
Expand All @@ -11,7 +11,7 @@ use futures::{SinkExt, StreamExt};
use mpz_share_conversion_core::Field;
use utils_aio::{sink::IoSink, stream::IoStream};

/// Converts a sink of random OLE messages into a sink of random OT messages.
/// Converts a sink of ROLE messages into a sink of random OT messages.
fn into_rot_sink<'a, Si: IoSink<ROLEeMessage<T, F>> + Send + Unpin, T: Send + 'a, F: Field>(
sink: &'a mut Si,
) -> impl IoSink<T> + Send + Unpin + 'a {
Expand All @@ -20,7 +20,7 @@ fn into_rot_sink<'a, Si: IoSink<ROLEeMessage<T, F>> + Send + Unpin, T: Send + 'a
}))
}

/// Converts a stream of random OLE messages into a stream of random OT messages.
/// Converts a stream of ROLE messages into a stream of random OT messages.
fn into_rot_stream<'a, St: IoStream<ROLEeMessage<T, F>> + Send + Unpin, T: Send + 'a, F: Field>(
stream: &'a mut St,
) -> impl IoStream<T> + Send + Unpin + 'a {
Expand All @@ -41,7 +41,7 @@ mod tests {

#[tokio::test]
async fn test_role() {
let count = 16;
let count = 12;
let (sender_channel, receiver_channel) = MemoryDuplex::new();

let (mut provider_sink, mut provider_stream) = sender_channel.split();
Expand Down
4 changes: 2 additions & 2 deletions ole/mpz-ole/src/role/rot/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ use utils_aio::{
stream::{ExpectStreamExt, IoStream},
};

/// A provider for ROLEe.
/// A provider for ROLE with errors.
pub struct ROLEeProvider<const N: usize, T: RandomOTSender<[[u8; N]; 2]>, F> {
rot_sender: T,
role_core: ROLEeCoreProvider<N, F>,
}

impl<const N: usize, T: RandomOTSender<[[u8; N]; 2]>, F: Field> ROLEeProvider<N, T, F> {
/// Create a new [`ROLEeProvider`].
/// Creates a new [`ROLEeProvider`].
pub fn new(rot_sender: T) -> Self {
// Check that the right N is used depending on the needed bit size of the field.
let _: () = Check::<N, F>::IS_BITSIZE_CORRECT;
Expand Down

0 comments on commit 0aa9250

Please sign in to comment.