diff --git a/crates/mpz-common/src/ideal.rs b/crates/mpz-common/src/ideal.rs index 804472ef..1b6b3181 100644 --- a/crates/mpz-common/src/ideal.rs +++ b/crates/mpz-common/src/ideal.rs @@ -18,7 +18,7 @@ struct Buffer { } /// The ideal functionality from the perspective of Alice. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Alice { f: Arc>, buffer: Arc>, @@ -35,7 +35,7 @@ impl Clone for Alice { impl Alice { /// Returns a lock to the ideal functionality. - pub fn get_mut(&mut self) -> MutexGuard<'_, F> { + pub fn lock(&self) -> MutexGuard<'_, F> { self.f.lock().unwrap() } @@ -79,7 +79,7 @@ impl Alice { } /// The ideal functionality from the perspective of Bob. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Bob { f: Arc>, buffer: Arc>, @@ -96,7 +96,7 @@ impl Clone for Bob { impl Bob { /// Returns a lock to the ideal functionality. - pub fn get_mut(&mut self) -> MutexGuard<'_, F> { + pub fn lock(&self) -> MutexGuard<'_, F> { self.f.lock().unwrap() } diff --git a/crates/mpz-core/src/ggm_tree.rs b/crates/mpz-core/src/ggm_tree.rs index 913fffb6..840efcc6 100644 --- a/crates/mpz-core/src/ggm_tree.rs +++ b/crates/mpz-core/src/ggm_tree.rs @@ -32,33 +32,35 @@ impl GgmTree { assert_eq!(k0.len(), self.depth); assert_eq!(k1.len(), self.depth); let mut buf = [Block::ZERO; 8]; - self.tkprp.expand_1to2(tree, seed); - k0[0] = tree[0]; - k1[0] = tree[1]; + if self.depth > 1 { + self.tkprp.expand_1to2(tree, seed); + k0[0] = tree[0]; + k1[0] = tree[1]; - self.tkprp.expand_2to4(&mut buf, tree); - k0[1] = buf[0] ^ buf[2]; - k1[1] = buf[1] ^ buf[3]; - tree[0..4].copy_from_slice(&buf[0..4]); - - for h in 2..self.depth { - k0[h] = Block::ZERO; - k1[h] = Block::ZERO; - - // How many nodes there are in this layer - let sz = 1 << h; - for i in (0..=sz - 4).rev().step_by(4) { - self.tkprp.expand_4to8(&mut buf, &tree[i..]); - k0[h] ^= buf[0]; - k0[h] ^= buf[2]; - k0[h] ^= buf[4]; - k0[h] ^= buf[6]; - k1[h] ^= buf[1]; - k1[h] ^= buf[3]; - k1[h] ^= buf[5]; - k1[h] ^= buf[7]; + self.tkprp.expand_2to4(&mut buf, tree); + k0[1] = buf[0] ^ buf[2]; + k1[1] = buf[1] ^ buf[3]; + tree[0..4].copy_from_slice(&buf[0..4]); - tree[2 * i..2 * i + 8].copy_from_slice(&buf); + for h in 2..self.depth { + k0[h] = Block::ZERO; + k1[h] = Block::ZERO; + + // How many nodes there are in this layer + let sz = 1 << h; + for i in (0..=sz - 4).rev().step_by(4) { + self.tkprp.expand_4to8(&mut buf, &tree[i..]); + k0[h] ^= buf[0]; + k0[h] ^= buf[2]; + k0[h] ^= buf[4]; + k0[h] ^= buf[6]; + k1[h] ^= buf[1]; + k1[h] ^= buf[3]; + k1[h] ^= buf[5]; + k1[h] ^= buf[7]; + + tree[2 * i..2 * i + 8].copy_from_slice(&buf); + } } } } diff --git a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs index 403802f9..d9638951 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs @@ -153,7 +153,7 @@ impl Receiver { let SenderPayload { id, payload } = payload; // Check that the transfer id matches - let expected_id = current_id.next(); + let expected_id = current_id.next_id(); if id != expected_id { return Err(ReceiverError::IdMismatch(expected_id, id)); } diff --git a/crates/mpz-ot-core/src/chou_orlandi/sender.rs b/crates/mpz-ot-core/src/chou_orlandi/sender.rs index 09a8b5a6..328354eb 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/sender.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/sender.rs @@ -139,7 +139,7 @@ impl Sender { } = receiver_payload; // Check that the transfer id matches - let expected_id = current_id.next(); + let expected_id = current_id.next_id(); if id != expected_id { return Err(SenderError::IdMismatch(expected_id, id)); } diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 3ad7701e..ac73c005 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -1,7 +1,4 @@ //! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. - -use mpz_core::lpn::LpnParameters; - pub mod cuckoo; pub mod error; pub mod mpcot; @@ -19,28 +16,13 @@ pub const CUCKOO_HASH_NUM: usize = 3; /// Trial numbers in Cuckoo hash insertion. pub const CUCKOO_TRIAL_NUM: usize = 100; -/// LPN parameters with regular noise. -/// Derived from https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/constants.h -pub const LPN_PARAMETERS_REGULAR: LpnParameters = LpnParameters { - n: 10180608, - k: 124000, - t: 4971, -}; - -/// LPN parameters with uniform noise. -/// Derived from Table 2. -pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { - n: 10616092, - k: 588160, - t: 1324, -}; - /// The type of Lpn parameters. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, Default)] pub enum LpnType { /// Uniform error distribution. Uniform, /// Regular error distribution. + #[default] Regular, } @@ -48,15 +30,15 @@ pub enum LpnType { mod tests { use super::*; - use msgs::LpnMatrixSeed; use receiver::Receiver; use sender::Sender; - use crate::ideal::{cot::IdealCOT, mpcot::IdealMpcot}; - use crate::test::assert_cot; - use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; + use crate::{ + ideal::{cot::IdealCOT, mpcot::IdealMpcot}, + test::assert_cot, + MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, + }; use mpz_core::{lpn::LpnParameters, prg::Prg}; - use rand::SeedableRng; const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { n: 9600, @@ -66,7 +48,7 @@ mod tests { #[test] fn ferret_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_cot = IdealCOT::default(); let mut ideal_mpcot = IdealMpcot::default(); @@ -101,18 +83,8 @@ mod tests { ) .unwrap(); - let LpnMatrixSeed { - seed: lpn_matrix_seed, - } = seed; - let mut sender = sender - .setup( - delta, - LPN_PARAMETERS_TEST, - LpnType::Regular, - lpn_matrix_seed, - &v, - ) + .setup(delta, LPN_PARAMETERS_TEST, LpnType::Regular, seed, &v) .unwrap(); // extend once @@ -122,8 +94,15 @@ mod tests { let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = ideal_mpcot.extend(&query.0, query.1); - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); + sender.extend(s).unwrap(); + receiver.extend(r).unwrap(); + + let RCOTSenderOutput { msgs, .. } = sender.consume(2).unwrap(); + let RCOTReceiverOutput { + choices, + msgs: received, + .. + } = receiver.consume(2).unwrap(); assert_cot(delta, &choices, &msgs, &received); @@ -134,8 +113,15 @@ mod tests { let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = ideal_mpcot.extend(&query.0, query.1); - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); + sender.extend(s).unwrap(); + receiver.extend(r).unwrap(); + + let RCOTSenderOutput { msgs, .. } = sender.consume(sender.remaining()).unwrap(); + let RCOTReceiverOutput { + choices, + msgs: received, + .. + } = receiver.consume(receiver.remaining()).unwrap(); assert_cot(delta, &choices, &msgs, &received); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs index e74dc38a..047780d4 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs @@ -16,11 +16,10 @@ mod tests { use crate::ideal::spcot::IdealSpcot; use crate::{SPCOTReceiverOutput, SPCOTSenderOutput}; use mpz_core::prg::Prg; - use rand::SeedableRng; #[test] fn mpcot_general_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); @@ -96,7 +95,7 @@ mod tests { #[test] fn mpcot_regular_test() { - let mut prg = Prg::from_seed([2u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs index 0f8613af..e4d362da 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs @@ -32,11 +32,11 @@ impl Receiver { /// # Argument /// /// * `hash_seed` - Random seed to generate hashes, will be sent to the sender. - pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { + pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); let recv = Receiver { - state: state::PreExtension { + state: state::Extension { counter: 0, hashes: Arc::new(hashes), }, @@ -48,7 +48,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -63,7 +63,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { if alphas.len() as u32 > n { return Err(ReceiverError::InvalidInput( "length of alphas should not exceed n".to_string(), @@ -104,7 +104,7 @@ impl Receiver { } let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, m, n, @@ -117,7 +117,7 @@ impl Receiver { Ok((receiver, p)) } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -128,7 +128,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt.len() != self.state.m { return Err(ReceiverError::InvalidInput( "the length rt should be m".to_string(), @@ -165,7 +165,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, hashes: self.state.hashes, }, @@ -182,8 +182,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -200,20 +200,20 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, /// The hashes to generate Cuckoo hash table. pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state of extension. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter pub(super) counter: usize, /// Current length of Cuckoo hash table, will possibly be changed in each extension. @@ -228,7 +228,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs index 2b226108..e1e7edfe 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs @@ -19,13 +19,13 @@ impl Receiver { } /// Completes the setup phase of the protocol. - pub fn setup(self) -> Receiver { + pub fn setup(self) -> Receiver { Receiver { - state: state::PreExtension { counter: 0 }, + state: state::Extension { counter: 0 }, } } } -impl Receiver { +impl Receiver { /// Performs the prepare procedure in MPCOT extension. /// Outputs the indices for SPCOT. /// @@ -38,7 +38,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { let t = alphas.len() as u32; if t > n { return Err(ReceiverError::InvalidInput( @@ -91,7 +91,7 @@ impl Receiver { .collect(); let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, n, queries_length, @@ -103,7 +103,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// # Arguments. @@ -112,7 +112,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt .iter() .zip(self.state.queries_depth.iter()) @@ -130,7 +130,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, }, }; @@ -145,8 +145,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -162,19 +162,19 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state after the setup phase. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter #[allow(dead_code)] pub(super) counter: usize, @@ -186,7 +186,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs index f1e49105..ad025574 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs @@ -31,12 +31,12 @@ impl Sender { /// /// * `delta` - The sender's global secret. /// * `hash_seed` - The seed for Cuckoo hash sent by the receiver. - pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { + pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { let HashSeed { seed: hash_seed } = hash_seed; let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); Sender { - state: state::PreExtension { + state: state::Extension { delta, counter: 0, hashes: Arc::new(hashes), @@ -45,7 +45,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -59,7 +59,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -86,7 +86,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, m, @@ -101,7 +101,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -112,7 +112,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st.len() != self.state.m { return Err(SenderError::InvalidInput( "the length st should be m".to_string(), @@ -147,7 +147,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, hashes: self.state.hashes, @@ -166,8 +166,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -184,7 +184,7 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -193,13 +193,13 @@ pub mod state { pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state of extension. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -217,7 +217,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs index db0646b6..7afa5106 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs @@ -23,14 +23,14 @@ impl Sender { /// # Argument. /// /// * `delta` - The sender's global secret. - pub fn setup(self, delta: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { - state: state::PreExtension { delta, counter: 0 }, + state: state::Extension { delta, counter: 0 }, } } } -impl Sender { +impl Sender { /// Performs the prepare procedure in MPCOT extension. /// Outputs the information for SPCOT. /// @@ -42,7 +42,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -78,7 +78,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, n, @@ -91,7 +91,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// # Arguments. @@ -100,7 +100,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st .iter() .zip(self.state.queries_depth.iter()) @@ -117,7 +117,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, }, @@ -135,8 +135,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -153,20 +153,20 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state after the setup phase. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -179,7 +179,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index 4d08c69b..782d2b9e 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -1,10 +1,15 @@ //! Ferret receiver +use std::collections::VecDeque; + use mpz_core::{ lpn::{LpnEncoder, LpnParameters}, Block, }; -use crate::ferret::{error::ReceiverError, LpnType}; +use crate::{ + ferret::{error::ReceiverError, LpnType}, + RCOTReceiverOutput, TransferId, +}; use super::msgs::LpnMatrixSeed; @@ -59,6 +64,9 @@ impl Receiver { u: u.to_vec(), w: w.to_vec(), e: Vec::default(), + id: TransferId::default(), + choices_buffer: VecDeque::new(), + msgs_buffer: VecDeque::new(), }, }, LpnMatrixSeed { seed }, @@ -67,12 +75,18 @@ impl Receiver { } impl Receiver { + /// Returns the current transfer id. + pub fn id(&self) -> TransferId { + self.state.id + } + + /// Returns the number of remaining COTs. + pub fn remaining(&self) -> usize { + self.state.choices_buffer.len() + } + /// The prepare precedure of extension, sample error vectors and outputs information for MPCOT. /// See step 3 and 4. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN parameters. pub fn get_mpcot_query(&mut self) -> (Vec, usize) { match self.state.lpn_type { LpnType::Uniform => { @@ -100,13 +114,15 @@ impl Receiver { /// # Arguments. /// /// * `r` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, r: &[Block]) -> Result<(Vec, Vec), ReceiverError> { + pub fn extend(&mut self, r: Vec) -> Result<(), ReceiverError> { if r.len() != self.state.lpn_parameters.n { return Err(ReceiverError("the length of r should be n".to_string())); } + self.state.id.next_id(); + // Compute z = A * w + r. - let mut z = r.to_vec(); + let mut z = r; self.state.lpn_encoder.compute(&mut z, &self.state.w); // Compute x = A * u + e. @@ -131,7 +147,32 @@ impl Receiver { // Update counter self.state.counter += 1; - Ok((x_, z_)) + self.state.choices_buffer.extend(x_); + self.state.msgs_buffer.extend(z_); + + Ok(()) + } + + /// Consumes `count` COTs. + pub fn consume( + &mut self, + count: usize, + ) -> Result, ReceiverError> { + if count > self.state.choices_buffer.len() { + return Err(ReceiverError(format!( + "insufficient OTs: {} < {count}", + self.state.choices_buffer.len() + ))); + } + + let choices = self.state.choices_buffer.drain(0..count).collect(); + let msgs = self.state.msgs_buffer.drain(0..count).collect(); + + Ok(RCOTReceiverOutput { + id: self.state.id.next_id(), + choices, + msgs, + }) } } @@ -176,6 +217,12 @@ pub mod state { /// Receiver's lpn error vector. pub(super) e: Vec, + + /// TransferID + pub(super) id: TransferId, + /// Extended OTs buffers. + pub(super) choices_buffer: VecDeque, + pub(super) msgs_buffer: VecDeque, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 9e8db180..e6af6452 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -1,10 +1,17 @@ //! Ferret sender. +use std::collections::VecDeque; + use mpz_core::{ lpn::{LpnEncoder, LpnParameters}, Block, }; -use crate::ferret::{error::SenderError, LpnType}; +use crate::{ + ferret::{error::SenderError, LpnType}, + RCOTSenderOutput, TransferId, +}; + +use super::msgs::LpnMatrixSeed; /// Ferret sender. #[derive(Debug, Default)] @@ -36,7 +43,7 @@ impl Sender { delta: Block, lpn_parameters: LpnParameters, lpn_type: LpnType, - seed: Block, + seed: LpnMatrixSeed, v: &[Block], ) -> Result, SenderError> { if v.len() != lpn_parameters.k { @@ -44,6 +51,7 @@ impl Sender { "the length of v should be equal to k".to_string(), )); } + let LpnMatrixSeed { seed } = seed; let lpn_encoder = LpnEncoder::<10>::new(seed, lpn_parameters.k as u32); Ok(Sender { @@ -54,15 +62,33 @@ impl Sender { lpn_type, lpn_encoder, v: v.to_vec(), + id: TransferId::default(), + msgs_buffer: VecDeque::new(), }, }) } } impl Sender { + /// Returns the current transfer id. + pub fn id(&self) -> TransferId { + self.state.id + } + + /// Returns the number of remaining COTs. + pub fn remaining(&self) -> usize { + self.state.msgs_buffer.len() + } + + /// Returns the delta correlation. + pub fn delta(&self) -> Block { + self.state.delta + } + /// Outputs the information for MPCOT. /// /// See step 3 and 4. + #[inline] pub fn get_mpcot_query(&self) -> (u32, u32) { ( self.state.lpn_parameters.t as u32, @@ -78,13 +104,15 @@ impl Sender { /// # Arguments. /// /// * `s` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, s: &[Block]) -> Result, SenderError> { + pub fn extend(&mut self, s: Vec) -> Result<(), SenderError> { if s.len() != self.state.lpn_parameters.n { return Err(SenderError("the length of s should be n".to_string())); } + self.state.id.next_id(); + // Compute y = A * v + s - let mut y = s.to_vec(); + let mut y = s; self.state.lpn_encoder.compute(&mut y, &self.state.v); let y_ = y.split_off(self.state.lpn_parameters.k); @@ -94,13 +122,33 @@ impl Sender { // Update counter self.state.counter += 1; + self.state.msgs_buffer.extend(y_); - Ok(y_) + Ok(()) + } + + /// Consumes `count` COTs. + pub fn consume(&mut self, count: usize) -> Result, SenderError> { + if count > self.state.msgs_buffer.len() { + return Err(SenderError(format!( + "insufficient OTs: {} < {count}", + self.state.msgs_buffer.len() + ))); + } + + let msgs = self.state.msgs_buffer.drain(0..count).collect(); + + Ok(RCOTSenderOutput { + id: self.state.id.next_id(), + msgs, + }) } } /// The sender's state. pub mod state { + use crate::TransferId; + use super::*; mod sealed { @@ -141,6 +189,11 @@ pub mod state { /// Sender's COT message in the setup phase. pub(super) v: Vec, + + /// Transfer ID. + pub(crate) id: TransferId, + /// COT messages buffer. + pub(super) msgs_buffer: VecDeque, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/spcot/mod.rs b/crates/mpz-ot-core/src/ferret/spcot/mod.rs index 802efb66..63ebea15 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/mod.rs @@ -7,8 +7,6 @@ pub mod sender; #[cfg(test)] mod tests { - use mpz_core::prg::Prg; - use super::{receiver::Receiver as SpcotReceiver, sender::Sender as SpcotSender}; use crate::{ferret::CSP, ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; @@ -18,49 +16,82 @@ mod tests { let sender = SpcotSender::new(); let receiver = SpcotReceiver::new(); - let mut prg = Prg::new(); - let sender_seed = prg.random_block(); let delta = ideal_cot.delta(); - let mut sender = sender.setup(delta, sender_seed); + let mut sender = sender.setup(delta); let mut receiver = receiver.setup(); - let h1 = 8; - let alpha1 = 3; + let hs = [8, 4, 10]; + let alphas = [3, 2, 4]; - // Extend once - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h1); + let h_sum = hs.iter().sum(); + // batch extension + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h1, alpha1, &rs).unwrap(); - let msg_from_sender = sender.extend(h1, &qs, maskbits).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); + + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); + + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); + + // Check + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); + + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = msg_for_receiver; + + let RCOTSenderOutput { msgs: y_star, .. } = msg_for_sender; + + let check_from_receiver = receiver.check_pre(&x_star).unwrap(); - receiver.extend(h1, alpha1, &ts, msg_from_sender).unwrap(); + let (mut output_sender, check) = sender.check(&y_star, check_from_receiver).unwrap(); - // Extend twice - let h2 = 4; - let alpha2 = 2; + let output_receiver = receiver.check(&z_star, check).unwrap(); - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h2); + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + // extend twice + let hs = [6, 9, 8]; + let alphas = [2, 1, 3]; + + let h_sum = hs.iter().sum(); + + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h2, alpha2, &rs).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); - let msg_from_sender = sender.extend(h2, &qs, maskbits).unwrap(); + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); - receiver.extend(h2, alpha2, &ts, msg_from_sender).unwrap(); + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); // Check let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); diff --git a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs index 5e860f31..baf10ae2 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs @@ -6,6 +6,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -43,71 +47,101 @@ impl Receiver { } impl Receiver { - /// Performs the mask bit step in extension. + /// Performs the mask bit step in batch in extension. /// /// See step 4 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `rs` - The message from COT ideal functionality for the receiver. Only the random bits are used. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `rss` - The message from COT ideal functionality for the receiver for all the tress. Only the random bits are used. pub fn extend_mask_bits( &mut self, - h: usize, - alpha: u32, - rs: &[bool], - ) -> Result { + hs: &[usize], + alphas: &[u32], + rss: &[bool], + ) -> Result, ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( "extension is not allowed".to_string(), )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - if rs.len() != h { + let h_sum = hs.iter().sum(); + + if rss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of r should be h".to_string(), + "the length of r should be the sum of h".to_string(), )); } - // Step 4 in Figure 6 + let mut rs_s = vec![Vec::::new(); hs.len()]; + let mut rss_vec = rss.to_vec(); + for (index, h) in hs.iter().enumerate() { + rs_s[index] = rss_vec.drain(0..*h).collect(); + } - let bs: Vec = alpha - .iter_msb0() - .skip(32 - h) - // Computes alpha_i XOR r_i XOR 1. - .zip(rs.iter()) - .map(|(alpha, &r)| alpha == r) - .collect(); + // Step 4 in Figure 6 + let mut bss = vec![Vec::::new(); hs.len()]; + + let iter = bss + .iter_mut() + .zip(alphas.iter()) + .zip(hs.iter()) + .zip(rs_s.iter()) + .map(|(((bs, alpha), h), rs)| (bs, alpha, h, rs)); + + for (bs, alpha, h, rs) in iter { + *bs = alpha + .iter_msb0() + .skip(32 - h) + // Computes alpha_i XOR r_i XOR 1. + .zip(rs.iter()) + .map(|(alpha, &r)| alpha == r) + .collect(); + } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); + + let res: Vec = bss.into_iter().map(|bs| MaskBits { bs }).collect(); - Ok(MaskBits { bs }) + Ok(res) } - /// Performs the GGM reconstruction step in extension. This function can be called multiple times before checking. + /// Performs the GGM reconstruction step in batch in extension. This function can be called multiple times before checking. /// /// See step 5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `ts` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. - /// * `extendfs` - The message sent by the sender. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `tss` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. + /// * `extendfss` - The vector of messages sent by the sender. pub fn extend( &mut self, - h: usize, - alpha: u32, - ts: &[Block], - extendfs: ExtendFromSender, + hs: &[usize], + alphas: &[u32], + tss: &[Block], + extendfss: &[ExtendFromSender], ) -> Result<(), ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( @@ -115,61 +149,122 @@ impl Receiver { )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - let ExtendFromSender { ms, sum } = extendfs; - if ts.len() != h { + let h_sum = hs.iter().sum(); + + if tss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of t should be h".to_string(), + "the length of tss should be the sum of h".to_string(), )); } - if ms.len() != h { + let mut ts_s = vec![Vec::::new(); hs.len()]; + let mut tss_vec = tss.to_vec(); + for (index, h) in hs.iter().enumerate() { + ts_s[index] = tss_vec.drain(0..*h).collect(); + } + + if extendfss.len() != hs.len() { return Err(ReceiverError::InvalidLength( - "the length of M should be h".to_string(), + "the length of extendfss should be the length of hs".to_string(), )); } - // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); - - let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); - - // Step 5 in Figure 6. - let k: Vec = ms - .into_iter() - .zip(ts) - .zip(alpha_bar_vec.iter()) - .enumerate() - .map(|(i, (([m0, m1], &t), &b))| { - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - if !b { - // H(t, i|ell) ^ M0 - FIXED_KEY_AES.tccr(tweak, t) ^ m0 - } else { - // H(t, i|ell) ^ M1 - FIXED_KEY_AES.tccr(tweak, t) ^ m1 - } - }) - .collect(); + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; - // Reconstructs GGM tree except `ws[alpha]`. - let ggm_tree = GgmTree::new(h); - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.reconstruct(&mut tree, &k, &alpha_bar_vec); + for (index, extendfs) in extendfss.iter().enumerate() { + ms_s[index].clone_from(&extendfs.ms); + sum_s[index] = extendfs.sum; + } + + if ms_s.iter().zip(hs.iter()).any(|(ms, h)| ms.len() != *h) { + return Err(ReceiverError::InvalidLength( + "the length of ms should be h".to_string(), + )); + } + // Updates hasher + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); + + let mut trees = vec![Vec::::new(); hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = alphas + .par_iter() + .zip(ms_s.par_iter()) + .zip(sum_s.par_iter()) + .zip(hs.par_iter()) + .zip(ts_s.par_iter()) + .zip(trees.par_iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + }else{ + let iter = alphas + .iter() + .zip(ms_s.iter()) + .zip(sum_s.iter()) + .zip(hs.iter()) + .zip(ts_s.iter()) + .zip(trees.iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + } + } - // Sets `tree[alpha]`, which is `ws[alpha]`. - tree[alpha as usize] = tree.iter().fold(sum, |acc, &x| acc ^ x); + iter.for_each(|(alpha, ms, sum, h, ts, tree)| { + let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); + + // Step 5 in Figure 6. + let k: Vec = ms + .iter() + .zip(ts) + .zip(alpha_bar_vec.iter()) + .enumerate() + .map(|(i, (([m0, m1], &t), &b))| { + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + if !b { + // H(t, i|ell) ^ M0 + FIXED_KEY_AES.tccr(tweak, t) ^ *m0 + } else { + // H(t, i|ell) ^ M1 + FIXED_KEY_AES.tccr(tweak, t) ^ *m1 + } + }) + .collect(); + + // Reconstructs GGM tree except `ws[alpha]`. + let ggm_tree = GgmTree::new(*h); + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.reconstruct(tree, &k, &alpha_bar_vec); + + // Sets `tree[alpha]`, which is `ws[alpha]`. + tree[(*alpha) as usize] = tree.iter().fold(*sum, |acc, &x| acc ^ x); + }); + + for tree in trees { + self.state.unchecked_ws.extend_from_slice(&tree); + } - self.state.unchecked_ws.extend_from_slice(&tree); - self.state.alphas_and_length.push((alpha, 1 << h)); + for (alpha, h) in alphas.iter().zip(hs.iter()) { + self.state.alphas_and_length.push((*alpha, 1 << h)); + } - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); Ok(()) } @@ -248,7 +343,6 @@ impl Receiver { } self.state.cot_counter += self.state.unchecked_ws.len(); - self.state.extended = true; let mut res = Vec::new(); for (alpha, n) in &self.state.alphas_and_length { @@ -256,8 +350,19 @@ impl Receiver { res.push((tmp, *alpha)); } + self.state.hasher = blake3::Hasher::new(); + self.state.alphas_and_length.clear(); + self.state.chis.clear(); + self.state.unchecked_ws.clear(); + Ok(res) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The receiver's state. diff --git a/crates/mpz-ot-core/src/ferret/spcot/sender.rs b/crates/mpz-ot-core/src/ferret/spcot/sender.rs index fef1327e..a62ad3bb 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/sender.rs @@ -5,6 +5,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -29,8 +33,7 @@ impl Sender { /// # Arguments /// /// * `delta` - The sender's global secret. - /// * `seed` - The random seed to generate PRG. - pub fn setup(self, delta: Block, seed: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { state: state::Extension { delta, @@ -39,7 +42,6 @@ impl Sender { cot_counter: 0, exec_counter: 0, extended: false, - prg: Prg::from_seed(seed), hasher: blake3::Hasher::new(), }, } @@ -47,85 +49,137 @@ impl Sender { } impl Sender { - /// Performs the SPCOT extension. + /// Performs batch SPCOT extension. /// /// See Step 1-5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `qs`- The blocks received by calling the COT functionality. - /// * `mask`- The mask bits sent by the receiver. + /// * `hs` - The depths of the GGM trees. + /// * `qss`- The blocks received by calling the COT functionality for hs trees. + /// * `masks`- The vector of mask bits sent by the receiver. pub fn extend( &mut self, - h: usize, - qs: &[Block], - mask: MaskBits, - ) -> Result { + hs: &[usize], + qss: &[Block], + masks: &[MaskBits], + ) -> Result, SenderError> { if self.state.extended { return Err(SenderError::InvalidState( "extension is not allowed".to_string(), )); } - if qs.len() != h { + let h_sum = hs.iter().sum(); + + if qss.len() != h_sum { return Err(SenderError::InvalidLength( - "the length of q should be h".to_string(), + "the length of qss should be the sum of h".to_string(), )); } - let MaskBits { bs } = mask; + let mut qs_s = vec![Vec::::new(); hs.len()]; + let mut qss_vec = qss.to_vec(); + for (index, h) in hs.iter().enumerate() { + qs_s[index] = qss_vec.drain(0..*h).collect(); + } - if bs.len() != h { + if masks.len() != hs.len() { + return Err(SenderError::InvalidLength( + "the length of masks should be the length of hs".to_string(), + )); + } + + let bss: Vec> = masks.iter().map(|m| m.clone().bs).collect(); + + if bss.iter().zip(hs.iter()).any(|(b, h)| b.len() != *h) { return Err(SenderError::InvalidLength( "the length of b should be h".to_string(), )); } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); // Step 3-4, Figure 6. // Generates a GGM tree with depth h and seed s. - let s = self.state.prg.random_block(); - let ggm_tree = GgmTree::new(h); - let mut k0 = vec![Block::ZERO; h]; - let mut k1 = vec![Block::ZERO; h]; - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.gen(s, &mut tree, &mut k0, &mut k1); + let mut trees = vec![Vec::::new(); hs.len()]; + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = trees + .par_iter_mut().zip(hs.par_iter()) + .zip(qs_s.par_iter()) + .zip(bss.par_iter()) + .zip(ms_s.par_iter_mut()) + .zip(sum_s.par_iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + }else{ + let iter = trees + .iter_mut() + .zip(hs.iter()) + .zip(qs_s.iter()) + .zip(bss.iter()) + .zip(ms_s.iter_mut()) + .zip(sum_s.iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + } + } + + iter.for_each(|(tree, h, qs, bs, ms, sum)| { + let s = Prg::new().random_block(); + let ggm_tree = GgmTree::new(*h); + let mut k0 = vec![Block::ZERO; *h]; + let mut k1 = vec![Block::ZERO; *h]; + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.gen(s, tree, &mut k0, &mut k1); + + // Computes the sum of the leaves and delta. + *sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); + + // Computes M0 and M1. + for (((i, &q), b), (k0, k1)) in + qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) + { + let mut m = if *b { + [q ^ self.state.delta, q] + } else { + [q, q ^ self.state.delta] + }; + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); + m[0] ^= k0; + m[1] ^= k1; + ms.push(m); + } + }); // Stores the tree, i.e., the possible output of sender. - self.state.unchecked_vs.extend_from_slice(&tree); + for tree in trees { + self.state.unchecked_vs.extend_from_slice(&tree); + } // Stores the length of this extension. - self.state.vs_length.push(1 << h); - - // Computes the sum of the leaves and delta. - let sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); - - // Computes M0 and M1. - let mut ms: Vec<[Block; 2]> = Vec::with_capacity(qs.len()); - for (((i, &q), b), (k0, k1)) in qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) { - let mut m = if b { - [q ^ self.state.delta, q] - } else { - [q, q ^ self.state.delta] - }; - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); - m[0] ^= k0; - m[1] ^= k1; - ms.push(m); + for h in hs { + self.state.vs_length.push(1 << h); } // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); + + let res: Vec = ms_s + .into_iter() + .zip(sum_s.iter()) + .map(|(ms, &sum)| ExtendFromSender { ms, sum }) + .collect(); - Ok(ExtendFromSender { ms, sum }) + Ok(res) } /// Performs the consistency check for the resulting COTs. @@ -193,10 +247,18 @@ impl Sender { res.push(tmp); } - self.state.extended = true; + self.state.hasher = blake3::Hasher::new(); + self.state.unchecked_vs.clear(); + self.state.vs_length.clear(); Ok((res, CheckFromSender { hashed_v })) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The sender's state. @@ -239,8 +301,6 @@ pub mod state { /// This is to prevent the receiver from extending twice pub(super) extended: bool, - /// A PRG to generate random strings. - pub(super) prg: Prg, /// A hasher to generate chi seed. pub(super) hasher: blake3::Hasher, } diff --git a/crates/mpz-ot-core/src/ideal/cot.rs b/crates/mpz-ot-core/src/ideal/cot.rs index a28abef8..a842129d 100644 --- a/crates/mpz-ot-core/src/ideal/cot.rs +++ b/crates/mpz-ot-core/src/ideal/cot.rs @@ -76,7 +76,7 @@ impl IdealCOT { .collect(); self.counter += count; - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( RCOTSenderOutput { id, msgs }, diff --git a/crates/mpz-ot-core/src/ideal/mpcot.rs b/crates/mpz-ot-core/src/ideal/mpcot.rs index 44a5595f..c038331b 100644 --- a/crates/mpz-ot-core/src/ideal/mpcot.rs +++ b/crates/mpz-ot-core/src/ideal/mpcot.rs @@ -60,7 +60,7 @@ impl IdealMpcot { self.counter += 1; } - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (MPCOTSenderOutput { id, s }, MPCOTReceiverOutput { id, r }) } diff --git a/crates/mpz-ot-core/src/ideal/ot.rs b/crates/mpz-ot-core/src/ideal/ot.rs index e389066e..76ebe630 100644 --- a/crates/mpz-ot-core/src/ideal/ot.rs +++ b/crates/mpz-ot-core/src/ideal/ot.rs @@ -55,7 +55,7 @@ impl IdealOT { self.counter += choices.len(); self.choices.extend(choices); - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (OTSenderOutput { id }, OTReceiverOutput { id, msgs: chosen }) } diff --git a/crates/mpz-ot-core/src/ideal/rot.rs b/crates/mpz-ot-core/src/ideal/rot.rs index 8a8b5d68..e29b9204 100644 --- a/crates/mpz-ot-core/src/ideal/rot.rs +++ b/crates/mpz-ot-core/src/ideal/rot.rs @@ -68,7 +68,7 @@ impl IdealROT { .collect(); self.counter += count; - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( ROTSenderOutput { id, msgs }, @@ -103,7 +103,7 @@ impl IdealROT { .collect(); self.counter += choices.len(); - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( ROTSenderOutput { id, msgs }, diff --git a/crates/mpz-ot-core/src/ideal/spcot.rs b/crates/mpz-ot-core/src/ideal/spcot.rs index 12c5f829..93b3c720 100644 --- a/crates/mpz-ot-core/src/ideal/spcot.rs +++ b/crates/mpz-ot-core/src/ideal/spcot.rs @@ -61,7 +61,7 @@ impl IdealSpcot { self.counter += n; } - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (SPCOTSenderOutput { id, v }, SPCOTReceiverOutput { id, w }) } diff --git a/crates/mpz-ot-core/src/kos/receiver.rs b/crates/mpz-ot-core/src/kos/receiver.rs index fdcad328..127c4f1d 100644 --- a/crates/mpz-ot-core/src/kos/receiver.rs +++ b/crates/mpz-ot-core/src/kos/receiver.rs @@ -330,7 +330,7 @@ impl Receiver { )); } - let id = self.state.transfer_id.next(); + let id = self.state.transfer_id.next_id(); let index = self.state.index - self.state.keys.len(); Ok(ReceiverKeys { diff --git a/crates/mpz-ot-core/src/kos/sender.rs b/crates/mpz-ot-core/src/kos/sender.rs index 24917940..23edff5c 100644 --- a/crates/mpz-ot-core/src/kos/sender.rs +++ b/crates/mpz-ot-core/src/kos/sender.rs @@ -294,7 +294,7 @@ impl Sender { return Err(SenderError::InsufficientSetup(count, self.state.keys.len())); } - let id = self.state.transfer_id.next(); + let id = self.state.transfer_id.next_id(); Ok(SenderKeys { id, diff --git a/crates/mpz-ot-core/src/lib.rs b/crates/mpz-ot-core/src/lib.rs index 8dd77287..b0b69260 100644 --- a/crates/mpz-ot-core/src/lib.rs +++ b/crates/mpz-ot-core/src/lib.rs @@ -45,7 +45,7 @@ impl std::fmt::Display for TransferId { impl TransferId { /// Returns the current transfer ID, incrementing `self` in-place. - pub(crate) fn next(&mut self) -> Self { + pub fn next_id(&mut self) -> Self { let id = *self; self.0 += 1; id diff --git a/crates/mpz-ot/examples/ferret.rs b/crates/mpz-ot/examples/ferret.rs new file mode 100644 index 00000000..f328e4d9 --- /dev/null +++ b/crates/mpz-ot/examples/ferret.rs @@ -0,0 +1 @@ +fn main() {} diff --git a/crates/mpz-ot/src/ferret/error.rs b/crates/mpz-ot/src/ferret/error.rs new file mode 100644 index 00000000..4e428a4b --- /dev/null +++ b/crates/mpz-ot/src/ferret/error.rs @@ -0,0 +1,342 @@ +use std::fmt::Display; + +/// Ferret sender error. +#[derive(Debug, thiserror::Error)] +pub struct SenderError { + kind: SenderErrorKind, + #[source] + source: Option>, +} + +impl SenderError { + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: SenderErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn io(msg: impl Into) -> Self { + Self { + kind: SenderErrorKind::Io, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +enum SenderErrorKind { + Io, + State, + Core, + Rcot, + Mpcot, +} + +impl Display for SenderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + SenderErrorKind::Io => f.write_str("io error")?, + SenderErrorKind::State => f.write_str("state error")?, + SenderErrorKind::Core => f.write_str("core error")?, + SenderErrorKind::Rcot => f.write_str("rcot error")?, + SenderErrorKind::Mpcot => f.write_str("mpcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } +} + +impl From for SenderError { + fn from(err: std::io::Error) -> Self { + Self { + kind: SenderErrorKind::Io, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: mpz_ot_core::ferret::error::SenderError) -> Self { + Self { + kind: SenderErrorKind::Core, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: crate::OTError) -> Self { + Self { + kind: SenderErrorKind::Rcot, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: MPCOTError) -> Self { + Self { + kind: SenderErrorKind::Mpcot, + source: Some(Box::new(err)), + } + } +} + +impl From for crate::OTError { + fn from(err: SenderError) -> Self { + crate::OTError::SenderError(Box::new(err)) + } +} + +/// Ferret receiver error. +#[derive(Debug, thiserror::Error)] +pub struct ReceiverError { + kind: ReceiverErrorKind, + #[source] + source: Option>, +} + +impl ReceiverError { + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: ReceiverErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn io(msg: impl Into) -> Self { + Self { + kind: ReceiverErrorKind::Io, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +enum ReceiverErrorKind { + Io, + State, + Core, + Rcot, + Mpcot, +} + +impl Display for ReceiverError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ReceiverErrorKind::Io => f.write_str("io error")?, + ReceiverErrorKind::State => f.write_str("state error")?, + ReceiverErrorKind::Core => f.write_str("core error")?, + ReceiverErrorKind::Rcot => f.write_str("rcot error")?, + ReceiverErrorKind::Mpcot => f.write_str("mpcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } +} + +impl From for ReceiverError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ReceiverErrorKind::Io, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: mpz_ot_core::ferret::error::ReceiverError) -> Self { + Self { + kind: ReceiverErrorKind::Core, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ReceiverErrorKind::Rcot, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: MPCOTError) -> Self { + Self { + kind: ReceiverErrorKind::Mpcot, + source: Some(Box::new(err)), + } + } +} + +impl From for crate::OTError { + fn from(err: ReceiverError) -> Self { + crate::OTError::ReceiverError(Box::new(err)) + } +} + +mod mpcot { + use super::*; + + /// MPCOT error. + #[derive(Debug, thiserror::Error)] + pub(crate) struct MPCOTError { + kind: ErrorKind, + #[source] + source: Option>, + } + + #[derive(Debug)] + enum ErrorKind { + Io, + Core, + Rcot, + Spcot, + } + + impl Display for MPCOTError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Core => f.write_str("core error")?, + ErrorKind::Rcot => f.write_str("rcot error")?, + ErrorKind::Spcot => f.write_str("spcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } + } + + impl From for MPCOTError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: mpz_ot_core::ferret::mpcot::error::SenderError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: mpz_ot_core::ferret::mpcot::error::ReceiverError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: SPCOTError) -> Self { + Self { + kind: ErrorKind::Spcot, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ErrorKind::Rcot, + source: Some(Box::new(err)), + } + } + } +} +pub(crate) use mpcot::MPCOTError; + +mod spcot { + use super::*; + + /// SPCOT error. + #[derive(Debug, thiserror::Error)] + pub(crate) struct SPCOTError { + kind: ErrorKind, + #[source] + source: Option>, + } + + #[derive(Debug)] + enum ErrorKind { + Io, + Core, + Rcot, + } + + impl Display for SPCOTError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Core => f.write_str("core error")?, + ErrorKind::Rcot => f.write_str("rcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } + } + + impl From for SPCOTError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: mpz_ot_core::ferret::spcot::error::SenderError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: mpz_ot_core::ferret::spcot::error::ReceiverError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ErrorKind::Rcot, + source: Some(Box::new(err)), + } + } + } +} +pub(crate) use spcot::SPCOTError; diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs new file mode 100644 index 00000000..9d421885 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -0,0 +1,256 @@ +//! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. +mod error; +mod mpcot; +mod receiver; +mod sender; +mod spcot; + +pub use error::{ReceiverError, SenderError}; +pub use receiver::Receiver; +pub use sender::Sender; + +use mpz_core::lpn::LpnParameters; +use mpz_ot_core::ferret::LpnType; + +/// Configuration of Ferret. +#[derive(Debug, Clone)] +pub struct FerretConfig { + lpn_parameters: LpnParameters, + lpn_type: LpnType, +} + +impl FerretConfig { + /// Create a new instance. + /// + /// # Arguments. + /// + /// * `lpn_parameters` - The parameters of LPN. + /// * `lpn_type` - The type of LPN. + pub fn new(lpn_parameters: LpnParameters, lpn_type: LpnType) -> Self { + Self { + lpn_parameters, + lpn_type, + } + } + + /// Get the lpn type + pub fn lpn_type(&self) -> LpnType { + self.lpn_type + } + + /// Get the lpn parameters + pub fn lpn_parameters(&self) -> LpnParameters { + self.lpn_parameters + } +} + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with small extension output. +pub const FERRET_REGULAR_SETUP_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 102_400, + k: 6_750, + t: 1_600, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with small extension output. +pub const FERRET_REGULAR_EXTENSION_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 1_740_800, + k: 66_400, + t: 1700, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with medium extension output. +pub const FERRET_REGULAR_SETUP_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 283_648, + k: 18_584, + t: 1_108, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with medium extension output. +pub const FERRET_REGULAR_EXTENSION_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with large extension output. +pub const FERRET_REGULAR_SETUP_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 518_656, + k: 34_643, + t: 1_013, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with large extension output. +pub const FERRET_REGULAR_EXTENSION_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 10_485_760, + k: 458_000, + t: 1280, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with small extension output. +pub const FERRET_UNIFORM_SETUP_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 98_000, + k: 4_450, + t: 1_600, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with small extension output. +pub const FERRET_UNIFORM_EXTENSION_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 1_071_888, + k: 40_800, + t: 1720, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with medium extension output. +pub const FERRET_UNIFORM_SETUP_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 283_648, + k: 18_584, + t: 1_108, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with medium extension output. +pub const FERRET_UNIFORM_EXTENSION_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with large extension output. +pub const FERRET_UNIFORM_SETUP_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 545_656, + k: 34_643, + t: 1_050, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with large extension output. +pub const FERRET_UNIFORM_EXTENSION_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 10_488_928, + k: 458_000, + t: 1_280, + }, + lpn_type: LpnType::Uniform, +}; + +#[cfg(test)] +mod tests { + use super::*; + use futures::TryFutureExt as _; + use mpz_common::executor::test_st_executor; + use mpz_core::lpn::LpnParameters; + use mpz_ot_core::{ferret::LpnType, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput}; + use rstest::*; + + use crate::{ideal::cot::ideal_rcot, Correlation, OTError, RandomCOTReceiver, RandomCOTSender}; + + // l = n - k = 8380 + const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { + n: 9600, + k: 1220, + t: 600, + }; + + #[rstest] + #[case::uniform(LpnType::Uniform)] + #[case::regular(LpnType::Regular)] + #[tokio::test] + async fn test_ferret(#[case] lpn_type: LpnType) { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (rcot_sender, rcot_receiver) = ideal_rcot(); + + let config = FerretConfig::new(LPN_PARAMETERS_TEST, lpn_type); + + let mut sender = Sender::new(config.clone(), rcot_sender); + let mut receiver = Receiver::new(config, rcot_receiver); + + tokio::try_join!( + sender.setup(&mut ctx_sender).map_err(OTError::from), + receiver.setup(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + // extend once. + let count = LPN_PARAMETERS_TEST.k; + tokio::try_join!( + sender.extend(&mut ctx_sender, count).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, count) + .map_err(OTError::from) + ) + .unwrap(); + + // extend twice + let count = 10000; + tokio::try_join!( + sender.extend(&mut ctx_sender, count).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, count) + .map_err(OTError::from) + ) + .unwrap(); + + let ( + RCOTSenderOutput { + id: sender_id, + msgs: u, + }, + RCOTReceiverOutput { + id: receiver_id, + choices: b, + msgs: w, + }, + ) = tokio::try_join!( + sender.send_random_correlated(&mut ctx_sender, count), + receiver.receive_random_correlated(&mut ctx_receiver, count) + ) + .unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(sender.delta(), &b, &u, &w); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot.rs b/crates/mpz-ot/src/ferret/mpcot.rs new file mode 100644 index 00000000..be7de33a --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot.rs @@ -0,0 +1,185 @@ +//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::ferret::{ + mpcot::{ + msgs::HashSeed, receiver::Receiver as UniformReceiverCore, + receiver_regular::Receiver as RegularReceiverCore, sender::Sender as UniformSender, + sender_regular::Sender as RegularSender, + }, + LpnType, +}; +use serio::{stream::IoStreamExt as _, SinkExt as _}; + +use crate::{ + ferret::{error::MPCOTError as Error, spcot}, + RandomCOTReceiver, RandomCOTSender, +}; + +/// MPCOT send. +/// +/// # Arguments. +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT sender. +/// * `delta` - Delta correlation. +/// * `lpn_type` - The type of LPN. +/// * `t` - The number of queried indices. +/// * `n` - The total number of indices. +pub(crate) async fn send>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + delta: Block, + lpn_type: LpnType, + t: u32, + n: u32, +) -> Result, Error> { + match lpn_type { + LpnType::Uniform => { + let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; + + let (sender, hs) = CpuBackend::blocking(move || { + UniformSender::new() + .setup(delta, hash_seed) + .pre_extend(t, n) + }) + .await?; + + let st = spcot::send(ctx, rcot, delta, &hs).await?; + + let (_, output) = CpuBackend::blocking(move || sender.extend(&st)).await?; + + Ok(output) + } + LpnType::Regular => { + let (sender, hs) = + CpuBackend::blocking(move || RegularSender::new().setup(delta).pre_extend(t, n)) + .await?; + + let st = spcot::send(ctx, rcot, delta, &hs).await?; + + let (_, output) = CpuBackend::blocking(move || sender.extend(&st)).await?; + + Ok(output) + } + } +} + +/// MPCOT receive. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT receiver. +/// * `lpn_type` - The type of LPN. +/// * `alphas` - The queried indices. +/// * `n` - The total number of indices. +pub(crate) async fn receive>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + lpn_type: LpnType, + alphas: Vec, + n: u32, +) -> Result, Error> { + match lpn_type { + LpnType::Uniform => { + let hash_seed = Prg::new().random_block(); + + let (receiver, hash_seed) = UniformReceiverCore::new().setup(hash_seed); + + ctx.io_mut().send(hash_seed).await?; + + let (receiver, h_and_pos) = + CpuBackend::blocking(move || receiver.pre_extend(&alphas, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + let rt = spcot::receive(ctx, rcot, &pos, &hs).await?; + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (_, output) = CpuBackend::blocking(move || receiver.extend(&rt)).await?; + + Ok(output) + } + LpnType::Regular => { + let receiver = RegularReceiverCore::new().setup(); + + let (receiver, h_and_pos) = + CpuBackend::blocking(move || receiver.pre_extend(&alphas, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + let rt = spcot::receive(ctx, rcot, &pos, &hs).await?; + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (_, output) = CpuBackend::blocking(move || receiver.extend(&rt)).await?; + + Ok(output) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ideal::cot::ideal_rcot; + use mpz_common::executor::test_st_executor; + use mpz_ot_core::ferret::LpnType; + use rstest::*; + + #[rstest] + #[case(LpnType::Uniform)] + #[case(LpnType::Regular)] + #[tokio::test] + async fn test_mpcot(#[case] lpn_type: LpnType) { + use crate::Correlation; + + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let alphas = match lpn_type { + LpnType::Uniform => vec![0, 1, 3, 4, 2], + LpnType::Regular => vec![0, 3, 4, 7, 9], + }; + + let t = alphas.len(); + let n = 10; + let delta = rcot_sender.delta(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + send( + &mut ctx_sender, + &mut rcot_sender, + delta, + lpn_type, + t as u32, + n + ), + receive( + &mut ctx_receiver, + &mut rcot_receiver, + lpn_type, + alphas.clone(), + n + ) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + } +} diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs new file mode 100644 index 00000000..fbbb38eb --- /dev/null +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -0,0 +1,253 @@ +use std::mem; + +use async_trait::async_trait; +use mpz_common::{cpu::CpuBackend, Allocate, Context, Preprocess}; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::{ + ferret::{ + receiver::{state, Receiver as ReceiverCore}, + LpnType, CSP, CUCKOO_HASH_NUM, + }, + RCOTReceiverOutput, +}; +use serio::SinkExt; + +use crate::{ + ferret::{mpcot, FerretConfig, ReceiverError}, + OTError, RandomCOTReceiver, +}; + +#[derive(Debug)] +pub(crate) enum State { + Initialized(Box>), + Extension(Box>), + Error, +} + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + +/// Ferret Receiver. +#[derive(Debug)] +pub struct Receiver { + state: State, + config: FerretConfig, + rcot: RandomCOT, + alloc: usize, + buffer: ReceiverBuffer, + buffer_len: usize, +} + +impl Receiver { + /// Creates a new Receiver. + /// + /// # Arguments. + /// + /// * `config` - The Ferret config. + /// * `rcot` - The random COT in setup. + pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { + Self { + state: State::Initialized(Box::new(ReceiverCore::new())), + config, + rcot, + alloc: 0, + buffer: Default::default(), + buffer_len: 0, + } + } + + /// Setup for receiver. + /// + /// # Arguments. + /// + /// * `ctx` - The channel context. + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), ReceiverError> + where + Ctx: Context, + RandomCOT: RandomCOTReceiver, + { + let State::Initialized(receiver) = self.state.take() else { + return Err(ReceiverError::state("receiver not in initialized state")); + }; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Compute the number of buffered OTs. + self.buffer_len = match lpn_type { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * CUCKOO_HASH_NUM * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + CSP + } + }; + + // Get random blocks from ideal Random COT. + let RCOTReceiverOutput { + choices: mut u, + msgs: mut w, + id, + } = self + .rcot + .receive_random_correlated(ctx, params.k + self.buffer_len) + .await?; + + // Initiate buffer. + let buffer = RCOTReceiverOutput { + id, + choices: u.drain(0..self.buffer_len).collect(), + msgs: w.drain(0..self.buffer_len).collect(), + }; + self.buffer = ReceiverBuffer::new(buffer); + + let seed = Prg::new().random_block(); + + let (receiver, seed) = receiver.setup(params, lpn_type, seed, &u, &w)?; + + ctx.io_mut().send(seed).await?; + + self.state = State::Extension(Box::new(receiver)); + + Ok(()) + } + + /// Performs extension. + /// + /// # Arguments + /// + /// * `ctx` - Thread context. + /// * `count` - The number of OTs to extend. + pub async fn extend(&mut self, ctx: &mut Ctx, count: usize) -> Result<(), ReceiverError> + where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send, + { + let State::Extension(mut receiver) = self.state.take() else { + return Err(ReceiverError::state("receiver not in extension state")); + }; + + let lpn_type = self.config.lpn_type(); + let target = receiver.remaining() + count; + while receiver.remaining() < target { + let (alphas, n) = receiver.get_mpcot_query(); + + let r = mpcot::receive(ctx, &mut self.buffer, lpn_type, alphas, n as u32).await?; + + receiver = CpuBackend::blocking(move || receiver.extend(r).map(|()| receiver)).await?; + + // Update receiver buffer. + let buffer = receiver + .consume(self.buffer_len) + .map_err(ReceiverError::from) + .map_err(OTError::from)?; + + self.buffer = ReceiverBuffer::new(buffer); + } + + self.state = State::Extension(receiver); + + Ok(()) + } +} + +#[async_trait] +impl RandomCOTReceiver for Receiver +where + RandomCOT: Send, +{ + async fn receive_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let State::Extension(receiver) = &mut self.state else { + return Err(ReceiverError::state("receiver not in extension state").into()); + }; + + receiver + .consume(count) + .map_err(ReceiverError::from) + .map_err(OTError::from) + } +} + +impl Allocate for Receiver { + fn alloc(&mut self, count: usize) { + self.alloc += count; + } +} + +#[async_trait] +impl Preprocess for Receiver +where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send, +{ + type Error = ReceiverError; + + async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let count = mem::take(&mut self.alloc); + self.extend(ctx, count).await + } +} + +#[derive(Debug)] +struct ReceiverBuffer { + buffer: RCOTReceiverOutput, +} + +impl ReceiverBuffer { + fn new(buffer: RCOTReceiverOutput) -> Self { + Self { buffer } + } +} + +impl Default for ReceiverBuffer { + fn default() -> Self { + ReceiverBuffer { + buffer: RCOTReceiverOutput { + id: Default::default(), + choices: Vec::new(), + msgs: Vec::new(), + }, + } + } +} + +#[async_trait] +impl RandomCOTReceiver for ReceiverBuffer { + async fn receive_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + if count > self.buffer.choices.len() { + return Err(ReceiverError::io(format!( + "insufficient OTs: {} < {count}", + self.buffer.choices.len() + )) + .into()); + } + + let choices = self.buffer.choices.drain(0..count).collect(); + let msgs = self.buffer.msgs.drain(0..count).collect(); + Ok(RCOTReceiverOutput { + id: self.buffer.id.next_id(), + choices, + msgs, + }) + } +} diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs new file mode 100644 index 00000000..02884b2c --- /dev/null +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -0,0 +1,294 @@ +use std::mem; + +use crate::{ferret::mpcot, Correlation, RandomCOTSender}; +use async_trait::async_trait; +use mpz_common::{cpu::CpuBackend, Allocate, Context, Preprocess}; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + sender::{state, Sender as SenderCore}, + LpnType, CSP, CUCKOO_HASH_NUM, + }, + RCOTSenderOutput, +}; +use serio::stream::IoStreamExt; + +use super::{FerretConfig, SenderError}; +use crate::OTError; + +#[derive(Debug)] +pub(crate) enum State { + Initialized(SenderCore), + Extension(SenderCore), + Error, +} + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + +/// Ferret Sender. +#[derive(Debug)] +pub struct Sender { + state: State, + config: FerretConfig, + rcot: RandomCOT, + alloc: usize, + buffer: SenderBuffer, + buffer_len: usize, +} + +impl Sender { + /// Creates a new Sender. + /// + /// # Argument + /// + /// `config` - The Ferret config. + /// `rcot` - The random COT in setup. + pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { + Self { + state: State::Initialized(SenderCore::new()), + config, + rcot, + alloc: 0, + buffer: Default::default(), + buffer_len: 0, + } + } + + /// Setup with provided delta. + /// + /// # Argument + /// + /// * `ctx` - The channel context. + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), SenderError> + where + Ctx: Context, + RandomCOT: RandomCOTSender + Correlation, + { + let State::Initialized(sender) = self.state.take() else { + return Err(SenderError::state("sender not in initialized state")); + }; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Compute the number of buffered OTs. + self.buffer_len = match lpn_type { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * CUCKOO_HASH_NUM * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + CSP + } + }; + + // Get random blocks from ideal Random COT. + let RCOTSenderOutput { msgs: mut v, id } = self + .rcot + .send_random_correlated(ctx, params.k + self.buffer_len) + .await?; + + // Initiate buffer. + let buffer = RCOTSenderOutput { + id, + msgs: v.drain(0..self.buffer_len).collect(), + }; + self.buffer = SenderBuffer::new(self.rcot.delta(), buffer); + + // Get seed for LPN matrix from receiver. + let seed = ctx.io_mut().expect_next().await?; + + // Ferret core setup. + let sender = sender.setup(self.rcot.delta(), params, lpn_type, seed, &v)?; + + self.state = State::Extension(sender); + + Ok(()) + } + + /// Performs extension. + /// + /// # Argument + /// + /// * `ctx` - Thread context. + /// * `count` - The number of OTs to extend. + pub async fn extend( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result<(), SenderError> + where + RandomCOT: RandomCOTSender + Send, + { + let State::Extension(mut sender) = self.state.take() else { + return Err(SenderError::state("sender not in extension state")); + }; + + let lpn_type = self.config.lpn_type(); + let delta = sender.delta(); + let target = sender.remaining() + count; + while sender.remaining() < target { + let (t, n) = sender.get_mpcot_query(); + + let s = mpcot::send(ctx, &mut self.buffer, delta, lpn_type, t, n).await?; + + sender = CpuBackend::blocking(move || sender.extend(s).map(|()| sender)).await?; + + // Update sender buffer. + let buffer = sender + .consume(self.buffer_len) + .map_err(SenderError::from) + .map_err(OTError::from)?; + + self.buffer = SenderBuffer::new(delta, buffer); + } + + self.state = State::Extension(sender); + + Ok(()) + } +} + +impl Correlation for Sender +where + RandomCOT: Correlation, +{ + type Correlation = Block; + + fn delta(&self) -> Self::Correlation { + self.rcot.delta() + } +} + +#[async_trait] +impl RandomCOTSender for Sender +where + RandomCOT: Correlation + Send, +{ + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let State::Extension(sender) = &mut self.state else { + return Err(SenderError::state("sender not in extension state").into()); + }; + + sender + .consume(count) + .map_err(SenderError::from) + .map_err(OTError::from) + } +} + +impl Allocate for Sender { + fn alloc(&mut self, count: usize) { + self.alloc += count; + } +} + +#[async_trait] +impl Preprocess for Sender +where + Ctx: Context, + RandomCOT: RandomCOTSender + Send, +{ + type Error = SenderError; + + async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let count = mem::take(&mut self.alloc); + self.extend(ctx, count).await + } +} + +#[derive(Debug)] +struct SenderBuffer { + delta: Block, + buffer: RCOTSenderOutput, +} + +impl SenderBuffer { + fn new(delta: Block, buffer: RCOTSenderOutput) -> Self { + Self { delta, buffer } + } +} + +impl Default for SenderBuffer { + fn default() -> Self { + let buffer = RCOTSenderOutput { + id: Default::default(), + msgs: Vec::new(), + }; + Self { + delta: Block::ZERO, + buffer, + } + } +} +impl Correlation for SenderBuffer { + type Correlation = Block; + + fn delta(&self) -> Self::Correlation { + self.delta + } +} + +#[async_trait] +impl RandomCOTSender for SenderBuffer { + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + if count > self.buffer.msgs.len() { + return Err(SenderError::io(format!( + "insufficient OTs: {} < {count}", + self.buffer.msgs.len() + )) + .into()); + } + + let msgs = self.buffer.msgs.drain(0..count).collect(); + Ok(RCOTSenderOutput { + id: self.buffer.id.next_id(), + msgs, + }) + } +} + +#[derive(Debug)] +struct BootstrappedSender<'a>(&'a mut SenderCore); + +impl Correlation for BootstrappedSender<'_> { + type Correlation = Block; + + fn delta(&self) -> Block { + self.0.delta() + } +} + +#[async_trait] +impl RandomCOTSender for BootstrappedSender<'_> { + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + self.0 + .consume(count) + .map_err(SenderError::from) + .map_err(OTError::from) + } +} diff --git a/crates/mpz-ot/src/ferret/spcot.rs b/crates/mpz-ot/src/ferret/spcot.rs new file mode 100644 index 00000000..e63a1aa9 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot.rs @@ -0,0 +1,161 @@ +//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::{ExtendFromSender, MaskBits}, + receiver::Receiver as ReceiverCore, + sender::Sender as SenderCore, + }, + CSP, + }, + RCOTReceiverOutput, RCOTSenderOutput, +}; +use serio::{stream::IoStreamExt as _, SinkExt as _}; + +use crate::{ferret::error::SPCOTError as Error, RandomCOTReceiver, RandomCOTSender}; + +/// SPCOT send. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT sender. +/// * `delta` - Delta correlation. +/// * `hs` - The depth of the GGM trees. +pub(crate) async fn send>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + delta: Block, + hs: &[usize], +) -> Result>, Error> { + let mut sender = SenderCore::new().setup(delta); + + let h = hs.iter().sum(); + let RCOTSenderOutput { msgs: qss, .. } = rcot.send_random_correlated(ctx, h).await?; + + let masks: Vec = ctx.io_mut().expect_next().await?; + + // extend + let h_in = hs.to_vec(); + let (mut sender, extend_msg) = CpuBackend::blocking(move || { + sender + .extend(&h_in, &qss, &masks) + .map(|extend_msg| (sender, extend_msg)) + }) + .await?; + + ctx.io_mut().send(extend_msg).await?; + + // batch check + let RCOTSenderOutput { msgs: y_star, .. } = rcot.send_random_correlated(ctx, CSP).await?; + + let checkfr = ctx.io_mut().expect_next().await?; + + let (output, check_msg) = CpuBackend::blocking(move || sender.check(&y_star, checkfr)).await?; + + ctx.io_mut().send(check_msg).await?; + + Ok(output) +} + +/// SPCOT receive. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT receiver. +/// * `alphas` - Vector of chosen positions. +/// * `hs` - The depth of the GGM trees. +pub(crate) async fn receive>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + alphas: &[u32], + hs: &[usize], +) -> Result, u32)>, Error> { + let mut receiver = ReceiverCore::new().setup(); + + let h = hs.iter().sum(); + let RCOTReceiverOutput { + choices: rss, + msgs: tss, + .. + } = rcot.receive_random_correlated(ctx, h).await?; + + // extend + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let (mut receiver, masks) = CpuBackend::blocking(move || { + receiver + .extend_mask_bits(&h_in, &alphas_in, &rss) + .map(|mask| (receiver, mask)) + }) + .await?; + + ctx.io_mut().send(masks).await?; + + let extendfss: Vec = ctx.io_mut().expect_next().await?; + + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let mut receiver = CpuBackend::blocking(move || { + receiver + .extend(&h_in, &alphas_in, &tss, &extendfss) + .map(|_| receiver) + }) + .await?; + + // batch check + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = rcot.receive_random_correlated(ctx, CSP).await?; + + let (mut receiver, checkfr) = CpuBackend::blocking(move || { + receiver + .check_pre(&x_star) + .map(|checkfr| (receiver, checkfr)) + }) + .await?; + + ctx.io_mut().send(checkfr).await?; + let check = ctx.io_mut().expect_next().await?; + + let output = CpuBackend::blocking(move || receiver.check(&z_star, check)).await?; + + Ok(output) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ideal::cot::ideal_rcot, Correlation}; + use mpz_common::executor::test_st_executor; + + #[tokio::test] + async fn test_spcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let hs = [8usize, 4]; + let alphas = [4u32, 2]; + let delta = rcot_sender.delta(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + send(&mut ctx_sender, &mut rcot_sender, delta, &hs), + receive(&mut ctx_receiver, &mut rcot_receiver, &alphas, &hs) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + } +} diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index b0084957..bc7df0a6 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -11,7 +11,9 @@ use mpz_ot_core::{ ideal::cot::IdealCOT, COTReceiverOutput, COTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, }; -use crate::{COTReceiver, COTSender, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender}; +use crate::{ + COTReceiver, COTSender, Correlation, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender, +}; fn cot( f: &mut IdealCOT, @@ -46,9 +48,16 @@ pub fn ideal_rcot() -> (IdealCOTSender, IdealCOTReceiver) { } /// Ideal COT sender. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTSender(Alice); +impl IdealCOTSender { + /// Returns Alice. + pub fn alice(&mut self) -> &mut Alice { + &mut self.0 + } +} + #[async_trait] impl OTSetup for IdealCOTSender where @@ -75,6 +84,14 @@ where } } +impl Correlation for IdealCOTSender { + type Correlation = Block; + + fn delta(&self) -> Block { + self.0.lock().delta() + } +} + #[async_trait] impl COTSender for IdealCOTSender { async fn send_correlated( @@ -98,7 +115,7 @@ impl RandomCOTSender for IdealCOTSender { } /// Ideal COT receiver. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTReceiver(Bob); #[async_trait] @@ -163,7 +180,7 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_executor(8); let (mut alice, mut bob) = ideal_cot(); - let delta = alice.0.get_mut().delta(); + let delta = alice.delta(); let count = 10; let choices = (0..count).map(|_| rng.gen()).collect::>(); @@ -194,7 +211,7 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_executor(8); let (mut alice, mut bob) = ideal_rcot(); - let delta = alice.0.get_mut().delta(); + let delta = alice.delta(); let count = 10; diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index b9871eab..c1508883 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -10,6 +10,7 @@ )] pub mod chou_orlandi; +pub mod ferret; #[cfg(any(test, feature = "ideal"))] pub mod ideal; pub mod kos; @@ -60,9 +61,18 @@ pub trait OTSender { async fn send(&mut self, ctx: &mut Ctx, msgs: &[T]) -> Result; } +/// Correlation of COT messages. +pub trait Correlation { + /// The type of the correlation. + type Correlation; + + /// Returns the correlation. + fn delta(&self) -> Self::Correlation; +} + /// A correlated oblivious transfer sender. #[async_trait] -pub trait COTSender { +pub trait COTSender: Correlation { /// Obliviously transfers the correlated messages to the receiver. /// /// Returns the `0`-bit messages that were obliviously transferred. @@ -96,7 +106,7 @@ pub trait RandomOTSender { /// A random correlated oblivious transfer sender. #[async_trait] -pub trait RandomCOTSender { +pub trait RandomCOTSender: Correlation { /// Obliviously transfers the correlated messages to the receiver. /// /// Returns the `0`-bit messages that were obliviously transferred.