From ebd6bb51892535c37a8f340783f01e33ffea3f01 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Fri, 31 May 2024 06:39:49 -0700 Subject: [PATCH] Squashed commit of the following: commit 2f352710b5b5021041055e671e813c20dce6312b Author: sinu.eth <65924192+sinui0@users.noreply.github.com> Date: Fri May 31 06:36:31 2024 -0700 feat(mpz-common): scoped! macro (#143) commit 9b51bd4de44666fe05e9764f2e64040f0bb02b13 Author: sinu.eth <65924192+sinui0@users.noreply.github.com> Date: Fri May 31 06:35:16 2024 -0700 feat(mpz-common): Context::blocking (#141) * feat(mpz-common): Context::blocking * Apply suggestions from code review Co-authored-by: dan --------- Co-authored-by: dan commit 8f0b298656db18d4e8e713065ebe4db7426ca9d6 Author: th4s Date: Fri May 31 10:30:08 2024 +0200 Add IO wrapper for OLE (#138) * Add `mpz-ole` content of old branch. * Reworked message enum. * Refactored to work with new `mpz-ole-core`. * Add part of feedback. * Add more feedback. * Add opaque error type. * Add `Display` for `OLEErrorKind` * Use `ok_or_elese` for lazy heap alloc. * Adapted `mpz-ole` to `hybrid-array`. * WIP: Improving API of const generics... * Add random OT for `hybrid-array`. * Adapt `mpz-ole` to use new random OT. * Added feedback. * Use random OT over field elements instead of arrays. * Refactored ideal implementation to use `mpz-common`. * Added more feedback. --- crates/Cargo.toml | 2 + crates/mpz-common/src/cpu.rs | 8 +- crates/mpz-fields/Cargo.toml | 1 + crates/mpz-fields/src/gf2_128.rs | 20 ++++- crates/mpz-fields/src/lib.rs | 16 +++- crates/mpz-fields/src/p256.rs | 33 +++++-- crates/mpz-ole-core/src/receiver.rs | 5 ++ crates/mpz-ole-core/src/sender.rs | 5 ++ crates/mpz-ole/Cargo.toml | 37 ++++++++ crates/mpz-ole/src/ideal.rs | 87 ++++++++++++++++++ crates/mpz-ole/src/lib.rs | 131 ++++++++++++++++++++++++++++ crates/mpz-ole/src/rot/mod.rs | 56 ++++++++++++ crates/mpz-ole/src/rot/receiver.rs | 94 ++++++++++++++++++++ crates/mpz-ole/src/rot/sender.rs | 92 +++++++++++++++++++ 14 files changed, 573 insertions(+), 14 deletions(-) create mode 100644 crates/mpz-ole/Cargo.toml create mode 100644 crates/mpz-ole/src/ideal.rs create mode 100644 crates/mpz-ole/src/lib.rs create mode 100644 crates/mpz-ole/src/rot/mod.rs create mode 100644 crates/mpz-ole/src/rot/receiver.rs create mode 100644 crates/mpz-ole/src/rot/sender.rs diff --git a/crates/Cargo.toml b/crates/Cargo.toml index cb5c5362..61d239b3 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -16,6 +16,7 @@ members = [ "matrix-transpose", "clmul", "mpz-ole-core", + "mpz-ole" ] resolver = "2" @@ -39,6 +40,7 @@ mpz-garble = { path = "mpz-garble" } mpz-garble-core = { path = "mpz-garble-core" } mpz-share-conversion = { path = "mpz-share-conversion" } mpz-share-conversion-core = { path = "mpz-share-conversion-core" } +mpz-ole = { path = "mpz-ole" } mpz-ole-core = { path = "mpz-ole-core" } clmul = { path = "clmul" } matrix-transpose = { path = "matrix-transpose" } diff --git a/crates/mpz-common/src/cpu.rs b/crates/mpz-common/src/cpu.rs index c180823e..7cc4766c 100644 --- a/crates/mpz-common/src/cpu.rs +++ b/crates/mpz-common/src/cpu.rs @@ -21,7 +21,7 @@ mod st { pub struct SingleThreadedBackend; impl SingleThreadedBackend { - /// Execute a future on the CPU backend. + /// Executes a future on the CPU backend. #[inline] pub fn blocking_async(fut: F) -> impl Future + Send where @@ -31,7 +31,7 @@ mod st { fut } - /// Execute a closure on the CPU backend. + /// Executes a closure on the CPU backend. #[inline] pub fn blocking(f: F) -> impl Future + Send where @@ -71,7 +71,7 @@ mod rayon_backend { pub struct RayonBackend; impl RayonBackend { - /// Execute a future on the CPU backend. + /// Executes a future on the CPU backend. pub fn blocking_async(fut: F) -> impl Future + Send where F: Future + Send + 'static, @@ -87,7 +87,7 @@ mod rayon_backend { } } - /// Execute a closure on the CPU backend. + /// Executes a closure on the CPU backend. pub fn blocking(f: F) -> impl Future + Send where F: FnOnce() -> R + Send + 'static, diff --git a/crates/mpz-fields/Cargo.toml b/crates/mpz-fields/Cargo.toml index 720a4404..8e586c90 100644 --- a/crates/mpz-fields/Cargo.toml +++ b/crates/mpz-fields/Cargo.toml @@ -22,6 +22,7 @@ serde.workspace = true itybity.workspace = true typenum.workspace = true hybrid-array.workspace = true +thiserror.workspace = true [dev-dependencies] ghash_rc.workspace = true diff --git a/crates/mpz-fields/src/gf2_128.rs b/crates/mpz-fields/src/gf2_128.rs index 83c910f6..37fb0404 100644 --- a/crates/mpz-fields/src/gf2_128.rs +++ b/crates/mpz-fields/src/gf2_128.rs @@ -1,15 +1,15 @@ //! This module implements the extension field GF(2^128). -use std::ops::{Add, Mul, Neg}; - +use hybrid_array::Array; use itybity::{BitLength, FromBitIterator, GetBit, Lsb0, Msb0}; use rand::{distributions::Standard, prelude::Distribution}; use serde::{Deserialize, Serialize}; +use std::ops::{Add, Mul, Neg}; use mpz_core::Block; -use typenum::U128; +use typenum::{U128, U16}; -use crate::Field; +use crate::{Field, FieldError}; /// A type for holding field elements of Gf(2^128). #[derive(Copy, Clone, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] @@ -44,6 +44,16 @@ impl From for Gf2_128 { } } +impl TryFrom> for Gf2_128 { + type Error = FieldError; + + fn try_from(value: Array) -> Result { + let inner: [u8; 16] = value.into(); + + Ok(Gf2_128(u128::from_be_bytes(inner))) + } +} + impl Distribution for Standard { fn sample(&self, rng: &mut R) -> Gf2_128 { Gf2_128(self.sample(rng)) @@ -101,6 +111,8 @@ impl Neg for Gf2_128 { impl Field for Gf2_128 { type BitSize = U128; + type ByteSize = U16; + fn zero() -> Self { Self::new(0) } diff --git a/crates/mpz-fields/src/lib.rs b/crates/mpz-fields/src/lib.rs index 372d5be8..d71956d8 100644 --- a/crates/mpz-fields/src/lib.rs +++ b/crates/mpz-fields/src/lib.rs @@ -8,13 +8,15 @@ pub mod gf2_128; pub mod p256; use std::{ + error::Error, fmt::Debug, ops::{Add, Mul, Neg}, }; -use hybrid_array::ArraySize; +use hybrid_array::{Array, ArraySize}; use itybity::{BitLength, FromBitIterator, GetBit, Lsb0, Msb0}; use rand::{distributions::Standard, prelude::Distribution, Rng}; +use thiserror::Error; use typenum::Unsigned; /// A trait for finite fields. @@ -38,13 +40,20 @@ pub trait Field: + GetBit + BitLength + Unpin + + TryFrom, Error = FieldError> { /// The number of bits of a field element. const BIT_SIZE: usize = ::USIZE; + /// The number of bytes of a field element. + const BYTE_SIZE: usize = ::USIZE; + /// The number of bits of a field element as a type number. type BitSize: ArraySize; + /// The number of bytes of a field element as a type number. + type ByteSize: ArraySize; + /// Return the additive identity element. fn zero() -> Self; @@ -64,6 +73,11 @@ pub trait Field: fn to_be_bytes(&self) -> Vec; } +/// Error type for finite fields. +#[derive(Debug, Error)] +#[error(transparent)] +pub struct FieldError(Box); + /// A trait for sampling random elements of the field. /// /// This is helpful, because we do not need to import other traits since this is a supertrait of diff --git a/crates/mpz-fields/src/p256.rs b/crates/mpz-fields/src/p256.rs index 1204cb33..b2824477 100644 --- a/crates/mpz-fields/src/p256.rs +++ b/crates/mpz-fields/src/p256.rs @@ -4,14 +4,18 @@ use std::ops::{Add, Mul, Neg}; use ark_ff::{BigInt, BigInteger, Field as ArkField, FpConfig, MontBackend, One, Zero}; use ark_secp256r1::{fq::Fq, FqConfig}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Validate, +}; +use hybrid_array::Array; use itybity::{BitLength, FromBitIterator, GetBit, Lsb0, Msb0}; use num_bigint::ToBigUint; use rand::{distributions::Standard, prelude::Distribution}; use serde::{Deserialize, Serialize}; -use typenum::U256; +use thiserror::Error; +use typenum::{U256, U32}; -use crate::Field; +use crate::{Field, FieldError}; /// A type for holding field elements of P256. #[derive(Copy, Clone, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] @@ -42,11 +46,23 @@ impl From for [u8; 32] { } impl TryFrom<[u8; 32]> for P256 { - type Error = ark_serialize::SerializationError; + type Error = FieldError; /// Converts little-endian bytes into a P256 field element. fn try_from(value: [u8; 32]) -> Result { - Fq::deserialize_with_mode(&value[..], Compress::No, Validate::Yes).map(P256) + Fq::deserialize_with_mode(&value[..], Compress::No, Validate::Yes) + .map(P256) + .map_err(|err| FieldError(Box::new(P256Error(err)))) + } +} + +impl TryFrom> for P256 { + type Error = FieldError; + + fn try_from(value: Array) -> Result { + let inner: [u8; 32] = value.into(); + + P256::try_from(inner) } } @@ -83,6 +99,8 @@ impl Neg for P256 { impl Field for P256 { type BitSize = U256; + type ByteSize = U32; + fn zero() -> Self { P256(::zero()) } @@ -139,6 +157,11 @@ impl FromBitIterator for P256 { } } +/// Helper type because [`SerializationError`] does not implement std::error::Error. +#[derive(Debug, Error)] +#[error("{0}")] +pub struct P256Error(SerializationError); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/mpz-ole-core/src/receiver.rs b/crates/mpz-ole-core/src/receiver.rs index ca4fd67e..012995c9 100644 --- a/crates/mpz-ole-core/src/receiver.rs +++ b/crates/mpz-ole-core/src/receiver.rs @@ -96,6 +96,11 @@ impl OLEReceiver { Some((receiver_adjust, adjustments)) } + + /// Returns the number of preprocessed OLEs that are available. + pub fn cache_size(&self) -> usize { + self.cache.len() + } } /// Receiver adjustments waiting for [`BatchAdjust`] from the sender. diff --git a/crates/mpz-ole-core/src/sender.rs b/crates/mpz-ole-core/src/sender.rs index e245a35c..4c48d917 100644 --- a/crates/mpz-ole-core/src/sender.rs +++ b/crates/mpz-ole-core/src/sender.rs @@ -97,6 +97,11 @@ impl OLESender { Some((sender_adjust, adjustments)) } + + /// Returns the number of preprocessed OLEs that are available. + pub fn cache_size(&self) -> usize { + self.cache.len() + } } /// Sender adjustments waiting for [`BatchAdjust`] from the receiver. diff --git a/crates/mpz-ole/Cargo.toml b/crates/mpz-ole/Cargo.toml new file mode 100644 index 00000000..f4fdd903 --- /dev/null +++ b/crates/mpz-ole/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "mpz-ole" +version = "0.1.0" +edition = "2021" + +[lib] +name = "mpz_ole" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +ideal = ["mpz-common/ideal"] + +[dependencies] +mpz-fields.workspace = true +mpz-ot.workspace = true +mpz-core.workspace = true +mpz-ole-core.workspace = true +mpz-common.workspace = true + +serio.workspace = true + +thiserror.workspace = true +async-trait.workspace = true +futures.workspace = true +rand.workspace = true +itybity.workspace = true + +[dev-dependencies] +tokio = { workspace = true, features = [ + "net", + "macros", + "rt", + "rt-multi-thread", +] } +mpz-common = { workspace = true, features = ["test-utils", "ideal"] } +mpz-ot = { workspace = true, features = ["ideal"] } diff --git a/crates/mpz-ole/src/ideal.rs b/crates/mpz-ole/src/ideal.rs new file mode 100644 index 00000000..29734dea --- /dev/null +++ b/crates/mpz-ole/src/ideal.rs @@ -0,0 +1,87 @@ +//! Ideal OLE implementation. + +use crate::{OLEError, OLEReceiver, OLESender}; +use async_trait::async_trait; +use mpz_common::{ + ideal::{ideal_f2p, Alice, Bob}, + Context, +}; +use mpz_fields::Field; +use rand::thread_rng; + +/// Ideal OLESender. +pub struct IdealOLESender(Alice<()>); + +/// Ideal OLEReceiver. +pub struct IdealOLEReceiver(Bob<()>); + +/// Returns an OLE sender and receiver pair. +pub fn ideal_ole() -> (IdealOLESender, IdealOLEReceiver) { + let (alice, bob) = ideal_f2p(()); + + (IdealOLESender(alice), IdealOLEReceiver(bob)) +} + +fn ole(_: &mut (), alice_input: Vec, bob_input: Vec) -> (Vec, Vec) { + let mut rng = thread_rng(); + let alice_output: Vec = (0..alice_input.len()).map(|_| F::rand(&mut rng)).collect(); + + let bob_output: Vec = alice_input + .iter() + .zip(bob_input.iter()) + .zip(alice_output.iter().copied()) + .map(|((&a, &b), x)| a * b + x) + .collect(); + + (alice_output, bob_output) +} + +#[async_trait] +impl OLESender for IdealOLESender { + async fn send(&mut self, ctx: &mut Ctx, a_k: Vec) -> Result, OLEError> { + Ok(self.0.call(ctx, a_k, ole).await) + } +} + +#[async_trait] +impl OLEReceiver for IdealOLEReceiver { + async fn receive(&mut self, ctx: &mut Ctx, b_k: Vec) -> Result, OLEError> { + Ok(self.0.call(ctx, b_k, ole).await) + } +} + +#[cfg(test)] +mod tests { + use crate::{ideal::ideal_ole, OLEReceiver, OLESender}; + use mpz_common::executor::test_st_executor; + use mpz_core::{prg::Prg, Block}; + use mpz_fields::{p256::P256, UniformRand}; + use rand::SeedableRng; + + #[tokio::test] + async fn test_ideal_ole() { + let count = 12; + let mut rng = Prg::from_seed(Block::ZERO); + + let a_k: Vec = (0..count).map(|_| P256::rand(&mut rng)).collect(); + let b_k: Vec = (0..count).map(|_| P256::rand(&mut rng)).collect(); + + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(10); + + let (mut sender, mut receiver) = ideal_ole(); + + let (x_k, y_k) = tokio::try_join!( + sender.send(&mut ctx_sender, a_k.clone()), + receiver.receive(&mut ctx_receiver, b_k.clone()) + ) + .unwrap(); + + assert_eq!(x_k.len(), count); + assert_eq!(y_k.len(), count); + a_k.iter() + .zip(b_k) + .zip(x_k) + .zip(y_k) + .for_each(|(((&a, b), x), y)| assert_eq!(y, a * b + x)); + } +} diff --git a/crates/mpz-ole/src/lib.rs b/crates/mpz-ole/src/lib.rs new file mode 100644 index 00000000..6619b0f8 --- /dev/null +++ b/crates/mpz-ole/src/lib.rs @@ -0,0 +1,131 @@ +//! IO wrappers for Oblivious Linear Function Evaluation (OLE). + +#![deny(missing_docs, unreachable_pub, unused_must_use)] +#![deny(unsafe_code)] +#![deny(clippy::all)] + +use async_trait::async_trait; +use mpz_common::Context; +use mpz_fields::{Field, FieldError}; +use mpz_ole_core::OLEError as OLECoreError; +use mpz_ot::OTError; +use std::{ + error::Error, + fmt::{Debug, Display}, + io::Error as IOError, +}; + +#[cfg(feature = "ideal")] +pub mod ideal; +pub mod rot; + +/// Batch OLE Sender. +/// +/// The sender inputs field elements `a_k` and gets outputs `x_k`, such that +/// `y_k = a_k * b_k + x_k` holds, where `b_k` and `y_k` are the [`OLEReceiver`]'s inputs and outputs +/// respectively. +#[async_trait] +pub trait OLESender { + /// Sends his masked inputs to the [`OLEReceiver`]. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `inputs` - The sender's OLE inputs. + /// + /// # Returns + /// + /// * The sender's OLE outputs `x_k`. + async fn send(&mut self, ctx: &mut Ctx, inputs: Vec) -> Result, OLEError>; +} + +/// Batch OLE Receiver. +/// +/// The receiver inputs field elements `b_k` and gets outputs `y_k`, such that +/// `y_k = a_k * b_k + x_k` holds, where `a_k` and `x_k` are the [`OLESender`]'s inputs and outputs +/// respectively. +#[async_trait] +pub trait OLEReceiver { + /// Receives the masked inputs of the [`OLESender`]. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `inputs` - The receiver's OLE inputs. + /// + /// # Returns + /// + /// * The receiver's OLE outputs `y_k`. + async fn receive(&mut self, ctx: &mut Ctx, inputs: Vec) -> Result, OLEError>; +} + +/// An OLE error. +#[derive(Debug, thiserror::Error)] +pub struct OLEError { + kind: OLEErrorKind, + #[source] + source: Option>, +} + +impl OLEError { + fn new(kind: OLEErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } +} + +impl Display for OLEError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.kind { + OLEErrorKind::OT => write!(f, "OT Error"), + OLEErrorKind::IO => write!(f, "IO Error"), + OLEErrorKind::Core => write!(f, "OLE Core Error"), + OLEErrorKind::Field => write!(f, "FieldError"), + OLEErrorKind::InsufficientOLEs => write!(f, "Insufficient OLEs"), + }?; + + if let Some(source) = self.source.as_ref() { + write!(f, " caused by: {source}")?; + } + + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) enum OLEErrorKind { + OT, + IO, + Core, + Field, + InsufficientOLEs, +} + +impl From for OLEError { + fn from(value: OTError) -> Self { + Self::new(OLEErrorKind::OT, value) + } +} + +impl From for OLEError { + fn from(value: IOError) -> Self { + Self::new(OLEErrorKind::IO, value) + } +} + +impl From for OLEError { + fn from(value: OLECoreError) -> Self { + Self::new(OLEErrorKind::Core, value) + } +} + +impl From for OLEError { + fn from(value: FieldError) -> Self { + Self::new(OLEErrorKind::Field, value) + } +} diff --git a/crates/mpz-ole/src/rot/mod.rs b/crates/mpz-ole/src/rot/mod.rs new file mode 100644 index 00000000..5fde91c7 --- /dev/null +++ b/crates/mpz-ole/src/rot/mod.rs @@ -0,0 +1,56 @@ +//! Implementation of OLE with errors based on random OT. + +mod receiver; +mod sender; + +pub use receiver::OLEReceiver; +pub use sender::OLESender; + +#[cfg(test)] +mod tests { + use crate::{ + rot::{OLEReceiver, OLESender}, + OLEReceiver as _, OLESender as _, + }; + use mpz_common::executor::test_st_executor; + use mpz_core::{prg::Prg, Block}; + use mpz_fields::{p256::P256, UniformRand}; + use mpz_ot::ideal::rot::ideal_rot; + use rand::SeedableRng; + + #[tokio::test] + async fn test_ole() { + let count = 12; + let mut rng = Prg::from_seed(Block::ZERO); + + let (rot_sender, rot_receiver) = ideal_rot(); + + let mut ole_sender = OLESender::<_, P256>::new(rot_sender); + let mut ole_receiver = OLEReceiver::<_, P256>::new(rot_receiver); + + let a_k: Vec = (0..count).map(|_| P256::rand(&mut rng)).collect(); + let b_k: Vec = (0..count).map(|_| P256::rand(&mut rng)).collect(); + + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(10); + + tokio::try_join!( + ole_sender.preprocess(&mut ctx_sender, count), + ole_receiver.preprocess(&mut ctx_receiver, count) + ) + .unwrap(); + + let (x_k, y_k) = tokio::try_join!( + ole_sender.send(&mut ctx_sender, a_k.clone()), + ole_receiver.receive(&mut ctx_receiver, b_k.clone()) + ) + .unwrap(); + + assert_eq!(x_k.len(), count); + assert_eq!(y_k.len(), count); + a_k.iter() + .zip(b_k) + .zip(x_k) + .zip(y_k) + .for_each(|(((&a, b), x), y)| assert_eq!(y, a * b + x)); + } +} diff --git a/crates/mpz-ole/src/rot/receiver.rs b/crates/mpz-ole/src/rot/receiver.rs new file mode 100644 index 00000000..05dc0ec8 --- /dev/null +++ b/crates/mpz-ole/src/rot/receiver.rs @@ -0,0 +1,94 @@ +use crate::{OLEError, OLEErrorKind, OLEReceiver as OLEReceive}; +use async_trait::async_trait; +use itybity::ToBits; +use mpz_common::Context; +use mpz_fields::Field; +use mpz_ole_core::msg::{BatchAdjust, MaskedCorrelations}; +use mpz_ole_core::OLEReceiver as OLECoreReceiver; +use mpz_ot::RandomOTReceiver; +use serio::stream::IoStreamExt; +use serio::SinkExt; +use serio::{Deserialize, Serialize}; + +/// OLE receiver. +pub struct OLEReceiver { + rot_receiver: T, + core: OLECoreReceiver, +} + +impl OLEReceiver +where + F: Field + Serialize + Deserialize, +{ + /// Creates a new receiver. + pub fn new(rot_receiver: T) -> Self { + Self { + rot_receiver, + core: OLECoreReceiver::default(), + } + } +} + +impl OLEReceiver +where + F: Field + Serialize + Deserialize, +{ + /// Preprocesses OLEs. + /// + /// # Arguments + /// + /// * `count` - The number of OLEs to preprocess. + pub async fn preprocess( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result<(), OLEError> + where + T: RandomOTReceiver + Send, + { + let random_ot = self + .rot_receiver + .receive_random(ctx, count * F::BIT_SIZE) + .await?; + + let rot_msg: Vec = random_ot.msgs; + + let rot_choices: Vec = random_ot + .choices + .chunks(F::BIT_SIZE) + .map(|choice| F::from_lsb0_iter(choice.iter_lsb0())) + .collect(); + + let channel = ctx.io_mut(); + let masks = channel.expect_next::>().await?; + + self.core.preprocess(rot_choices, rot_msg, masks)?; + Ok(()) + } +} + +#[async_trait] +impl OLEReceive for OLEReceiver +where + F: Field + Serialize + Deserialize, +{ + async fn receive(&mut self, ctx: &mut Ctx, b_k: Vec) -> Result, OLEError> { + let len_requested = b_k.len(); + + let (receiver_adjust, adjust) = self.core.adjust(b_k).ok_or_else(|| { + OLEError::new( + OLEErrorKind::InsufficientOLEs, + format!("{} < {}", self.core.cache_size(), len_requested), + ) + })?; + + let channel = ctx.io_mut(); + channel.send(adjust).await?; + let adjust = channel.expect_next::>().await?; + + let shares = receiver_adjust.finish_adjust(adjust)?; + let y_k = shares.into_iter().map(|s| s.inner()).collect(); + + Ok(y_k) + } +} diff --git a/crates/mpz-ole/src/rot/sender.rs b/crates/mpz-ole/src/rot/sender.rs new file mode 100644 index 00000000..bf001a06 --- /dev/null +++ b/crates/mpz-ole/src/rot/sender.rs @@ -0,0 +1,92 @@ +use crate::{OLEError, OLEErrorKind, OLESender as OLESend}; +use async_trait::async_trait; +use mpz_common::Context; +use mpz_fields::Field; +use mpz_ole_core::msg::BatchAdjust; +use mpz_ole_core::OLESender as OLECoreSender; +use mpz_ot::RandomOTSender; +use rand::thread_rng; +use serio::stream::IoStreamExt; +use serio::SinkExt; +use serio::{Deserialize, Serialize}; + +/// OLE sender. +pub struct OLESender { + rot_sender: T, + core: OLECoreSender, +} + +impl OLESender +where + F: Field + Serialize + Deserialize, +{ + /// Creates a new sender. + pub fn new(rot_sender: T) -> Self { + Self { + rot_sender, + core: OLECoreSender::default(), + } + } +} + +impl OLESender +where + F: Field + Serialize + Deserialize, +{ + /// Preprocesses OLEs. + /// + /// # Arguments + /// + /// * `count` - The number of OLEs to preprocess. + pub async fn preprocess( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result<(), OLEError> + where + T: RandomOTSender + Send, + { + let random = { + let mut rng = thread_rng(); + (0..count).map(|_| F::rand(&mut rng)).collect() + }; + + let random_ot: Vec<[F; 2]> = self + .rot_sender + .send_random(ctx, count * F::BIT_SIZE) + .await? + .msgs; + + let channel = ctx.io_mut(); + + let masks = self.core.preprocess(random, random_ot)?; + channel.send(masks).await?; + + Ok(()) + } +} + +#[async_trait] +impl OLESend for OLESender +where + F: Field + Serialize + Deserialize, +{ + async fn send(&mut self, ctx: &mut Ctx, a_k: Vec) -> Result, OLEError> { + let len_requested = a_k.len(); + + let (sender_adjust, adjust) = self.core.adjust(a_k).ok_or_else(|| { + OLEError::new( + OLEErrorKind::InsufficientOLEs, + format!("{} < {}", self.core.cache_size(), len_requested), + ) + })?; + let channel = ctx.io_mut(); + channel.send(adjust).await?; + let adjust = channel.expect_next::>().await?; + + let shares = sender_adjust.finish_adjust(adjust)?; + let x_k = shares.into_iter().map(|s| s.inner()).collect(); + + Ok(x_k) + } +}