From 7cfe967d88fe5ed15ee4ba7b3d76cfd0847bcffa Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:02:47 -0700 Subject: [PATCH] feat: lazy ot --- Cargo.toml | 6 +- crates/mpz-cointoss/Cargo.toml | 2 +- crates/mpz-common/Cargo.toml | 14 +- crates/mpz-common/src/future.rs | 193 ++++++ crates/mpz-common/src/lib.rs | 26 +- crates/mpz-core/Cargo.toml | 1 + crates/mpz-core/src/bitvec.rs | 6 + crates/mpz-core/src/lib.rs | 1 + crates/mpz-core/src/prg.rs | 5 + crates/mpz-ot-core/Cargo.toml | 2 + crates/mpz-ot-core/benches/ot.rs | 54 +- crates/mpz-ot-core/examples/ot.rs | 33 - crates/mpz-ot-core/src/chou_orlandi.rs | 127 ++++ crates/mpz-ot-core/src/chou_orlandi/config.rs | 57 -- crates/mpz-ot-core/src/chou_orlandi/mod.rs | 176 ----- crates/mpz-ot-core/src/chou_orlandi/msgs.rs | 13 - .../mpz-ot-core/src/chou_orlandi/receiver.rs | 207 +++--- crates/mpz-ot-core/src/chou_orlandi/sender.rs | 189 ++---- crates/mpz-ot-core/src/cot.rs | 72 ++ crates/mpz-ot-core/src/cot/derandomize.rs | 390 +++++++++++ crates/mpz-ot-core/src/ferret/cuckoo.rs | 195 ------ crates/mpz-ot-core/src/ferret/error.rs | 11 - crates/mpz-ot-core/src/ferret/mod.rs | 142 ---- crates/mpz-ot-core/src/ferret/mpcot/error.rs | 24 - crates/mpz-ot-core/src/ferret/mpcot/mod.rs | 169 ----- crates/mpz-ot-core/src/ferret/mpcot/msgs.rs | 19 - .../mpz-ot-core/src/ferret/mpcot/receiver.rs | 234 ------- .../src/ferret/mpcot/receiver_regular.rs | 192 ------ crates/mpz-ot-core/src/ferret/mpcot/sender.rs | 223 ------- .../src/ferret/mpcot/sender_regular.rs | 185 ------ crates/mpz-ot-core/src/ferret/msgs.rs | 11 - crates/mpz-ot-core/src/ferret/receiver.rs | 184 ------ crates/mpz-ot-core/src/ferret/sender.rs | 149 ----- crates/mpz-ot-core/src/ferret/spcot/error.rs | 25 - crates/mpz-ot-core/src/ferret/spcot/mod.rs | 90 --- crates/mpz-ot-core/src/ferret/spcot/msgs.rs | 45 -- .../mpz-ot-core/src/ferret/spcot/receiver.rs | 310 --------- crates/mpz-ot-core/src/ferret/spcot/sender.rs | 251 ------- crates/mpz-ot-core/src/ideal.rs | 6 + crates/mpz-ot-core/src/ideal/cot.rs | 333 ++++++---- crates/mpz-ot-core/src/ideal/mod.rs | 7 - crates/mpz-ot-core/src/ideal/mpcot.rs | 97 --- crates/mpz-ot-core/src/ideal/ot.rs | 199 ++++-- crates/mpz-ot-core/src/ideal/rcot.rs | 311 +++++++++ crates/mpz-ot-core/src/ideal/rot.rs | 334 +++++++--- crates/mpz-ot-core/src/ideal/spcot.rs | 104 --- crates/mpz-ot-core/src/kos.rs | 337 ++++++++++ crates/mpz-ot-core/src/kos/config.rs | 50 +- crates/mpz-ot-core/src/kos/error.rs | 26 +- crates/mpz-ot-core/src/kos/mod.rs | 385 ----------- crates/mpz-ot-core/src/kos/msgs.rs | 12 +- crates/mpz-ot-core/src/kos/receiver.rs | 613 ++++-------------- crates/mpz-ot-core/src/kos/sender.rs | 418 +++++------- crates/mpz-ot-core/src/lib.rs | 135 +--- crates/mpz-ot-core/src/msgs.rs | 69 -- crates/mpz-ot-core/src/ot.rs | 57 ++ crates/mpz-ot-core/src/rcot.rs | 84 +++ crates/mpz-ot-core/src/rot.rs | 87 +++ crates/mpz-ot-core/src/rot/any.rs | 147 +++++ crates/mpz-ot-core/src/rot/randomize.rs | 232 +++++++ crates/mpz-ot-core/src/test.rs | 14 + crates/mpz-ot/Cargo.toml | 13 +- crates/mpz-ot/benches/ot.rs | 46 -- crates/mpz-ot/src/chou_orlandi.rs | 19 + crates/mpz-ot/src/chou_orlandi/error.rs | 63 -- crates/mpz-ot/src/chou_orlandi/mod.rs | 154 ----- crates/mpz-ot/src/chou_orlandi/receiver.rs | 231 +++---- crates/mpz-ot/src/chou_orlandi/sender.rs | 222 +++---- crates/mpz-ot/src/cot.rs | 8 + crates/mpz-ot/src/cot/derandomize.rs | 14 + crates/mpz-ot/src/cot/derandomize/receiver.rs | 123 ++++ crates/mpz-ot/src/cot/derandomize/sender.rs | 126 ++++ crates/mpz-ot/src/ideal.rs | 6 + crates/mpz-ot/src/ideal/cot.rs | 249 +++---- crates/mpz-ot/src/ideal/mod.rs | 5 - crates/mpz-ot/src/ideal/ot.rs | 176 ++--- crates/mpz-ot/src/ideal/rcot.rs | 131 ++++ crates/mpz-ot/src/ideal/rot.rs | 181 +++--- crates/mpz-ot/src/kos.rs | 33 + crates/mpz-ot/src/kos/error.rs | 92 --- crates/mpz-ot/src/kos/mod.rs | 258 -------- crates/mpz-ot/src/kos/receiver.rs | 475 +++++--------- crates/mpz-ot/src/kos/sender.rs | 491 +++++--------- crates/mpz-ot/src/kos/shared_receiver.rs | 139 ---- crates/mpz-ot/src/kos/shared_sender.rs | 117 ---- crates/mpz-ot/src/lib.rs | 235 +------ crates/mpz-ot/src/ot.rs | 3 + crates/mpz-ot/src/rcot.rs | 3 + crates/mpz-ot/src/rot.rs | 6 + crates/mpz-ot/src/rot/any.rs | 33 + crates/mpz-ot/src/rot/any/receiver.rs | 67 ++ crates/mpz-ot/src/rot/any/sender.rs | 67 ++ crates/mpz-ot/src/rot/randomize.rs | 27 + crates/mpz-ot/src/rot/randomize/receiver.rs | 71 ++ crates/mpz-ot/src/rot/randomize/sender.rs | 68 ++ crates/mpz-ot/src/test.rs | 152 +++++ rustfmt.toml | 2 + 97 files changed, 4958 insertions(+), 7238 deletions(-) create mode 100644 crates/mpz-common/src/future.rs create mode 100644 crates/mpz-core/src/bitvec.rs delete mode 100644 crates/mpz-ot-core/examples/ot.rs create mode 100644 crates/mpz-ot-core/src/chou_orlandi.rs delete mode 100644 crates/mpz-ot-core/src/chou_orlandi/config.rs delete mode 100644 crates/mpz-ot-core/src/chou_orlandi/mod.rs create mode 100644 crates/mpz-ot-core/src/cot.rs create mode 100644 crates/mpz-ot-core/src/cot/derandomize.rs delete mode 100644 crates/mpz-ot-core/src/ferret/cuckoo.rs delete mode 100644 crates/mpz-ot-core/src/ferret/error.rs delete mode 100644 crates/mpz-ot-core/src/ferret/mod.rs delete mode 100644 crates/mpz-ot-core/src/ferret/mpcot/error.rs delete mode 100644 crates/mpz-ot-core/src/ferret/mpcot/mod.rs delete mode 100644 crates/mpz-ot-core/src/ferret/mpcot/msgs.rs delete mode 100644 crates/mpz-ot-core/src/ferret/mpcot/receiver.rs delete mode 100644 crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs delete mode 100644 crates/mpz-ot-core/src/ferret/mpcot/sender.rs delete mode 100644 crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs delete mode 100644 crates/mpz-ot-core/src/ferret/msgs.rs delete mode 100644 crates/mpz-ot-core/src/ferret/receiver.rs delete mode 100644 crates/mpz-ot-core/src/ferret/sender.rs delete mode 100644 crates/mpz-ot-core/src/ferret/spcot/error.rs delete mode 100644 crates/mpz-ot-core/src/ferret/spcot/mod.rs delete mode 100644 crates/mpz-ot-core/src/ferret/spcot/msgs.rs delete mode 100644 crates/mpz-ot-core/src/ferret/spcot/receiver.rs delete mode 100644 crates/mpz-ot-core/src/ferret/spcot/sender.rs create mode 100644 crates/mpz-ot-core/src/ideal.rs delete mode 100644 crates/mpz-ot-core/src/ideal/mod.rs delete mode 100644 crates/mpz-ot-core/src/ideal/mpcot.rs create mode 100644 crates/mpz-ot-core/src/ideal/rcot.rs delete mode 100644 crates/mpz-ot-core/src/ideal/spcot.rs create mode 100644 crates/mpz-ot-core/src/kos.rs delete mode 100644 crates/mpz-ot-core/src/kos/mod.rs delete mode 100644 crates/mpz-ot-core/src/msgs.rs create mode 100644 crates/mpz-ot-core/src/ot.rs create mode 100644 crates/mpz-ot-core/src/rcot.rs create mode 100644 crates/mpz-ot-core/src/rot.rs create mode 100644 crates/mpz-ot-core/src/rot/any.rs create mode 100644 crates/mpz-ot-core/src/rot/randomize.rs delete mode 100644 crates/mpz-ot/benches/ot.rs create mode 100644 crates/mpz-ot/src/chou_orlandi.rs delete mode 100644 crates/mpz-ot/src/chou_orlandi/error.rs delete mode 100644 crates/mpz-ot/src/chou_orlandi/mod.rs create mode 100644 crates/mpz-ot/src/cot.rs create mode 100644 crates/mpz-ot/src/cot/derandomize.rs create mode 100644 crates/mpz-ot/src/cot/derandomize/receiver.rs create mode 100644 crates/mpz-ot/src/cot/derandomize/sender.rs create mode 100644 crates/mpz-ot/src/ideal.rs delete mode 100644 crates/mpz-ot/src/ideal/mod.rs create mode 100644 crates/mpz-ot/src/ideal/rcot.rs create mode 100644 crates/mpz-ot/src/kos.rs delete mode 100644 crates/mpz-ot/src/kos/error.rs delete mode 100644 crates/mpz-ot/src/kos/mod.rs delete mode 100644 crates/mpz-ot/src/kos/shared_receiver.rs delete mode 100644 crates/mpz-ot/src/kos/shared_sender.rs create mode 100644 crates/mpz-ot/src/ot.rs create mode 100644 crates/mpz-ot/src/rcot.rs create mode 100644 crates/mpz-ot/src/rot.rs create mode 100644 crates/mpz-ot/src/rot/any.rs create mode 100644 crates/mpz-ot/src/rot/any/receiver.rs create mode 100644 crates/mpz-ot/src/rot/any/sender.rs create mode 100644 crates/mpz-ot/src/rot/randomize.rs create mode 100644 crates/mpz-ot/src/rot/randomize/receiver.rs create mode 100644 crates/mpz-ot/src/rot/randomize/sender.rs create mode 100644 crates/mpz-ot/src/test.rs create mode 100644 rustfmt.toml diff --git a/Cargo.toml b/Cargo.toml index d1e7c492..e7773595 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,8 +44,8 @@ mpz-ole-core = { path = "crates/mpz-ole-core" } clmul = { path = "crates/clmul" } matrix-transpose = { path = "crates/matrix-transpose" } -tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" } +tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "5899190" } +tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "5899190" } # rand rand_chacha = "0.3" @@ -124,3 +124,5 @@ typenum = "1" generic-array = "0.14" itybity = "0.2" enum-try-as-inner = "0.1.0" +bitvec = "1.0" +hashbrown = "0.14.5" diff --git a/crates/mpz-cointoss/Cargo.toml b/crates/mpz-cointoss/Cargo.toml index e5765d19..e2c9f476 100644 --- a/crates/mpz-cointoss/Cargo.toml +++ b/crates/mpz-cointoss/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] mpz-core.workspace = true -mpz-common.workspace = true +mpz-common = { workspace = true, features = ["ctx"] } mpz-cointoss-core.workspace = true futures.workspace = true diff --git a/crates/mpz-common/Cargo.toml b/crates/mpz-common/Cargo.toml index 2dd1e3f5..baa0ee0f 100644 --- a/crates/mpz-common/Cargo.toml +++ b/crates/mpz-common/Cargo.toml @@ -4,10 +4,14 @@ version = "0.1.0" edition = "2021" [features] -default = ["sync"] -sync = ["tokio/sync"] -test-utils = ["uid-mux/test-utils"] -ideal = [] +default = [] +ctx = [] +cpu = ["rayon"] +executor = ["ctx", "cpu"] +sync = ["tokio/sync", "ctx"] +future = [] +test-utils = ["dep:uid-mux", "uid-mux/test-utils"] +ideal = ["ctx"] rayon = ["dep:rayon"] force-st = [] @@ -20,7 +24,7 @@ pin-project-lite.workspace = true scoped-futures.workspace = true thiserror.workspace = true serio.workspace = true -uid-mux.workspace = true +uid-mux = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } pollster.workspace = true rayon = { workspace = true, optional = true } diff --git a/crates/mpz-common/src/future.rs b/crates/mpz-common/src/future.rs new file mode 100644 index 00000000..f8fb2055 --- /dev/null +++ b/crates/mpz-common/src/future.rs @@ -0,0 +1,193 @@ +//! Future types. + +use std::{ + future::Future, + mem, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use futures::{channel::oneshot, FutureExt}; +use pin_project_lite::pin_project; + +/// Creates a new output future. +pub fn new_output() -> (Sender, MaybeDone) { + let (send, recv) = oneshot::channel(); + (Sender { send }, MaybeDone { recv }) +} + +/// A future output value. +/// +/// This trait extends [`std::future::Future`] for values which can be received +/// outside of a task context. +pub trait Output: Future> { + /// Success type. + type Ok; + + /// Attempts to receive the output outside of a task context, returning + /// `None` if it is not ready. + fn try_recv(&mut self) -> Result, Canceled>; +} + +/// An extension trait for [`Output`]. +pub trait OutputExt: Output { + /// Maps the output value to a different type. + fn map(self, f: F) -> Map + where + Self: Sized, + F: FnOnce(Self::Ok) -> O, + { + Map::new(self, f) + } +} + +impl OutputExt for T where T: Output {} + +/// Output canceled error. +#[derive(Debug, thiserror::Error)] +#[error("output canceled")] +pub struct Canceled { + _private: (), +} + +/// Sender of an output value. +#[derive(Debug)] +pub struct Sender { + send: oneshot::Sender, +} + +impl Sender { + /// Sends an output value. + pub fn send(self, value: T) { + let _ = self.send.send(value); + } +} + +/// An output value that may be ready. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct MaybeDone { + recv: oneshot::Receiver, +} + +impl Output for MaybeDone { + type Ok = T; + + fn try_recv(&mut self) -> Result, Canceled> { + match self.recv.try_recv() { + Ok(Some(value)) => Ok(Some(value)), + Ok(None) => Ok(None), + Err(oneshot::Canceled) => Err(Canceled { _private: () }), + } + } +} + +impl Future for MaybeDone { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.recv + .poll_unpin(cx) + .map_err(|_| Canceled { _private: () }) + } +} + +pin_project! { + /// Maps an output value to a different type. + /// + /// Returned by [`OutputExt::map`]. + #[derive(Debug)] + pub struct Map { + #[pin] + inner: MapInner, + } +} + +impl Map { + fn new(inner: I, f: F) -> Self { + Self { + inner: MapInner::Incomplete { inner, f }, + } + } +} + +impl Output for Map +where + I: Output, + F: FnOnce(I::Ok) -> O, +{ + type Ok = O; + + fn try_recv(&mut self) -> Result, Canceled> { + self.inner.try_recv() + } +} + +impl Future for Map +where + I: Output, + F: FnOnce(I::Ok) -> O, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) + } +} + +pin_project! { + /// Maps an output value to a different type. + /// + /// Returned by [`OutputExt::map`]. + #[derive(Debug)] + #[project = MapProj] + #[project_replace = MapProjReplace] + #[must_use = "futures do nothing unless you `.await` or poll them"] + enum MapInner { + Incomplete { + #[pin] + inner: I, + f: F, + }, + Done, + } +} + +impl Output for MapInner +where + I: Output, + F: FnOnce(I::Ok) -> O, +{ + type Ok = O; + + fn try_recv(&mut self) -> Result, Canceled> { + let this = mem::replace(self, MapInner::Done); + match this { + MapInner::Incomplete { mut inner, f } => inner.try_recv().map(|res| res.map(f)), + MapInner::Done => Err(Canceled { _private: () }), + } + } +} + +impl Future for MapInner +where + I: Output, + F: FnOnce(I::Ok) -> O, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + MapProj::Incomplete { inner, .. } => { + let output = ready!(inner.poll(cx)); + match self.project_replace(Self::Done) { + MapProjReplace::Incomplete { f, .. } => Poll::Ready(output.map(f)), + MapProjReplace::Done => unreachable!(), + } + } + MapProj::Done => { + panic!("Map must not be polled after it returned `Poll::Ready`") + } + } + } +} diff --git a/crates/mpz-common/src/lib.rs b/crates/mpz-common/src/lib.rs index 7308e138..266de1cf 100644 --- a/crates/mpz-common/src/lib.rs +++ b/crates/mpz-common/src/lib.rs @@ -1,9 +1,10 @@ //! Common functionality for `mpz`. //! -//! This crate provides various common functionalities needed for modeling protocol execution, I/O, -//! and multi-threading. +//! This crate provides various common functionalities needed for modeling +//! protocol execution, I/O, and multi-threading. //! -//! This crate does not provide any cryptographic primitives, see `mpz-core` for that. +//! This crate does not provide any cryptographic primitives, see `mpz-core` for +//! that. #![deny( unsafe_code, @@ -14,9 +15,14 @@ clippy::all )] +#[cfg(any(test, feature = "ctx"))] mod context; +#[cfg(any(test, feature = "cpu"))] pub mod cpu; +#[cfg(any(test, feature = "executor"))] pub mod executor; +#[cfg(any(test, feature = "future"))] +pub mod future; mod id; #[cfg(any(test, feature = "ideal"))] pub mod ideal; @@ -24,6 +30,7 @@ pub mod ideal; pub mod sync; use async_trait::async_trait; +#[cfg(any(test, feature = "ctx"))] pub use context::{Context, ContextError}; pub use id::{Counter, ThreadId}; @@ -46,6 +53,19 @@ pub trait Preprocess: Allocate { async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error>; } +/// A functionality that can be flushed. +#[async_trait] +pub trait Flush { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + + /// Returns `true` if the functionality wants to be flushed. + fn wants_flush(&self) -> bool; + + /// Flushes the functionality. + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error>; +} + /// A convenience macro for creating a closure which returns a scoped future. /// /// # Example diff --git a/crates/mpz-core/Cargo.toml b/crates/mpz-core/Cargo.toml index 55adf7a0..c3df0aae 100644 --- a/crates/mpz-core/Cargo.toml +++ b/crates/mpz-core/Cargo.toml @@ -32,6 +32,7 @@ bytemuck = { workspace = true, features = ["derive"] } generic-array.workspace = true rayon = { workspace = true, optional = true } cfg-if.workspace = true +bitvec = { workspace = true, features = ["serde"] } [dev-dependencies] rstest.workspace = true diff --git a/crates/mpz-core/src/bitvec.rs b/crates/mpz-core/src/bitvec.rs new file mode 100644 index 00000000..d83bd9e0 --- /dev/null +++ b/crates/mpz-core/src/bitvec.rs @@ -0,0 +1,6 @@ +//! Bit vectors. + +/// Bit vector. +pub type BitVec = bitvec::vec::BitVec; +/// Bit slice. +pub type BitSlice = bitvec::slice::BitSlice; diff --git a/crates/mpz-core/src/lib.rs b/crates/mpz-core/src/lib.rs index 4fa60e0b..b9824a2c 100644 --- a/crates/mpz-core/src/lib.rs +++ b/crates/mpz-core/src/lib.rs @@ -3,6 +3,7 @@ #![deny(clippy::all)] pub mod aes; +pub mod bitvec; pub mod block; pub mod commit; pub mod ggm_tree; diff --git a/crates/mpz-core/src/prg.rs b/crates/mpz-core/src/prg.rs index d49f8d03..f9a8b5c9 100644 --- a/crates/mpz-core/src/prg.rs +++ b/crates/mpz-core/src/prg.rs @@ -118,6 +118,11 @@ impl Prg { Prg::from_seed(rand::random::()) } + /// Create a new PRG from a seed. + pub fn new_with_seed(seed: [u8; 16]) -> Self { + Prg::from_seed(Block::from(seed)) + } + /// Returns the current counter. pub fn counter(&self) -> u64 { self.0.core.counter diff --git a/crates/mpz-ot-core/Cargo.toml b/crates/mpz-ot-core/Cargo.toml index 8c109327..fe735622 100644 --- a/crates/mpz-ot-core/Cargo.toml +++ b/crates/mpz-ot-core/Cargo.toml @@ -16,6 +16,7 @@ test-utils = [] [dependencies] mpz-core.workspace = true +mpz-common = { workspace = true, features = ["future"] } clmul.workspace = true matrix-transpose.workspace = true @@ -38,6 +39,7 @@ opaque-debug.workspace = true cfg-if.workspace = true bytemuck = { workspace = true, features = ["derive"] } enum-try-as-inner.workspace = true +futures = { workspace = true } [dev-dependencies] rstest.workspace = true diff --git a/crates/mpz-ot-core/benches/ot.rs b/crates/mpz-ot-core/benches/ot.rs index 34a958e7..f3f770a4 100644 --- a/crates/mpz-ot-core/benches/ot.rs +++ b/crates/mpz-ot-core/benches/ot.rs @@ -1,8 +1,12 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; -use itybity::{IntoBitIterator, ToBits}; +use itybity::ToBits; use mpz_core::Block; -use mpz_ot_core::{chou_orlandi, kos}; -use rand::{Rng, RngCore, SeedableRng}; +use mpz_ot_core::{ + chou_orlandi, kos, + ot::{OTReceiver, OTSender}, + rcot::{RCOTReceiver, RCOTSender}, +}; +use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha12Rng; fn chou_orlandi(c: &mut Criterion) { @@ -11,8 +15,7 @@ fn chou_orlandi(c: &mut Criterion) { group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &n| { let msgs = vec![[Block::ONES; 2]; n]; let mut rng = ChaCha12Rng::seed_from_u64(0); - let mut choices = vec![0u8; n / 8]; - rng.fill_bytes(&mut choices); + let choices = (0..n).map(|_| rng.gen()).collect::>(); b.iter(|| { let sender = chou_orlandi::Sender::default(); let receiver = chou_orlandi::Receiver::default(); @@ -20,9 +23,14 @@ fn chou_orlandi(c: &mut Criterion) { let (sender_setup, mut sender) = sender.setup(); let mut receiver = receiver.setup(sender_setup); - let receiver_payload = receiver.receive_random(choices.as_slice()); - let sender_payload = sender.send(&msgs, receiver_payload).unwrap(); - black_box(receiver.receive(sender_payload).unwrap()) + let sender_output = sender.queue_send_ot(&msgs).unwrap(); + let receiver_output = receiver.queue_recv_ot(&choices).unwrap(); + + let receiver_payload = receiver.choose(); + let sender_payload = sender.send(receiver_payload).unwrap(); + receiver.receive(sender_payload).unwrap(); + + black_box((sender_output, receiver_output)) }) }); } @@ -32,11 +40,7 @@ fn kos(c: &mut Criterion) { let mut group = c.benchmark_group("kos"); for n in [1024, 262144] { group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &n| { - let msgs = vec![[Block::ONES; 2]; n]; let mut rng = ChaCha12Rng::seed_from_u64(0); - let mut choices = vec![0u8; n / 8]; - rng.fill_bytes(&mut choices); - let choices = choices.into_lsb0_vec(); let delta = Block::random(&mut rng); let chi_seed = Block::random(&mut rng); @@ -50,28 +54,24 @@ fn kos(c: &mut Criterion) { .unwrap(); b.iter(|| { - let sender = kos::Sender::new(kos::SenderConfig::default()); + let sender = kos::Sender::new(kos::SenderConfig::default(), delta); let receiver = kos::Receiver::new(kos::ReceiverConfig::default()); - let mut sender = sender.setup(delta, sender_seeds); + let mut sender = sender.setup(sender_seeds); let mut receiver = receiver.setup(receiver_seeds); - let receiver_setup = receiver.extend(choices.len() + 256).unwrap(); - sender.extend(msgs.len() + 256, receiver_setup).unwrap(); - - let receiver_check = receiver.check(chi_seed).unwrap(); - sender.check(chi_seed, receiver_check).unwrap(); - - let mut receiver_keys = receiver.keys(choices.len()).unwrap(); - let derandomize = receiver_keys.derandomize(&choices).unwrap(); + sender.alloc(n).unwrap(); + receiver.alloc(n).unwrap(); - let mut sender_keys = sender.keys(msgs.len()).unwrap(); - sender_keys.derandomize(derandomize).unwrap(); - let payload = sender_keys.encrypt_blocks(&msgs).unwrap(); + while receiver.wants_extend() { + let extend = receiver.extend().unwrap(); + sender.extend(extend).unwrap(); + } - let received = receiver_keys.decrypt_blocks(payload).unwrap(); + let check = receiver.check(chi_seed).unwrap(); + sender.check(chi_seed, check).unwrap(); - black_box(received) + black_box((sender, receiver)); }) }); } diff --git a/crates/mpz-ot-core/examples/ot.rs b/crates/mpz-ot-core/examples/ot.rs deleted file mode 100644 index 38376c64..00000000 --- a/crates/mpz-ot-core/examples/ot.rs +++ /dev/null @@ -1,33 +0,0 @@ -// This example demonstrates how to securely and privately transfer data using OT. -// In practical situations data would be communicated over a channel such as TCP. -// For simplicity, this example shows how to use CO15 OT in memory. - -use mpz_core::Block; -use mpz_ot_core::chou_orlandi::{Receiver, Sender}; - -pub fn main() { - // Receiver choice bits - let choices = vec![false, true, false, false, true, true, false, true]; - - println!("Receiver choices: {:?}", &choices); - - // Sender messages the receiver chooses from - let inputs = [[Block::ZERO, Block::ONES]; 8]; - - println!("Sender inputs: {:?}", &inputs); - - // First the sender creates a setup message and passes it to receiver - let (sender_setup, mut sender) = Sender::default().setup(); - - // Receiver takes sender's setup and generates the receiver payload - let mut receiver = Receiver::default().setup(sender_setup); - let receiver_payload = receiver.receive_random(&choices); - - // Finally, sender encrypts their inputs and sends them to receiver - let sender_payload = sender.send(&inputs, receiver_payload).unwrap(); - - // Receiver takes the encrypted inputs and is able to decrypt according to their choice bits - let received = receiver.receive(sender_payload).unwrap(); - - println!("Transferred messages: {:?}", received); -} diff --git a/crates/mpz-ot-core/src/chou_orlandi.rs b/crates/mpz-ot-core/src/chou_orlandi.rs new file mode 100644 index 00000000..757c6e5e --- /dev/null +++ b/crates/mpz-ot-core/src/chou_orlandi.rs @@ -0,0 +1,127 @@ +//! An implementation of the Chou-Orlandi [`CO15`](https://eprint.iacr.org/2015/267.pdf) oblivious transfer protocol. + +mod error; +pub mod msgs; +mod receiver; +mod sender; + +pub use error::{ReceiverError, SenderError, SenderVerifyError}; +pub use receiver::{state as receiver_state, Receiver}; +pub use sender::{state as sender_state, Sender}; + +use blake3::Hasher; +use curve25519_dalek::ristretto::RistrettoPoint; +use mpz_core::Block; + +/// Hashes a ristretto point to a symmetric key +/// +/// Prepending a tweak is suggested in Section 2, "Non-Malleability in Practice" +pub(crate) fn hash_point(point: &RistrettoPoint, tweak: u128) -> Block { + // Compute H(tweak || point) + let mut h = Hasher::new(); + h.update(&tweak.to_be_bytes()); + h.update(point.compress().as_bytes()); + let digest = h.finalize(); + let digest: &[u8; 32] = digest.as_bytes(); + + // Copy the first 16 bytes into a Block + let mut block = [0u8; 16]; + block.copy_from_slice(&digest[..16]); + block.into() +} + +#[cfg(test)] +mod tests { + use crate::{ + ot::{OTReceiver, OTReceiverOutput, OTSender, OTSenderOutput}, + test::assert_ot, + }; + + use super::*; + use mpz_common::future::Output; + use rstest::*; + + use rand::Rng; + use rand_chacha::ChaCha12Rng; + use rand_core::SeedableRng; + + const SENDER_SEED: [u8; 32] = [0u8; 32]; + const RECEIVER_SEED: [u8; 32] = [1u8; 32]; + + #[fixture] + fn choices() -> Vec { + let mut rng = ChaCha12Rng::seed_from_u64(0); + (0..128).map(|_| rng.gen()).collect() + } + + #[fixture] + fn data() -> Vec<[Block; 2]> { + let mut rng = ChaCha12Rng::seed_from_u64(0); + (0..128) + .map(|_| [rng.gen::<[u8; 16]>().into(), rng.gen::<[u8; 16]>().into()]) + .collect() + } + + fn setup() -> (Sender, Receiver) { + let sender = Sender::new_with_seed(SENDER_SEED); + let receiver = Receiver::new_with_seed(RECEIVER_SEED); + + let (sender_setup, sender) = sender.setup(); + let receiver = receiver.setup(sender_setup); + + (sender, receiver) + } + + #[rstest] + fn test_ot_pass(choices: Vec, data: Vec<[Block; 2]>) { + let (mut sender, mut receiver) = setup(); + + let mut sender_output = sender.queue_send_ot(&data).unwrap(); + let mut receiver_output = receiver.queue_recv_ot(&choices).unwrap(); + + let receiver_payload = receiver.choose(); + let sender_payload = sender.send(receiver_payload).unwrap(); + receiver.receive(sender_payload).unwrap(); + + let OTSenderOutput { id: sender_id } = sender_output.try_recv().unwrap().unwrap(); + let OTReceiverOutput { + id: receiver_id, + msgs, + } = receiver_output.try_recv().unwrap().unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_ot(&choices, &data, &msgs); + } + + #[rstest] + fn test_multiple_ot_pass(choices: Vec, data: Vec<[Block; 2]>) { + let (mut sender, mut receiver) = setup(); + + let mut sender_output = sender.queue_send_ot(&data).unwrap(); + let mut sender_output2 = sender.queue_send_ot(&data).unwrap(); + let mut receiver_output = receiver.queue_recv_ot(&choices).unwrap(); + let mut receiver_output2 = receiver.queue_recv_ot(&choices).unwrap(); + + let receiver_payload = receiver.choose(); + let sender_payload = sender.send(receiver_payload).unwrap(); + receiver.receive(sender_payload).unwrap(); + + let OTSenderOutput { id: sender_id } = sender_output.try_recv().unwrap().unwrap(); + let OTReceiverOutput { + id: receiver_id, + msgs, + } = receiver_output.try_recv().unwrap().unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_ot(&choices, &data, &msgs); + + let OTSenderOutput { id: sender_id2 } = sender_output2.try_recv().unwrap().unwrap(); + let OTReceiverOutput { + id: receiver_id2, + msgs: msgs2, + } = receiver_output2.try_recv().unwrap().unwrap(); + + assert_eq!(sender_id2, receiver_id2); + assert_ot(&choices, &data, &msgs2); + } +} diff --git a/crates/mpz-ot-core/src/chou_orlandi/config.rs b/crates/mpz-ot-core/src/chou_orlandi/config.rs deleted file mode 100644 index b211a43c..00000000 --- a/crates/mpz-ot-core/src/chou_orlandi/config.rs +++ /dev/null @@ -1,57 +0,0 @@ -use derive_builder::Builder; - -/// CO15 sender configuration. -#[derive(Debug, Default, Clone, Builder)] -pub struct SenderConfig { - /// Whether the Receiver should commit to their choices. - #[builder(setter(custom), default = "false")] - receiver_commit: bool, -} - -impl SenderConfigBuilder { - /// Sets the Receiver to commit to their choices. - pub fn receiver_commit(&mut self) -> &mut Self { - self.receiver_commit = Some(true); - self - } -} - -impl SenderConfig { - /// Creates a new builder for SenderConfig. - pub fn builder() -> SenderConfigBuilder { - SenderConfigBuilder::default() - } - - /// Whether the Receiver should commit to their choices. - pub fn receiver_commit(&self) -> bool { - self.receiver_commit - } -} - -/// CO15 receiver configuration. -#[derive(Debug, Default, Clone, Builder)] -pub struct ReceiverConfig { - /// Whether the Receiver should commit to their choices. - #[builder(setter(custom), default = "false")] - receiver_commit: bool, -} - -impl ReceiverConfigBuilder { - /// Sets the Receiver to commit to their choices. - pub fn receiver_commit(&mut self) -> &mut Self { - self.receiver_commit = Some(true); - self - } -} - -impl ReceiverConfig { - /// Creates a new builder for ReceiverConfig. - pub fn builder() -> ReceiverConfigBuilder { - ReceiverConfigBuilder::default() - } - - /// Whether the Receiver should commit to their choices. - pub fn receiver_commit(&self) -> bool { - self.receiver_commit - } -} diff --git a/crates/mpz-ot-core/src/chou_orlandi/mod.rs b/crates/mpz-ot-core/src/chou_orlandi/mod.rs deleted file mode 100644 index a3f2d278..00000000 --- a/crates/mpz-ot-core/src/chou_orlandi/mod.rs +++ /dev/null @@ -1,176 +0,0 @@ -//! An implementation of the Chou-Orlandi [`CO15`](https://eprint.iacr.org/2015/267.pdf) oblivious transfer protocol. - -mod config; -mod error; -pub mod msgs; -mod receiver; -mod sender; - -pub use config::{ - ReceiverConfig, ReceiverConfigBuilder, ReceiverConfigBuilderError, SenderConfig, - SenderConfigBuilder, SenderConfigBuilderError, -}; -pub use error::{ReceiverError, SenderError, SenderVerifyError}; -pub use receiver::{state as receiver_state, Receiver}; -pub use sender::{state as sender_state, Sender}; - -use blake3::Hasher; -use curve25519_dalek::ristretto::RistrettoPoint; -use mpz_core::Block; - -/// Hashes a ristretto point to a symmetric key -/// -/// Prepending a tweak is suggested in Section 2, "Non-Malleability in Practice" -pub(crate) fn hash_point(point: &RistrettoPoint, tweak: u128) -> Block { - // Compute H(tweak || point) - let mut h = Hasher::new(); - h.update(&tweak.to_be_bytes()); - h.update(point.compress().as_bytes()); - let digest = h.finalize(); - let digest: &[u8; 32] = digest.as_bytes(); - - // Copy the first 16 bytes into a Block - let mut block = [0u8; 16]; - block.copy_from_slice(&digest[..16]); - block.into() -} - -#[cfg(test)] -mod tests { - use super::*; - use itybity::IntoBitIterator; - use rstest::*; - - use rand::Rng; - use rand_chacha::ChaCha12Rng; - use rand_core::SeedableRng; - - const SENDER_SEED: [u8; 32] = [0u8; 32]; - const RECEIVER_SEED: [u8; 32] = [1u8; 32]; - - #[fixture] - fn choices() -> Vec { - let mut rng = ChaCha12Rng::seed_from_u64(0); - (0..128).map(|_| rng.gen()).collect() - } - - #[fixture] - fn data() -> Vec<[Block; 2]> { - let mut rng = ChaCha12Rng::seed_from_u64(0); - (0..128) - .map(|_| [rng.gen::<[u8; 16]>().into(), rng.gen::<[u8; 16]>().into()]) - .collect() - } - - #[fixture] - fn expected(data: Vec<[Block; 2]>, choices: Vec) -> Vec { - data.iter() - .zip(choices.iter()) - .map(|([a, b], choice)| if *choice { *b } else { *a }) - .collect() - } - - fn setup( - sender_config: SenderConfig, - receiver_config: ReceiverConfig, - ) -> (Sender, Receiver) { - let sender = Sender::new_with_seed(sender_config, SENDER_SEED); - let receiver = Receiver::new_with_seed(receiver_config, RECEIVER_SEED); - - let (sender_setup, sender) = sender.setup(); - let receiver = receiver.setup(sender_setup); - - (sender, receiver) - } - - #[rstest] - fn test_ot_pass(choices: Vec, data: Vec<[Block; 2]>, expected: Vec) { - let (mut sender, mut receiver) = setup(SenderConfig::default(), ReceiverConfig::default()); - - let receiver_payload = receiver.receive_random(&choices); - let sender_payload = sender.send(&data, receiver_payload).unwrap(); - - let received_data = receiver.receive(sender_payload).unwrap(); - - assert_eq!(received_data, expected); - } - - #[rstest] - fn test_multiple_ot_pass(choices: Vec, data: Vec<[Block; 2]>, expected: Vec) { - let (mut sender, mut receiver) = setup(SenderConfig::default(), ReceiverConfig::default()); - - let receiver_payload = receiver.receive_random(&choices); - let sender_payload = sender.send(&data, receiver_payload).unwrap(); - - let received_data = receiver.receive(sender_payload).unwrap(); - - assert_eq!(received_data, expected); - - let receiver_payload = receiver.receive_random(&choices); - let sender_payload = sender.send(&data, receiver_payload).unwrap(); - - let received_data = receiver.receive(sender_payload).unwrap(); - - assert_eq!(received_data, expected); - } - - #[rstest] - fn test_committed_ot_receiver_pass( - choices: Vec, - data: Vec<[Block; 2]>, - expected: Vec, - ) { - let (mut sender, mut receiver) = setup( - SenderConfig::builder().receiver_commit().build().unwrap(), - ReceiverConfig::builder().receiver_commit().build().unwrap(), - ); - - let receiver_payload = receiver.receive_random(&choices); - let sender_payload = sender.send(&data, receiver_payload).unwrap(); - - let received_data = receiver.receive(sender_payload).unwrap(); - - assert_eq!(received_data, expected); - - let receiver_reveal = receiver.reveal_choices().unwrap(); - - let verified_choices = sender - .verify_choices(RECEIVER_SEED, receiver_reveal) - .unwrap(); - - assert_eq!(choices, verified_choices.into_lsb0_vec()); - } - - #[rstest] - fn test_committed_ot_receiver_cheat_choice( - choices: Vec, - data: Vec<[Block; 2]>, - expected: Vec, - ) { - let (mut sender, mut receiver) = setup( - SenderConfig::builder().receiver_commit().build().unwrap(), - ReceiverConfig::builder().receiver_commit().build().unwrap(), - ); - - let receiver_payload = receiver.receive_random(&choices); - let sender_payload = sender.send(&data, receiver_payload).unwrap(); - - let received_data = receiver.receive(sender_payload).unwrap(); - - assert_eq!(received_data, expected); - - let mut receiver_reveal = receiver.reveal_choices().unwrap(); - - // Flip a bit - receiver_reveal.choices[0] ^= 1; - - let err = sender - .verify_choices(RECEIVER_SEED, receiver_reveal) - .unwrap_err(); - - assert!(matches!( - err, - SenderError::VerifyError(error::SenderVerifyError::InconsistentChoice) - )); - } -} diff --git a/crates/mpz-ot-core/src/chou_orlandi/msgs.rs b/crates/mpz-ot-core/src/chou_orlandi/msgs.rs index f19c1b8f..c70bf5a9 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/msgs.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/msgs.rs @@ -4,8 +4,6 @@ use curve25519_dalek::RistrettoPoint; use mpz_core::Block; use serde::{Deserialize, Serialize}; -use crate::TransferId; - /// Sender setup message. #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] pub struct SenderSetup { @@ -16,8 +14,6 @@ pub struct SenderSetup { /// Sender payload message. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct SenderPayload { - /// The transfer ID. - pub id: TransferId, /// The sender's ciphertexts pub payload: Vec<[Block; 2]>, } @@ -25,15 +21,6 @@ pub struct SenderPayload { /// Receiver payload message. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ReceiverPayload { - /// The transfer ID. - pub id: TransferId, /// The receiver's blinded choices. pub blinded_choices: Vec, } - -/// Receiver reveal message. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ReceiverReveal { - /// The receiver's choices. - pub choices: Vec, -} diff --git a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs index 403802f9..b2d3cbff 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs @@ -1,11 +1,16 @@ -use crate::chou_orlandi::{ - hash_point, - msgs::{ReceiverPayload, ReceiverReveal, SenderPayload, SenderSetup}, - ReceiverConfig, ReceiverError, +use std::{collections::VecDeque, mem}; + +use crate::{ + chou_orlandi::{ + hash_point, + msgs::{ReceiverPayload, SenderPayload, SenderSetup}, + ReceiverError, + }, + ot::{OTReceiver, OTReceiverOutput}, + TransferId, }; -use crate::TransferId; -use itybity::{BitIterable, FromBitIterator, ToBits}; +use mpz_common::future::{new_output, MaybeDone, Sender as OutputSender}; use mpz_core::Block; use curve25519_dalek::{ @@ -19,64 +24,49 @@ use rand_core::SeedableRng; #[cfg(feature = "rayon")] use rayon::prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +type Error = ReceiverError; +type Result = core::result::Result; + +#[derive(Debug)] +struct Queued { + count: usize, + sender: OutputSender>, +} + /// A [CO15](https://eprint.iacr.org/2015/267.pdf) receiver. #[derive(Debug, Default)] pub struct Receiver { - /// The receiver's configuration - config: ReceiverConfig, + queue: VecDeque, + choices: Vec, /// The current state of the protocol state: T, } impl Receiver { /// Creates a new receiver. - /// - /// # Committed Receiver - /// - /// ## ⚠️ Warning ⚠️ - /// - /// If the receiver is committed, the receiver's RNG seed must be unbiased such as generated by - /// a secure coin toss protocol with the sender. - /// - /// Use the [`new_with_seed`] method to provide a seed. - /// - /// # Arguments - /// - /// * `config` - The receiver's configuration - pub fn new(config: ReceiverConfig) -> Self { + pub fn new() -> Self { Self { - config, + queue: VecDeque::new(), + choices: Vec::new(), state: state::Initialized::default(), } } /// Creates a new receiver with the provided RNG seed. /// - /// # Committed Receiver - /// - /// ## ⚠️ Warning ⚠️ - /// - /// If the receiver is committed, the receiver's RNG seed must be unbiased such as generated by - /// a secure coin toss protocol with the sender. - /// /// # Arguments /// - /// * `config` - The receiver's configuration /// * `seed` - The RNG seed used to generate the receiver's keys - pub fn new_with_seed(config: ReceiverConfig, seed: [u8; 32]) -> Self { + pub fn new_with_seed(seed: [u8; 32]) -> Self { Self { - config, + queue: VecDeque::new(), + choices: Vec::new(), state: state::Initialized { rng: ChaCha20Rng::from_seed(seed), }, } } - /// Returns the receiver's configuration. - pub fn config(&self) -> &ReceiverConfig { - &self.config - } - /// Sets up the receiver. /// /// # Arguments @@ -86,13 +76,13 @@ impl Receiver { let state::Initialized { rng } = self.state; Receiver { - config: self.config, + queue: self.queue, + choices: self.choices, state: state::Setup { rng, sender_base_table: RistrettoBasepointTable::create(&sender_setup.public_key), transfer_id: TransferId::default(), counter: 0, - choice_log: Vec::default(), decryption_keys: Vec::default(), }, } @@ -100,65 +90,51 @@ impl Receiver { } impl Receiver { - /// Computes the decryption keys, returning the Receiver's payload to be sent to the Sender. - /// - /// # Arguments - /// - /// * `choices` - The receiver's choices - pub fn receive_random(&mut self, choices: &[T]) -> ReceiverPayload { + /// Returns whether the receiver wants to flush. + pub fn wants_flush(&self) -> bool { + !self.choices.is_empty() + } + + /// Sends the blinded choices to the Sender. + pub fn choose(&mut self) -> ReceiverPayload { let state::Setup { rng, sender_base_table, counter, - choice_log, - decryption_keys: cached_decryption_keys, + decryption_keys, .. } = &mut self.state; - let private_keys = choices - .iter_lsb0() + let choices = mem::take(&mut self.choices); + let private_keys = (0..choices.len()) .map(|_| Scalar::random(rng)) .collect::>(); - let (blinded_choices, decryption_keys) = - compute_decryption_keys(sender_base_table, &private_keys, choices, *counter); + let (blinded_choices, new_keys) = + compute_decryption_keys(sender_base_table, &private_keys, &choices, *counter); *counter += blinded_choices.len(); - cached_decryption_keys.extend(decryption_keys); + decryption_keys.extend(new_keys); - // If configured, log the choices - if self.config.receiver_commit() { - choice_log.extend(choices.iter_lsb0()); - } - - ReceiverPayload { - id: self.state.transfer_id, - blinded_choices, - } + ReceiverPayload { blinded_choices } } - /// Receives the encrypted payload from the Sender, returning the plaintext messages corresponding - /// to the Receiver's choices. + /// Receives the encrypted payload from the Sender. /// /// # Arguments /// /// * `payload` - The encrypted payload from the Sender - pub fn receive(&mut self, payload: SenderPayload) -> Result, ReceiverError> { + pub fn receive(&mut self, payload: SenderPayload) -> Result<()> { let state::Setup { - transfer_id: current_id, + transfer_id, decryption_keys, .. } = &mut self.state; - let SenderPayload { id, payload } = payload; - - // Check that the transfer id matches - let expected_id = current_id.next(); - if id != expected_id { - return Err(ReceiverError::IdMismatch(expected_id, id)); - } + let SenderPayload { payload } = payload; - // Check that the number of ciphertexts does not exceed the number of pending keys + // Check that the number of ciphertexts does not exceed the number of pending + // keys if payload.len() > decryption_keys.len() { return Err(ReceiverError::CountMismatch( decryption_keys.len(), @@ -166,33 +142,59 @@ impl Receiver { )); } - // Drain the decryption keys and decrypt the ciphertexts - Ok(decryption_keys - .drain(..payload.len()) - .zip(payload) - .map( - |((c, key), [ct0, ct1])| { - if c { - key ^ ct1 - } else { - key ^ ct0 - } - }, - ) - .collect::>()) + let mut msgs = + decryption_keys + .drain(..payload.len()) + .zip(payload) + .map( + |((c, key), [ct0, ct1])| { + if c { + key ^ ct1 + } else { + key ^ ct0 + } + }, + ); + + while let Some(Queued { count, sender }) = self.queue.pop_front() { + let output = OTReceiverOutput { + id: transfer_id.next(), + msgs: msgs.by_ref().take(count).collect(), + }; + + sender.send(output); + } + + Ok(()) + } +} + +impl OTReceiver for Receiver +where + T: state::State, +{ + type Error = Error; + type Future = MaybeDone>; + + fn alloc(&mut self, _count: usize) -> Result<(), Self::Error> { + Ok(()) } - /// Reveals the receiver's choices to the Sender - pub fn reveal_choices(self) -> Result { - let state::Setup { choice_log, .. } = self.state; + fn queue_recv_ot(&mut self, choices: &[bool]) -> Result { + let (sender, recv) = new_output(); + + self.choices.extend(choices); + self.queue.push_back(Queued { + count: choices.len(), + sender, + }); - Ok(ReceiverReveal { - choices: Vec::::from_lsb0_iter(choice_log), - }) + Ok(recv) } } -/// Computes the blinded choices `B` and the decryption keys for the OT receiver. +/// Computes the blinded choices `B` and the decryption keys for the OT +/// receiver. /// /// # Arguments /// @@ -200,11 +202,11 @@ impl Receiver { /// * `receiver_private_keys` - The private keys of the OT receiver /// * `choices` - The choices of the OT receiver /// * `offset` - The number of decryption keys that have already been computed -/// (used for the key derivation tweak) -fn compute_decryption_keys( +/// (used for the key derivation tweak) +fn compute_decryption_keys( base_table: &RistrettoBasepointTable, receiver_private_keys: &[Scalar], - choices: &[T], + choices: &[bool], offset: usize, ) -> (Vec, Vec<(bool, Block)>) { let zero = &Scalar::ZERO * base_table; @@ -213,12 +215,9 @@ fn compute_decryption_keys( cfg_if::cfg_if! { if #[cfg(feature = "rayon")] { - // itybity currently doesn't support `IndexedParallelIterator` for collections, - // so we allocate instead. - let temp = receiver_private_keys.iter().zip(choices.iter_lsb0()).collect::>(); - let iter = temp.into_par_iter().enumerate(); + let iter = receiver_private_keys.into_par_iter().zip(choices.into_par_iter().copied()).enumerate(); } else { - let iter = receiver_private_keys.iter().zip(choices.iter_lsb0()).enumerate(); + let iter = receiver_private_keys.iter().zip(choices).enumerate(); } } @@ -284,8 +283,6 @@ pub mod state { pub(super) transfer_id: TransferId, /// Counts how many decryption keys we've computed so far pub(super) counter: usize, - /// Log of the receiver's choice bits - pub(super) choice_log: Vec, /// The decryption key for each OT, with the corresponding choice bit pub(super) decryption_keys: Vec<(bool, Block)>, diff --git a/crates/mpz-ot-core/src/chou_orlandi/sender.rs b/crates/mpz-ot-core/src/chou_orlandi/sender.rs index 09a8b5a6..cd7641a1 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/sender.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/sender.rs @@ -1,13 +1,16 @@ +use std::{collections::VecDeque, mem}; + use crate::{ chou_orlandi::{ hash_point, - msgs::{ReceiverPayload, ReceiverReveal, SenderPayload, SenderSetup}, - Receiver, ReceiverConfig, SenderConfig, SenderError, SenderVerifyError, + msgs::{ReceiverPayload, SenderPayload, SenderSetup}, + SenderError, }, + ot::{OTSender, OTSenderOutput}, TransferId, }; -use itybity::IntoBitIterator; +use mpz_common::future::{new_output, MaybeDone, Sender as OutputSender}; use mpz_core::Block; use curve25519_dalek::{ @@ -19,50 +22,38 @@ use rand_chacha::ChaCha20Rng; #[cfg(feature = "rayon")] use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; -/// A tape used to record all the blinded choices made by the receiver, which -/// can later be used to perform a consistency check. -#[derive(Debug, Default)] -struct Tape { - receiver_choices: Vec, +type Error = SenderError; +type Result = core::result::Result; + +#[derive(Debug)] +struct Queued { + sender: OutputSender, } /// A [CO15](https://eprint.iacr.org/2015/267.pdf) sender. #[derive(Debug, Default)] pub struct Sender { - config: SenderConfig, + queue: VecDeque, + msgs: Vec<[Block; 2]>, /// Current state state: T, - /// Protocol tape - tape: Option, } impl Sender { /// Creates a new Sender - /// - /// # Arguments - /// - /// * `config` - The Sender's configuration - pub fn new(config: SenderConfig) -> Self { - let tape = if config.receiver_commit() { - Some(Tape::default()) - } else { - None - }; - + pub fn new() -> Self { Sender { - config, + queue: VecDeque::new(), + msgs: Vec::new(), state: state::Initialized::default(), - tape, } } /// Creates a new Sender with the provided RNG seed /// /// # Arguments - /// - /// * `config` - The Sender's configuration /// * `seed` - The RNG seed used to generate the sender's keys - pub fn new_with_seed(config: SenderConfig, seed: [u8; 32]) -> Self { + pub fn new_with_seed(seed: [u8; 32]) -> Self { let mut rng = ChaCha20Rng::from_seed(seed); let private_key = Scalar::random(&mut rng); @@ -72,24 +63,13 @@ impl Sender { public_key, }; - let tape = if config.receiver_commit() { - Some(Tape::default()) - } else { - None - }; - Sender { - config, + queue: VecDeque::new(), + msgs: Vec::new(), state, - tape, } } - /// Returns the Sender's configuration - pub fn config(&self) -> &SenderConfig { - &self.config - } - /// Returns the setup message to be sent to the receiver. pub fn setup(self) -> (SenderSetup, Sender) { let state::Initialized { @@ -100,129 +80,90 @@ impl Sender { ( SenderSetup { public_key }, Sender { - config: self.config, + queue: self.queue, + msgs: self.msgs, state: state::Setup { private_key, public_key, transfer_id: TransferId::default(), counter: 0, }, - tape: self.tape, }, ) } } impl Sender { - /// Obliviously sends `inputs` to the receiver. + /// Returns `true` if the sender wants to receive. + pub fn wants_recv(&self) -> bool { + !self.queue.is_empty() + } + + /// Obliviously sends messages to the receiver. /// /// # Arguments /// - /// * `inputs` - The inputs to be obliviously sent to the receiver. /// * `receiver_payload` - The receiver's choice payload. - pub fn send( - &mut self, - inputs: &[[Block; 2]], - receiver_payload: ReceiverPayload, - ) -> Result { + pub fn send(&mut self, receiver_payload: ReceiverPayload) -> Result { let state::Setup { private_key, public_key, - transfer_id: current_id, + transfer_id, counter, .. } = &mut self.state; - let ReceiverPayload { - id, - blinded_choices, - } = receiver_payload; - - // Check that the transfer id matches - let expected_id = current_id.next(); - if id != expected_id { - return Err(SenderError::IdMismatch(expected_id, id)); - } + let ReceiverPayload { blinded_choices } = receiver_payload; + let msgs = mem::take(&mut self.msgs); - // Check that the number of inputs matches the number of choices - if inputs.len() != blinded_choices.len() { + // Check that the number of messages matches the number of choices + if msgs.len() != blinded_choices.len() { return Err(SenderError::CountMismatch( - inputs.len(), + msgs.len(), blinded_choices.len(), )); } - if let Some(tape) = self.tape.as_mut() { - // Record the receiver's choices - tape.receiver_choices.extend_from_slice(&blinded_choices); - } - let mut payload = compute_encryption_keys(private_key, public_key, &blinded_choices, *counter); - *counter += inputs.len(); + *counter += msgs.len(); - // Encrypt the inputs - for (input, payload) in inputs.iter().zip(payload.iter_mut()) { - payload[0] = input[0] ^ payload[0]; - payload[1] = input[1] ^ payload[1]; + // Encrypt the messages + for (msg, payload) in msgs.iter().zip(payload.iter_mut()) { + payload[0] = msg[0] ^ payload[0]; + payload[1] = msg[1] ^ payload[1]; } - Ok(SenderPayload { id, payload }) - } - - /// Returns the Receiver choices after verifying them against the tape. - /// - /// # ⚠️ Warning ⚠️ - /// - /// The receiver's RNG seed must be unbiased such as generated by - /// a secure coin toss protocol with the sender. - /// - /// # Arguments - /// - /// * `receiver_seed` - The seed used to generate the receiver's private keys. - /// * `receiver_reveal` - The receiver's private inputs. - pub fn verify_choices( - self, - receiver_seed: [u8; 32], - receiver_reveal: ReceiverReveal, - ) -> Result, SenderError> { - let state::Setup { public_key, .. } = self.state; - - let Some(tape) = &self.tape else { - return Err(SenderVerifyError::TapeNotRecorded)?; - }; - - let ReceiverReveal { choices } = receiver_reveal; - - let choices = choices - .into_iter_lsb0() - .take(tape.receiver_choices.len()) - .collect::>(); - - // Check that the number of choices matches - if tape.receiver_choices.len() != choices.len() { - return Err(SenderVerifyError::ChoiceCountMismatch( - tape.receiver_choices.len(), - choices.len(), - ))?; + // Clear the queue. + for Queued { sender } in self.queue.drain(..) { + sender.send(OTSenderOutput { + id: transfer_id.next(), + }); } - // Simulate the receiver - let receiver = Receiver::new_with_seed(ReceiverConfig::default(), receiver_seed); + Ok(SenderPayload { payload }) + } +} + +impl OTSender for Sender +where + S: state::State, +{ + type Error = Error; + type Future = MaybeDone; - let mut receiver = receiver.setup(SenderSetup { public_key }); + fn alloc(&mut self, _count: usize) -> Result<()> { + Ok(()) + } - let ReceiverPayload { - blinded_choices, .. - } = receiver.receive_random(&choices); + fn queue_send_ot(&mut self, msgs: &[[Block; 2]]) -> Result { + let (sender, recv) = new_output(); - // Check that the simulated receiver's choices match the ones recorded in the tape - if blinded_choices != tape.receiver_choices { - return Err(SenderVerifyError::InconsistentChoice)?; - } + self.msgs.extend_from_slice(msgs); + self.queue.push_back(Queued { sender }); - Ok(choices) + Ok(recv) } } @@ -233,8 +174,8 @@ impl Sender { /// * `private_key` - The sender's private key. /// * `public_key` - The sender's public key. /// * `blinded_choices` - The receiver's blinded choices. -/// * `offset` - The number of OTs that have already been performed -/// (used for the key derivation tweak) +/// * `offset` - The number of OTs that have already been performed (used for +/// the key derivation tweak) fn compute_encryption_keys( private_key: &Scalar, public_key: &RistrettoPoint, diff --git a/crates/mpz-ot-core/src/cot.rs b/crates/mpz-ot-core/src/cot.rs new file mode 100644 index 00000000..aa901b41 --- /dev/null +++ b/crates/mpz-ot-core/src/cot.rs @@ -0,0 +1,72 @@ +//! Correlated oblivious transfer. + +mod derandomize; + +pub use derandomize::{ + Adjust, DerandCOTReceiver, DerandCOTReceiverError, DerandCOTSender, DerandCOTSenderError, +}; + +use mpz_common::future::Output; + +use crate::TransferId; + +/// Output the sender receives from the COT functionality. +#[derive(Debug)] +pub struct COTSenderOutput { + /// Transfer id. + pub id: TransferId, +} + +/// Correlated oblivious transfer sender. +pub trait COTSender { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + /// Future type. + type Future: Output; + + /// Allocates `count` COTs for preprocessing. + fn alloc(&mut self, count: usize) -> Result<(), Self::Error>; + + /// Returns the number of available COTs. + fn available(&self) -> usize; + + /// Returns the global correlation key, `delta`. + fn delta(&self) -> T; + + /// Queues sending of COTs. + /// + /// # Arguments + /// + /// * `keys` - Keys to send. + fn queue_send_cot(&mut self, keys: &[T]) -> Result; +} + +/// Output the receiver receives from the COT functionality. +#[derive(Debug)] +pub struct COTReceiverOutput { + /// Transfer id. + pub id: TransferId, + /// Chosen messages. + pub msgs: Vec, +} + +/// Correlated oblivious transfer receiver. +pub trait COTReceiver { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + /// Future type. + type Future: Output>; + + /// Allocates `count` COTs for preprocessing. + fn alloc(&mut self, count: usize) -> Result<(), Self::Error>; + + /// Returns the number of available COTs. + fn available(&self) -> usize; + + /// Queues receiving of COTs. + /// + /// # Arguments + /// + /// * `choices` - COT choices. + fn queue_recv_cot(&mut self, choices: &[T]) -> Result; +} diff --git a/crates/mpz-ot-core/src/cot/derandomize.rs b/crates/mpz-ot-core/src/cot/derandomize.rs new file mode 100644 index 00000000..90571d26 --- /dev/null +++ b/crates/mpz-ot-core/src/cot/derandomize.rs @@ -0,0 +1,390 @@ +use std::{collections::VecDeque, mem}; + +use mpz_common::future::{new_output, MaybeDone, Sender}; +use mpz_core::{bitvec::BitVec, Block}; +use serde::{Deserialize, Serialize}; + +use crate::{ + cot::{COTReceiver, COTReceiverOutput, COTSender, COTSenderOutput}, + rcot::{RCOTReceiver, RCOTReceiverOutput, RCOTSender, RCOTSenderOutput}, + Derandomize, +}; + +/// COT adjustment message. +#[derive(Debug, Serialize, Deserialize)] +pub struct Adjust { + adjust: Vec, +} + +#[derive(Debug)] +struct QueuedSend { + count: usize, + sender: Sender, +} + +/// Derandomized COT sender. +/// +/// This is a COT sender which derandomizes preprocessed RCOTs. +#[derive(Debug)] +pub struct DerandCOTSender { + rcot: T, + pending: usize, + adjust: Vec, + queue: VecDeque, +} + +impl DerandCOTSender { + /// Creates a new `DerandCOTSender`. + pub fn new(rcot: T) -> Self { + Self { + rcot, + pending: 0, + adjust: Vec::new(), + queue: VecDeque::new(), + } + } + + /// Returns a reference to the RCOT sender. + pub fn rcot(&self) -> &T { + &self.rcot + } + + /// Returns a mutable reference to the RCOT sender. + pub fn rcot_mut(&mut self) -> &mut T { + &mut self.rcot + } + + /// Returns the inner RCOT sender. + pub fn into_inner(self) -> T { + self.rcot + } +} + +impl DerandCOTSender +where + T: RCOTSender, +{ + /// Returns `true` if the sender wants to send adjustments. + pub fn wants_adjust(&self) -> bool { + self.pending > 0 + } + + /// Returns the adjustment message. + pub fn adjust( + &mut self, + derandomize: Derandomize, + ) -> Result, DerandCOTSenderError> { + let Derandomize { flip } = derandomize; + + if flip.len() != self.pending { + return Err(DerandCOTSenderError::new(format!( + "derandomize is wrong length: {} != {}", + flip.len(), + self.pending + ))); + } + + let mut i = 0; + let delta = self.delta(); + let mut adjust = mem::take(&mut self.adjust); + for QueuedSend { count, sender } in mem::take(&mut self.queue) { + let RCOTSenderOutput { id, keys } = self + .rcot + .try_send_rcot(count) + .map_err(DerandCOTSenderError::new)?; + + adjust[i..i + count] + .iter_mut() + .zip(&flip[i..i + count]) + .zip(keys) + .for_each(|((adjust, flip), key)| { + *adjust ^= key; + *adjust ^= if *flip { delta } else { Block::ZERO }; + }); + + i += count; + + sender.send(COTSenderOutput { id }); + } + + self.pending = 0; + + Ok(Adjust { adjust }) + } +} + +impl COTSender for DerandCOTSender +where + T: RCOTSender, +{ + type Error = DerandCOTSenderError; + type Future = MaybeDone; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.rcot.alloc(count).map_err(DerandCOTSenderError::new) + } + + fn available(&self) -> usize { + self.rcot.available() + } + + fn delta(&self) -> Block { + self.rcot.delta() + } + + fn queue_send_cot(&mut self, keys: &[Block]) -> Result { + let count = keys.len(); + let (sender, recv) = new_output(); + + self.adjust.extend_from_slice(keys); + self.pending += count; + self.queue.push_back(QueuedSend { count, sender }); + + Ok(recv) + } +} + +/// Error for [`DerandCOTSender`]. +#[derive(Debug, thiserror::Error)] +#[error("derandomized COT sender error: {source}")] +pub struct DerandCOTSenderError { + source: Box, +} + +impl DerandCOTSenderError { + fn new(err: E) -> Self + where + E: Into>, + { + Self { source: err.into() } + } +} + +#[derive(Debug)] +struct QueuedReceive { + count: usize, + sender: Sender>, +} + +/// Derandomized COT receiver. +#[derive(Debug)] +pub struct DerandCOTReceiver { + rcot: T, + pending: usize, + derandomize: BitVec, + queue: VecDeque, +} + +impl DerandCOTReceiver { + /// Creates a new `DerandCOTReceiver`. + pub fn new(rcot: T) -> Self { + Self { + rcot, + pending: 0, + derandomize: BitVec::new(), + queue: VecDeque::new(), + } + } + + /// Returns a reference to the RCOT receiver. + pub fn rcot(&self) -> &T { + &self.rcot + } + + /// Returns a mutable reference to the RCOT receiver. + pub fn rcot_mut(&mut self) -> &mut T { + &mut self.rcot + } + + /// Returns the inner RCOT receiver. + pub fn into_inner(self) -> T { + self.rcot + } +} + +impl DerandCOTReceiver +where + T: RCOTReceiver, +{ + /// Returns `true` if the receiver wants to adjust COTs. + pub fn wants_adjust(&self) -> bool { + self.pending > 0 + } + + /// Adjusts the COTs. + pub fn adjust( + &mut self, + ) -> Result<(Derandomize, ReceiveAdjust<'_, T>), DerandCOTReceiverError> { + let mut flip = self.derandomize.clone(); + let mut cots = Vec::new(); + let mut i = 0; + for QueuedReceive { count, .. } in self.queue.iter() { + let count = *count; + if self.rcot.available() < count { + break; + } + + let RCOTReceiverOutput { + id, + choices: masks, + msgs, + } = self + .rcot + .try_recv_rcot(count) + .map_err(DerandCOTReceiverError::new)?; + + // Mask choice bits. + flip[i..i + count] + .iter_mut() + .zip(masks) + .for_each(|(mut choice, mask)| *choice ^= mask); + + cots.push(COTReceiverOutput { id, msgs }); + i += count; + } + + Ok(( + Derandomize { flip }, + ReceiveAdjust { + recv: self, + cots, + count: i, + }, + )) + } +} + +/// Receiver returned by [`DerandCOTReceiver::adjust`]. +#[must_use] +pub struct ReceiveAdjust<'a, T> { + recv: &'a mut DerandCOTReceiver, + count: usize, + cots: Vec>, +} + +impl ReceiveAdjust<'_, T> +where + T: RCOTReceiver, +{ + /// Receives the adjusted COTs. + pub fn receive(self, adjust: Adjust) -> Result<(), DerandCOTReceiverError> { + let Adjust { adjust } = adjust; + + if adjust.len() != self.count { + return Err(DerandCOTReceiverError::new(format!( + "adjust is wrong length: {} != {}", + adjust.len(), + self.count + ))); + } + + let mut adjust = adjust.into_iter(); + let n = self.cots.len(); + for (mut output, QueuedReceive { sender, .. }) in + self.cots.into_iter().zip(self.recv.queue.drain(..n)) + { + output + .msgs + .iter_mut() + .zip(adjust.by_ref()) + .for_each(|(msg, adjust)| *msg ^= adjust); + + sender.send(output); + } + + self.recv.pending -= self.count; + + Ok(()) + } +} + +impl COTReceiver for DerandCOTReceiver +where + T: RCOTReceiver, +{ + type Error = DerandCOTReceiverError; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.rcot.alloc(count).map_err(DerandCOTReceiverError::new) + } + + fn available(&self) -> usize { + self.rcot.available() + } + + fn queue_recv_cot(&mut self, choices: &[bool]) -> Result { + let count = choices.len(); + let (sender, recv) = new_output(); + + self.derandomize.extend(choices.iter().copied()); + self.pending += count; + self.queue.push_back(QueuedReceive { count, sender }); + + Ok(recv) + } +} + +/// Error for [`DerandCOTReceiver`]. +#[derive(Debug, thiserror::Error)] +#[error("derandomized COT receiver error: {source}")] +pub struct DerandCOTReceiverError { + source: Box, +} + +impl DerandCOTReceiverError { + fn new(err: E) -> Self + where + E: Into>, + { + Self { source: err.into() } + } +} + +#[cfg(test)] +mod tests { + use mpz_common::future::Output; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use crate::{ideal::rcot::IdealRCOT, test::assert_cot}; + + use super::*; + + #[test] + fn test_derandomize_cot() { + let mut rng = StdRng::seed_from_u64(0); + let delta = Block::random(&mut rng); + let rcot = IdealRCOT::new(rng.gen(), delta); + + let mut sender = DerandCOTSender::new(rcot.clone()); + let mut receiver = DerandCOTReceiver::new(rcot); + + let count = 10; + sender.alloc(count).unwrap(); + receiver.alloc(count).unwrap(); + + sender.rcot_mut().flush().unwrap(); + + let choices = (0..count).map(|_| rng.gen()).collect::>(); + let keys: Vec<_> = (0..count).map(|_| Block::random(&mut rng)).collect(); + + let mut sender_output = sender.queue_send_cot(&keys).unwrap(); + let mut receiver_output = receiver.queue_recv_cot(&choices).unwrap(); + + assert!(sender.wants_adjust()); + assert!(receiver.wants_adjust()); + + let (derandomize, recv) = receiver.adjust().unwrap(); + let adjust = sender.adjust(derandomize).unwrap(); + recv.receive(adjust).unwrap(); + + let COTSenderOutput { id: sender_id } = sender_output.try_recv().unwrap().unwrap(); + let COTReceiverOutput { + id: receiver_id, + msgs, + } = receiver_output.try_recv().unwrap().unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &choices, &keys, &msgs); + } +} diff --git a/crates/mpz-ot-core/src/ferret/cuckoo.rs b/crates/mpz-ot-core/src/ferret/cuckoo.rs deleted file mode 100644 index 938e2f03..00000000 --- a/crates/mpz-ot-core/src/ferret/cuckoo.rs +++ /dev/null @@ -1,195 +0,0 @@ -//! Implementation of Cuckoo hash. - -use std::sync::Arc; - -use mpz_core::{aes::AesEncryptor, Block}; - -use super::{CUCKOO_HASH_NUM, CUCKOO_TRIAL_NUM}; - -/// Cuckoo hash insertion error -#[derive(Debug, thiserror::Error)] -#[error("insertion loops")] -pub struct CuckooHashError; - -/// Errors that can occur when handling Buckets. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum BucketError { - #[error("invalid bucket state: expected {0}")] - NotInBucket(String), - #[error("invalid bucket index: expected {0}")] - OutOfRange(String), -} - -/// Item in Cuckoo hash table. -#[derive(Copy, Clone, Debug, PartialEq)] -pub struct Item { - /// Value in the table. - pub(crate) value: u32, - /// The hash index during the insertion. - pub(crate) hash_index: usize, -} - -/// Implementation of Cuckoo hash. See [here](https://eprint.iacr.org/2019/1084.pdf) for reference. -pub struct CuckooHash { - hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, -} - -impl CuckooHash { - /// Creates a new instance. - #[inline] - pub fn new(hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>) -> Self { - Self { hashes } - } - - /// Insert elements into a Cuckoo hash table. - /// - /// * Argument - /// - /// * `alphas` - A u32 vector being inserted. - #[inline] - pub fn insert(&self, alphas: &[u32]) -> Result>, CuckooHashError> { - // Always sets m = 1.5 * t. t is the length of `alphas`. - let m = compute_table_length(alphas.len() as u32); - - // Allocates table. - let mut table = vec![None; m]; - // Inserts each alpha. - for &value in alphas { - self.hash(&mut table, value)? - } - Ok(table) - } - - // Hash an element to a position with the current hash function. - #[inline] - fn hash(&self, table: &mut [Option], value: u32) -> Result<(), CuckooHashError> { - // The item consists of the value and hash index, starting from 0. - let mut item = Item { - value, - hash_index: 0, - }; - - for _ in 0..CUCKOO_TRIAL_NUM { - // Computes the position of the value. - let pos = hash_to_index(&self.hashes[item.hash_index], table.len(), item.value); - - // Inserts the value to position `pos`. - let opt_item = table[pos].replace(item); - - // If position `pos` is not empty before the above insertion, iteratively inserts the obtained value. - if let Some(x) = opt_item { - item = x; - item.hash_index = (item.hash_index + 1) % CUCKOO_HASH_NUM; - } else { - // If no value assigned to position `pos`, end the process. - return Ok(()); - } - } - Err(CuckooHashError) - } -} - -/// Implementation of Bucket. See step 3 in Figure 7 -pub struct Bucket { - // The hash functions. - hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, - // The number of buckets. - m: usize, -} - -impl Bucket { - /// Creates a new instance. - #[inline] - pub fn new(hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, m: usize) -> Self { - Self { hashes, m } - } - - /// Inserts the input vector [0..n-1] into buckets. - /// - /// # Argument - /// - /// * `n` - The length of the vector [0..n-1]. - #[inline] - pub fn insert(&self, n: u32) -> Vec> { - let mut buckets = vec![Vec::default(); self.m]; - // NOTE: the sorted step in Step 3.c can be removed. - for i in 0..n { - for (index, hash) in self.hashes.iter().enumerate() { - let pos = hash_to_index(hash, self.m, i); - buckets[pos].push(Item { - value: i, - hash_index: index, - }); - } - } - buckets - } -} - -// Always sets m = 1.5 * t. t is the length of `alphas`. See Section 7.1 Parameter Selection. -#[inline(always)] -pub(crate) fn compute_table_length(t: u32) -> usize { - (1.5 * (t as f32)).ceil() as usize -} - -// Hash the value into index using AES. -#[inline(always)] -pub(crate) fn hash_to_index(hash: &AesEncryptor, range: usize, value: u32) -> usize { - let mut blk: Block = bytemuck::cast::<_, Block>(value as u128); - blk = hash.encrypt_block(blk); - let res = u128::from_le_bytes(blk.to_bytes()); - (res as usize) % range -} - -// Finds the position of the item in each Bucket. -#[inline(always)] -pub(crate) fn find_pos(bucket: &[Item], item: &Item) -> Result { - let pos = bucket.iter().position(|&x| *item == x); - pos.ok_or(BucketError::NotInBucket("not in the bucket".to_string())) -} - -#[cfg(test)] -mod tests { - use crate::ferret::cuckoo::find_pos; - use std::sync::Arc; - - use super::{Bucket, CuckooHash}; - use mpz_core::{aes::AesEncryptor, prg::Prg}; - - #[test] - fn cockoo_hash_bucket_test() { - let mut prg = Prg::new(); - const NUM: usize = 50; - let hashes = Arc::new(std::array::from_fn(|_| { - AesEncryptor::new(prg.random_block()) - })); - let cuckoo = CuckooHash::new(hashes.clone()); - let input: [u32; NUM] = std::array::from_fn(|i| i as u32); - - let table = cuckoo.insert(&input).unwrap(); - - let bucket = Bucket::new(hashes, table.len()); - let buckets = bucket.insert((2 * NUM) as u32); - - assert!(table - .iter() - .zip(buckets.iter()) - .all(|(value, bin)| match value { - Some(x) => bin.contains(x), - None => true, - })); - - let _: Vec = table - .iter() - .zip(buckets.iter()) - .map(|(value, bin)| { - if let Some(x) = value { - find_pos(bin, x).unwrap() - } else { - bin.len() + 1 - } - }) - .collect(); - } -} diff --git a/crates/mpz-ot-core/src/ferret/error.rs b/crates/mpz-ot-core/src/ferret/error.rs deleted file mode 100644 index d209b77d..00000000 --- a/crates/mpz-ot-core/src/ferret/error.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! Errors that can occur when using the Ferret protocol. - -/// Errors that can occur when using the Ferret sender. -#[derive(Debug, thiserror::Error)] -#[error("invalid input: expected {0}")] -pub struct SenderError(pub String); - -/// Errors that can occur when using the Ferret receiver. -#[derive(Debug, thiserror::Error)] -#[error("invalid input: expected {0}")] -pub struct ReceiverError(pub String); diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs deleted file mode 100644 index 3ad7701e..00000000 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ /dev/null @@ -1,142 +0,0 @@ -//! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. - -use mpz_core::lpn::LpnParameters; - -pub mod cuckoo; -pub mod error; -pub mod mpcot; -pub mod msgs; -pub mod receiver; -pub mod sender; -pub mod spcot; - -/// Computational security parameter -pub const CSP: usize = 128; - -/// Number of hashes in Cuckoo hash. -pub const CUCKOO_HASH_NUM: usize = 3; - -/// Trial numbers in Cuckoo hash insertion. -pub const CUCKOO_TRIAL_NUM: usize = 100; - -/// LPN parameters with regular noise. -/// Derived from https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/constants.h -pub const LPN_PARAMETERS_REGULAR: LpnParameters = LpnParameters { - n: 10180608, - k: 124000, - t: 4971, -}; - -/// LPN parameters with uniform noise. -/// Derived from Table 2. -pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { - n: 10616092, - k: 588160, - t: 1324, -}; - -/// The type of Lpn parameters. -#[derive(Debug)] -pub enum LpnType { - /// Uniform error distribution. - Uniform, - /// Regular error distribution. - Regular, -} - -#[cfg(test)] -mod tests { - use super::*; - - use msgs::LpnMatrixSeed; - use receiver::Receiver; - use sender::Sender; - - use crate::ideal::{cot::IdealCOT, mpcot::IdealMpcot}; - use crate::test::assert_cot; - use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; - use mpz_core::{lpn::LpnParameters, prg::Prg}; - use rand::SeedableRng; - - const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { - n: 9600, - k: 1220, - t: 600, - }; - - #[test] - fn ferret_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); - let delta = prg.random_block(); - let mut ideal_cot = IdealCOT::default(); - let mut ideal_mpcot = IdealMpcot::default(); - - ideal_cot.set_delta(delta); - ideal_mpcot.set_delta(delta); - - let sender = Sender::new(); - let receiver = Receiver::new(); - - // Invoke Ideal COT to init the Ferret setup phase. - let (sender_cot, receiver_cot) = ideal_cot.random_correlated(LPN_PARAMETERS_TEST.k); - - let RCOTSenderOutput { msgs: v, .. } = sender_cot; - let RCOTReceiverOutput { - choices: u, - msgs: w, - .. - } = receiver_cot; - - // receiver generates the random seed of lpn matrix. - let lpn_matrix_seed = prg.random_block(); - - // init the setup of sender and receiver. - let (mut receiver, seed) = receiver - .setup( - LPN_PARAMETERS_TEST, - LpnType::Regular, - lpn_matrix_seed, - &u, - &w, - ) - .unwrap(); - - let LpnMatrixSeed { - seed: lpn_matrix_seed, - } = seed; - - let mut sender = sender - .setup( - delta, - LPN_PARAMETERS_TEST, - LpnType::Regular, - lpn_matrix_seed, - &v, - ) - .unwrap(); - - // extend once - let _ = sender.get_mpcot_query(); - let query = receiver.get_mpcot_query(); - - let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = - ideal_mpcot.extend(&query.0, query.1); - - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); - - assert_cot(delta, &choices, &msgs, &received); - - // extend twice - let _ = sender.get_mpcot_query(); - let query = receiver.get_mpcot_query(); - - let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = - ideal_mpcot.extend(&query.0, query.1); - - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); - - assert_cot(delta, &choices, &msgs, &received); - } -} diff --git a/crates/mpz-ot-core/src/ferret/mpcot/error.rs b/crates/mpz-ot-core/src/ferret/mpcot/error.rs deleted file mode 100644 index 8a0d9e24..00000000 --- a/crates/mpz-ot-core/src/ferret/mpcot/error.rs +++ /dev/null @@ -1,24 +0,0 @@ -//! Errors that can occur when using the MPCOT protocol. - -use crate::ferret::cuckoo::{BucketError, CuckooHashError}; -/// Errors that can occur when using the MPCOT sender. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum SenderError { - #[error("invalid input: expected {0}")] - InvalidInput(String), - #[error(transparent)] - BucketError(#[from] BucketError), -} - -/// Errors that can occur when using the MPCOT receiver. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ReceiverError { - #[error("invalid input: expected {0}")] - InvalidInput(String), - #[error(transparent)] - CuckooHashError(#[from] CuckooHashError), - #[error(transparent)] - BucketError(#[from] BucketError), -} diff --git a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs deleted file mode 100644 index e74dc38a..00000000 --- a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs +++ /dev/null @@ -1,169 +0,0 @@ -//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. - -pub mod error; -pub mod msgs; -pub mod receiver; -pub mod receiver_regular; -pub mod sender; -pub mod sender_regular; - -#[cfg(test)] -mod tests { - use super::{ - receiver::Receiver as MpcotReceiver, receiver_regular::Receiver as RegularReceiver, - sender::Sender as MpcotSender, sender_regular::Sender as RegularSender, - }; - use crate::ideal::spcot::IdealSpcot; - use crate::{SPCOTReceiverOutput, SPCOTSenderOutput}; - use mpz_core::prg::Prg; - use rand::SeedableRng; - - #[test] - fn mpcot_general_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); - let delta = prg.random_block(); - let mut ideal_spcot = IdealSpcot::new_with_delta(delta); - - let sender = MpcotSender::new(); - let receiver = MpcotReceiver::new(); - - // receiver chooses hash and setup. - let hash_seed = prg.random_block(); - let (receiver_pre, hash_seed) = receiver.setup(hash_seed); - // sender receives the hash and setup. - let sender_pre = sender.setup(delta, hash_seed); - - // extend once. - let alphas = [0, 1, 3, 4, 2]; - let t = alphas.len(); - let n = 10; - // sender generates the messages to invoke ideal spcot. - let (sender, sender_queries) = sender_pre.pre_extend(t as u32, n).unwrap(); - - let (receiver, mut queries) = receiver_pre.pre_extend(&alphas, n).unwrap(); - - assert!(sender_queries - .iter() - .zip(queries.iter()) - .all(|(x, (y, _))| *x == *y)); - - queries.iter_mut().for_each(|(x, _)| *x = 1 << (*x)); - - let (sender_spcot_msg, receiver_spcot_msg) = ideal_spcot.extend(&queries); - - let SPCOTSenderOutput { v: st, .. } = sender_spcot_msg; - let SPCOTReceiverOutput { w: rt, .. } = receiver_spcot_msg; - - let (sender_pre, mut output_sender) = sender.extend(&st).unwrap(); - let (receiver_pre, output_receiver) = receiver.extend(&rt).unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - // extend twice. - let alphas = [5, 1, 7, 2]; - let t = alphas.len(); - let n = 16; - // sender generates the messages to invoke ideal spcot. - let (sender, sender_queries) = sender_pre.pre_extend(t as u32, n).unwrap(); - - let (receiver, mut queries) = receiver_pre.pre_extend(&alphas, n).unwrap(); - - assert!(sender_queries - .iter() - .zip(queries.iter()) - .all(|(x, (y, _))| *x == *y)); - - queries.iter_mut().for_each(|(x, _)| *x = 1 << (*x)); - - let (sender_spcot_msg, receiver_spcot_msg) = ideal_spcot.extend(&queries); - - let SPCOTSenderOutput { v: st, .. } = sender_spcot_msg; - let SPCOTReceiverOutput { w: rt, .. } = receiver_spcot_msg; - - let (_, mut output_sender) = sender.extend(&st).unwrap(); - let (_, output_receiver) = receiver.extend(&rt).unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - } - - #[test] - fn mpcot_regular_test() { - let mut prg = Prg::from_seed([2u8; 16].into()); - let delta = prg.random_block(); - let mut ideal_spcot = IdealSpcot::new_with_delta(delta); - - let sender = RegularSender::new(); - let receiver = RegularReceiver::new(); - - let sender_pre = sender.setup(delta); - let receiver_pre = receiver.setup(); - - // extend once. - let alphas = [0, 3, 4, 7, 9]; - let t = alphas.len(); - let n = 10; - - // sender generates the messages to invoke ideal spcot. - let (sender, sender_queries) = sender_pre.pre_extend(t as u32, n).unwrap(); - let (receiver, mut queries) = receiver_pre.pre_extend(&alphas, n).unwrap(); - - assert!(sender_queries - .iter() - .zip(queries.iter()) - .all(|(x, (y, _))| *x == *y)); - - queries.iter_mut().for_each(|(x, _)| *x = 1 << (*x)); - - let (sender_spcot_msg, receiver_spcot_msg) = ideal_spcot.extend(&queries); - - let SPCOTSenderOutput { v: st, .. } = sender_spcot_msg; - let SPCOTReceiverOutput { w: rt, .. } = receiver_spcot_msg; - - let (sender_pre, mut output_sender) = sender.extend(&st).unwrap(); - let (receiver_pre, output_receiver) = receiver.extend(&rt).unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - // extend twice. - let alphas = [0, 3, 7, 9, 14, 15]; - let t = alphas.len(); - let n = 16; - - // sender generates the messages to invoke ideal spcot. - let (sender, sender_queries) = sender_pre.pre_extend(t as u32, n).unwrap(); - let (receiver, mut queries) = receiver_pre.pre_extend(&alphas, n).unwrap(); - - assert!(sender_queries - .iter() - .zip(queries.iter()) - .all(|(x, (y, _))| *x == *y)); - - queries.iter_mut().for_each(|(x, _)| *x = 1 << (*x)); - - let (sender_spcot_msg, receiver_spcot_msg) = ideal_spcot.extend(&queries); - - let SPCOTSenderOutput { v: st, .. } = sender_spcot_msg; - let SPCOTReceiverOutput { w: rt, .. } = receiver_spcot_msg; - - let (_, mut output_sender) = sender.extend(&st).unwrap(); - let (_, output_receiver) = receiver.extend(&rt).unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - } -} diff --git a/crates/mpz-ot-core/src/ferret/mpcot/msgs.rs b/crates/mpz-ot-core/src/ferret/mpcot/msgs.rs deleted file mode 100644 index 632a56c0..00000000 --- a/crates/mpz-ot-core/src/ferret/mpcot/msgs.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! Messages for the MPCOT protocol. - -use mpz_core::Block; -use serde::{Deserialize, Serialize}; - -/// An MPCOT message. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[allow(missing_docs)] -pub enum Message { - SpcotMsg(SpcotMsg), - HashSeed(HashSeed), -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -/// The seed to generate Cuckoo hashes. -pub struct HashSeed { - /// The seed. - pub seed: Block, -} diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs deleted file mode 100644 index 0f8613af..00000000 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs +++ /dev/null @@ -1,234 +0,0 @@ -//! MPCOT receiver for general indices. -use std::sync::Arc; - -use crate::ferret::{ - cuckoo::{find_pos, hash_to_index, Bucket, CuckooHash, Item}, - mpcot::error::ReceiverError, - CUCKOO_HASH_NUM, -}; -use mpz_core::{aes::AesEncryptor, prg::Prg, Block}; -use rand_core::SeedableRng; - -use super::msgs::HashSeed; - -/// MPCOT receiver. -#[derive(Debug, Default)] -pub struct Receiver { - state: T, -} - -impl Receiver { - /// Creates a new Receiver. - pub fn new() -> Self { - Receiver { - state: state::Initialized::default(), - } - } - - /// Completes the setup phase for PreExtend. - /// - /// See step 1 in Figure 6. - /// - /// # Argument - /// - /// * `hash_seed` - Random seed to generate hashes, will be sent to the sender. - pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { - let mut prg = Prg::from_seed(hash_seed); - let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); - let recv = Receiver { - state: state::PreExtension { - counter: 0, - hashes: Arc::new(hashes), - }, - }; - - let seed = HashSeed { seed: hash_seed }; - - (recv, seed) - } -} - -impl Receiver { - /// Performs the hash procedure in MPCOT extension. - /// Outputs the length of each bucket plus 1. - /// - /// See Step 1 to Step 4 in Figure 7. - /// - /// # Arguments - /// - /// * `alphas` - The queried indices. - /// * `n` - The total number of indices. - #[allow(clippy::type_complexity)] - pub fn pre_extend( - self, - alphas: &[u32], - n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { - if alphas.len() as u32 > n { - return Err(ReceiverError::InvalidInput( - "length of alphas should not exceed n".to_string(), - )); - } - let cuckoo = CuckooHash::new(self.state.hashes.clone()); - - // Inserts all the alpha's. - let table = cuckoo.insert(alphas)?; - - let m = table.len(); - - let bucket = Bucket::new(self.state.hashes.clone(), m); - - // Generates the buckets. - let buckets = bucket.insert(n); - - // Generates queries for SPCOT. - // See Step 4 in Figure 7. - let mut p = vec![]; - let mut buckets_length = vec![]; - for (alpha, bin) in table.iter().zip(buckets.iter()) { - // pad to power of 2. - let power_of_two = (bin.len() + 1) - .checked_next_power_of_two() - .expect("bucket length should be less than usize::MAX / 2 - 1"); - - let power = power_of_two.ilog2() as usize; - - if let Some(x) = alpha { - let pos = find_pos(bin, x)?; - p.push((power, pos as u32)); - } else { - p.push((power, bin.len() as u32)); - } - - buckets_length.push(power_of_two); - } - - let receiver = Receiver { - state: state::Extension { - counter: self.state.counter, - m, - n, - hashes: self.state.hashes.clone(), - buckets, - buckets_length, - }, - }; - - Ok((receiver, p)) - } -} -impl Receiver { - /// Performs MPCOT extension. - /// - /// See Step 5 in Figure 7. - /// - /// # Arguments - /// - /// * `rt` - The vector received from SPCOT protocol on multiple queries. - pub fn extend( - self, - rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { - if rt.len() != self.state.m { - return Err(ReceiverError::InvalidInput( - "the length rt should be m".to_string(), - )); - } - - if rt - .iter() - .zip(self.state.buckets_length.iter()) - .any(|(s, b)| s.len() != *b) - { - return Err(ReceiverError::InvalidInput( - "the length of st[i] should be self.state.buckets_length".to_string(), - )); - } - - let mut res = vec![Block::ZERO; self.state.n as usize]; - - for (value, x) in res.iter_mut().enumerate() { - for tau in 0..CUCKOO_HASH_NUM { - // Computes the index of `value`. - let bucket_index = - hash_to_index(&self.state.hashes[tau], self.state.m, value as u32); - let pos = find_pos( - &self.state.buckets[bucket_index], - &Item { - value: value as u32, - hash_index: tau, - }, - )?; - - *x ^= rt[bucket_index][pos]; - } - } - - let receiver = Receiver { - state: state::PreExtension { - counter: self.state.counter + 1, - hashes: self.state.hashes, - }, - }; - - Ok((receiver, res)) - } -} -/// The receiver's state. -pub mod state { - use super::*; - - mod sealed { - pub trait Sealed {} - - impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} - impl Sealed for super::Extension {} - } - - /// The receiver's state. - pub trait State: sealed::Sealed {} - - /// The receiver's initial state. - #[derive(Default)] - pub struct Initialized {} - - impl State for Initialized {} - - opaque_debug::implement!(Initialized); - - /// The receiver's state before extending. - /// - /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { - /// Current MPCOT counter - pub(super) counter: usize, - /// The hashes to generate Cuckoo hash table. - pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, - } - - impl State for PreExtension {} - - opaque_debug::implement!(PreExtension); - /// The receiver's state of extension. - /// - /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { - /// Current MPCOT counter - pub(super) counter: usize, - /// Current length of Cuckoo hash table, will possibly be changed in each extension. - pub(super) m: usize, - /// The total number of indices in the current extension. - pub(super) n: u32, - /// The hashes to generate Cuckoo hash table. - pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, - /// The buckets contains all the hash values, will be cleared after each extension. - pub(super) buckets: Vec>, - /// The padded buckets length (power of 2). - pub(super) buckets_length: Vec, - } - - impl State for Extension {} - - opaque_debug::implement!(Extension); -} diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs deleted file mode 100644 index 2b226108..00000000 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs +++ /dev/null @@ -1,192 +0,0 @@ -//! MPCOT receiver for regular indices. Regular indices means the indices are evenly distributed. - -use mpz_core::Block; - -use crate::ferret::mpcot::error::ReceiverError; - -/// MPCOT receiver. -#[derive(Debug, Default)] -pub struct Receiver { - state: T, -} - -impl Receiver { - /// Creates a new Receiver. - pub fn new() -> Self { - Receiver { - state: state::Initialized::default(), - } - } - - /// Completes the setup phase of the protocol. - pub fn setup(self) -> Receiver { - Receiver { - state: state::PreExtension { counter: 0 }, - } - } -} -impl Receiver { - /// Performs the prepare procedure in MPCOT extension. - /// Outputs the indices for SPCOT. - /// - /// # Arguments. - /// - /// * `alphas` - The queried indices. - /// * `n` - The total number of indices. - #[allow(clippy::type_complexity)] - pub fn pre_extend( - self, - alphas: &[u32], - n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { - let t = alphas.len() as u32; - if t > n { - return Err(ReceiverError::InvalidInput( - "the length of alpha should not exceed n".to_string(), - )); - } - - // The range of each interval. - let k = (n + t - 1) / t; - - let queries_length = if n % t == 0 { - vec![k as usize; t as usize] - } else { - let mut tmp = vec![k as usize; (t - 1) as usize]; - tmp.push((n % k) as usize); - if tmp.iter().sum::() != n as usize { - return Err(ReceiverError::InvalidInput( - "the input parameters (t,n) are not regular".to_string(), - )); - } else { - tmp - } - }; - - let mut queries_depth = Vec::with_capacity(queries_length.len()); - for len in queries_length.iter() { - // pad `len` to power of 2. - let power = len - .checked_next_power_of_two() - .expect("len should be less than usize::MAX / 2 - 1") - .ilog2() as usize; - - queries_depth.push(power); - } - - if !alphas - .iter() - .enumerate() - .all(|(i, &alpha)| (i as u32) * k <= alpha && alpha < ((i + 1) as u32) * k) - { - return Err(ReceiverError::InvalidInput( - "the input position is not regular".to_string(), - )); - } - - let res: Vec<(usize, u32)> = queries_depth - .iter() - .zip(alphas.iter()) - .map(|(&d, &alpha)| (d, alpha % k)) - .collect(); - - let receiver = Receiver { - state: state::Extension { - counter: self.state.counter, - n, - queries_length, - queries_depth, - }, - }; - - Ok((receiver, res)) - } -} - -impl Receiver { - /// Performs MPCOT extension. - /// - /// # Arguments. - /// - /// * `rt` - The vector received from SPCOT protocol on multiple queries. - pub fn extend( - self, - rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { - if rt - .iter() - .zip(self.state.queries_depth.iter()) - .any(|(blks, m)| blks.len() != 1 << m) - { - return Err(ReceiverError::InvalidInput( - "the length of rt[i] should be 2^self.state.queries_depth[i]".to_string(), - )); - } - - let mut res: Vec = Vec::with_capacity(self.state.n as usize); - - for (blks, pos) in rt.iter().zip(self.state.queries_length.iter()) { - res.extend(&blks[..*pos]); - } - - let receiver = Receiver { - state: state::PreExtension { - counter: self.state.counter + 1, - }, - }; - - Ok((receiver, res)) - } -} -/// The receiver's state. -pub mod state { - - mod sealed { - pub trait Sealed {} - - impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} - impl Sealed for super::Extension {} - } - - /// The receiver's state. - pub trait State: sealed::Sealed {} - - /// The receiver's initial state. - #[derive(Default)] - pub struct Initialized {} - - impl State for Initialized {} - - opaque_debug::implement!(Initialized); - /// The receiver's state before extending. - /// - /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { - /// Current MPCOT counter - pub(super) counter: usize, - } - - impl State for PreExtension {} - - opaque_debug::implement!(PreExtension); - - /// The receiver's state after the setup phase. - /// - /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { - /// Current MPCOT counter - #[allow(dead_code)] - pub(super) counter: usize, - /// The total number of indices in the current extension. - pub(super) n: u32, - /// Current queries length. - pub(super) queries_length: Vec, - /// The depth of queries. - pub(super) queries_depth: Vec, - } - - impl State for Extension {} - - opaque_debug::implement!(Extension); -} diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs deleted file mode 100644 index f1e49105..00000000 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs +++ /dev/null @@ -1,223 +0,0 @@ -//! MPCOT sender for general indices. -use std::sync::Arc; - -use crate::ferret::{ - cuckoo::{compute_table_length, find_pos, hash_to_index, Bucket, Item}, - mpcot::error::SenderError, - CUCKOO_HASH_NUM, -}; -use mpz_core::{aes::AesEncryptor, prg::Prg, Block}; -use rand_core::SeedableRng; - -use super::msgs::HashSeed; - -/// MPCOT sender. -#[derive(Debug, Default)] -pub struct Sender { - state: T, -} - -impl Sender { - /// Creates a new Sender. - pub fn new() -> Self { - Sender { - state: state::Initialized::default(), - } - } - - /// Completes the setup phase for PreExtend. - /// - /// # Arguments. - /// - /// * `delta` - The sender's global secret. - /// * `hash_seed` - The seed for Cuckoo hash sent by the receiver. - pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { - let HashSeed { seed: hash_seed } = hash_seed; - let mut prg = Prg::from_seed(hash_seed); - let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); - Sender { - state: state::PreExtension { - delta, - counter: 0, - hashes: Arc::new(hashes), - }, - } - } -} - -impl Sender { - /// Performs the hash procedure in MPCOT extension. - /// Outputs the length of each bucket plus 1. - /// - /// See Step 1 to Step 4 in Figure 7. - /// - /// # Arguments - /// - /// * `t` - The number of queried indices. - /// * `n` - The total number of indices. - pub fn pre_extend( - self, - t: u32, - n: u32, - ) -> Result<(Sender, Vec), SenderError> { - if t > n { - return Err(SenderError::InvalidInput( - "t should not exceed n".to_string(), - )); - } - - // Compute m = 1.5 * t. - let m = compute_table_length(t); - - let bucket = Bucket::new(self.state.hashes.clone(), m); - - // Generates the buckets. - let buckets = bucket.insert(n); - - // First pad (length + 1) to a pow of 2, then computes `log(length + 1)` of each bucket. - let mut bs = vec![]; - let mut buckets_length = vec![]; - for bin in buckets.iter() { - let power_of_two = (bin.len() + 1) - .checked_next_power_of_two() - .expect("bucket length should be less than usize::MAX / 2 - 1"); - bs.push(power_of_two.ilog2() as usize); - buckets_length.push(power_of_two); - } - - let sender = Sender { - state: state::Extension { - delta: self.state.delta, - counter: self.state.counter, - m, - n, - hashes: self.state.hashes, - buckets, - buckets_length, - }, - }; - - Ok((sender, bs)) - } -} - -impl Sender { - /// Performs MPCOT extension. - /// - /// See Step 5 in Figure 7. - /// - /// # Arguments - /// - /// * `st` - The vector received from SPCOT protocol on multiple queries. - pub fn extend( - self, - st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { - if st.len() != self.state.m { - return Err(SenderError::InvalidInput( - "the length st should be m".to_string(), - )); - } - - if st - .iter() - .zip(self.state.buckets_length.iter()) - .any(|(s, b)| s.len() != *b) - { - return Err(SenderError::InvalidInput( - "the length of st[i] should be self.state.buckets_length[i]".to_string(), - )); - } - let mut res = vec![Block::ZERO; self.state.n as usize]; - for (value, x) in res.iter_mut().enumerate() { - for tau in 0..CUCKOO_HASH_NUM { - // Computes the index of `value`. - let bucket_index = - hash_to_index(&self.state.hashes[tau], self.state.m, value as u32); - let pos = find_pos( - &self.state.buckets[bucket_index], - &Item { - value: value as u32, - hash_index: tau, - }, - )?; - - *x ^= st[bucket_index][pos]; - } - } - - let sender = Sender { - state: state::PreExtension { - delta: self.state.delta, - counter: self.state.counter + 1, - hashes: self.state.hashes, - }, - }; - - Ok((sender, res)) - } -} - -/// The sender's state. -pub mod state { - use super::*; - - mod sealed { - pub trait Sealed {} - - impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} - impl Sealed for super::Extension {} - } - - /// The sender's state. - pub trait State: sealed::Sealed {} - - /// The sender's initial state. - #[derive(Default)] - pub struct Initialized {} - - impl State for Initialized {} - - opaque_debug::implement!(Initialized); - - /// The sender's state before extending. - /// - /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { - /// Sender's global secret. - pub(super) delta: Block, - /// Current MPCOT counter - pub(super) counter: usize, - /// The hashes to generate Cuckoo hash table. - pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, - } - - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); - - /// The sender's state of extension. - /// - /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { - /// Sender's global secret. - pub(super) delta: Block, - /// Current MPCOT counter - pub(super) counter: usize, - - /// Current length of Cuckoo hash table, will possibly be changed in each extension. - pub(super) m: usize, - /// The total number of indices in the current extension. - pub(super) n: u32, - /// The hashes to generate Cuckoo hash table. - pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, - /// The buckets contains all the hash values. - pub(super) buckets: Vec>, - /// The padded buckets length (power of 2). - pub(super) buckets_length: Vec, - } - - impl State for Extension {} - - opaque_debug::implement!(Extension); -} diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs deleted file mode 100644 index db0646b6..00000000 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs +++ /dev/null @@ -1,185 +0,0 @@ -//! MPCOT sender for regular indices. Regular indices means the indices are evenly distributed. - -use mpz_core::Block; - -use crate::ferret::mpcot::error::SenderError; - -/// MPCOT sender. -#[derive(Debug, Default)] -pub struct Sender { - state: T, -} - -impl Sender { - /// Creates a new Sender. - pub fn new() -> Self { - Sender { - state: state::Initialized::default(), - } - } - - /// Completes the setup phase of the protocol. - /// - /// # Argument. - /// - /// * `delta` - The sender's global secret. - pub fn setup(self, delta: Block) -> Sender { - Sender { - state: state::PreExtension { delta, counter: 0 }, - } - } -} - -impl Sender { - /// Performs the prepare procedure in MPCOT extension. - /// Outputs the information for SPCOT. - /// - /// # Arguments. - /// - /// * `t` - The number of queried indices. - /// * `n` - The total number of indices. - pub fn pre_extend( - self, - t: u32, - n: u32, - ) -> Result<(Sender, Vec), SenderError> { - if t > n { - return Err(SenderError::InvalidInput( - "t should not exceed n".to_string(), - )); - } - - // The range of each interval. - let k = (n + t - 1) / t; - - let queries_length = if n % t == 0 { - vec![k as usize; t as usize] - } else { - let mut tmp = vec![k as usize; (t - 1) as usize]; - tmp.push((n % k) as usize); - if tmp.iter().sum::() != n as usize { - return Err(SenderError::InvalidInput( - "the input parameters (t,n) are not regular".to_string(), - )); - } else { - tmp - } - }; - - let mut queries_depth = Vec::with_capacity(queries_length.len()); - - for len in queries_length.iter() { - // pad `len`` to power of 2. - let power = len - .checked_next_power_of_two() - .expect("len should be less than usize::MAX / 2 - 1") - .ilog2() as usize; - queries_depth.push(power); - } - - let sender = Sender { - state: state::Extension { - delta: self.state.delta, - counter: self.state.counter, - n, - queries_length, - queries_depth: queries_depth.clone(), - }, - }; - - Ok((sender, queries_depth)) - } -} - -impl Sender { - /// Performs MPCOT extension. - /// - /// # Arguments. - /// - /// * `st` - The vector received from SPCOT protocol on multiple queries. - pub fn extend( - self, - st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { - if st - .iter() - .zip(self.state.queries_depth.iter()) - .any(|(blks, m)| blks.len() != 1 << m) - { - return Err(SenderError::InvalidInput( - "the length of st[i] should be 2^self.state.queries_depth[i]".to_string(), - )); - } - let mut res: Vec = Vec::with_capacity(self.state.n as usize); - - for (blks, pos) in st.iter().zip(self.state.queries_length.iter()) { - res.extend(&blks[..*pos]); - } - - let sender = Sender { - state: state::PreExtension { - delta: self.state.delta, - counter: self.state.counter + 1, - }, - }; - - Ok((sender, res)) - } -} -/// The sender's state. -pub mod state { - - use super::*; - - mod sealed { - pub trait Sealed {} - - impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} - impl Sealed for super::Extension {} - } - - /// The sender's state. - pub trait State: sealed::Sealed {} - - /// The sender's initial state. - #[derive(Default)] - pub struct Initialized {} - - impl State for Initialized {} - - opaque_debug::implement!(Initialized); - - /// The sender's state before extending. - /// - /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { - /// Sender's global secret. - pub(super) delta: Block, - /// Current MPCOT counter - pub(super) counter: usize, - } - - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); - - /// The sender's state after the setup phase. - /// - /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { - /// Sender's global secret. - pub(super) delta: Block, - /// Current MPCOT counter - pub(super) counter: usize, - /// The total number of indices in the current extension. - pub(super) n: u32, - /// Current queries from sender, will possibly be changed in each extension. - pub(super) queries_length: Vec, - /// The depth of queries. - pub(super) queries_depth: Vec, - } - - impl State for Extension {} - - opaque_debug::implement!(Extension); -} diff --git a/crates/mpz-ot-core/src/ferret/msgs.rs b/crates/mpz-ot-core/src/ferret/msgs.rs deleted file mode 100644 index 4c4a426f..00000000 --- a/crates/mpz-ot-core/src/ferret/msgs.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! Ferret protocol messages. - -use mpz_core::Block; -use serde::{Deserialize, Serialize}; - -/// The seed to generate Lpn matrix. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct LpnMatrixSeed { - /// The seed. - pub seed: Block, -} diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs deleted file mode 100644 index 4d08c69b..00000000 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ /dev/null @@ -1,184 +0,0 @@ -//! Ferret receiver -use mpz_core::{ - lpn::{LpnEncoder, LpnParameters}, - Block, -}; - -use crate::ferret::{error::ReceiverError, LpnType}; - -use super::msgs::LpnMatrixSeed; - -/// Ferret receiver. -#[derive(Debug, Default)] -pub struct Receiver { - state: T, -} - -impl Receiver { - /// Create a new Receiver. - pub fn new() -> Self { - Receiver { - state: state::Initialized::default(), - } - } - - /// Completes the setup phase of the protocol. - /// - /// See step 1 and 2 in Figure 9. - /// - /// # Arguments - /// - /// * `lpn_parameters` - The lpn parameters. - /// * `seed` - The seed to generate lpn matrix. - /// * `lpn_type` - The lpn type. - /// * `u` - The bits received from the COT ideal functionality. - /// * `w` - The vector received from the COT ideal functionality. - pub fn setup( - self, - lpn_parameters: LpnParameters, - lpn_type: LpnType, - seed: Block, - u: &[bool], - w: &[Block], - ) -> Result<(Receiver, LpnMatrixSeed), ReceiverError> { - if u.len() != lpn_parameters.k || w.len() != lpn_parameters.k { - return Err(ReceiverError( - "the length of u and w should be k".to_string(), - )); - } - - let lpn_encoder = LpnEncoder::<10>::new(seed, lpn_parameters.k as u32); - - Ok(( - Receiver { - state: state::Extension { - counter: 0, - lpn_parameters, - lpn_encoder, - lpn_type, - u: u.to_vec(), - w: w.to_vec(), - e: Vec::default(), - }, - }, - LpnMatrixSeed { seed }, - )) - } -} - -impl Receiver { - /// The prepare precedure of extension, sample error vectors and outputs information for MPCOT. - /// See step 3 and 4. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN parameters. - pub fn get_mpcot_query(&mut self) -> (Vec, usize) { - match self.state.lpn_type { - LpnType::Uniform => { - self.state.e = self.state.lpn_parameters.sample_uniform_error_vector(); - } - - LpnType::Regular => { - self.state.e = self.state.lpn_parameters.sample_regular_error_vector(); - } - } - let mut alphas = Vec::with_capacity(self.state.lpn_parameters.t); - for (i, x) in self.state.e.iter().enumerate() { - if *x != Block::ZERO { - alphas.push(i as u32); - } - } - (alphas, self.state.lpn_parameters.n) - } - - /// Performs the Ferret extension. - /// Outputs exactly l = n - t COTs. - /// - /// See step 5 and 6. - /// - /// # Arguments. - /// - /// * `r` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, r: &[Block]) -> Result<(Vec, Vec), ReceiverError> { - if r.len() != self.state.lpn_parameters.n { - return Err(ReceiverError("the length of r should be n".to_string())); - } - - // Compute z = A * w + r. - let mut z = r.to_vec(); - self.state.lpn_encoder.compute(&mut z, &self.state.w); - - // Compute x = A * u + e. - let u_block = self - .state - .u - .iter() - .map(|x| if *x { Block::ONE } else { Block::ZERO }) - .collect::>(); - let mut x = self.state.e.clone(); - self.state.lpn_encoder.compute(&mut x, &u_block); - - let mut x = x.iter().map(|a| a.lsb() == 1).collect::>(); - - let x_ = x.split_off(self.state.lpn_parameters.k); - let z_ = z.split_off(self.state.lpn_parameters.k); - - // Update u, w - self.state.u = x; - self.state.w = z; - - // Update counter - self.state.counter += 1; - - Ok((x_, z_)) - } -} - -/// The receiver's state. -pub mod state { - use super::*; - - mod sealed { - pub trait Sealed {} - impl Sealed for super::Initialized {} - impl Sealed for super::Extension {} - } - - /// The receiver's state. - pub trait State: sealed::Sealed {} - - /// The receiver's initial state. - #[derive(Default)] - pub struct Initialized {} - - impl State for Initialized {} - - opaque_debug::implement!(Initialized); - - /// The receiver's state after the setup phase. - /// - /// In this state the sender performs Ferret extension (potentially multiple times). - pub struct Extension { - /// Current Ferret counter. - pub(super) counter: usize, - - /// Lpn parameters. - pub(super) lpn_parameters: LpnParameters, - /// Lpn encoder. - pub(super) lpn_encoder: LpnEncoder<10>, - /// Lpn type. - pub(super) lpn_type: LpnType, - - /// Receiver's COT messages in the setup phase. - pub(super) u: Vec, - pub(super) w: Vec, - - /// Receiver's lpn error vector. - pub(super) e: Vec, - } - - impl State for Extension {} - - opaque_debug::implement!(Extension); -} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs deleted file mode 100644 index 9e8db180..00000000 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ /dev/null @@ -1,149 +0,0 @@ -//! Ferret sender. -use mpz_core::{ - lpn::{LpnEncoder, LpnParameters}, - Block, -}; - -use crate::ferret::{error::SenderError, LpnType}; - -/// Ferret sender. -#[derive(Debug, Default)] -pub struct Sender { - state: T, -} - -impl Sender { - /// Creates a new Sender. - pub fn new() -> Self { - Sender { - state: state::Initialized::default(), - } - } - - /// Completes the setup phase of the protocol. - /// - /// See step 1 and 2 in Figure 9. - /// - /// # Arguments - /// - /// * `delta` - The sender's global secret. - /// * `lpn_parameters` - The lpn parameters. - /// * `lpn_type` - The lpn type. - /// * `seed` - The seed received from receiver to generate lpn matrix. - /// * `v` - The vector received from the COT ideal functionality. - pub fn setup( - self, - delta: Block, - lpn_parameters: LpnParameters, - lpn_type: LpnType, - seed: Block, - v: &[Block], - ) -> Result, SenderError> { - if v.len() != lpn_parameters.k { - return Err(SenderError( - "the length of v should be equal to k".to_string(), - )); - } - let lpn_encoder = LpnEncoder::<10>::new(seed, lpn_parameters.k as u32); - - Ok(Sender { - state: state::Extension { - delta, - counter: 0, - lpn_parameters, - lpn_type, - lpn_encoder, - v: v.to_vec(), - }, - }) - } -} - -impl Sender { - /// Outputs the information for MPCOT. - /// - /// See step 3 and 4. - pub fn get_mpcot_query(&self) -> (u32, u32) { - ( - self.state.lpn_parameters.t as u32, - self.state.lpn_parameters.n as u32, - ) - } - - /// Performs the Ferret extension. - /// Outputs exactly l = n-t COTs. - /// - /// See step 5 and 6. - /// - /// # Arguments. - /// - /// * `s` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, s: &[Block]) -> Result, SenderError> { - if s.len() != self.state.lpn_parameters.n { - return Err(SenderError("the length of s should be n".to_string())); - } - - // Compute y = A * v + s - let mut y = s.to_vec(); - self.state.lpn_encoder.compute(&mut y, &self.state.v); - - let y_ = y.split_off(self.state.lpn_parameters.k); - - // Update v to y[0..k] - self.state.v = y; - - // Update counter - self.state.counter += 1; - - Ok(y_) - } -} - -/// The sender's state. -pub mod state { - use super::*; - - mod sealed { - pub trait Sealed {} - - impl Sealed for super::Initialized {} - impl Sealed for super::Extension {} - } - - /// The sender's state. - pub trait State: sealed::Sealed {} - - /// The sender's initial state. - #[derive(Default)] - pub struct Initialized {} - - impl State for Initialized {} - - opaque_debug::implement!(Initialized); - - /// The sender's state after the setup phase. - /// - /// In this state the sender performs Ferret extension (potentially multiple times). - pub struct Extension { - /// Sender's global secret. - #[allow(dead_code)] - pub(super) delta: Block, - /// Current Ferret counter. - pub(super) counter: usize, - - /// Lpn type. - #[allow(dead_code)] - pub(super) lpn_type: LpnType, - /// Lpn parameters. - pub(super) lpn_parameters: LpnParameters, - /// Lpn encoder. - pub(super) lpn_encoder: LpnEncoder<10>, - - /// Sender's COT message in the setup phase. - pub(super) v: Vec, - } - - impl State for Extension {} - - opaque_debug::implement!(Extension); -} diff --git a/crates/mpz-ot-core/src/ferret/spcot/error.rs b/crates/mpz-ot-core/src/ferret/spcot/error.rs deleted file mode 100644 index bf94e2e2..00000000 --- a/crates/mpz-ot-core/src/ferret/spcot/error.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Errors that can occur when using the SPCOT. - -/// Errors that can occur when using the SPCOT sender. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum SenderError { - #[error("invalid state: expected {0}")] - InvalidState(String), - #[error("invalid length: expected {0}")] - InvalidLength(String), -} - -/// Errors that can occur when using the SPCOT receiver. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ReceiverError { - #[error("invalid state: expected {0}")] - InvalidState(String), - #[error("invalid input: expected {0}")] - InvalidInput(String), - #[error("invalid length: expected {0}")] - InvalidLength(String), - #[error("consistency check failed")] - ConsistencyCheckFailed, -} diff --git a/crates/mpz-ot-core/src/ferret/spcot/mod.rs b/crates/mpz-ot-core/src/ferret/spcot/mod.rs deleted file mode 100644 index 802efb66..00000000 --- a/crates/mpz-ot-core/src/ferret/spcot/mod.rs +++ /dev/null @@ -1,90 +0,0 @@ -//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. - -pub mod error; -pub mod msgs; -pub mod receiver; -pub mod sender; - -#[cfg(test)] -mod tests { - use mpz_core::prg::Prg; - - use super::{receiver::Receiver as SpcotReceiver, sender::Sender as SpcotSender}; - use crate::{ferret::CSP, ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; - - #[test] - fn spcot_test() { - let mut ideal_cot = IdealCOT::default(); - let sender = SpcotSender::new(); - let receiver = SpcotReceiver::new(); - - let mut prg = Prg::new(); - let sender_seed = prg.random_block(); - let delta = ideal_cot.delta(); - - let mut sender = sender.setup(delta, sender_seed); - let mut receiver = receiver.setup(); - - let h1 = 8; - let alpha1 = 3; - - // Extend once - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h1); - - let RCOTReceiverOutput { - choices: rs, - msgs: ts, - .. - } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h1, alpha1, &rs).unwrap(); - - let msg_from_sender = sender.extend(h1, &qs, maskbits).unwrap(); - - receiver.extend(h1, alpha1, &ts, msg_from_sender).unwrap(); - - // Extend twice - let h2 = 4; - let alpha2 = 2; - - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h2); - - let RCOTReceiverOutput { - choices: rs, - msgs: ts, - .. - } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - - let maskbits = receiver.extend_mask_bits(h2, alpha2, &rs).unwrap(); - - let msg_from_sender = sender.extend(h2, &qs, maskbits).unwrap(); - - receiver.extend(h2, alpha2, &ts, msg_from_sender).unwrap(); - - // Check - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); - - let RCOTReceiverOutput { - choices: x_star, - msgs: z_star, - .. - } = msg_for_receiver; - - let RCOTSenderOutput { msgs: y_star, .. } = msg_for_sender; - - let check_from_receiver = receiver.check_pre(&x_star).unwrap(); - - let (mut output_sender, check) = sender.check(&y_star, check_from_receiver).unwrap(); - - let output_receiver = receiver.check(&z_star, check).unwrap(); - - assert!(output_sender - .iter_mut() - .zip(output_receiver.iter()) - .all(|(vs, (ws, alpha))| { - vs[*alpha as usize] ^= delta; - vs == ws - })); - } -} diff --git a/crates/mpz-ot-core/src/ferret/spcot/msgs.rs b/crates/mpz-ot-core/src/ferret/spcot/msgs.rs deleted file mode 100644 index 22e88a4b..00000000 --- a/crates/mpz-ot-core/src/ferret/spcot/msgs.rs +++ /dev/null @@ -1,45 +0,0 @@ -//! Messages for the SPCOT protocol - -use mpz_core::{hash::Hash, Block}; -use serde::{Deserialize, Serialize}; - -/// An SPCOT message. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[allow(missing_docs)] -pub enum Message { - CotMsg(CotMsg), - MaskBits(MaskBits), - ExtendFromSender(ExtendFromSender), - CheckFromReceiver(CheckFromReceiver), - CheckFromSender(CheckFromSender), -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -/// The mask bits sent by the receiver. -pub struct MaskBits { - /// The mask bits sent by the receiver. - pub bs: Vec, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -/// The extend messages sent by the sender. -pub struct ExtendFromSender { - /// The mask `m0` and `m1`. - pub ms: Vec<[Block; 2]>, - /// The sum of the ggm tree leaves and delta. - pub sum: Block, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -/// The consistency check message sent from the receiver. -pub struct CheckFromReceiver { - /// The `x'` from the receiver. - pub x_prime: Vec, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -/// The consistency check message sent from the sender. -pub struct CheckFromSender { - /// The hashed `V` from the sender. - pub hashed_v: Hash, -} diff --git a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs deleted file mode 100644 index 5e860f31..00000000 --- a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs +++ /dev/null @@ -1,310 +0,0 @@ -//! SPCOT receiver -use crate::ferret::{spcot::error::ReceiverError, CSP}; -use itybity::ToBits; -use mpz_core::{ - aes::FIXED_KEY_AES, ggm_tree::GgmTree, hash::Hash, prg::Prg, serialize::CanonicalSerialize, - utils::blake3, Block, -}; -use rand_core::SeedableRng; - -use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; - -/// SPCOT receiver. -#[derive(Debug, Default)] -pub struct Receiver { - state: T, -} - -impl Receiver { - /// Creates a new Receiver. - pub fn new() -> Self { - Receiver { - state: state::Initialized::default(), - } - } - - /// Completes the setup phase of the protocol. - /// - /// See step 1 in Figure 6. - /// - pub fn setup(self) -> Receiver { - Receiver { - state: state::Extension { - unchecked_ws: Vec::default(), - chis: Vec::default(), - alphas_and_length: Vec::default(), - cot_counter: 0, - exec_counter: 0, - extended: false, - hasher: blake3::Hasher::new(), - }, - } - } -} - -impl Receiver { - /// Performs the mask bit step in extension. - /// - /// See step 4 in Figure 6. - /// - /// # Arguments - /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `rs` - The message from COT ideal functionality for the receiver. Only the random bits are used. - pub fn extend_mask_bits( - &mut self, - h: usize, - alpha: u32, - rs: &[bool], - ) -> Result { - if self.state.extended { - return Err(ReceiverError::InvalidState( - "extension is not allowed".to_string(), - )); - } - - if alpha >= (1 << h) { - return Err(ReceiverError::InvalidInput( - "the input pos should be no more than 2^h-1".to_string(), - )); - } - - if rs.len() != h { - return Err(ReceiverError::InvalidLength( - "the length of r should be h".to_string(), - )); - } - - // Step 4 in Figure 6 - - let bs: Vec = alpha - .iter_msb0() - .skip(32 - h) - // Computes alpha_i XOR r_i XOR 1. - .zip(rs.iter()) - .map(|(alpha, &r)| alpha == r) - .collect(); - - // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); - - Ok(MaskBits { bs }) - } - - /// Performs the GGM reconstruction step in extension. This function can be called multiple times before checking. - /// - /// See step 5 in Figure 6. - /// - /// # Arguments - /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `ts` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. - /// * `extendfs` - The message sent by the sender. - pub fn extend( - &mut self, - h: usize, - alpha: u32, - ts: &[Block], - extendfs: ExtendFromSender, - ) -> Result<(), ReceiverError> { - if self.state.extended { - return Err(ReceiverError::InvalidState( - "extension is not allowed".to_string(), - )); - } - - if alpha >= (1 << h) { - return Err(ReceiverError::InvalidInput( - "the input pos should be no more than 2^h-1".to_string(), - )); - } - - let ExtendFromSender { ms, sum } = extendfs; - if ts.len() != h { - return Err(ReceiverError::InvalidLength( - "the length of t should be h".to_string(), - )); - } - - if ms.len() != h { - return Err(ReceiverError::InvalidLength( - "the length of M should be h".to_string(), - )); - } - - // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); - - let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); - - // Step 5 in Figure 6. - let k: Vec = ms - .into_iter() - .zip(ts) - .zip(alpha_bar_vec.iter()) - .enumerate() - .map(|(i, (([m0, m1], &t), &b))| { - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - if !b { - // H(t, i|ell) ^ M0 - FIXED_KEY_AES.tccr(tweak, t) ^ m0 - } else { - // H(t, i|ell) ^ M1 - FIXED_KEY_AES.tccr(tweak, t) ^ m1 - } - }) - .collect(); - - // Reconstructs GGM tree except `ws[alpha]`. - let ggm_tree = GgmTree::new(h); - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.reconstruct(&mut tree, &k, &alpha_bar_vec); - - // Sets `tree[alpha]`, which is `ws[alpha]`. - tree[alpha as usize] = tree.iter().fold(sum, |acc, &x| acc ^ x); - - self.state.unchecked_ws.extend_from_slice(&tree); - self.state.alphas_and_length.push((alpha, 1 << h)); - - self.state.exec_counter += 1; - - Ok(()) - } - - /// Performs the decomposition and bit-mask steps in check. - /// - /// See step 7 in Figure 6. - /// - /// # Arguments - /// - /// * `x_star` - The message from COT ideal functionality for the receiver. Only the random bits are used. - pub fn check_pre(&mut self, x_star: &[bool]) -> Result { - if x_star.len() != CSP { - return Err(ReceiverError::InvalidLength(format!( - "the length of x* should be {CSP}" - ))); - } - - let seed = *self.state.hasher.finalize().as_bytes(); - let mut prg = Prg::from_seed(Block::try_from(&seed[0..16]).unwrap()); - - // The sum of all the chi[alpha]. - let mut sum_chi_alpha = Block::ZERO; - - for (alpha, n) in &self.state.alphas_and_length { - let mut chis = vec![Block::ZERO; *n as usize]; - prg.random_blocks(&mut chis); - sum_chi_alpha ^= chis[*alpha as usize]; - self.state.chis.extend_from_slice(&chis); - } - - let x_prime: Vec = sum_chi_alpha - .iter_lsb0() - .zip(x_star) - .map(|(x, &x_star)| x != x_star) - .collect(); - - Ok(CheckFromReceiver { x_prime }) - } - - /// Performs the final step of the consistency check. - /// - /// See step 9 in Figure 6. - /// - /// # Arguments - /// - /// * `z_star` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. - /// * `check` - The hashed value sent by the Sender. - pub fn check( - &mut self, - z_star: &[Block], - check: CheckFromSender, - ) -> Result, u32)>, ReceiverError> { - let CheckFromSender { hashed_v } = check; - - if z_star.len() != CSP { - return Err(ReceiverError::InvalidLength(format!( - "the length of z* should be {CSP}" - ))); - } - - // Computes the base X^i - let base: Vec = (0..CSP).map(|x| bytemuck::cast((1_u128) << x)).collect(); - - // Computes Z. - let mut w = Block::inn_prdt_red(z_star, &base); - - // Computes W. - w ^= Block::inn_prdt_red(&self.state.chis, &self.state.unchecked_ws); - - // Computes H'(W) - let hashed_w = Hash::from(blake3(&w.to_bytes())); - - if hashed_v != hashed_w { - return Err(ReceiverError::ConsistencyCheckFailed); - } - - self.state.cot_counter += self.state.unchecked_ws.len(); - self.state.extended = true; - - let mut res = Vec::new(); - for (alpha, n) in &self.state.alphas_and_length { - let tmp: Vec = self.state.unchecked_ws.drain(..*n as usize).collect(); - res.push((tmp, *alpha)); - } - - Ok(res) - } -} - -/// The receiver's state. -pub mod state { - use super::*; - - mod sealed { - pub trait Sealed {} - - impl Sealed for super::Initialized {} - impl Sealed for super::Extension {} - } - - /// The receiver's state. - pub trait State: sealed::Sealed {} - - /// The receiver's initial state. - #[derive(Default)] - pub struct Initialized {} - - impl State for Initialized {} - - opaque_debug::implement!(Initialized); - - /// The receiver's state after the setup phase. - /// - /// In this state the receiver performs COT extension and outputs random choice bits (potentially multiple times). - pub struct Extension { - /// Receiver's output blocks. - pub(super) unchecked_ws: Vec, - /// Receiver's random challenges chis. - pub(super) chis: Vec, - /// Stores the alpha and the length in each extend phase. - pub(super) alphas_and_length: Vec<(u32, u32)>, - - /// Current COT counter - pub(super) cot_counter: usize, - /// Current execution counter - pub(super) exec_counter: usize, - /// This is to prevent the receiver from extending twice - pub(super) extended: bool, - - /// A hasher to generate chi seed from the protocol transcript. - pub(super) hasher: blake3::Hasher, - } - - impl State for Extension {} - - opaque_debug::implement!(Extension); -} diff --git a/crates/mpz-ot-core/src/ferret/spcot/sender.rs b/crates/mpz-ot-core/src/ferret/spcot/sender.rs deleted file mode 100644 index fef1327e..00000000 --- a/crates/mpz-ot-core/src/ferret/spcot/sender.rs +++ /dev/null @@ -1,251 +0,0 @@ -//! SPCOT sender. -use crate::ferret::{spcot::error::SenderError, CSP}; -use mpz_core::{ - aes::FIXED_KEY_AES, ggm_tree::GgmTree, hash::Hash, prg::Prg, serialize::CanonicalSerialize, - utils::blake3, Block, -}; -use rand_core::SeedableRng; - -use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; - -/// SPCOT sender. -#[derive(Debug, Default)] -pub struct Sender { - state: T, -} - -impl Sender { - /// Creates a new Sender. - pub fn new() -> Self { - Sender { - state: state::Initialized::default(), - } - } - - /// Completes the setup phase of the protocol. - /// - /// See step 1 in Figure 6. - /// - /// # Arguments - /// - /// * `delta` - The sender's global secret. - /// * `seed` - The random seed to generate PRG. - pub fn setup(self, delta: Block, seed: Block) -> Sender { - Sender { - state: state::Extension { - delta, - unchecked_vs: Vec::default(), - vs_length: Vec::default(), - cot_counter: 0, - exec_counter: 0, - extended: false, - prg: Prg::from_seed(seed), - hasher: blake3::Hasher::new(), - }, - } - } -} - -impl Sender { - /// Performs the SPCOT extension. - /// - /// See Step 1-5 in Figure 6. - /// - /// # Arguments - /// - /// * `h` - The depth of the GGM tree. - /// * `qs`- The blocks received by calling the COT functionality. - /// * `mask`- The mask bits sent by the receiver. - pub fn extend( - &mut self, - h: usize, - qs: &[Block], - mask: MaskBits, - ) -> Result { - if self.state.extended { - return Err(SenderError::InvalidState( - "extension is not allowed".to_string(), - )); - } - - if qs.len() != h { - return Err(SenderError::InvalidLength( - "the length of q should be h".to_string(), - )); - } - - let MaskBits { bs } = mask; - - if bs.len() != h { - return Err(SenderError::InvalidLength( - "the length of b should be h".to_string(), - )); - } - - // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); - - // Step 3-4, Figure 6. - - // Generates a GGM tree with depth h and seed s. - let s = self.state.prg.random_block(); - let ggm_tree = GgmTree::new(h); - let mut k0 = vec![Block::ZERO; h]; - let mut k1 = vec![Block::ZERO; h]; - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.gen(s, &mut tree, &mut k0, &mut k1); - - // Stores the tree, i.e., the possible output of sender. - self.state.unchecked_vs.extend_from_slice(&tree); - - // Stores the length of this extension. - self.state.vs_length.push(1 << h); - - // Computes the sum of the leaves and delta. - let sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); - - // Computes M0 and M1. - let mut ms: Vec<[Block; 2]> = Vec::with_capacity(qs.len()); - for (((i, &q), b), (k0, k1)) in qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) { - let mut m = if b { - [q ^ self.state.delta, q] - } else { - [q, q ^ self.state.delta] - }; - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); - m[0] ^= k0; - m[1] ^= k1; - ms.push(m); - } - - // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); - - self.state.exec_counter += 1; - - Ok(ExtendFromSender { ms, sum }) - } - - /// Performs the consistency check for the resulting COTs. - /// - /// See Step 6-9 in Figure 6. - /// - /// # Arguments - /// - /// * `y_star` - The blocks received from the ideal functionality for the check. - /// * `checkfr` - The bits received from the receiver for the check. - pub fn check( - &mut self, - y_star: &[Block], - checkfr: CheckFromReceiver, - ) -> Result<(Vec>, CheckFromSender), SenderError> { - let CheckFromReceiver { x_prime } = checkfr; - - if y_star.len() != CSP { - return Err(SenderError::InvalidLength(format!( - "the length of y* should be {CSP}" - ))); - } - - if x_prime.len() != CSP { - return Err(SenderError::InvalidLength(format!( - "the length of x' should be {CSP}" - ))); - } - - // Step 8 in Figure 6. - - // Computes y = y_star + x' * Delta - let y: Vec = y_star - .iter() - .zip(x_prime.iter()) - .map(|(&y, &x)| if x { y ^ self.state.delta } else { y }) - .collect(); - - // Computes the base X^i - let base: Vec = (0..CSP).map(|x| bytemuck::cast((1_u128) << x)).collect(); - - // Computes Y - let mut v = Block::inn_prdt_red(&y, &base); - - // Computes V - let seed = *self.state.hasher.finalize().as_bytes(); - let mut prg = Prg::from_seed(Block::try_from(&seed[0..16]).unwrap()); - - let mut chis = Vec::new(); - for n in &self.state.vs_length { - let mut chi = vec![Block::ZERO; *n as usize]; - prg.random_blocks(&mut chi); - chis.extend_from_slice(&chi); - } - v ^= Block::inn_prdt_red(&chis, &self.state.unchecked_vs); - - // Computes H'(V) - let hashed_v = Hash::from(blake3(&v.to_bytes())); - - self.state.cot_counter += self.state.unchecked_vs.len(); - - let mut res = Vec::new(); - for n in &self.state.vs_length { - let tmp: Vec = self.state.unchecked_vs.drain(..*n as usize).collect(); - res.push(tmp); - } - - self.state.extended = true; - - Ok((res, CheckFromSender { hashed_v })) - } -} - -/// The sender's state. -pub mod state { - use super::*; - - mod sealed { - pub trait Sealed {} - - impl Sealed for super::Initialized {} - impl Sealed for super::Extension {} - } - - /// The sender's state. - pub trait State: sealed::Sealed {} - - /// The sender's initial state. - #[derive(Default)] - pub struct Initialized {} - - impl State for Initialized {} - - opaque_debug::implement!(Initialized); - - /// The sender's state after the setup phase. - /// - /// In this state the sender performs COT extension with random choice bits (potentially multiple times). Also in this state the sender responds to COT requests. - pub struct Extension { - /// Sender's global secret. - pub(super) delta: Block, - /// Sender's output blocks, support multiple extensions. - pub(super) unchecked_vs: Vec, - /// Store the length of each extension. - pub(super) vs_length: Vec, - - /// Current COT counter - pub(super) cot_counter: usize, - /// Current execution counter - pub(super) exec_counter: usize, - /// This is to prevent the receiver from extending twice - pub(super) extended: bool, - - /// A PRG to generate random strings. - pub(super) prg: Prg, - /// A hasher to generate chi seed. - pub(super) hasher: blake3::Hasher, - } - - impl State for Extension {} - - opaque_debug::implement!(Extension); -} diff --git a/crates/mpz-ot-core/src/ideal.rs b/crates/mpz-ot-core/src/ideal.rs new file mode 100644 index 00000000..aac1b71d --- /dev/null +++ b/crates/mpz-ot-core/src/ideal.rs @@ -0,0 +1,6 @@ +//! Ideal functionalities. + +pub mod cot; +pub mod ot; +pub mod rcot; +pub mod rot; diff --git a/crates/mpz-ot-core/src/ideal/cot.rs b/crates/mpz-ot-core/src/ideal/cot.rs index a28abef8..6e1b71b6 100644 --- a/crates/mpz-ot-core/src/ideal/cot.rs +++ b/crates/mpz-ot-core/src/ideal/cot.rs @@ -1,167 +1,262 @@ //! Ideal Correlated Oblivious Transfer functionality. -use mpz_core::{prg::Prg, Block}; -use rand::{Rng, SeedableRng}; -use rand_chacha::ChaCha8Rng; +use std::{ + mem, + sync::{Arc, Mutex}, +}; -use crate::TransferId; -use crate::{COTReceiverOutput, COTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; +use mpz_common::future::{new_output, MaybeDone, Output, Sender}; +use mpz_core::Block; -/// The ideal COT functionality. -#[derive(Debug)] +use crate::{ + cot::{COTReceiver, COTReceiverOutput, COTSender, COTSenderOutput}, + TransferId, +}; + +type Error = IdealCOTError; +type Result = core::result::Result; + +#[derive(Debug, Default)] +struct SenderState { + alloc: usize, + transfer_id: TransferId, + queue: Vec<(usize, Sender)>, +} + +#[derive(Debug, Default)] +struct ReceiverState { + alloc: usize, + transfer_id: TransferId, + queue: Vec<(usize, Sender>)>, +} + +/// Ideal COT functionality. +#[derive(Debug, Clone)] pub struct IdealCOT { + inner: Arc>, +} + +#[derive(Debug)] +struct Inner { delta: Block, - transfer_id: TransferId, - counter: usize, - prg: Prg, + + sender_state: SenderState, + receiver_state: ReceiverState, + + keys: Vec, + choices: Vec, } impl IdealCOT { - /// Creates a new ideal OT functionality. + /// Creates a new ideal COT functionality. /// /// # Arguments /// - /// * `seed` - The seed for the PRG. - /// * `delta` - The correlation. - pub fn new(seed: Block, delta: Block) -> Self { + /// * `delta` - Global correlation key. + pub fn new(delta: Block) -> Self { IdealCOT { - delta, - transfer_id: TransferId::default(), - counter: 0, - prg: Prg::from_seed(seed), + inner: Arc::new(Mutex::new(Inner { + delta, + sender_state: SenderState::default(), + receiver_state: ReceiverState::default(), + keys: Vec::new(), + choices: Vec::new(), + })), } } - /// Returns the correlation, delta. - pub fn delta(&self) -> Block { - self.delta + /// Transfers correlated OTs. + pub fn transfer( + &mut self, + choices: &[bool], + keys: &[Block], + ) -> Result<(COTSenderOutput, COTReceiverOutput)> { + if choices.len() != keys.len() { + return Err(Error::new(format!( + "choices and keys length mismatch: {} != {}", + choices.len(), + keys.len() + ))); + } + + let mut sender_output = self.queue_send_cot(keys)?; + let mut receiver_output = self.queue_recv_cot(choices)?; + + self.flush()?; + + Ok(( + sender_output.try_recv().unwrap().unwrap(), + receiver_output.try_recv().unwrap().unwrap(), + )) } - /// Sets the correlation, delta. - pub fn set_delta(&mut self, delta: Block) { - self.delta = delta; + /// Returns `true` if the functionality wants to be flushed. + pub fn wants_flush(&self) -> bool { + let this = self.inner.lock().unwrap(); + let sender_queue = this.sender_state.queue.len(); + let receiver_queue = this.receiver_state.queue.len(); + + sender_queue > 0 && receiver_queue > 0 && sender_queue == receiver_queue } - /// Returns the current transfer id. - pub fn transfer_id(&self) -> TransferId { - self.transfer_id + /// Flushes the functionality. + pub fn flush(&mut self) -> Result<()> { + let mut this = self.inner.lock().unwrap(); + if this.sender_state.alloc != this.receiver_state.alloc { + return Err(Error::new(format!( + "sender and receiver alloc out of sync: {} != {}", + this.sender_state.alloc, this.receiver_state.alloc + ))); + } else if this.keys.len() != this.choices.len() { + return Err(Error::new(format!( + "keys and choices length mismatch: {} != {}", + this.keys.len(), + this.choices.len() + ))); + } + + this.sender_state.alloc = 0; + this.receiver_state.alloc = 0; + + let keys = mem::take(&mut this.keys); + let choices = mem::take(&mut this.choices); + let sender_queue = mem::take(&mut this.sender_state.queue); + let receiver_queue = mem::take(&mut this.receiver_state.queue); + + let delta = this.delta; + let mut msgs = keys.into_iter().zip(choices).map( + move |(key, choice)| { + if choice { + key ^ delta + } else { + key + } + }, + ); + + for ((sender_count, sender_output), (receiver_count, receiver_output)) in + sender_queue.into_iter().zip(receiver_queue.into_iter()) + { + let sender_id = this.sender_state.transfer_id.next(); + let receiver_id = this.receiver_state.transfer_id.next(); + + if sender_count != receiver_count { + return Err(Error::new(format!("number of messages and choices do not match ({sender_id}): {sender_count} != {receiver_count}"))); + } + + sender_output.send(COTSenderOutput { id: sender_id }); + receiver_output.send(COTReceiverOutput { + id: receiver_id, + msgs: msgs.by_ref().take(receiver_count).collect(), + }); + } + + Ok(()) } +} + +impl COTSender for IdealCOT { + type Error = Error; + type Future = MaybeDone; - /// Returns the number of OTs executed. - pub fn count(&self) -> usize { - self.counter + fn alloc(&mut self, count: usize) -> Result<()> { + let mut this = self.inner.lock().unwrap(); + this.sender_state.alloc += count; + Ok(()) } - /// Executes random correlated oblivious transfers. - /// - /// The functionality deals random choices to the receiver, along with the corresponding messages. - /// - /// # Arguments - /// - /// * `count` - The number of COTs to execute. - pub fn random_correlated( - &mut self, - count: usize, - ) -> (RCOTSenderOutput, RCOTReceiverOutput) { - let mut msgs = vec![Block::ZERO; count]; - let mut choices = vec![false; count]; - - self.prg.random_blocks(&mut msgs); - self.prg.random_bools(&mut choices); - - let chosen: Vec = msgs - .iter() - .zip(choices.iter()) - .map(|(&q, &r)| if r { q ^ self.delta } else { q }) - .collect(); - - self.counter += count; - let id = self.transfer_id.next(); - - ( - RCOTSenderOutput { id, msgs }, - RCOTReceiverOutput { - id, - choices, - msgs: chosen, - }, - ) + fn available(&self) -> usize { + let this = self.inner.lock().unwrap(); + this.keys.len() } - /// Executes correlated oblivious transfers with choices provided by the receiver. - /// - /// # Arguments - /// - /// * `choices` - The choices made by the receiver. - pub fn correlated( + fn delta(&self) -> Block { + let this = self.inner.lock().unwrap(); + this.delta + } + + fn queue_send_cot( &mut self, - choices: Vec, - ) -> (COTSenderOutput, COTReceiverOutput) { - let (sender_output, mut receiver_output) = self.random_correlated(choices.len()); - - receiver_output - .msgs - .iter_mut() - .zip(choices.iter().zip(receiver_output.choices)) - .for_each(|(msg, (&actual_choice, random_choice))| { - if actual_choice ^ random_choice { - *msg ^= self.delta - } - }); + keys: &[Block], + ) -> Result, Self::Error> { + let mut this = self.inner.lock().unwrap(); - ( - COTSenderOutput { - id: sender_output.id, - msgs: sender_output.msgs, - }, - COTReceiverOutput { - id: receiver_output.id, - msgs: receiver_output.msgs, - }, - ) + this.keys.extend_from_slice(keys); + + let (send, recv) = new_output(); + + this.sender_state.queue.push((keys.len(), send)); + + Ok(recv) } } -impl Default for IdealCOT { - fn default() -> Self { - let mut rng = ChaCha8Rng::seed_from_u64(0); - Self::new(rng.gen(), rng.gen()) +impl COTReceiver for IdealCOT { + type Error = Error; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<()> { + let mut this = self.inner.lock().unwrap(); + this.receiver_state.alloc += count; + Ok(()) + } + + fn available(&self) -> usize { + let this = self.inner.lock().unwrap(); + this.choices.len() + } + + fn queue_recv_cot(&mut self, choices: &[bool]) -> Result>> { + let mut this = self.inner.lock().unwrap(); + + this.choices.extend_from_slice(choices); + + let (send, recv) = new_output(); + + this.receiver_state.queue.push((choices.len(), send)); + + Ok(recv) + } +} + +/// Error for [`IdealCOT`]. +#[derive(Debug, thiserror::Error)] +#[error("ideal COT error: {0}")] +pub struct IdealCOTError(String); + +impl IdealCOTError { + fn new(msg: impl Into) -> Self { + Self(msg.into()) } } #[cfg(test)] mod tests { + use rand::{rngs::StdRng, Rng, SeedableRng}; + use super::*; use crate::test::assert_cot; #[test] - fn test_ideal_rcot() { - let mut ideal = IdealCOT::default(); + fn test_ideal_cot() { + let mut rng = StdRng::seed_from_u64(0); + let delta = Block::random(&mut rng); + let mut ideal = IdealCOT::new(delta); + + let count = 128; + let choices = (0..count).map(|_| rng.gen()).collect::>(); + let keys = (0..count).map(|_| rng.gen()).collect::>(); let ( - RCOTSenderOutput { msgs, .. }, - RCOTReceiverOutput { - choices, + COTSenderOutput { id: sender_id }, + COTReceiverOutput { + id: receiver_id, msgs: received, - .. }, - ) = ideal.random_correlated(100); - - assert_cot(ideal.delta(), &choices, &msgs, &received) - } - - #[test] - fn test_ideal_cot() { - let mut ideal = IdealCOT::default(); - - let mut rng = ChaCha8Rng::seed_from_u64(0); - let mut choices = vec![false; 100]; - rng.fill(&mut choices[..]); - - let (COTSenderOutput { msgs, .. }, COTReceiverOutput { msgs: received, .. }) = - ideal.correlated(choices.clone()); + ) = ideal.transfer(&choices, &keys).unwrap(); - assert_cot(ideal.delta(), &choices, &msgs, &received) + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &choices, &keys, &received) } } diff --git a/crates/mpz-ot-core/src/ideal/mod.rs b/crates/mpz-ot-core/src/ideal/mod.rs deleted file mode 100644 index 8e1bcb61..00000000 --- a/crates/mpz-ot-core/src/ideal/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! Define ideal functionalities of OTs. - -pub mod cot; -pub mod mpcot; -pub mod ot; -pub mod rot; -pub mod spcot; diff --git a/crates/mpz-ot-core/src/ideal/mpcot.rs b/crates/mpz-ot-core/src/ideal/mpcot.rs deleted file mode 100644 index 44a5595f..00000000 --- a/crates/mpz-ot-core/src/ideal/mpcot.rs +++ /dev/null @@ -1,97 +0,0 @@ -//! Ideal functionality for the multi-point correlated OT. - -use mpz_core::{prg::Prg, Block}; -use rand::{Rng, SeedableRng}; -use rand_chacha::ChaCha8Rng; - -use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, TransferId}; - -/// The ideal MPCOT functionality. -#[derive(Debug)] -pub struct IdealMpcot { - delta: Block, - transfer_id: TransferId, - counter: usize, - prg: Prg, -} - -impl IdealMpcot { - /// Creates a new ideal MPCOT functionality. - pub fn new(seed: Block, delta: Block) -> Self { - IdealMpcot { - delta, - transfer_id: TransferId::default(), - counter: 0, - prg: Prg::from_seed(seed), - } - } - - /// Returns the correlation, delta. - pub fn delta(&self) -> Block { - self.delta - } - - /// Sets the correlation, delta. - pub fn set_delta(&mut self, delta: Block) { - self.delta = delta; - } - - /// Performs the extension of MPCOT. - /// - /// # Argument - /// - /// * `alphas` - The positions in each extension. - /// * `n` - The length of the vector. - pub fn extend( - &mut self, - alphas: &[u32], - n: usize, - ) -> (MPCOTSenderOutput, MPCOTReceiverOutput) { - assert!(alphas.len() < n); - let mut s = vec![Block::ZERO; n]; - let mut r = vec![Block::ZERO; n]; - self.prg.random_blocks(&mut s); - r.copy_from_slice(&s); - - for alpha in alphas { - assert!((*alpha as usize) < n); - r[*alpha as usize] ^= self.delta; - - self.counter += 1; - } - - let id = self.transfer_id.next(); - - (MPCOTSenderOutput { id, s }, MPCOTReceiverOutput { id, r }) - } -} - -impl Default for IdealMpcot { - fn default() -> Self { - let mut rng = ChaCha8Rng::seed_from_u64(0); - IdealMpcot::new(rng.gen(), rng.gen()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn ideal_mpcot_test() { - let mut ideal = IdealMpcot::default(); - - let alphas = [1, 3, 4, 6]; - let n = 20; - - let (MPCOTSenderOutput { mut s, .. }, MPCOTReceiverOutput { r, .. }) = - ideal.extend(&alphas, n); - - for alpha in alphas { - assert!((alpha as usize) < n); - s[alpha as usize] ^= ideal.delta(); - } - - assert!(s.iter_mut().zip(r.iter()).all(|(s, r)| *s == *r)); - } -} diff --git a/crates/mpz-ot-core/src/ideal/ot.rs b/crates/mpz-ot-core/src/ideal/ot.rs index e389066e..6ab3d411 100644 --- a/crates/mpz-ot-core/src/ideal/ot.rs +++ b/crates/mpz-ot-core/src/ideal/ot.rs @@ -1,39 +1,98 @@ //! Ideal Chosen-Message Oblivious Transfer functionality. -use crate::{OTReceiverOutput, OTSenderOutput, TransferId}; +use std::{ + mem, + sync::{Arc, Mutex}, +}; + +use mpz_common::future::{new_output, MaybeDone, Output, Sender}; +use mpz_core::Block; + +use crate::{ + ot::{OTReceiver, OTReceiverOutput, OTSender, OTSenderOutput}, + TransferId, +}; -/// The ideal OT functionality. #[derive(Debug, Default)] -pub struct IdealOT { +struct SenderState { + transfer_id: TransferId, + queue: Vec<(usize, Sender)>, +} + +#[derive(Debug, Default)] +struct ReceiverState { transfer_id: TransferId, - counter: usize, - /// Log of choices made by the receiver. + queue: Vec<(usize, Sender>)>, +} + +/// The ideal OT functionality. +#[derive(Debug, Default, Clone)] +pub struct IdealOT { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct Inner { + sender_state: SenderState, + receiver_state: ReceiverState, + + msgs: Vec<[Block; 2]>, choices: Vec, } impl IdealOT { /// Creates a new ideal OT functionality. pub fn new() -> Self { - IdealOT { - transfer_id: TransferId::default(), - counter: 0, - choices: Vec::new(), - } + Self::default() } - /// Returns the current transfer id. - pub fn transfer_id(&self) -> TransferId { - self.transfer_id - } + /// Returns `true` if the functionality wants to be flushed. + pub fn wants_flush(&self) -> bool { + let this = self.inner.lock().unwrap(); + let sender_queue = this.sender_state.queue.len(); + let receiver_queue = this.receiver_state.queue.len(); - /// Returns the number of OTs executed. - pub fn count(&self) -> usize { - self.counter + sender_queue > 0 && receiver_queue > 0 && sender_queue == receiver_queue } - /// Returns the choices made by the receiver. - pub fn choices(&self) -> &[bool] { - &self.choices + /// Flushes the functionality. + pub fn flush(&mut self) -> Result<(), IdealOTError> { + let mut this = self.inner.lock().unwrap(); + + if this.msgs.len() != this.choices.len() { + return Err(IdealOTError::new( + "number of messages and choices do not match", + )); + } + + let sender_queue = mem::take(&mut this.sender_state.queue); + let receiver_queue = mem::take(&mut this.receiver_state.queue); + let msgs = mem::take(&mut this.msgs); + let choices = mem::take(&mut this.choices); + + let mut msgs = msgs + .into_iter() + .zip(choices) + .map(|([zero, one], choice)| if choice { one } else { zero }); + + for ((sender_count, sender_output), (receiver_count, receiver_output)) in + sender_queue.into_iter().zip(receiver_queue) + { + let sender_id = this.sender_state.transfer_id.next(); + let receiver_id = this.receiver_state.transfer_id.next(); + + if sender_count != receiver_count { + return Err(IdealOTError::new(format!("number of messages and choices do not match ({sender_id}): {sender_count} != {receiver_count}"))); + } + + sender_output.send(OTSenderOutput { id: sender_id }); + receiver_output.send(OTReceiverOutput { + id: receiver_id, + msgs: msgs.by_ref().take(sender_count).collect(), + }); + } + + Ok(()) } /// Executes chosen-message oblivious transfers. @@ -42,52 +101,94 @@ impl IdealOT { /// /// * `choices` - The choices made by the receiver. /// * `msgs` - The sender's messages. - pub fn chosen( + pub fn transfer( &mut self, - choices: Vec, - msgs: Vec<[T; 2]>, - ) -> (OTSenderOutput, OTReceiverOutput) { - let chosen = choices - .iter() - .zip(msgs.iter()) - .map(|(&choice, [zero, one])| if choice { *one } else { *zero }) - .collect(); - - self.counter += choices.len(); - self.choices.extend(choices); - let id = self.transfer_id.next(); - - (OTSenderOutput { id }, OTReceiverOutput { id, msgs: chosen }) + choices: &[bool], + msgs: &[[Block; 2]], + ) -> Result<(OTSenderOutput, OTReceiverOutput), IdealOTError> { + let mut sender_output = self.queue_send_ot(msgs)?; + let mut receiver_output = self.queue_recv_ot(choices)?; + + self.flush()?; + + Ok(( + sender_output.try_recv().unwrap().unwrap(), + receiver_output.try_recv().unwrap().unwrap(), + )) + } +} + +impl OTSender for IdealOT { + type Error = IdealOTError; + type Future = MaybeDone; + + fn alloc(&mut self, _count: usize) -> Result<(), Self::Error> { + Ok(()) + } + + fn queue_send_ot(&mut self, msgs: &[[Block; 2]]) -> Result { + let mut this = self.inner.lock().unwrap(); + this.msgs.extend_from_slice(msgs); + + let (sender, recv) = new_output(); + + this.sender_state.queue.push((msgs.len(), sender)); + + Ok(recv) + } +} + +impl OTReceiver for IdealOT { + type Error = IdealOTError; + type Future = MaybeDone>; + + fn alloc(&mut self, _count: usize) -> Result<(), Self::Error> { + Ok(()) + } + + fn queue_recv_ot(&mut self, choices: &[bool]) -> Result { + let mut this = self.inner.lock().unwrap(); + this.choices.extend_from_slice(choices); + + let (sender, recv) = new_output(); + + this.receiver_state.queue.push((choices.len(), sender)); + + Ok(recv) + } +} + +/// Error for [`IdealOT`]. +#[derive(Debug, thiserror::Error)] +#[error("Ideal OT error: {0}")] +pub struct IdealOTError(String); + +impl IdealOTError { + fn new(msg: impl Into) -> Self { + Self(msg.into()) } } #[cfg(test)] mod tests { use mpz_core::Block; - use rand::{Rng, SeedableRng}; - use rand_chacha::ChaCha8Rng; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use crate::test::assert_ot; use super::*; #[test] fn test_ideal_ot() { - let mut rng = ChaCha8Rng::seed_from_u64(0); + let mut rng = StdRng::seed_from_u64(0); let mut choices = vec![false; 100]; rng.fill(&mut choices[..]); let msgs: Vec<[Block; 2]> = (0..100).map(|_| [rng.gen(), rng.gen()]).collect(); let (OTSenderOutput { .. }, OTReceiverOutput { msgs: chosen, .. }) = - IdealOT::default().chosen(choices.clone(), msgs.clone()); - - assert!(choices.into_iter().zip(msgs.into_iter().zip(chosen)).all( - |(choice, (msg, chosen))| { - if choice { - chosen == msg[1] - } else { - chosen == msg[0] - } - } - )); + IdealOT::default().transfer(&choices, &msgs).unwrap(); + + assert_ot(&choices, &msgs, &chosen); } } diff --git a/crates/mpz-ot-core/src/ideal/rcot.rs b/crates/mpz-ot-core/src/ideal/rcot.rs new file mode 100644 index 00000000..73ed520c --- /dev/null +++ b/crates/mpz-ot-core/src/ideal/rcot.rs @@ -0,0 +1,311 @@ +//! Ideal Random Correlated Oblivious Transfer functionality. + +use std::{ + mem, + sync::{Arc, Mutex}, +}; + +use mpz_common::future::{new_output, MaybeDone, Sender}; +use mpz_core::{prg::Prg, Block}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +use crate::{ + rcot::{RCOTReceiver, RCOTReceiverOutput, RCOTSender, RCOTSenderOutput}, + TransferId, +}; + +type Error = IdealRCOTError; +type Result = core::result::Result; + +#[derive(Debug, Default)] +struct SenderState { + alloc: usize, + transfer_id: TransferId, + queue: Vec<(usize, Sender>)>, +} + +#[derive(Debug, Default)] +struct ReceiverState { + alloc: usize, + transfer_id: TransferId, + queue: Vec<(usize, Sender>)>, +} + +/// Ideal RCOT functionality. +#[derive(Debug, Clone)] +pub struct IdealRCOT { + inner: Arc>, +} + +#[derive(Debug)] +struct Inner { + delta: Block, + prg: Prg, + + sender_state: SenderState, + receiver_state: ReceiverState, + + keys: Vec, + msgs: Vec, + choices: Vec, +} + +impl IdealRCOT { + /// Creates a new ideal RCOT functionality. + /// + /// # Arguments + /// + /// * `seed` - Seed for the PRG. + /// * `delta` - Global correlation key. + pub fn new(seed: Block, delta: Block) -> Self { + IdealRCOT { + inner: Arc::new(Mutex::new(Inner { + delta, + prg: Prg::from_seed(seed), + sender_state: SenderState::default(), + receiver_state: ReceiverState::default(), + keys: Vec::new(), + msgs: Vec::new(), + choices: Vec::new(), + })), + } + } + + /// Allocates `count` random correlated OTs. + pub fn alloc(&mut self, count: usize) { + let mut this = self.inner.lock().unwrap(); + this.sender_state.alloc += count; + this.receiver_state.alloc += count; + } + + /// Transfers `count` random correlated OTs. + pub fn transfer( + &mut self, + count: usize, + ) -> Result<(RCOTSenderOutput, RCOTReceiverOutput)> { + Ok((self.try_send_rcot(count)?, self.try_recv_rcot(count)?)) + } + + /// Returns `true` if the functionality wants to be flushed. + pub fn wants_flush(&self) -> bool { + let this = self.inner.lock().unwrap(); + let sender_queue = this.sender_state.queue.len(); + let receiver_queue = this.receiver_state.queue.len(); + + sender_queue > 0 && receiver_queue > 0 && sender_queue == receiver_queue + } + + /// Flushes pending operations. + pub fn flush(&mut self) -> Result<()> { + let mut this = self.inner.lock().unwrap(); + if this.sender_state.alloc != this.receiver_state.alloc { + return Err(Error::new(format!( + "sender and receiver alloc out of sync: {} != {}", + this.sender_state.alloc, this.receiver_state.alloc + ))); + } + + let count = this.sender_state.alloc; + + let keys = (0..count).map(|_| this.prg.gen()).collect::>(); + let choices = (0..count).map(|_| this.prg.gen()).collect::>(); + let msgs = keys + .iter() + .zip(&choices) + .map(|(key, choice)| if *choice { *key ^ this.delta } else { *key }) + .collect::>(); + + this.keys.extend_from_slice(&keys); + this.choices.extend_from_slice(&choices); + this.msgs.extend_from_slice(&msgs); + + this.sender_state.alloc = 0; + this.receiver_state.alloc = 0; + + let mut i = 0; + for (count, sender) in mem::take(&mut this.sender_state.queue) { + let keys = this.keys[i..i + count].to_vec(); + i += count; + sender.send(RCOTSenderOutput { + id: this.sender_state.transfer_id, + keys, + }); + } + this.keys.drain(..i); + + i = 0; + for (count, sender) in mem::take(&mut this.receiver_state.queue) { + let choices = this.choices[i..i + count].to_vec(); + let keys = this.msgs[i..i + count].to_vec(); + i += count; + sender.send(RCOTReceiverOutput { + id: this.receiver_state.transfer_id, + choices, + msgs: keys, + }); + } + this.choices.drain(..i); + this.msgs.drain(..i); + + Ok(()) + } +} + +impl RCOTSender for IdealRCOT { + type Error = Error; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<()> { + let mut this = self.inner.lock().unwrap(); + this.sender_state.alloc += count; + Ok(()) + } + + fn available(&self) -> usize { + let this = self.inner.lock().unwrap(); + this.keys.len() + } + + fn delta(&self) -> Block { + let this = self.inner.lock().unwrap(); + this.delta + } + + fn try_send_rcot(&mut self, count: usize) -> Result> { + let mut this = self.inner.lock().unwrap(); + if count > this.keys.len() { + return Err(Error::new(format!( + "not enough sender RCOTs available: {} < {}", + this.keys.len(), + count + ))); + } + + let id = this.sender_state.transfer_id.next(); + let keys = this.keys.drain(..count).collect(); + + Ok(RCOTSenderOutput { id, keys }) + } + + fn queue_send_rcot( + &mut self, + count: usize, + ) -> Result>, Self::Error> { + let mut this = self.inner.lock().unwrap(); + let (send, recv) = new_output(); + + let available = this.keys.len(); + if available >= count { + let id = this.sender_state.transfer_id.next(); + let keys = this.keys.drain(..count).collect(); + + send.send(RCOTSenderOutput { id, keys }); + } else { + this.sender_state.queue.push((count, send)); + } + + Ok(recv) + } +} + +impl RCOTReceiver for IdealRCOT { + type Error = Error; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<()> { + let mut this = self.inner.lock().unwrap(); + this.receiver_state.alloc += count; + Ok(()) + } + + fn available(&self) -> usize { + let this = self.inner.lock().unwrap(); + this.choices.len() + } + + fn try_recv_rcot(&mut self, count: usize) -> Result> { + let mut this = self.inner.lock().unwrap(); + if count > this.choices.len() { + return Err(Error::new(format!( + "not enough receiver RCOTs available: {} < {}", + this.choices.len(), + count + ))); + } + + let choices = this.choices.drain(..count).collect(); + let msgs = this.msgs.drain(..count).collect(); + + Ok(RCOTReceiverOutput { + id: this.receiver_state.transfer_id.next(), + choices, + msgs, + }) + } + + fn queue_recv_rcot( + &mut self, + count: usize, + ) -> Result>> { + let mut this = self.inner.lock().unwrap(); + let (send, recv) = new_output(); + + let available = this.choices.len(); + if available >= count { + let id = this.receiver_state.transfer_id.next(); + let choices = this.choices.drain(..count).collect(); + let msgs = this.msgs.drain(..count).collect(); + + send.send(RCOTReceiverOutput { id, choices, msgs }); + } else { + this.receiver_state.queue.push((count, send)); + } + + Ok(recv) + } +} + +impl Default for IdealRCOT { + fn default() -> Self { + let mut rng = ChaCha8Rng::seed_from_u64(0); + Self::new(rng.gen(), rng.gen()) + } +} + +/// Error for [`IdealRCOT`]. +#[derive(Debug, thiserror::Error)] +#[error("ideal RCOT error: {0}")] +pub struct IdealRCOTError(String); + +impl IdealRCOTError { + fn new(msg: impl Into) -> Self { + Self(msg.into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::test::assert_cot; + + #[test] + fn test_ideal_rcot() { + let mut ideal = IdealRCOT::default(); + + ideal.alloc(100); + ideal.flush().unwrap(); + + let ( + RCOTSenderOutput { keys: msgs, .. }, + RCOTReceiverOutput { + choices, + msgs: received, + .. + }, + ) = ideal.transfer(100).unwrap(); + + assert_cot(ideal.delta(), &choices, &msgs, &received) + } +} diff --git a/crates/mpz-ot-core/src/ideal/rot.rs b/crates/mpz-ot-core/src/ideal/rot.rs index 8a8b5d68..28cc70fc 100644 --- a/crates/mpz-ot-core/src/ideal/rot.rs +++ b/crates/mpz-ot-core/src/ideal/rot.rs @@ -1,20 +1,51 @@ //! Ideal Random Oblivious Transfer functionality. -use mpz_core::{prg::Prg, Block}; -use rand::{ - distributions::{Distribution, Standard}, - Rng, SeedableRng, +use std::{ + mem, + sync::{Arc, Mutex}, }; + +use mpz_common::future::{new_output, MaybeDone, Sender}; +use mpz_core::{prg::Prg, Block}; +use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use crate::{ROTReceiverOutput, ROTSenderOutput, TransferId}; +use crate::{ + rot::{ROTReceiver, ROTReceiverOutput, ROTSender, ROTSenderOutput}, + TransferId, +}; + +type Error = IdealROTError; +type Result = core::result::Result; + +#[derive(Debug, Default)] +struct SenderState { + alloc: usize, + transfer_id: TransferId, + queue: Vec<(usize, Sender>)>, +} + +#[derive(Debug, Default)] +struct ReceiverState { + alloc: usize, + transfer_id: TransferId, + queue: Vec<(usize, Sender>)>, +} /// The ideal ROT functionality. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct IdealROT { - transfer_id: TransferId, - counter: usize, + inner: Arc>, +} + +#[derive(Debug)] +struct Inner { prg: Prg, + sender_state: SenderState, + receiver_state: ReceiverState, + keys: Vec<[Block; 2]>, + msgs: Vec, + choices: Vec, } impl IdealROT { @@ -25,94 +56,196 @@ impl IdealROT { /// * `seed` - The seed for the PRG. pub fn new(seed: Block) -> Self { IdealROT { - transfer_id: TransferId::default(), - counter: 0, - prg: Prg::from_seed(seed), + inner: Arc::new(Mutex::new(Inner { + prg: Prg::from_seed(seed), + sender_state: SenderState::default(), + receiver_state: ReceiverState::default(), + keys: Vec::new(), + msgs: Vec::new(), + choices: Vec::new(), + })), } } - /// Returns the current transfer id. - pub fn transfer_id(&self) -> TransferId { - self.transfer_id + /// Returns `count` random ROTs. + pub fn transfer( + &mut self, + count: usize, + ) -> Result<(ROTSenderOutput<[Block; 2]>, ROTReceiverOutput)> { + Ok((self.try_send_rot(count)?, self.try_recv_rot(count)?)) } - /// Returns the number of OTs executed. - pub fn count(&self) -> usize { - self.counter - } + /// Returns `true` if the functionality wants to be flushed. + pub fn wants_flush(&self) -> bool { + let this = self.inner.lock().unwrap(); + let sender_queue = this.sender_state.queue.len(); + let receiver_queue = this.receiver_state.queue.len(); - /// Executes random oblivious transfers. - /// - /// # Arguments - /// - /// * `count` - The number of OTs to execute. - pub fn random( - &mut self, - count: usize, - ) -> (ROTSenderOutput<[T; 2]>, ROTReceiverOutput) - where - Standard: Distribution, - { - let mut choices = vec![false; count]; + sender_queue > 0 && receiver_queue > 0 && sender_queue == receiver_queue + } - self.prg.random_bools(&mut choices); + /// Flushes the functionality. + pub fn flush(&mut self) -> Result<()> { + let mut this = self.inner.lock().unwrap(); + if this.sender_state.alloc != this.receiver_state.alloc { + return Err(Error::new(format!( + "sender and receiver alloc out of sync: {} != {}", + this.sender_state.alloc, this.receiver_state.alloc + ))); + } - let msgs: Vec<[T; 2]> = (0..count) - .map(|_| [self.prg.sample(Standard), self.prg.sample(Standard)]) - .collect(); + let count = this.sender_state.alloc; - let chosen = choices + let keys = (0..count) + .map(|_| [this.prg.gen(), this.prg.gen()]) + .collect::>(); + let choices = (0..count).map(|_| this.prg.gen()).collect::>(); + let msgs = keys .iter() - .zip(msgs.iter()) - .map(|(&choice, [zero, one])| if choice { *one } else { *zero }) - .collect(); + .zip(&choices) + .map(|(keys, choice)| keys[*choice as usize]) + .collect::>(); - self.counter += count; - let id = self.transfer_id.next(); + this.keys.extend_from_slice(&keys); + this.choices.extend_from_slice(&choices); + this.msgs.extend_from_slice(&msgs); - ( - ROTSenderOutput { id, msgs }, - ROTReceiverOutput { - id, + this.sender_state.alloc = 0; + this.receiver_state.alloc = 0; + + let mut i = 0; + for (count, sender) in mem::take(&mut this.sender_state.queue) { + let keys = this.keys[i..i + count].to_vec(); + i += count; + sender.send(ROTSenderOutput { + id: this.sender_state.transfer_id, + keys, + }); + } + this.keys.drain(..i); + + i = 0; + for (count, sender) in mem::take(&mut this.receiver_state.queue) { + let choices = this.choices[i..i + count].to_vec(); + let keys = this.msgs[i..i + count].to_vec(); + i += count; + sender.send(ROTReceiverOutput { + id: this.receiver_state.transfer_id, choices, - msgs: chosen, - }, - ) + msgs: keys, + }); + } + this.choices.drain(..i); + this.msgs.drain(..i); + + Ok(()) } +} - /// Executes random oblivious transfers with choices provided by the receiver. - /// - /// # Arguments - /// - /// * `choices` - The choices made by the receiver. - pub fn random_with_choices( +impl ROTSender<[Block; 2]> for IdealROT { + type Error = IdealROTError; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + let mut this = self.inner.lock().unwrap(); + this.sender_state.alloc += count; + Ok(()) + } + + fn available(&self) -> usize { + let this = self.inner.lock().unwrap(); + this.keys.len() + } + + fn try_send_rot(&mut self, count: usize) -> Result, Self::Error> { + let mut this = self.inner.lock().unwrap(); + if this.keys.len() < count { + return Err(IdealROTError::new(format!( + "not enough ROTs available: {} < {}", + this.keys.len(), + count + ))); + } + + let keys = this.keys.drain(..count).collect(); + Ok(ROTSenderOutput { + id: this.sender_state.transfer_id.next(), + keys, + }) + } + + fn queue_send_rot(&mut self, count: usize) -> Result { + let mut this = self.inner.lock().unwrap(); + let (sender, recv) = new_output(); + + if this.keys.len() >= count { + let keys = this.keys.drain(..count).collect(); + sender.send(ROTSenderOutput { + id: this.sender_state.transfer_id.next(), + keys, + }); + } else { + this.sender_state.queue.push((count, sender)); + } + + Ok(recv) + } +} + +impl ROTReceiver for IdealROT { + type Error = IdealROTError; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + let mut this = self.inner.lock().unwrap(); + this.receiver_state.alloc += count; + Ok(()) + } + + fn available(&self) -> usize { + let this = self.inner.lock().unwrap(); + this.choices.len() + } + + fn try_recv_rot( &mut self, - choices: Vec, - ) -> (ROTSenderOutput<[T; 2]>, ROTReceiverOutput) - where - Standard: Distribution, - { - let msgs: Vec<[T; 2]> = (0..choices.len()) - .map(|_| [self.prg.sample(Standard), self.prg.sample(Standard)]) - .collect(); - - let chosen = choices - .iter() - .zip(msgs.iter()) - .map(|(&choice, [zero, one])| if choice { *one } else { *zero }) - .collect(); + count: usize, + ) -> Result, Self::Error> { + let mut this = self.inner.lock().unwrap(); + if this.choices.len() < count { + return Err(IdealROTError::new(format!( + "not enough ROTs available: {} < {}", + this.choices.len(), + count + ))); + } - self.counter += choices.len(); - let id = self.transfer_id.next(); + let choices = this.choices.drain(..count).collect(); + let msgs = this.msgs.drain(..count).collect(); + Ok(ROTReceiverOutput { + id: this.receiver_state.transfer_id.next(), + choices, + msgs, + }) + } - ( - ROTSenderOutput { id, msgs }, - ROTReceiverOutput { - id, + fn queue_recv_rot(&mut self, count: usize) -> Result { + let mut this = self.inner.lock().unwrap(); + let (sender, recv) = new_output(); + + if this.choices.len() >= count { + let choices = this.choices.drain(..count).collect(); + let keys = this.msgs.drain(..count).collect(); + sender.send(ROTReceiverOutput { + id: this.receiver_state.transfer_id.next(), choices, - msgs: chosen, - }, - ) + msgs: keys, + }); + } else { + this.receiver_state.queue.push((count, sender)); + } + + Ok(recv) } } @@ -123,6 +256,17 @@ impl Default for IdealROT { } } +/// Error for [`IdealROT`]. +#[derive(Debug, thiserror::Error)] +#[error("ideal ROT error: {0}")] +pub struct IdealROTError(String); + +impl IdealROTError { + fn new(msg: impl Into) -> Self { + IdealROTError(msg.into()) + } +} + #[cfg(test)] mod tests { use crate::test::assert_rot; @@ -131,33 +275,29 @@ mod tests { #[test] fn test_ideal_rot() { - let ( - ROTSenderOutput { msgs, .. }, - ROTReceiverOutput { - choices, - msgs: received, - .. - }, - ) = IdealROT::default().random::(100); + let mut rng = ChaCha8Rng::seed_from_u64(0); + let mut ideal = IdealROT::new(rng.gen()); - assert_rot(&choices, &msgs, &received) - } + let count = 10; - #[test] - fn test_ideal_rot_with_choices() { - let mut rng = ChaCha8Rng::seed_from_u64(0); - let mut choices = vec![false; 100]; - rng.fill(&mut choices[..]); + ROTSender::alloc(&mut ideal, count).unwrap(); + ROTReceiver::alloc(&mut ideal, count).unwrap(); + + ideal.flush().unwrap(); let ( - ROTSenderOutput { msgs, .. }, + ROTSenderOutput { + id: sender_id, + keys, + }, ROTReceiverOutput { + id: receiver_id, choices, - msgs: received, - .. + msgs, }, - ) = IdealROT::default().random_with_choices::(choices); + ) = ideal.transfer(count).unwrap(); - assert_rot(&choices, &msgs, &received) + assert_eq!(sender_id, receiver_id); + assert_rot(&choices, &keys, &msgs); } } diff --git a/crates/mpz-ot-core/src/ideal/spcot.rs b/crates/mpz-ot-core/src/ideal/spcot.rs deleted file mode 100644 index 12c5f829..00000000 --- a/crates/mpz-ot-core/src/ideal/spcot.rs +++ /dev/null @@ -1,104 +0,0 @@ -//! Ideal functionality for single-point correlated OT. - -use mpz_core::{prg::Prg, Block}; - -use crate::{SPCOTReceiverOutput, SPCOTSenderOutput, TransferId}; - -/// The ideal SPCOT functionality. -#[derive(Debug)] -pub struct IdealSpcot { - delta: Block, - transfer_id: TransferId, - counter: usize, - prg: Prg, -} - -impl IdealSpcot { - /// Initiate the functionality. - pub fn new() -> Self { - let mut prg = Prg::new(); - let delta = prg.random_block(); - IdealSpcot { - delta, - transfer_id: TransferId::default(), - counter: 0, - prg, - } - } - - /// Initiate with a given delta - pub fn new_with_delta(delta: Block) -> Self { - let prg = Prg::new(); - IdealSpcot { - delta, - transfer_id: TransferId::default(), - counter: 0, - prg, - } - } - - /// Performs the batch extension of SPCOT. - /// - /// # Argument - /// - /// * `pos` - The positions in each extension. - pub fn extend( - &mut self, - pos: &[(usize, u32)], - ) -> (SPCOTSenderOutput, SPCOTReceiverOutput) { - let mut v = vec![]; - let mut w = vec![]; - - for (n, alpha) in pos { - assert!((*alpha as usize) < *n); - let mut v_tmp = vec![Block::ZERO; *n]; - self.prg.random_blocks(&mut v_tmp); - let mut w_tmp = v_tmp.clone(); - w_tmp[*alpha as usize] ^= self.delta; - - v.push(v_tmp); - w.push(w_tmp); - self.counter += n; - } - - let id = self.transfer_id.next(); - - (SPCOTSenderOutput { id, v }, SPCOTReceiverOutput { id, w }) - } -} - -impl Default for IdealSpcot { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn ideal_spcot_test() { - let mut ideal_spcot = IdealSpcot::new(); - let delta = ideal_spcot.delta; - - let pos = [(10, 2), (20, 3)]; - - let (SPCOTSenderOutput { mut v, .. }, SPCOTReceiverOutput { w, .. }) = - ideal_spcot.extend(&pos); - - v.iter_mut() - .zip(w.iter()) - .zip(pos.iter()) - .for_each(|((v, w), (n, p))| { - assert_eq!(v.len(), *n); - assert_eq!(w.len(), *n); - v[*p as usize] ^= delta; - }); - - assert!(v - .iter() - .zip(w.iter()) - .all(|(v, w)| v.iter().zip(w.iter()).all(|(x, y)| *x == *y))); - } -} diff --git a/crates/mpz-ot-core/src/kos.rs b/crates/mpz-ot-core/src/kos.rs new file mode 100644 index 00000000..46d992d8 --- /dev/null +++ b/crates/mpz-ot-core/src/kos.rs @@ -0,0 +1,337 @@ +//! An implementation of the [`KOS15`](https://eprint.iacr.org/2015/546.pdf) oblivious transfer extension protocol. + +mod config; +mod error; +pub mod msgs; +mod receiver; +mod sender; + +pub use config::{ + ReceiverConfig, ReceiverConfigBuilder, ReceiverConfigBuilderError, SenderConfig, + SenderConfigBuilder, SenderConfigBuilderError, +}; +pub use error::{ReceiverError, SenderError}; +use mpz_core::Block; +pub use receiver::{state as receiver_state, Receiver}; +pub use sender::{state as sender_state, Sender}; +use serde::{Deserialize, Serialize}; + +/// Computational security parameter +pub const CSP: usize = 128; +/// Statistical security parameter +pub const SSP: usize = 128; + +/// Returns the size in bytes of the extension matrix for a given number of OTs. +fn extension_matrix_size(count: usize) -> usize { + count * CSP / 8 +} + +/// Extend message sent from Receiver to Sender. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(try_from = "validation::ExtendUnchecked")] +pub struct Extend { + count: usize, + us: Vec, +} + +/// Check message sent from Receiver to Sender. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Check { + x: Block, + t0: Block, + t1: Block, +} + +mod validation { + use super::*; + + #[derive(Deserialize)] + pub(super) struct ExtendUnchecked { + count: usize, + us: Vec, + } + + impl TryFrom for Extend { + type Error = String; + + fn try_from(value: ExtendUnchecked) -> Result { + let ExtendUnchecked { count, us } = value; + + if us.len() != extension_matrix_size(count) { + return Err("invalid extension matrix size".to_string()); + } + + Ok(Extend { count, us }) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + rcot::{RCOTReceiver, RCOTReceiverOutput, RCOTSender, RCOTSenderOutput}, + test::assert_cot, + }; + + use super::*; + use itybity::ToBits; + use rstest::*; + + use mpz_core::Block; + + use rand::Rng; + use rand_chacha::ChaCha12Rng; + use rand_core::SeedableRng; + + #[fixture] + fn choices() -> Vec { + let mut rng = ChaCha12Rng::seed_from_u64(0); + (0..128).map(|_| rng.gen()).collect() + } + + #[fixture] + fn data() -> Vec<[Block; 2]> { + let mut rng = ChaCha12Rng::seed_from_u64(1); + (0..128) + .map(|_| [rng.gen::<[u8; 16]>().into(), rng.gen::<[u8; 16]>().into()]) + .collect() + } + + #[fixture] + fn delta() -> Block { + let mut rng = ChaCha12Rng::seed_from_u64(2); + rng.gen::<[u8; 16]>().into() + } + + #[fixture] + fn receiver_seeds() -> [[Block; 2]; CSP] { + let mut rng = ChaCha12Rng::seed_from_u64(3); + std::array::from_fn(|_| [rng.gen(), rng.gen()]) + } + + #[fixture] + fn sender_seeds(delta: Block, receiver_seeds: [[Block; 2]; CSP]) -> [Block; CSP] { + delta + .iter_lsb0() + .zip(receiver_seeds) + .map(|(b, seeds)| if b { seeds[1] } else { seeds[0] }) + .collect::>() + .try_into() + .unwrap() + } + + #[fixture] + fn chi_seed() -> Block { + let mut rng = ChaCha12Rng::seed_from_u64(4); + rng.gen::<[u8; 16]>().into() + } + + #[fixture] + fn expected(data: Vec<[Block; 2]>, choices: Vec) -> Vec { + data.iter() + .zip(choices.iter()) + .map(|([a, b], choice)| if *choice { *b } else { *a }) + .collect() + } + + #[rstest] + fn test_kos_extension( + delta: Block, + sender_seeds: [Block; CSP], + receiver_seeds: [[Block; 2]; CSP], + chi_seed: Block, + ) { + let count = 128; + + let sender = Sender::new(SenderConfig::default(), delta); + let receiver = Receiver::new(ReceiverConfig::default()); + + let mut sender = sender.setup(sender_seeds); + let mut receiver = receiver.setup(receiver_seeds); + + sender.alloc(count).unwrap(); + receiver.alloc(count).unwrap(); + + assert!(sender.wants_extend()); + assert!(receiver.wants_extend()); + + while receiver.wants_extend() { + sender.extend(receiver.extend().unwrap()).unwrap(); + } + + assert!(!sender.wants_extend()); + assert!(!receiver.wants_extend()); + assert!(sender.wants_check()); + assert!(receiver.wants_check()); + + let receiver_check = receiver.check(chi_seed).unwrap(); + sender.check(chi_seed, receiver_check).unwrap(); + + assert_eq!(sender.available(), count); + assert_eq!(receiver.available(), count); + + let RCOTSenderOutput { + id: sender_id, + keys, + } = sender.try_send_rcot(count).unwrap(); + let RCOTReceiverOutput { + id: receiver_id, + choices, + msgs, + } = receiver.try_recv_rcot(count).unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &choices, &keys, &msgs); + } + + #[rstest] + fn test_kos_extension_stream_extends( + delta: Block, + sender_seeds: [Block; CSP], + receiver_seeds: [[Block; 2]; CSP], + chi_seed: Block, + ) { + let sender_config = SenderConfig::default(); + let receiver_config = ReceiverConfig::default(); + + let count = sender_config.batch_size() * 3; + + let sender = Sender::new(sender_config, delta); + let receiver = Receiver::new(receiver_config); + + let mut sender = sender.setup(sender_seeds); + let mut receiver = receiver.setup(receiver_seeds); + + sender.alloc(count).unwrap(); + receiver.alloc(count).unwrap(); + + assert!(sender.wants_extend()); + assert!(receiver.wants_extend()); + + while receiver.wants_extend() { + sender.extend(receiver.extend().unwrap()).unwrap(); + } + + assert!(!sender.wants_extend()); + assert!(!receiver.wants_extend()); + assert!(sender.wants_check()); + assert!(receiver.wants_check()); + + let receiver_check = receiver.check(chi_seed).unwrap(); + sender.check(chi_seed, receiver_check).unwrap(); + + assert_eq!(sender.available(), count); + assert_eq!(receiver.available(), count); + + let RCOTSenderOutput { + id: sender_id, + keys, + } = sender.try_send_rcot(count).unwrap(); + let RCOTReceiverOutput { + id: receiver_id, + choices, + msgs, + } = receiver.try_recv_rcot(count).unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &choices, &keys, &msgs); + } + + #[rstest] + fn test_kos_extension_multiple_extends_fail( + delta: Block, + sender_seeds: [Block; CSP], + receiver_seeds: [[Block; 2]; CSP], + chi_seed: Block, + ) { + let count = 128; + + let sender = Sender::new(SenderConfig::default(), delta); + let receiver = Receiver::new(ReceiverConfig::default()); + + let mut sender = sender.setup(sender_seeds); + let mut receiver = receiver.setup(receiver_seeds); + + sender.alloc(count).unwrap(); + receiver.alloc(count).unwrap(); + + while receiver.wants_extend() { + sender.extend(receiver.extend().unwrap()).unwrap(); + } + + let receiver_check = receiver.check(chi_seed).unwrap(); + sender.check(chi_seed, receiver_check).unwrap(); + + assert!(sender.alloc(1).is_err()); + assert!(receiver.alloc(1).is_err()); + assert!(!sender.wants_extend()); + assert!(!receiver.wants_extend()); + assert!(receiver.extend().is_err()); + } + + #[rstest] + fn test_kos_extension_insufficient_setup( + delta: Block, + sender_seeds: [Block; CSP], + receiver_seeds: [[Block; 2]; CSP], + chi_seed: Block, + ) { + let count = 128; + + let sender = Sender::new(SenderConfig::default(), delta); + let receiver = Receiver::new(ReceiverConfig::default()); + + let mut sender = sender.setup(sender_seeds); + let mut receiver = receiver.setup(receiver_seeds); + + sender.alloc(count).unwrap(); + receiver.alloc(count).unwrap(); + + while receiver.wants_extend() { + sender.extend(receiver.extend().unwrap()).unwrap(); + } + + let receiver_check = receiver.check(chi_seed).unwrap(); + sender.check(chi_seed, receiver_check).unwrap(); + + let err = sender.try_send_rcot(count + 1).unwrap_err(); + assert!(matches!(err, SenderError::InsufficientSetup { .. })); + + let err = receiver.try_recv_rcot(count + 1).unwrap_err(); + assert!(matches!(err, ReceiverError::InsufficientSetup { .. })); + } + + #[rstest] + fn test_kos_extension_bad_consistency_check( + delta: Block, + sender_seeds: [Block; CSP], + receiver_seeds: [[Block; 2]; CSP], + chi_seed: Block, + ) { + let count = 128; + + let sender = Sender::new(SenderConfig::default(), delta); + let receiver = Receiver::new(ReceiverConfig::default()); + + let mut sender = sender.setup(sender_seeds); + let mut receiver = receiver.setup(receiver_seeds); + + sender.alloc(count).unwrap(); + receiver.alloc(count).unwrap(); + + while receiver.wants_extend() { + let mut extend = receiver.extend().unwrap(); + + // Flip a bit in the receiver's extension message (breaking the mono-chrome + // choice vector) + *extend.us.first_mut().unwrap() ^= 1; + + sender.extend(extend).unwrap(); + } + + let receiver_check = receiver.check(chi_seed).unwrap(); + let err = sender.check(chi_seed, receiver_check).unwrap_err(); + + assert!(matches!(err, SenderError::ConsistencyCheckFailed)); + } +} diff --git a/crates/mpz-ot-core/src/kos/config.rs b/crates/mpz-ot-core/src/kos/config.rs index aa2b1249..f3c03f2b 100644 --- a/crates/mpz-ot-core/src/kos/config.rs +++ b/crates/mpz-ot-core/src/kos/config.rs @@ -1,18 +1,20 @@ use derive_builder::Builder; +const DEFAULT_BATCH_SIZE: usize = 4096; + /// KOS15 sender configuration. -#[derive(Debug, Default, Clone, Builder)] +#[derive(Debug, Clone, Builder)] pub struct SenderConfig { - /// Enables committed sender functionality. - #[builder(setter(custom), default = "false")] - sender_commit: bool, + /// Batch size for each flush. + #[builder(default = "DEFAULT_BATCH_SIZE")] + batch_size: usize, } -impl SenderConfigBuilder { - /// Enables committed sender functionality. - pub fn sender_commit(&mut self) -> &mut Self { - self.sender_commit = Some(true); - self +impl Default for SenderConfig { + fn default() -> Self { + Self { + batch_size: DEFAULT_BATCH_SIZE, + } } } @@ -22,25 +24,25 @@ impl SenderConfig { SenderConfigBuilder::default() } - /// Enables committed sender functionality. - pub fn sender_commit(&self) -> bool { - self.sender_commit + /// Returns the batch size for each flush. + pub fn batch_size(&self) -> usize { + self.batch_size } } /// KOS15 receiver configuration. -#[derive(Debug, Default, Clone, Builder)] +#[derive(Debug, Clone, Builder)] pub struct ReceiverConfig { - /// Enables committed sender functionality. - #[builder(setter(custom), default = "false")] - sender_commit: bool, + /// Batch size for each flush. + #[builder(default = "DEFAULT_BATCH_SIZE")] + batch_size: usize, } -impl ReceiverConfigBuilder { - /// Enables committed sender functionality. - pub fn sender_commit(&mut self) -> &mut Self { - self.sender_commit = Some(true); - self +impl Default for ReceiverConfig { + fn default() -> Self { + Self { + batch_size: DEFAULT_BATCH_SIZE, + } } } @@ -50,8 +52,8 @@ impl ReceiverConfig { ReceiverConfigBuilder::default() } - /// Enables committed sender functionality. - pub fn sender_commit(&self) -> bool { - self.sender_commit + /// Returns the batch size for each flush. + pub fn batch_size(&self) -> usize { + self.batch_size } } diff --git a/crates/mpz-ot-core/src/kos/error.rs b/crates/mpz-ot-core/src/kos/error.rs index 7acbcd3d..83839717 100644 --- a/crates/mpz-ot-core/src/kos/error.rs +++ b/crates/mpz-ot-core/src/kos/error.rs @@ -8,16 +8,16 @@ pub enum SenderError { InvalidState(String), #[error("invalid count, must be a multiple of 64: {0}")] InvalidCount(usize), - #[error("count mismatch: expected {0}, got {1}")] - CountMismatch(usize, usize), + #[error("count mismatch: expected {expected}, got {actual}")] + CountMismatch { expected: usize, actual: usize }, #[error("id mismatch: expected {0}, got {1}")] IdMismatch(TransferId, TransferId), #[error("invalid extend")] InvalidExtend, #[error("consistency check failed")] ConsistencyCheckFailed, - #[error("not enough OTs are setup: expected {0}, actual {1}")] - InsufficientSetup(usize, usize), + #[error("not enough OTs are setup: expected {expected}, actual {actual}")] + InsufficientSetup { expected: usize, actual: usize }, } /// Errors that can occur when using the KOS15 receiver. @@ -32,22 +32,8 @@ pub enum ReceiverError { CountMismatch(usize, usize), #[error("id mismatch: expected {0}, got {1}")] IdMismatch(TransferId, TransferId), - #[error("not enough OTs are setup: expected {0}, actual {1}")] - InsufficientSetup(usize, usize), + #[error("not enough OTs are setup: expected {expected}, actual {actual}")] + InsufficientSetup { expected: usize, actual: usize }, #[error("invalid payload")] InvalidPayload(String), - #[error(transparent)] - ReceiverVerifyError(#[from] ReceiverVerifyError), -} - -/// Errors that can occur during verification of the sender's messages. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ReceiverVerifyError { - #[error("tape was not recorded")] - TapeNotRecorded, - #[error("invalid transfer id: {0}")] - InvalidTransferId(TransferId), - #[error("payload inconsistent")] - InconsistentPayload, } diff --git a/crates/mpz-ot-core/src/kos/mod.rs b/crates/mpz-ot-core/src/kos/mod.rs deleted file mode 100644 index bf3e2b41..00000000 --- a/crates/mpz-ot-core/src/kos/mod.rs +++ /dev/null @@ -1,385 +0,0 @@ -//! An implementation of the [`KOS15`](https://eprint.iacr.org/2015/546.pdf) oblivious transfer extension protocol. - -mod config; -mod error; -pub mod msgs; -mod receiver; -mod sender; - -pub use config::{ - ReceiverConfig, ReceiverConfigBuilder, ReceiverConfigBuilderError, SenderConfig, - SenderConfigBuilder, SenderConfigBuilderError, -}; -pub use error::{ReceiverError, ReceiverVerifyError, SenderError}; -use rand_chacha::ChaCha20Rng; -use rand_core::SeedableRng; -pub use receiver::{state as receiver_state, PayloadRecord, Receiver, ReceiverKeys}; -pub use sender::{state as sender_state, Sender, SenderKeys}; - -/// Computational security parameter -pub const CSP: usize = 128; -/// Statistical security parameter -pub const SSP: usize = 128; -/// Rng to use for secret sharing the IKNP matrix. -pub(crate) type Rng = ChaCha20Rng; -/// Rng seed type -pub(crate) type RngSeed = ::Seed; - -/// AES-128 CTR used for encryption. -pub(crate) type Aes128Ctr = ctr::Ctr64LE; - -/// Pads the number of OTs to accommodate for the KOS extension check and -/// the extension matrix transpose optimization. -pub fn pad_ot_count(mut count: usize) -> usize { - // Add OTs for the KOS extension check. - count += CSP + SSP; - // Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization). - (count + 63) & !63 -} - -/// Returns the size in bytes of the extension matrix for a given number of OTs. -pub fn extension_matrix_size(count: usize) -> usize { - count * CSP / 8 -} - -#[cfg(test)] -mod tests { - use super::*; - use itybity::ToBits; - use rstest::*; - - use mpz_core::Block; - - use rand::Rng; - use rand_chacha::ChaCha12Rng; - use rand_core::SeedableRng; - - #[fixture] - fn choices() -> Vec { - let mut rng = ChaCha12Rng::seed_from_u64(0); - (0..128).map(|_| rng.gen()).collect() - } - - #[fixture] - fn data() -> Vec<[Block; 2]> { - let mut rng = ChaCha12Rng::seed_from_u64(1); - (0..128) - .map(|_| [rng.gen::<[u8; 16]>().into(), rng.gen::<[u8; 16]>().into()]) - .collect() - } - - #[fixture] - fn delta() -> Block { - let mut rng = ChaCha12Rng::seed_from_u64(2); - rng.gen::<[u8; 16]>().into() - } - - #[fixture] - fn receiver_seeds() -> [[Block; 2]; CSP] { - let mut rng = ChaCha12Rng::seed_from_u64(3); - std::array::from_fn(|_| [rng.gen(), rng.gen()]) - } - - #[fixture] - fn sender_seeds(delta: Block, receiver_seeds: [[Block; 2]; CSP]) -> [Block; CSP] { - delta - .iter_lsb0() - .zip(receiver_seeds) - .map(|(b, seeds)| if b { seeds[1] } else { seeds[0] }) - .collect::>() - .try_into() - .unwrap() - } - - #[fixture] - fn chi_seed() -> Block { - let mut rng = ChaCha12Rng::seed_from_u64(4); - rng.gen::<[u8; 16]>().into() - } - - #[fixture] - fn expected(data: Vec<[Block; 2]>, choices: Vec) -> Vec { - data.iter() - .zip(choices.iter()) - .map(|([a, b], choice)| if *choice { *b } else { *a }) - .collect() - } - - #[rstest] - fn test_kos_extension( - delta: Block, - sender_seeds: [Block; CSP], - receiver_seeds: [[Block; 2]; CSP], - chi_seed: Block, - choices: Vec, - data: Vec<[Block; 2]>, - expected: Vec, - ) { - let sender = Sender::new(SenderConfig::default()); - let receiver = Receiver::new(ReceiverConfig::default()); - - let mut sender = sender.setup(delta, sender_seeds); - let mut receiver = receiver.setup(receiver_seeds); - - let receiver_setup = receiver.extend(choices.len() + 256).unwrap(); - sender.extend(data.len() + 256, receiver_setup).unwrap(); - - let receiver_check = receiver.check(chi_seed).unwrap(); - sender.check(chi_seed, receiver_check).unwrap(); - - let mut receiver_keys = receiver.keys(choices.len()).unwrap(); - let derandomize = receiver_keys.derandomize(&choices).unwrap(); - - let mut sender_keys = sender.keys(data.len()).unwrap(); - sender_keys.derandomize(derandomize).unwrap(); - let payload = sender_keys.encrypt_blocks(&data).unwrap(); - - let received = receiver_keys.decrypt_blocks(payload).unwrap(); - - assert_eq!(received, expected); - } - - #[rstest] - fn test_kos_extension_bytes( - delta: Block, - sender_seeds: [Block; CSP], - receiver_seeds: [[Block; 2]; CSP], - chi_seed: Block, - choices: Vec, - data: Vec<[Block; 2]>, - expected: Vec, - ) { - let sender = Sender::new(SenderConfig::default()); - let receiver = Receiver::new(ReceiverConfig::default()); - - let mut sender = sender.setup(delta, sender_seeds); - let mut receiver = receiver.setup(receiver_seeds); - - let receiver_setup = receiver.extend(choices.len() + 256).unwrap(); - sender.extend(data.len() + 256, receiver_setup).unwrap(); - - let receiver_check = receiver.check(chi_seed).unwrap(); - sender.check(chi_seed, receiver_check).unwrap(); - - let mut receiver_keys = receiver.keys(choices.len()).unwrap(); - let derandomize = receiver_keys.derandomize(&choices).unwrap(); - - let data: Vec<_> = data - .iter() - .map(|[a, b]| [a.to_bytes(), b.to_bytes()]) - .collect(); - - let mut sender_keys = sender.keys(data.len()).unwrap(); - sender_keys.derandomize(derandomize).unwrap(); - let payload = sender_keys.encrypt_bytes(&data).unwrap(); - - let received = receiver_keys.decrypt_bytes::<16>(payload).unwrap(); - - let expected = expected.iter().map(|b| b.to_bytes()).collect::>(); - - assert_eq!(received, expected); - } - - #[rstest] - fn test_kos_extension_stream_extends( - delta: Block, - sender_seeds: [Block; CSP], - receiver_seeds: [[Block; 2]; CSP], - chi_seed: Block, - choices: Vec, - data: Vec<[Block; 2]>, - expected: Vec, - ) { - let sender = Sender::new(SenderConfig::default()); - let receiver = Receiver::new(ReceiverConfig::default()); - - let mut sender = sender.setup(delta, sender_seeds); - let mut receiver = receiver.setup(receiver_seeds); - - let receiver_setup = receiver.extend(choices.len()).unwrap(); - sender.extend(choices.len(), receiver_setup).unwrap(); - - // Extend 256 more - let receiver_setup = receiver.extend(256).unwrap(); - sender.extend(256, receiver_setup).unwrap(); - - let receiver_check = receiver.check(chi_seed).unwrap(); - sender.check(chi_seed, receiver_check).unwrap(); - - let mut receiver_keys = receiver.keys(choices.len()).unwrap(); - let derandomize = receiver_keys.derandomize(&choices).unwrap(); - - let mut sender_keys = sender.keys(data.len()).unwrap(); - sender_keys.derandomize(derandomize).unwrap(); - let payload = sender_keys.encrypt_blocks(&data).unwrap(); - - let received = receiver_keys.decrypt_blocks(payload).unwrap(); - - assert_eq!(received, expected); - } - - #[rstest] - fn test_kos_extension_multiple_extends_fail( - delta: Block, - sender_seeds: [Block; CSP], - receiver_seeds: [[Block; 2]; CSP], - chi_seed: Block, - ) { - let sender = Sender::new(SenderConfig::default()); - let receiver = Receiver::new(ReceiverConfig::default()); - - let mut sender = sender.setup(delta, sender_seeds); - let mut receiver = receiver.setup(receiver_seeds); - - let receiver_setup = receiver.extend(256).unwrap(); - sender.extend(256, receiver_setup).unwrap(); - - // Perform check - let receiver_check = receiver.check(chi_seed).unwrap(); - sender.check(chi_seed, receiver_check).unwrap(); - - // Extending more should fail - let receiver_setup = receiver.extend(256).unwrap_err(); - - assert!(matches!(receiver_setup, ReceiverError::InvalidState(_))); - } - - #[rstest] - fn test_kos_extension_insufficient_setup( - delta: Block, - sender_seeds: [Block; CSP], - receiver_seeds: [[Block; 2]; CSP], - chi_seed: Block, - ) { - let sender = Sender::new(SenderConfig::default()); - let receiver = Receiver::new(ReceiverConfig::default()); - - let mut sender = sender.setup(delta, sender_seeds); - let mut receiver = receiver.setup(receiver_seeds); - - let receiver_setup = receiver.extend(64).unwrap(); - sender.extend(64, receiver_setup).unwrap(); - - // Perform check - let err = receiver.check(chi_seed).unwrap_err(); - - assert!(matches!(err, ReceiverError::InsufficientSetup(_, _))); - } - - #[rstest] - fn test_kos_extension_bad_consistency_check( - delta: Block, - sender_seeds: [Block; CSP], - receiver_seeds: [[Block; 2]; CSP], - chi_seed: Block, - ) { - let sender = Sender::new(SenderConfig::default()); - let receiver = Receiver::new(ReceiverConfig::default()); - - let mut sender = sender.setup(delta, sender_seeds); - let mut receiver = receiver.setup(receiver_seeds); - - let mut receiver_setup = receiver.extend(512).unwrap(); - - // Flip a bit in the receiver's extension message (breaking the mono-chrome choice vector) - *receiver_setup.us.first_mut().unwrap() ^= 1; - - sender.extend(512, receiver_setup).unwrap(); - - let receiver_check = receiver.check(chi_seed).unwrap(); - let err = sender.check(chi_seed, receiver_check).unwrap_err(); - - assert!(matches!(err, SenderError::ConsistencyCheckFailed)); - } - - #[rstest] - fn test_kos_extension_verify_messages( - delta: Block, - sender_seeds: [Block; CSP], - receiver_seeds: [[Block; 2]; CSP], - chi_seed: Block, - choices: Vec, - data: Vec<[Block; 2]>, - expected: Vec, - ) { - let sender = Sender::new(SenderConfig::default()); - let receiver = Receiver::new(ReceiverConfig::builder().sender_commit().build().unwrap()); - - let mut sender = sender.setup(delta, sender_seeds); - let mut receiver = receiver.setup(receiver_seeds); - - let receiver_setup = receiver.extend(choices.len() + 256).unwrap(); - sender.extend(data.len() + 256, receiver_setup).unwrap(); - - let receiver_check = receiver.check(chi_seed).unwrap(); - sender.check(chi_seed, receiver_check).unwrap(); - - let mut receiver_keys = receiver.keys(choices.len()).unwrap(); - let derandomize = receiver_keys.derandomize(&choices).unwrap(); - - let mut sender_keys = sender.keys(data.len()).unwrap(); - sender_keys.derandomize(derandomize).unwrap(); - let payload = sender_keys.encrypt_blocks(&data).unwrap(); - - let id = payload.id; - - let received = receiver_keys.decrypt_blocks(payload).unwrap(); - - assert_eq!(received, expected); - - let receiver = receiver.start_verification(delta).unwrap(); - - receiver.remove_record(id).unwrap().verify(&data).unwrap(); - } - - #[rstest] - fn test_kos_extension_verify_messages_fail( - delta: Block, - sender_seeds: [Block; CSP], - receiver_seeds: [[Block; 2]; CSP], - chi_seed: Block, - choices: Vec, - mut data: Vec<[Block; 2]>, - expected: Vec, - ) { - let sender = Sender::new(SenderConfig::default()); - let receiver = Receiver::new(ReceiverConfig::builder().sender_commit().build().unwrap()); - - let mut sender = sender.setup(delta, sender_seeds); - let mut receiver = receiver.setup(receiver_seeds); - - let receiver_setup = receiver.extend(choices.len() + 256).unwrap(); - sender.extend(data.len() + 256, receiver_setup).unwrap(); - - let receiver_check = receiver.check(chi_seed).unwrap(); - sender.check(chi_seed, receiver_check).unwrap(); - - let mut receiver_keys = receiver.keys(choices.len()).unwrap(); - let derandomize = receiver_keys.derandomize(&choices).unwrap(); - - let mut sender_keys = sender.keys(data.len()).unwrap(); - sender_keys.derandomize(derandomize).unwrap(); - let payload = sender_keys.encrypt_blocks(&data).unwrap(); - - let id = payload.id; - - let received = receiver_keys.decrypt_blocks(payload).unwrap(); - - assert_eq!(received, expected); - - data[0][0] = Block::default(); - - let receiver = receiver.start_verification(delta).unwrap(); - - let err = receiver - .remove_record(id) - .unwrap() - .verify(&data) - .unwrap_err(); - - assert!(matches!( - err, - ReceiverError::ReceiverVerifyError(ReceiverVerifyError::InconsistentPayload) - )); - } -} diff --git a/crates/mpz-ot-core/src/kos/msgs.rs b/crates/mpz-ot-core/src/kos/msgs.rs index 6710f274..1f9e5165 100644 --- a/crates/mpz-ot-core/src/kos/msgs.rs +++ b/crates/mpz-ot-core/src/kos/msgs.rs @@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize}; use crate::TransferId; -/// Extension message sent by the receiver to agree upon the number of OTs to set up. +/// Extension message sent by the receiver to agree upon the number of OTs to +/// set up. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct StartExtend { /// The number of OTs to set up. @@ -49,15 +50,6 @@ impl Iterator for ExtendChunks { } } -/// Values for the correlation check sent by the receiver. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[allow(missing_docs)] -pub struct Check { - pub x: Block, - pub t0: Block, - pub t1: Block, -} - /// Sender payload message. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct SenderPayload { diff --git a/crates/mpz-ot-core/src/kos/receiver.rs b/crates/mpz-ot-core/src/kos/receiver.rs index fdcad328..8a4fc9cf 100644 --- a/crates/mpz-ot-core/src/kos/receiver.rs +++ b/crates/mpz-ot-core/src/kos/receiver.rs @@ -1,39 +1,34 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; +use std::{collections::VecDeque, mem}; use crate::{ - kos::{ - error::ReceiverVerifyError, - msgs::{Check, Ciphertexts, Extend, SenderPayload}, - Aes128Ctr, ReceiverConfig, ReceiverError, Rng, RngSeed, CSP, SSP, - }, - msgs::Derandomize, + kos::{Check, Extend, ReceiverConfig, ReceiverError, CSP, SSP}, + rcot::{RCOTReceiver, RCOTReceiverOutput}, TransferId, }; -use itybity::{FromBitIterator, IntoBits, ToBits}; -use mpz_core::{aes::FIXED_KEY_AES, Block}; +use itybity::{FromBitIterator, IntoBits}; +use mpz_common::future::{new_output, MaybeDone, Sender}; +use mpz_core::{prg::Prg, Block}; -use blake3::Hasher; -use cipher::{KeyIvInit, StreamCipher}; use rand::{thread_rng, Rng as _, SeedableRng}; -use rand_chacha::ChaCha20Rng; use rand_core::RngCore; #[cfg(feature = "rayon")] use rayon::prelude::*; -#[derive(Debug, Default)] -struct Tape { - records: HashMap, +#[derive(Debug)] +struct Queued { + count: usize, + sender: Sender>, } /// KOS15 receiver. #[derive(Debug, Default)] pub struct Receiver { config: ReceiverConfig, + alloc: usize, + transfer_id: TransferId, + queue: VecDeque, state: T, } @@ -54,15 +49,15 @@ impl Receiver { /// /// * `config` - The Receiver's configuration pub fn new(config: ReceiverConfig) -> Self { - let tape = if config.sender_commit() { - Some(Default::default()) - } else { - None - }; - Receiver { config, - state: state::Initialized { tape }, + // We need to extend CSP + SSP OTs for the consistency check. + // Right now we only support one extension, so we just alloc + // them here. + alloc: CSP + SSP, + transfer_id: TransferId::default(), + queue: VecDeque::default(), + state: state::Initialized {}, } } @@ -72,80 +67,48 @@ impl Receiver { /// /// * `seeds` - The receiver's rng seeds pub fn setup(self, seeds: [[Block; 2]; CSP]) -> Receiver { - let rngs = seeds - .iter() - .map(|seeds| { - seeds.map(|seed| { - // Stretch the Block-sized seed to a 32-byte seed. - let mut seed_ = RngSeed::default(); - seed_ - .iter_mut() - .zip(seed.to_bytes().into_iter().cycle()) - .for_each(|(s, c)| *s = c); - Rng::from_seed(seed_) - }) - }) - .collect(); - Receiver { config: self.config, + alloc: self.alloc, + transfer_id: self.transfer_id, + queue: self.queue, state: state::Extension { - rngs, - ts: Vec::default(), - keys: Vec::default(), + rngs: seeds + .into_iter() + .map(|seeds| seeds.map(|seed| Prg::from_seed(seed))) + .collect(), + msgs: Vec::default(), choices: Vec::default(), - index: 0, - transfer_id: TransferId::default(), extended: false, unchecked_ts: Vec::default(), unchecked_choices: Vec::default(), - tape: self.state.tape, }, } } } impl Receiver { - /// Returns the current transfer id. - pub fn current_transfer_id(&self) -> TransferId { - self.state.transfer_id + /// Returns `true` if the receiver wants to extend. + pub fn wants_extend(&self) -> bool { + self.alloc != 0 && !self.state.extended } - /// The number of remaining OTs which can be consumed. - pub fn remaining(&self) -> usize { - self.state.keys.len() + /// Returns `true` if the receiver wants to run the consistency check. + pub fn wants_check(&self) -> bool { + self.alloc == 0 && !self.state.unchecked_ts.is_empty() } /// Perform the IKNP OT extension. - /// - /// The provided count _must_ be a multiple of 64, otherwise an error will be returned. - /// - /// # Sacrificial OTs - /// - /// Performing the consistency check sacrifices 256 OTs, so be sure to - /// extend enough OTs to compensate for this. - /// - /// # Streaming - /// - /// Extension can be performed in a streaming fashion by calling this method multiple times, sending - /// the `Extend` messages to the sender in-between calls. - /// - /// The freshly extended OTs are not available until after the consistency check has been - /// performed. See [`Receiver::check`]. - /// - /// # Arguments - /// - /// * `count` - The number of OTs to extend (must be a multiple of 64). - pub fn extend(&mut self, count: usize) -> Result { + pub fn extend(&mut self) -> Result { if self.state.extended { return Err(ReceiverError::InvalidState( "extending more than once is currently disabled".to_string(), )); } - if count % 64 != 0 { - return Err(ReceiverError::InvalidCount(count)); - } + let count = self.config.batch_size().min(self.alloc); + // round up count to a multiple of 64 + let count = (count + 63) & !63; const NROWS: usize = CSP; let row_width = count / 8; @@ -199,43 +162,33 @@ impl Receiver { .map(|t| Block::try_from(t).unwrap()), ); self.state.unchecked_choices.extend(choices); + self.alloc = self.alloc.saturating_sub(count); - Ok(Extend { us }) + Ok(Extend { count, us }) } /// Performs the correlation check for all outstanding OTS. /// /// See section 3.1 of the paper for more details. /// - /// # Sacrificial OTs - /// - /// Performing this check sacrifices 256 OTs for the consistency check, so be sure to - /// extend enough OTs to compensate for this. - /// /// # ⚠️ Warning ⚠️ /// - /// The provided seed must be unbiased! It should be generated using a secure - /// coin-toss protocol **after** the receiver has sent their setup message, ie - /// after they have already committed to their choice vectors. + /// The provided seed must be unbiased! It should be generated using a + /// secure coin-toss protocol **after** the receiver has sent their + /// setup message, ie after they have already committed to their choice + /// vectors. /// /// # Arguments /// /// * `chi_seed` - The seed used to generate the consistency check weights. pub fn check(&mut self, chi_seed: Block) -> Result { - // Make sure we have enough sacrificial OTs to perform the consistency check. - if self.state.unchecked_ts.len() < CSP + SSP { - return Err(ReceiverError::InsufficientSetup( - CSP + SSP, - self.state.unchecked_ts.len(), + if !self.wants_check() { + return Err(ReceiverError::InvalidState( + "receiver not ready to check".to_string(), )); } - let mut seed = RngSeed::default(); - seed.iter_mut() - .zip(chi_seed.to_bytes().into_iter().cycle()) - .for_each(|(s, c)| *s = c); - - let mut rng = Rng::from_seed(seed); + let mut rng = Prg::from_seed(chi_seed); let mut unchecked_ts = std::mem::take(&mut self.state.unchecked_ts); let mut unchecked_choices = std::mem::take(&mut self.state.unchecked_choices); @@ -284,389 +237,125 @@ impl Receiver { unchecked_ts.truncate(nrows); unchecked_choices.truncate(nrows); - cfg_if::cfg_if! { - if #[cfg(feature = "rayon")] { - let iter = unchecked_ts.par_iter().enumerate(); - } else { - let iter = unchecked_ts.iter().enumerate(); - } - } - - let cipher = &(*FIXED_KEY_AES); - let keys = iter - .map(|(j, t)| { - let j = Block::from(((self.state.index + j) as u128).to_be_bytes()); - cipher.tccr(j, *t) - }) - .collect::>(); - - self.state.index += keys.len(); + // Add to existing msgs. + self.state.msgs.extend_from_slice(&unchecked_ts); + self.state.choices.extend_from_slice(&unchecked_choices); + self.state.extended = true; - // Add to existing keys. - self.state.keys.extend(keys); - self.state.choices.extend(unchecked_choices); + // Resolve any queued transfers. + if !self.queue.is_empty() { + let mut i = 0; + for Queued { count, sender } in mem::take(&mut self.queue) { + let choices = self.state.choices[i..i + count].to_vec(); + let msgs = self.state.msgs[i..i + count].to_vec(); + i += count; + sender.send(RCOTReceiverOutput { + id: self.transfer_id.next(), + choices, + msgs, + }); + } - // If we're recording, we track `ts` too - if self.state.tape.is_some() { - self.state.ts.extend(unchecked_ts); + self.state.choices.drain(..i); + self.state.msgs.drain(..i); } - // Disable any further extensions. - self.state.extended = true; - Ok(Check { x, t0, t1 }) } +} - /// Returns receiver's keys for the given number of OTs. - /// - /// # Arguments - /// - /// * `count` - The number of keys to take. - pub fn keys(&mut self, count: usize) -> Result { - if count > self.state.keys.len() { - return Err(ReceiverError::InsufficientSetup( - count, - self.state.keys.len(), - )); - } +impl RCOTReceiver for Receiver { + type Error = ReceiverError; + type Future = MaybeDone>; - let id = self.state.transfer_id.next(); - let index = self.state.index - self.state.keys.len(); + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.alloc += count; - Ok(ReceiverKeys { - id, - index, - keys: self.state.keys.drain(..count).collect(), - choices: self.state.choices.drain(..count).collect(), - ts: if self.state.tape.is_some() { - Some(self.state.ts.drain(..count).collect()) - } else { - None - }, - tape: self.state.tape.clone(), - }) + Ok(()) } - /// Enters the verification state for verifiable OT. - /// - /// # ⚠️ Warning ⚠️ - /// - /// The authenticity of `delta` must be established outside the context of this function. This - /// can be achieved using verifiable base OT. - /// - /// # Arguments - /// - /// * `delta` - The sender's base OT choice bits. - pub fn start_verification( - mut self, - delta: Block, - ) -> Result, ReceiverError> { - let Some(tape) = self.state.tape.take() else { - return Err(ReceiverVerifyError::TapeNotRecorded)?; - }; - - Ok(Receiver { - config: self.config, - state: state::Verify { tape, delta }, - }) + fn available(&self) -> usize { + 0 } -} -impl Receiver { - /// Returns the [`PayloadRecord`] for the given transfer id if it exists. - /// - /// # Errors - /// - /// Returns an error if the record does not exist. - /// - /// # Arguments - /// - /// * `id` - The transfer id - pub fn remove_record(&self, id: TransferId) -> Result { - let PayloadRecordNoDelta { - index, - choices, - ts, - keys, - ciphertext_digest, - } = self - .state - .tape - .lock() - .unwrap() - .records - .remove(&id) - .ok_or(ReceiverVerifyError::InvalidTransferId(id)) - .map_err(ReceiverError::from)?; - - Ok(PayloadRecord { - index, - choices, - ts, - keys, - delta: self.state.delta, - ciphertext_digest, - }) + fn try_recv_rcot( + &mut self, + _count: usize, + ) -> Result, Self::Error> { + return Err(ReceiverError::InvalidState( + "receiver has not been setup yet".to_string(), + )); } -} -/// KOS receiver's keys for a single transfer. -/// -/// Returned by the [`Receiver::keys`] method, used in cases where the receiver -/// wishes to reserve a set of keys for a transfer, but hasn't yet received the -/// payload. -pub struct ReceiverKeys { - /// Transfer ID - id: TransferId, - /// Start index of the OTs - index: usize, - /// Decryption keys - keys: Vec, - /// The Receiver's choices. If derandomization is performed, these are the overwritten - /// with the derandomized choices. - choices: Vec, - - /// Receiver `ts` - ts: Option>, - /// Receiver tape - tape: Option>>, -} + fn queue_recv_rcot(&mut self, count: usize) -> Result { + let (sender, recv) = new_output(); -opaque_debug::implement!(ReceiverKeys); + self.queue.push_back(Queued { count, sender }); -impl ReceiverKeys { - /// Returns the transfer ID. - pub fn id(&self) -> TransferId { - self.id + return Ok(recv); } +} - /// Derandomizes the receiver's choices. - pub fn derandomize(&mut self, choices: &[bool]) -> Result { - if choices.len() != self.choices.len() { - return Err(ReceiverError::CountMismatch( - self.choices.len(), - choices.len(), - )); - } - - let derandomize = Derandomize { - id: self.id, - count: self.choices.len() as u32, - flip: Vec::::from_lsb0_iter( - self.choices - .iter() - .zip(choices) - .map(|(setup_choice, new_choice)| setup_choice ^ new_choice), - ), - }; - - self.choices.copy_from_slice(choices); - - Ok(derandomize) - } - - /// Decrypts the sender's payload. - pub fn decrypt_blocks(mut self, payload: SenderPayload) -> Result, ReceiverError> { - let SenderPayload { id, ciphertexts } = payload; - - let Ciphertexts::Blocks { ciphertexts } = ciphertexts else { - return Err(ReceiverError::InvalidPayload( - "expected block ciphertexts".to_string(), - )); - }; - - if id != self.id { - return Err(ReceiverError::IdMismatch(self.id, id)); - } +impl RCOTReceiver for Receiver { + type Error = ReceiverError; + type Future = MaybeDone>; - if ciphertexts.len() / 2 != self.keys.len() { - return Err(ReceiverError::CountMismatch( - self.keys.len(), - ciphertexts.len() / 2, + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + if self.state.extended { + return Err(ReceiverError::InvalidState( + "extending more than once is currently disabled".to_string(), )); } - if let Some(tape) = self.tape.take() { - let ts = self.ts.take().expect("ts set if tape is set"); - - let mut hasher = Hasher::default(); - ciphertexts.iter().for_each(|ct| { - hasher.update(&ct.to_bytes()); - }); + self.alloc += count; - tape.lock().unwrap().records.insert( - id, - PayloadRecordNoDelta { - index: self.index, - choices: Vec::from_lsb0_iter(self.choices.iter().copied()), - ts, - keys: self.keys.clone(), - ciphertext_digest: hasher.finalize().into(), - }, - ); - } - - Ok(self - .keys - .into_iter() - .zip(self.choices) - .zip(ciphertexts.chunks(2)) - .map(|((key, c), ct)| if c { key ^ ct[1] } else { key ^ ct[0] }) - .collect()) + Ok(()) } - /// Decrypts the sender's payload. - /// - /// # Verifiable OT - /// - /// Verifiable OT with KOS does not currently support byte payloads, so no record of this payload - /// will be recorded. - pub fn decrypt_bytes( - self, - payload: SenderPayload, - ) -> Result, ReceiverError> { - let SenderPayload { id, ciphertexts } = payload; - - let Ciphertexts::Bytes { - ciphertexts, - iv, - length, - } = ciphertexts - else { - return Err(ReceiverError::InvalidPayload( - "expected byte ciphertexts".to_string(), - )); - }; - - if id != self.id { - return Err(ReceiverError::IdMismatch(self.id, id)); - } - - let length = length as usize; - if length != N { - return Err(ReceiverError::InvalidPayload(format!( - "invalid message length: expected {}, got {}", - N, length - ))); - } + fn available(&self) -> usize { + self.state.msgs.len() + } - if ciphertexts.len() / (2 * length) != self.keys.len() { - return Err(ReceiverError::CountMismatch( - self.keys.len(), - ciphertexts.len() / (2 * length), - )); + fn try_recv_rcot( + &mut self, + count: usize, + ) -> Result, Self::Error> { + if self.available() < count { + return Err(ReceiverError::InsufficientSetup { + expected: count, + actual: self.available(), + }); } - let iv: [u8; 16] = iv - .try_into() - .map_err(|_| ReceiverError::InvalidPayload("invalid iv length".to_string()))?; - - Ok(self - .keys - .into_iter() - .zip(self.choices) - .zip(ciphertexts.chunks(2 * N)) - .map(|((key, c), ct)| { - // Initialize AES-CTR with the key from ROT. - let mut e = Aes128Ctr::new(&key.into(), &iv.into()); - - let mut msg = [0u8; N]; - if c { - msg.copy_from_slice(&ct[N..]) - } else { - msg.copy_from_slice(&ct[..N]) - }; - - e.apply_keystream(&mut msg); - - msg - }) - .collect()) - } + let choices = self.state.choices.drain(..count).collect(); + let keys = self.state.msgs.drain(..count).collect(); - /// Returns the choices and the keys - pub fn take_choices_and_keys(self) -> (Vec, Vec) { - (self.choices, self.keys) + Ok(RCOTReceiverOutput { + id: self.transfer_id.next(), + choices, + msgs: keys, + }) } -} -struct PayloadRecordNoDelta { - /// The starting index for the corresponding OTs. This is used to compute the - /// "tweak" for the randomization. - index: usize, - /// The receiver's choices for the transfer. - choices: Vec, - ts: Vec, - keys: Vec, - ciphertext_digest: [u8; 32], -} + fn queue_recv_rcot(&mut self, count: usize) -> Result { + if self.available() >= count { + let output = self.try_recv_rcot(count)?; + let (sender, recv) = new_output(); + sender.send(output); -opaque_debug::implement!(PayloadRecordNoDelta); - -/// A record of a transfer's payload. -pub struct PayloadRecord { - /// The starting index for the corresponding OTs. This is used to compute the - /// "tweak" for the randomization. - index: usize, - /// The receiver's choices for the transfer. - choices: Vec, - ts: Vec, - keys: Vec, - /// The sender's base OT choice bits. - delta: Block, - ciphertext_digest: [u8; 32], -} - -opaque_debug::implement!(PayloadRecord); + return Ok(recv); + } else if !self.state.extended { + let (sender, recv) = new_output(); -impl PayloadRecord { - /// Checks the purported messages against the record - /// - /// # Arguments - /// - /// * `purported_msgs` - The purported messages sent by the sender. - pub fn verify(self, purported_msgs: &[[Block; 2]]) -> Result<(), ReceiverError> { - let PayloadRecord { - index: counter, - choices, - ts, - keys, - delta, - ciphertext_digest, - } = self; - - // Here we compute the complementary key to the one used earlier in the protocol. - // - // From this, we encrypt the purported messages and check that the ciphertext digests match. - let cipher = &(*FIXED_KEY_AES); - let mut hasher = Hasher::default(); - for (j, (((c, t), key), msgs)) in choices - .iter_lsb0() - .zip(ts) - .zip(keys) - .zip(purported_msgs) - .enumerate() - { - let j = Block::new(((counter + j) as u128).to_be_bytes()); - let key_ = cipher.tccr(j, t ^ delta); - - let (ct0, ct1) = if c { - (msgs[0] ^ key_, msgs[1] ^ key) - } else { - (msgs[0] ^ key, msgs[1] ^ key_) - }; + self.queue.push_back(Queued { count, sender }); - hasher.update(&ct0.to_bytes()); - hasher.update(&ct1.to_bytes()); - } - - let digest: [u8; 32] = hasher.finalize().into(); - - if ciphertext_digest != digest { - return Err(ReceiverVerifyError::InconsistentPayload)?; + return Ok(recv); + } else { + return Err(ReceiverError::InsufficientSetup { + expected: count, + actual: self.available(), + }); } - - Ok(()) } } @@ -679,7 +368,6 @@ pub mod state { impl Sealed for super::Initialized {} impl Sealed for super::Extension {} - impl Sealed for super::Verify {} } /// The receiver's state. @@ -687,10 +375,7 @@ pub mod state { /// The receiver's initial state. #[derive(Default)] - pub struct Initialized { - /// Protocol tape - pub(super) tape: Option>>, - } + pub struct Initialized {} impl State for Initialized {} @@ -698,21 +383,11 @@ pub mod state { /// The receiver's state after the setup phase. /// - /// In this state the receiver performs OT extension (potentially multiple times). Also in this - /// state the receiver sends OT requests. + /// In this state the receiver performs OT extension (potentially multiple + /// times). Also in this state the receiver sends OT requests. pub struct Extension { /// Receiver's rngs - pub(super) rngs: Vec<[ChaCha20Rng; 2]>, - /// Receiver's ts - pub(super) ts: Vec, - /// Receiver's keys - pub(super) keys: Vec, - /// Receiver's random choices - pub(super) choices: Vec, - /// Current OT index - pub(super) index: usize, - /// Current transfer id - pub(super) transfer_id: TransferId, + pub(super) rngs: Vec<[Prg; 2]>, /// Whether extension has occurred yet /// @@ -724,23 +399,13 @@ pub mod state { /// Receiver's unchecked choices pub(super) unchecked_choices: Vec, - /// Protocol tape - pub(super) tape: Option>>, + /// Receiver's chosen messages. + pub(super) msgs: Vec, + /// Receiver's random choices. + pub(super) choices: Vec, } impl State for Extension {} opaque_debug::implement!(Extension); - - /// The receiver's state after receiving the sender's base OT choice bits, a.k.a delta. - pub struct Verify { - /// Protocol tape - pub(super) tape: Arc>, - /// The sender's base OT choice bits. - pub(super) delta: Block, - } - - impl State for Verify {} - - opaque_debug::implement!(Verify); } diff --git a/crates/mpz-ot-core/src/kos/sender.rs b/crates/mpz-ot-core/src/kos/sender.rs index 24917940..4335c6ff 100644 --- a/crates/mpz-ot-core/src/kos/sender.rs +++ b/crates/mpz-ot-core/src/kos/sender.rs @@ -1,19 +1,15 @@ +use std::{collections::VecDeque, mem}; + use crate::{ - kos::{ - extension_matrix_size, - msgs::{Check, Ciphertexts, Extend, SenderPayload}, - Aes128Ctr, Rng, RngSeed, SenderConfig, SenderError, CSP, SSP, - }, - msgs::Derandomize, + kos::{Check, Extend, SenderConfig, SenderError, CSP, SSP}, + rcot::{RCOTSender, RCOTSenderOutput}, TransferId, }; -use cipher::{KeyIvInit, StreamCipher}; -use itybity::ToBits; -use mpz_core::{aes::FIXED_KEY_AES, Block}; +use mpz_common::future::{new_output, MaybeDone, Sender as OutputSender}; +use mpz_core::{prg::Prg, Block}; use rand::{Rng as _, SeedableRng}; -use rand_chacha::ChaCha20Rng; use rand_core::RngCore; cfg_if::cfg_if! { @@ -25,10 +21,20 @@ cfg_if::cfg_if! { } } +#[derive(Debug)] +struct Queued { + count: usize, + sender: OutputSender>, +} + /// KOS15 sender. -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Sender { config: SenderConfig, + alloc: usize, + queue: VecDeque, + transfer_id: TransferId, + delta: Block, state: T, } @@ -42,15 +48,24 @@ where } } -impl Sender { +impl Sender { /// Creates a new Sender /// /// # Arguments /// - /// * `config` - The Sender's configuration - pub fn new(config: SenderConfig) -> Self { + /// * `config` - Sender's configuration. + /// * `delta` - Global COT correlation. + /// * `base_ot` - Base OT. + pub fn new(config: SenderConfig, delta: Block) -> Self { Sender { config, + // We need to extend CSP + SSP OTs for the consistency check. + // Right now we only support one extension, so we just alloc + // them here. + alloc: CSP + SSP, + transfer_id: TransferId::default(), + queue: VecDeque::default(), + delta, state: state::Initialized::default(), } } @@ -61,28 +76,16 @@ impl Sender { /// /// * `delta` - The sender's base OT choice bits /// * `seeds` - The rng seeds chosen during base OT - pub fn setup(self, delta: Block, seeds: [Block; CSP]) -> Sender { - let rngs = seeds - .iter() - .map(|seed| { - // Stretch the Block-sized seed to a 32-byte seed. - let mut seed_ = RngSeed::default(); - seed_ - .iter_mut() - .zip(seed.to_bytes().into_iter().cycle()) - .for_each(|(s, c)| *s = c); - Rng::from_seed(seed_) - }) - .collect(); - + pub fn setup(self, seeds: [Block; CSP]) -> Sender { Sender { config: self.config, + alloc: self.alloc, + transfer_id: self.transfer_id, + queue: self.queue, + delta: self.delta, state: state::Extension { - delta, - rngs, + rngs: seeds.into_iter().map(|seed| Prg::from_seed(seed)).collect(), keys: Vec::default(), - transfer_id: TransferId::default(), - counter: 0, extended: false, unchecked_qs: Vec::default(), }, @@ -91,62 +94,53 @@ impl Sender { } impl Sender { - /// The number of remaining OTs which can be consumed. - pub fn remaining(&self) -> usize { - self.state.keys.len() + /// Returns `true` if the sender wants to extend. + pub fn wants_extend(&self) -> bool { + self.alloc != 0 + } + + /// Returns `true` if the sender wants to run the consistency check. + pub fn wants_check(&self) -> bool { + self.alloc == 0 && !self.state.unchecked_qs.is_empty() } /// Perform the IKNP OT extension. /// - /// The provided count _must_ be a multiple of 64, otherwise an error will be returned. - /// - /// # Sacrificial OTs - /// - /// Performing the consistency check sacrifices 256 OTs, so be sure to extend enough to - /// compensate for this. - /// - /// # Streaming - /// - /// Extension can be performed in a streaming fashion by processing an extension in batches via - /// multiple calls to this method. - /// - /// The freshly extended OTs are not available until after the consistency check has been - /// performed. See [`Sender::check`]. - /// /// # Arguments /// - /// * `count` - The number of additional OTs to extend (must be a multiple of 64). - /// * `extend` - The receiver's setup message. - pub fn extend(&mut self, count: usize, extend: Extend) -> Result<(), SenderError> { + /// * `extend` - Extend message from the receiver. + pub fn extend(&mut self, extend: Extend) -> Result<(), SenderError> { if self.state.extended { return Err(SenderError::InvalidState( "extending more than once is currently disabled".to_string(), )); } - if count % 64 != 0 { - return Err(SenderError::InvalidCount(count)); + let Extend { count, us } = extend; + + let expected_count = self.config.batch_size().min(self.alloc); + // round up count to a multiple of 64 + let expected_count = (expected_count + 63) & !63; + if count != expected_count { + return Err(SenderError::CountMismatch { + expected: expected_count, + actual: count, + }); } const NROWS: usize = CSP; let row_width = count / 8; - let Extend { us } = extend; - - if us.len() != extension_matrix_size(count) { - return Err(SenderError::InvalidExtend); - } - let mut qs = vec![0u8; NROWS * row_width]; cfg_if::cfg_if! { if #[cfg(feature = "rayon")] { - let iter = self.state.delta + let iter = self.delta .par_iter_lsb0() .zip(self.state.rngs.par_iter_mut()) .zip(qs.par_chunks_exact_mut(row_width)) .zip(us.par_chunks_exact(row_width)); } else { - let iter = self.state.delta + let iter = self.delta .iter_lsb0() .zip(self.state.rngs.iter_mut()) .zip(qs.chunks_exact_mut(row_width)) @@ -159,7 +153,8 @@ impl Sender { iter.for_each(|(((b, rng), q), u)| { // Reuse `q` to avoid memory allocation for tⁱ_∆ᵢ rng.fill_bytes(q); - // If `b` (i.e. ∆ᵢ) is true, xor `u` into `q`, otherwise xor 0 into `q` (constant time). + // If `b` (i.e. ∆ᵢ) is true, xor `u` into `q`, otherwise xor 0 into `q` + // (constant time). let u = if b { u } else { &zero }; q.iter_mut().zip(u).for_each(|(q, u)| *q ^= u); }); @@ -173,6 +168,7 @@ impl Sender { let q: Block = q.try_into().unwrap(); q })); + self.alloc = self.alloc.saturating_sub(count); Ok(()) } @@ -181,36 +177,31 @@ impl Sender { /// /// See section 3.1 of the paper for more details. /// - /// # Sacrificial OTs - /// - /// Performing this check sacrifices 256 OTs for the consistency check, so be sure to - /// extend enough OTs to compensate for this. - /// /// # ⚠️ Warning ⚠️ /// - /// The provided seed must be unbiased! It should be generated using a secure - /// coin-toss protocol **after** the receiver has sent their extension message, ie - /// after they have already committed to their choice vectors. + /// The provided seed must be unbiased! It should be generated using a + /// secure coin-toss protocol **after** the receiver has sent their + /// extension message, ie after they have already committed to their + /// choice vectors. /// /// # Arguments /// /// * `chi_seed` - The seed used to generate the consistency check weights. /// * `receiver_check` - The receiver's consistency check message. pub fn check(&mut self, chi_seed: Block, receiver_check: Check) -> Result<(), SenderError> { + if !self.wants_check() { + return Err(SenderError::InvalidState("not ready to check".to_string())); + } + // Make sure we have enough sacrificial OTs to perform the consistency check. if self.state.unchecked_qs.len() < CSP + SSP { - return Err(SenderError::InsufficientSetup( - CSP + SSP, - self.state.unchecked_qs.len(), - )); + return Err(SenderError::InsufficientSetup { + expected: CSP + SSP, + actual: self.state.unchecked_qs.len(), + }); } - let mut seed = RngSeed::default(); - seed.iter_mut() - .zip(chi_seed.to_bytes().into_iter().cycle()) - .for_each(|(s, c)| *s = c); - - let mut rng = Rng::from_seed(seed); + let mut rng = Prg::from_seed(chi_seed); let mut unchecked_qs = std::mem::take(&mut self.state.unchecked_qs); @@ -242,7 +233,7 @@ impl Sender { } let Check { x, t0, t1 } = receiver_check; - let tmp = x.clmul(self.state.delta); + let tmp = x.clmul(self.delta); let check = (check.0 ^ tmp.0, check.1 ^ tmp.1); // The Receiver is malicious. @@ -256,198 +247,120 @@ impl Sender { let nrows = unchecked_qs.len() - (CSP + SSP); unchecked_qs.truncate(nrows); - // Figure 7, "Randomization" - cfg_if::cfg_if! { - if #[cfg(feature = "rayon")] { - let iter = unchecked_qs.into_par_iter().enumerate(); - } else { - let iter = unchecked_qs.into_iter().enumerate(); + self.state.keys.extend_from_slice(&unchecked_qs); + self.state.extended = true; + + // Resolve any queued transfers. + if !self.queue.is_empty() { + let mut i = 0; + for Queued { count, sender } in mem::take(&mut self.queue) { + let keys = self.state.keys[i..i + count].to_vec(); + i += count; + sender.send(RCOTSenderOutput { + id: self.transfer_id.next(), + keys, + }); } - } - let cipher = &(*FIXED_KEY_AES); - let keys = iter - .map(|(j, q)| { - let j = Block::new(((self.state.counter + j) as u128).to_be_bytes()); + self.state.keys.drain(..i); + } - let k0 = cipher.tccr(j, q); - let k1 = cipher.tccr(j, q ^ self.state.delta); + Ok(()) + } +} - [k0, k1] - }) - .collect::>(); +impl RCOTSender for Sender { + type Error = SenderError; + type Future = MaybeDone>; - self.state.counter += keys.len(); - self.state.keys.extend(keys); - self.state.extended = true; + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.alloc += count; Ok(()) } - /// Reserves a set of keys which can be used to encrypt a payload later. - /// - /// # Arguments - /// - /// * `count` - The number of keys to reserve. - pub fn keys(&mut self, count: usize) -> Result { - if count > self.state.keys.len() { - return Err(SenderError::InsufficientSetup(count, self.state.keys.len())); - } + fn available(&self) -> usize { + 0 + } - let id = self.state.transfer_id.next(); + fn delta(&self) -> Block { + self.delta + } - Ok(SenderKeys { - id, - keys: self.state.keys.drain(..count).collect(), - derandomize: None, - }) + fn try_send_rcot(&mut self, _count: usize) -> Result, Self::Error> { + return Err(SenderError::InvalidState( + "sender has not been setup yet".to_string(), + )); } -} -/// KOS sender's keys for a single transfer. -/// -/// Returned by the [`Sender::keys`] method, used in cases where the sender -/// wishes to reserve a set of keys for use later, while still being able to process -/// other payloads. -pub struct SenderKeys { - /// Transfer ID - id: TransferId, - /// Encryption keys - keys: Vec<[Block; 2]>, - /// Derandomization - derandomize: Option, -} + fn queue_send_rcot(&mut self, count: usize) -> Result { + let (sender, recv) = new_output(); + + self.queue.push_back(Queued { count, sender }); -impl SenderKeys { - /// Returns the transfer ID. - pub fn id(&self) -> TransferId { - self.id + return Ok(recv); } +} - /// Applies Beaver derandomization to correct the receiver's choices made during extension. - pub fn derandomize(&mut self, derandomize: Derandomize) -> Result<(), SenderError> { - if derandomize.id != self.id { - return Err(SenderError::IdMismatch(self.id, derandomize.id)); - } +impl RCOTSender for Sender { + type Error = SenderError; + type Future = MaybeDone>; - if derandomize.count as usize != self.keys.len() { - return Err(SenderError::CountMismatch( - self.keys.len(), - derandomize.count as usize, + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + if self.state.extended { + return Err(SenderError::InvalidState( + "extending more than once is currently disabled".to_string(), )); } - self.derandomize = Some(derandomize); + self.alloc += count; Ok(()) } - /// Encrypts the provided messages using the keys. - /// - /// # Arguments - /// - /// * `msgs` - The messages to encrypt - pub fn encrypt_blocks(self, msgs: &[[Block; 2]]) -> Result { - if msgs.len() != self.keys.len() { - return Err(SenderError::InsufficientSetup(msgs.len(), self.keys.len())); - } + fn available(&self) -> usize { + self.state.keys.len() + } - // If we have derandomization, use it to correct the receiver's choices, else we use - // default - let flip = self - .derandomize - .map(|x| x.flip) - .unwrap_or_else(|| vec![0; self.keys.len() / 8 + 1]); - - // Encrypt the chosen messages using the generated keys from ROT. - let ciphertexts = self - .keys - .into_iter() - .zip(msgs) - .zip(flip.iter_lsb0()) - .flat_map(|(([k0, k1], [m0, m1]), flip)| { - // Use Beaver derandomization to correct the receiver's choices - // from the extension phase. - if flip { - [k1 ^ *m0, k0 ^ *m1] - } else { - [k0 ^ *m0, k1 ^ *m1] - } - }) - .collect(); - - Ok(SenderPayload { - id: self.id, - ciphertexts: Ciphertexts::Blocks { ciphertexts }, - }) + fn delta(&self) -> Block { + self.delta } - /// Encrypts the provided messages using the keys. - /// - /// # Arguments - /// - /// * `msgs` - The messages to encrypt - pub fn encrypt_bytes( - self, - msgs: &[[[u8; N]; 2]], - ) -> Result { - if msgs.len() != self.keys.len() { - return Err(SenderError::InsufficientSetup(msgs.len(), self.keys.len())); + fn try_send_rcot(&mut self, count: usize) -> Result, Self::Error> { + if self.available() < count { + return Err(SenderError::InsufficientSetup { + expected: count, + actual: self.available(), + }); } - // Generate a random IV which is used for all messages. - // This is safe because every message is encrypted with a different key. - let iv: [u8; 16] = rand::thread_rng().gen(); - - // If we have derandomization, use it to correct the receiver's choices, else we use - // default - let flip = self - .derandomize - .map(|x| x.flip) - .unwrap_or_else(|| vec![0; self.keys.len() / 8 + 1]); - - // Encrypt the chosen messages using the generated keys from ROT. - let ciphertexts = self - .keys - .into_iter() - .zip(msgs) - .zip(flip.iter_lsb0()) - .flat_map(|(([k0, k1], [m0, m1]), flip)| { - // Initialize AES-CTR with the keys from ROT. - let mut e0 = Aes128Ctr::new(&k0.into(), &iv.into()); - let mut e1 = Aes128Ctr::new(&k1.into(), &iv.into()); - - let mut m0 = *m0; - let mut m1 = *m1; - - // Use Beaver derandomization to correct the receiver's choices - // from the extension phase. - if flip { - e1.apply_keystream(&mut m0); - e0.apply_keystream(&mut m1); - } else { - e0.apply_keystream(&mut m0); - e1.apply_keystream(&mut m1); - } - - [m0, m1] - }) - .flatten() - .collect(); - - Ok(SenderPayload { - id: self.id, - ciphertexts: Ciphertexts::Bytes { - ciphertexts, - iv: iv.to_vec(), - length: N as u32, - }, + let keys = self.state.keys.drain(..count).collect(); + + Ok(RCOTSenderOutput { + id: self.transfer_id.next(), + keys, }) } - /// Returns the keys - pub fn take_keys(self) -> Vec<[Block; 2]> { - self.keys + fn queue_send_rcot(&mut self, count: usize) -> Result { + if self.available() >= count { + let output = self.try_send_rcot(count)?; + let (sender, recv) = new_output(); + sender.send(output); + + return Ok(recv); + } else if !self.state.extended { + let (sender, recv) = new_output(); + + self.queue.push_back(Queued { count, sender }); + + return Ok(recv); + } else { + return Err(SenderError::InsufficientSetup { + expected: count, + actual: self.available(), + }); + } } } @@ -475,28 +388,19 @@ pub mod state { /// The sender's state after the setup phase. /// - /// In this state the sender performs OT extension (potentially multiple times). Also in this - /// state the sender responds to OT requests. + /// In this state the sender performs OT extension (potentially multiple + /// times). Also in this state the sender responds to OT requests. pub struct Extension { - /// Sender's base OT choices - pub(super) delta: Block, /// Receiver's rngs seeded from seeds obliviously received from base OT - pub(super) rngs: Vec, - /// Sender's keys - pub(super) keys: Vec<[Block; 2]>, - - /// Current transfer id - pub(super) transfer_id: TransferId, - /// Current OT counter - pub(super) counter: usize, - + pub(super) rngs: Vec, /// Whether extension has occurred yet /// /// This is to prevent the receiver from extending twice pub(super) extended: bool, - /// Sender's unchecked qs pub(super) unchecked_qs: Vec, + /// Sender's keys + pub(super) keys: Vec, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/lib.rs b/crates/mpz-ot-core/src/lib.rs index 8dd77287..98c0c51c 100644 --- a/crates/mpz-ot-core/src/lib.rs +++ b/crates/mpz-ot-core/src/lib.rs @@ -1,12 +1,14 @@ -//! Low-level crate containing core functionalities for oblivious transfer protocols. +//! Low-level crate containing core functionalities for oblivious transfer +//! protocols. //! -//! This crate is not intended to be used directly. Instead, use the higher-level APIs provided by -//! the `mpz-ot` crate. +//! This crate is not intended to be used directly. Instead, use the +//! higher-level APIs provided by the `mpz-ot` crate. //! //! # ⚠️ Warning ⚠️ //! -//! Some implementations make assumptions about invariants which may not be checked if using these -//! low-level APIs naively. Failing to uphold these invariants may result in security vulnerabilities. +//! Some implementations make assumptions about invariants which may not be +//! checked if using these low-level APIs naively. Failing to uphold these +//! invariants may result in security vulnerabilities. //! //! USE AT YOUR OWN RISK. @@ -19,13 +21,16 @@ clippy::all )] +use mpz_core::bitvec::BitVec; use serde::{Deserialize, Serialize}; pub mod chou_orlandi; -pub mod ferret; +pub mod cot; pub mod ideal; pub mod kos; -pub mod msgs; +pub mod ot; +pub mod rcot; +pub mod rot; #[cfg(any(test, feature = "test-utils"))] pub mod test; @@ -44,6 +49,10 @@ impl std::fmt::Display for TransferId { } impl TransferId { + pub(crate) fn as_u64(&self) -> u64 { + self.0 + } + /// Returns the current transfer ID, incrementing `self` in-place. pub(crate) fn next(&mut self) -> Self { let id = *self; @@ -52,112 +61,10 @@ impl TransferId { } } -/// The output the sender receives from the COT functionality. -#[derive(Debug)] -pub struct COTSenderOutput { - /// The transfer id. - pub id: TransferId, - /// The `0-bit` messages. - pub msgs: Vec, -} - -/// The output the receiver receives from the COT functionality. -#[derive(Debug)] -pub struct COTReceiverOutput { - /// The transfer id. - pub id: TransferId, - /// The chosen messages. - pub msgs: Vec, -} - -/// The output the sender receives from the random COT functionality. -#[derive(Debug)] -pub struct RCOTSenderOutput { - /// The transfer id. - pub id: TransferId, - /// The `0-bit` messages. - pub msgs: Vec, -} - -/// The output the receiver receives from the random COT functionality. -#[derive(Debug)] -pub struct RCOTReceiverOutput { - /// The transfer id. - pub id: TransferId, - /// The choice bits. - pub choices: Vec, - /// The chosen messages. - pub msgs: Vec, -} - -/// The output the sender receives from the ROT functionality. -#[derive(Debug)] -pub struct ROTSenderOutput { - /// The transfer id. - pub id: TransferId, - /// The random messages. - pub msgs: Vec, -} - -/// The output the receiver receives from the ROT functionality. -#[derive(Debug)] -pub struct ROTReceiverOutput { - /// The transfer id. - pub id: TransferId, - /// The choice bits. - pub choices: Vec, - /// The chosen messages. - pub msgs: Vec, -} - -/// The output the sender receives from the OT functionality. -#[derive(Debug)] -pub struct OTSenderOutput { - /// The transfer id. - pub id: TransferId, -} - -/// The output the receiver receives from the OT functionality. -#[derive(Debug)] -pub struct OTReceiverOutput { - /// The transfer id. - pub id: TransferId, - /// The chosen messages. - pub msgs: Vec, -} - -/// The output that sender receives from the SPCOT functionality. -#[derive(Debug)] -pub struct SPCOTSenderOutput { - /// The transfer id. - pub id: TransferId, - /// The random blocks that sender receives from the SPCOT functionality. - pub v: Vec>, -} - -/// The output that receiver receives from the SPCOT functionality. -#[derive(Debug)] -pub struct SPCOTReceiverOutput { - /// The transfer id. - pub id: TransferId, - /// The random blocks that receiver receives from the SPCOT functionality. - pub w: Vec>, -} - -/// The output that sender receives from the MPCOT functionality. -#[derive(Debug)] -pub struct MPCOTSenderOutput { - /// The transfer id. - pub id: TransferId, - /// The random blocks that sender receives from the MPCOT functionality. - pub s: Vec, -} - -/// The output that receiver receives from the MPCOT functionality. +/// A message sent by the receiver which a sender can use to perform +/// Beaver derandomization. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct MPCOTReceiverOutput { - /// The transfer id. - pub id: TransferId, - /// The random blocks that receiver receives from the MPCOT functionality. - pub r: Vec, +pub struct Derandomize { + /// Correction bits + pub flip: BitVec, } diff --git a/crates/mpz-ot-core/src/msgs.rs b/crates/mpz-ot-core/src/msgs.rs deleted file mode 100644 index 809443a3..00000000 --- a/crates/mpz-ot-core/src/msgs.rs +++ /dev/null @@ -1,69 +0,0 @@ -//! General OT message types - -use serde::{Deserialize, Serialize}; - -use crate::TransferId; - -/// A message sent by the receiver which a sender can use to perform -/// Beaver derandomization. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(try_from = "UncheckedDerandomize")] -pub struct Derandomize { - /// Transfer ID - pub id: TransferId, - /// The number of choices to derandomize. - pub count: u32, - /// Correction bits - pub flip: Vec, -} - -#[derive(Debug, Deserialize)] -struct UncheckedDerandomize { - id: TransferId, - count: u32, - flip: Vec, -} - -impl TryFrom for Derandomize { - type Error = std::io::Error; - - fn try_from(value: UncheckedDerandomize) -> Result { - // Divide by 8, rounding up - let expected_len = (value.count as usize + 7) / 8; - - if value.flip.len() != expected_len { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "flip length does not match count", - )); - } - - Ok(Derandomize { - id: value.id, - count: value.count, - flip: value.flip, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_unchecked_derandomize() { - assert!(Derandomize::try_from(UncheckedDerandomize { - id: TransferId::default(), - count: 0, - flip: vec![], - }) - .is_ok()); - - assert!(Derandomize::try_from(UncheckedDerandomize { - id: TransferId::default(), - count: 9, - flip: vec![0], - }) - .is_err()); - } -} diff --git a/crates/mpz-ot-core/src/ot.rs b/crates/mpz-ot-core/src/ot.rs new file mode 100644 index 00000000..45fe4213 --- /dev/null +++ b/crates/mpz-ot-core/src/ot.rs @@ -0,0 +1,57 @@ +//! Chosen-message oblivious transfer. + +use mpz_common::future::Output; + +use crate::TransferId; + +/// Output the sender receives from the OT functionality. +#[derive(Debug)] +pub struct OTSenderOutput { + /// Transfer id. + pub id: TransferId, +} + +/// Oblivious transfer sender. +pub trait OTSender { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + /// Future type. + type Future: Output; + + /// Allocates `count` OTs for preprocessing. + fn alloc(&mut self, count: usize) -> Result<(), Self::Error>; + + /// Queues sending of OTs. + /// + /// # Arguments + /// + /// * `msgs` - Messages to send. + fn queue_send_ot(&mut self, msgs: &[[T; 2]]) -> Result; +} + +/// Output the receiver receives from the OT functionality. +#[derive(Debug)] +pub struct OTReceiverOutput { + /// Transfer id. + pub id: TransferId, + /// Chosen messages. + pub msgs: Vec, +} + +/// Oblivious transfer receiver. +pub trait OTReceiver { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + /// Future type. + type Future: Output>; + + /// Allocates `count` OTs for preprocessing. + fn alloc(&mut self, count: usize) -> Result<(), Self::Error>; + + /// Queues receiving of OTs. + /// + /// # Arguments + /// + /// * `choices` - OT choices. + fn queue_recv_ot(&mut self, choices: &[T]) -> Result; +} diff --git a/crates/mpz-ot-core/src/rcot.rs b/crates/mpz-ot-core/src/rcot.rs new file mode 100644 index 00000000..e7137388 --- /dev/null +++ b/crates/mpz-ot-core/src/rcot.rs @@ -0,0 +1,84 @@ +//! Random correlated oblivious transfer. + +use mpz_common::future::Output; + +use crate::TransferId; + +/// Output the sender receives from the random COT functionality. +#[derive(Debug)] +pub struct RCOTSenderOutput { + /// Transfer id. + pub id: TransferId, + /// Random keys. + pub keys: Vec, +} + +/// Random correlated oblivious transfer sender. +pub trait RCOTSender { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + /// Future type. + type Future: Output>; + + /// Allocates `count` RCOTs for preprocessing. + fn alloc(&mut self, count: usize) -> Result<(), Self::Error>; + + /// Returns the number of available RCOTs. + fn available(&self) -> usize; + + /// Returns the global correlation key, `delta`. + fn delta(&self) -> T; + + /// Returns preprocessed RCOTs, if available. + /// + /// # Arguments + /// + /// * `count` - Number of preprocessed RCOTs to try to consume. + fn try_send_rcot(&mut self, count: usize) -> Result, Self::Error>; + + /// Queues `count` RCOTs for sending. + /// + /// # Arguments + /// + /// * `count` - Number of RCOTs to queue for sending. + fn queue_send_rcot(&mut self, count: usize) -> Result; +} + +/// Output the receiver receives from the random COT functionality. +#[derive(Debug)] +pub struct RCOTReceiverOutput { + /// Transfer id. + pub id: TransferId, + /// Choice bits. + pub choices: Vec, + /// Chosen messages. + pub msgs: Vec, +} + +/// Random correlated oblivious transfer receiver. +pub trait RCOTReceiver { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + /// Future type. + type Future: Output>; + + /// Allocates `count` RCOTs for preprocessing. + fn alloc(&mut self, count: usize) -> Result<(), Self::Error>; + + /// Returns the number of available RCOTs. + fn available(&self) -> usize; + + /// Returns preprocessed RCOTs, if available. + /// + /// # Arguments + /// + /// * `count` - Number of preprocessed RCOTs to try to consume. + fn try_recv_rcot(&mut self, count: usize) -> Result, Self::Error>; + + /// Queues `count` RCOTs for receiving. + /// + /// # Arguments + /// + /// * `count` - Number of RCOTs to queue for receiving. + fn queue_recv_rcot(&mut self, count: usize) -> Result; +} diff --git a/crates/mpz-ot-core/src/rot.rs b/crates/mpz-ot-core/src/rot.rs new file mode 100644 index 00000000..f5128da3 --- /dev/null +++ b/crates/mpz-ot-core/src/rot.rs @@ -0,0 +1,87 @@ +//! Random oblivious transfer. + +mod any; +mod randomize; + +pub use any::{AnyReceiver, AnySender}; +pub use randomize::{RandomizeRCOTReceiver, RandomizeRCOTSender}; + +use mpz_common::future::Output; + +use crate::TransferId; + +/// Output the sender receives from the ROT functionality. +#[derive(Debug)] +pub struct ROTSenderOutput { + /// Transfer id. + pub id: TransferId, + /// Random keys. + pub keys: Vec, +} + +/// Random oblivious transfer sender. +pub trait ROTSender { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + /// Future type. + type Future: Output>; + + /// Allocates `count` ROTs for preprocessing. + fn alloc(&mut self, count: usize) -> Result<(), Self::Error>; + + /// Returns the number of available ROTs. + fn available(&self) -> usize; + + /// Returns preprocessed ROTs, if available. + /// + /// # Arguments + /// + /// * `count` - Number of preprocessed ROTs to try to consume. + fn try_send_rot(&mut self, count: usize) -> Result, Self::Error>; + + /// Queues sending of ROTs. + /// + /// # Arguments + /// + /// * `count` - Number of ROTs to send. + fn queue_send_rot(&mut self, count: usize) -> Result; +} + +/// Output the receiver receives from the ROT functionality. +#[derive(Debug)] +pub struct ROTReceiverOutput { + /// Transfer id. + pub id: TransferId, + /// Random choices. + pub choices: Vec, + /// Chosen msgs. + pub msgs: Vec, +} + +/// Random oblivious transfer receiver. +pub trait ROTReceiver { + /// Error type. + type Error: std::error::Error + Send + Sync + 'static; + /// Future type. + type Future: Output>; + + /// Allocates `count` ROTs for preprocessing. + fn alloc(&mut self, count: usize) -> Result<(), Self::Error>; + + /// Returns the number of available ROTs. + fn available(&self) -> usize; + + /// Returns preprocessed ROTs, if available. + /// + /// # Arguments + /// + /// * `count` - Number of preprocessed ROTs to try to consume. + fn try_recv_rot(&mut self, count: usize) -> Result, Self::Error>; + + /// Queues receiving of ROTs. + /// + /// # Arguments + /// + /// * `count` - Number of ROTs to receive. + fn queue_recv_rot(&mut self, count: usize) -> Result; +} diff --git a/crates/mpz-ot-core/src/rot/any.rs b/crates/mpz-ot-core/src/rot/any.rs new file mode 100644 index 00000000..6b946687 --- /dev/null +++ b/crates/mpz-ot-core/src/rot/any.rs @@ -0,0 +1,147 @@ +use mpz_common::future::{Map, OutputExt}; +use mpz_core::{prg::Prg, Block}; +use rand::{distributions::Standard, prelude::Distribution, Rng}; + +use crate::rot::{ROTReceiver, ROTReceiverOutput, ROTSender, ROTSenderOutput}; + +/// A ROT sender which sends any type implementing `rand` traits. +#[derive(Debug)] +pub struct AnySender { + rot: T, +} + +impl AnySender { + /// Creates a new `AnySender`. + pub fn new(rot: T) -> Self { + Self { rot } + } + + /// Returns a reference to the inner sender. + pub fn rot(&self) -> &T { + &self.rot + } + + /// Returns a mutable reference to the inner sender. + pub fn rot_mut(&mut self) -> &mut T { + &mut self.rot + } + + /// Returns the inner sender. + pub fn into_inner(self) -> T { + self.rot + } +} + +impl ROTSender<[U; 2]> for AnySender +where + T: ROTSender<[Block; 2]>, + Standard: Distribution, +{ + type Error = T::Error; + type Future = Map) -> ROTSenderOutput<[U; 2]>>; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.rot.alloc(count) + } + + fn available(&self) -> usize { + self.rot.available() + } + + fn try_send_rot(&mut self, count: usize) -> Result, Self::Error> { + self.rot.try_send_rot(count).map(map_sender) + } + + fn queue_send_rot(&mut self, count: usize) -> Result { + self.rot + .queue_send_rot(count) + .map(|output| output.map(map_sender as fn(_) -> _)) + } +} + +fn map_sender(output: ROTSenderOutput<[Block; 2]>) -> ROTSenderOutput<[T; 2]> +where + Standard: Distribution, +{ + let ROTSenderOutput { id, keys } = output; + let keys = keys + .into_iter() + .map(|[k0, k1]| { + let mut prg_0 = Prg::new_with_seed(k0.to_bytes()); + let mut prg_1 = Prg::new_with_seed(k1.to_bytes()); + + [prg_0.gen(), prg_1.gen()] + }) + .collect(); + ROTSenderOutput { id, keys } +} + +/// A ROT receiver which receives any type implementing `rand` traits. +#[derive(Debug)] +pub struct AnyReceiver { + rot: T, +} + +impl AnyReceiver { + /// Creates a new `AnyReceiver`. + pub fn new(rot: T) -> Self { + Self { rot } + } + + /// Returns a reference to the inner receiver. + pub fn rot(&self) -> &T { + &self.rot + } + + /// Returns a mutable reference to the inner receiver. + pub fn rot_mut(&mut self) -> &mut T { + &mut self.rot + } + + /// Returns the inner receiver. + pub fn into_inner(self) -> T { + self.rot + } +} + +impl ROTReceiver for AnyReceiver +where + T: ROTReceiver, + Standard: Distribution, +{ + type Error = T::Error; + type Future = Map) -> ROTReceiverOutput>; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.rot.alloc(count) + } + + fn available(&self) -> usize { + self.rot.available() + } + + fn try_recv_rot(&mut self, count: usize) -> Result, Self::Error> { + self.rot.try_recv_rot(count).map(map_receiver) + } + + fn queue_recv_rot(&mut self, count: usize) -> Result { + self.rot + .queue_recv_rot(count) + .map(|output| output.map(map_receiver as fn(_) -> _)) + } +} + +fn map_receiver(output: ROTReceiverOutput) -> ROTReceiverOutput +where + Standard: Distribution, +{ + let ROTReceiverOutput { id, choices, msgs } = output; + let msgs = msgs + .into_iter() + .map(|msg| { + let mut prg = Prg::new_with_seed(msg.to_bytes()); + prg.gen() + }) + .collect(); + ROTReceiverOutput { id, choices, msgs } +} diff --git a/crates/mpz-ot-core/src/rot/randomize.rs b/crates/mpz-ot-core/src/rot/randomize.rs new file mode 100644 index 00000000..0ba31faf --- /dev/null +++ b/crates/mpz-ot-core/src/rot/randomize.rs @@ -0,0 +1,232 @@ +use mpz_common::future::{Map, OutputExt}; +use mpz_core::{aes::FIXED_KEY_AES, Block}; + +use crate::{ + rcot::{RCOTReceiver, RCOTReceiverOutput, RCOTSender, RCOTSenderOutput}, + rot::{ROTReceiver, ROTReceiverOutput, ROTSender, ROTSenderOutput}, +}; + +// We have to Box the closure because it's not name-able in the associated type. +type FnSender = Box) -> ROTSenderOutput<[Block; 2]>>; + +/// ROT sender which randomizes the output of an RCOT sender. +#[derive(Debug)] +pub struct RandomizeRCOTSender { + rcot: T, +} + +impl RandomizeRCOTSender { + /// Creates a new [`RandomizeRCOTSender`]. + /// + /// # Arguments + /// + /// * `rcot` - RCOT sender. + pub fn new(rcot: T) -> Self { + Self { rcot } + } + + /// Returns a reference to the RCOT sender. + pub fn rcot(&self) -> &T { + &self.rcot + } + + /// Returns a mutable reference to the RCOT sender. + pub fn rcot_mut(&mut self) -> &mut T { + &mut self.rcot + } + + /// Returns the RCOT sender. + pub fn into_inner(self) -> T { + self.rcot + } +} + +impl ROTSender<[Block; 2]> for RandomizeRCOTSender +where + T: RCOTSender, +{ + type Error = T::Error; + type Future = Map; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.rcot.alloc(count) + } + + fn available(&self) -> usize { + self.rcot.available() + } + + fn try_send_rot(&mut self, count: usize) -> Result, Self::Error> { + let delta = self.rcot.delta(); + self.rcot + .try_send_rcot(count) + .map(|output| randomize_sender(delta, output)) + } + + fn queue_send_rot(&mut self, count: usize) -> Result { + let delta = self.rcot.delta(); + self.rcot.queue_send_rcot(count).map(move |output| { + output.map(Box::new(move |output| randomize_sender(delta, output)) as FnSender) + }) + } +} + +fn randomize_sender(delta: Block, output: RCOTSenderOutput) -> ROTSenderOutput<[Block; 2]> { + let RCOTSenderOutput { id, keys } = output; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")] { + use rayon::prelude::*; + let iter = keys.into_par_iter().enumerate(); + } else { + let iter = keys.into_iter().enumerate(); + } + } + + let cipher = &(*FIXED_KEY_AES); + let keys = iter + .map(|(i, key)| { + // Transfer ID ensures a unique tweak for each ROT. + let j = ((id.as_u64() as u128) << 64) + (i as u128); + let j = Block::new(j.to_be_bytes()); + + let k0 = cipher.tccr(j, key); + let k1 = cipher.tccr(j, key ^ delta); + + [k0, k1] + }) + .collect(); + + ROTSenderOutput { id, keys } +} + +/// ROT receiver which randomizes the output of an RCOT receiver. +#[derive(Debug)] +pub struct RandomizeRCOTReceiver { + rcot: T, +} + +impl RandomizeRCOTReceiver { + /// Creates a new [`RandomizeRCOTReceiver`]. + /// + /// # Arguments + /// + /// * `rcot` - RCOT receiver. + pub fn new(rcot: T) -> Self { + Self { rcot } + } + + /// Returns a reference to the RCOT receiver. + pub fn rcot(&self) -> &T { + &self.rcot + } + + /// Returns a mutable reference to the RCOT receiver. + pub fn rcot_mut(&mut self) -> &mut T { + &mut self.rcot + } + + /// Returns the RCOT receiver. + pub fn into_inner(self) -> T { + self.rcot + } +} + +impl ROTReceiver for RandomizeRCOTReceiver +where + T: RCOTReceiver, +{ + type Error = T::Error; + type Future = + Map) -> ROTReceiverOutput>; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.rcot.alloc(count) + } + + fn available(&self) -> usize { + self.rcot.available() + } + + fn try_recv_rot( + &mut self, + count: usize, + ) -> Result, Self::Error> { + self.rcot.try_recv_rcot(count).map(randomize_receiver) + } + + fn queue_recv_rot(&mut self, count: usize) -> Result { + self.rcot + .queue_recv_rcot(count) + .map(|output| output.map(randomize_receiver as fn(_) -> _)) + } +} + +fn randomize_receiver(output: RCOTReceiverOutput) -> ROTReceiverOutput { + let RCOTReceiverOutput { + id, + choices, + mut msgs, + } = output; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")] { + use rayon::prelude::*; + let iter = msgs.par_iter_mut().enumerate(); + } else { + let iter = msgs.iter_mut().enumerate(); + } + } + + let cipher = &(*FIXED_KEY_AES); + iter.for_each(|(i, msg)| { + // Transfer ID ensures a unique tweak for each ROT. + let j = ((id.as_u64() as u128) << 64) + (i as u128); + let j = Block::new(j.to_be_bytes()); + + *msg = cipher.tccr(j, *msg); + }); + + ROTReceiverOutput { id, choices, msgs } +} + +#[cfg(test)] +mod tests { + use mpz_common::future::Output; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use super::*; + + use crate::{ideal::rcot::IdealRCOT, test::assert_rot}; + + #[test] + fn test_randomize_rcot() { + let mut rng = StdRng::seed_from_u64(0); + let rcot = IdealRCOT::new(rng.gen(), rng.gen()); + + let mut sender = RandomizeRCOTSender::new(rcot.clone()); + let mut receiver = RandomizeRCOTReceiver::new(rcot); + + let count = 128; + sender.alloc(count).unwrap(); + let mut sender_output = sender.queue_send_rot(count).unwrap(); + + receiver.alloc(count).unwrap(); + let mut receiver_output = receiver.queue_recv_rot(count).unwrap(); + + sender.rcot_mut().flush().unwrap(); + + let ROTSenderOutput { + id: sender_id, + keys, + } = sender_output.try_recv().unwrap().unwrap(); + let ROTReceiverOutput { + id: receiver_id, + choices, + msgs, + } = receiver_output.try_recv().unwrap().unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_rot(&choices, &keys, &msgs); + } +} diff --git a/crates/mpz-ot-core/src/test.rs b/crates/mpz-ot-core/src/test.rs index cc65d87e..74c1a282 100644 --- a/crates/mpz-ot-core/src/test.rs +++ b/crates/mpz-ot-core/src/test.rs @@ -2,6 +2,20 @@ use mpz_core::Block; +/// Asserts the correctness of oblivious transfer. +pub fn assert_ot(choices: &[bool], msgs: &[[Block; 2]], received: &[Block]) { + assert!(choices + .iter() + .zip(msgs.iter().zip(received)) + .all(|(&choice, (&msg, &received))| { + if choice { + received == msg[1] + } else { + received == msg[0] + } + })); +} + /// Asserts the correctness of correlated oblivious transfer. pub fn assert_cot(delta: Block, choices: &[bool], msgs: &[Block], received: &[Block]) { assert!(choices diff --git a/crates/mpz-ot/Cargo.toml b/crates/mpz-ot/Cargo.toml index cee9db73..3e40b6ec 100644 --- a/crates/mpz-ot/Cargo.toml +++ b/crates/mpz-ot/Cargo.toml @@ -11,8 +11,9 @@ name = "mpz_ot" [features] default = ["rayon"] -rayon = ["mpz-ot-core/rayon"] +rayon = ["mpz-ot-core/rayon", "mpz-common/rayon"] ideal = ["mpz-common/ideal"] +test-utils = ["mpz-ot-core/test-utils"] [dependencies] mpz-core.workspace = true @@ -38,7 +39,11 @@ serio.workspace = true cfg-if.workspace = true [dev-dependencies] -mpz-common = { workspace = true, features = ["test-utils", "ideal"] } +mpz-common = { workspace = true, features = [ + "test-utils", + "ideal", + "executor", +] } mpz-ot-core = { workspace = true, features = ["test-utils"] } rstest = { workspace = true } criterion = { workspace = true, features = ["async_tokio"] } @@ -48,7 +53,3 @@ tokio = { workspace = true, features = [ "rt", "rt-multi-thread", ] } - -[[bench]] -name = "ot" -harness = false diff --git a/crates/mpz-ot/benches/ot.rs b/crates/mpz-ot/benches/ot.rs deleted file mode 100644 index 4acb7b4e..00000000 --- a/crates/mpz-ot/benches/ot.rs +++ /dev/null @@ -1,46 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; -use mpz_common::executor::test_st_executor; -use mpz_core::Block; -use mpz_ot::{ - chou_orlandi::{Receiver, Sender}, - OTReceiver, OTSender, OTSetup, -}; - -fn chou_orlandi(c: &mut Criterion) { - let rt = tokio::runtime::Runtime::new().unwrap(); - let mut group = c.benchmark_group("chou_orlandi"); - for n in [128, 256, 1024] { - group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &n| { - let msgs = vec![[Block::ONES; 2]; n]; - let choices = vec![false; n]; - b.to_async(&rt).iter(|| async { - let (mut sender_ctx, mut receiver_ctx) = test_st_executor(8); - - let mut sender = Sender::default(); - let mut receiver = Receiver::default(); - - futures::try_join!( - sender.setup(&mut sender_ctx), - receiver.setup(&mut receiver_ctx) - ) - .unwrap(); - - let (_, received) = futures::try_join!( - sender.send(&mut sender_ctx, &msgs), - receiver.receive(&mut receiver_ctx, &choices) - ) - .unwrap(); - - black_box(received) - }) - }); - } -} - -criterion_group! { - name = chou_orlandi_benches; - config = Criterion::default().sample_size(50); - targets = chou_orlandi -} - -criterion_main!(chou_orlandi_benches); diff --git a/crates/mpz-ot/src/chou_orlandi.rs b/crates/mpz-ot/src/chou_orlandi.rs new file mode 100644 index 00000000..0bc705fb --- /dev/null +++ b/crates/mpz-ot/src/chou_orlandi.rs @@ -0,0 +1,19 @@ +//! [`CO15`](https://eprint.iacr.org/2015/267.pdf) Chou-Orlandi oblivious transfer protocol. + +mod receiver; +mod sender; + +pub use receiver::Receiver; +pub use sender::{Sender, SenderError}; + +#[cfg(test)] +mod tests { + use crate::test::test_ot; + + use super::*; + + #[tokio::test] + async fn test_chou_orlandi() { + test_ot(Sender::new(), Receiver::new()).await; + } +} diff --git a/crates/mpz-ot/src/chou_orlandi/error.rs b/crates/mpz-ot/src/chou_orlandi/error.rs deleted file mode 100644 index 9bdde51e..00000000 --- a/crates/mpz-ot/src/chou_orlandi/error.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::OTError; - -/// A Chou-Orlandi sender error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum SenderError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::chou_orlandi::SenderError), - #[error("{0}")] - StateError(String), - #[error("coin-toss error: {0}")] - CointossError(#[from] mpz_cointoss::CointossError), - #[error("invalid configuration: {0}")] - InvalidConfig(String), -} - -impl From for OTError { - fn from(err: SenderError) -> Self { - match err { - SenderError::IOError(e) => e.into(), - e => OTError::SenderError(Box::new(e)), - } - } -} - -impl From for SenderError { - fn from(err: crate::chou_orlandi::sender::StateError) -> Self { - SenderError::StateError(err.to_string()) - } -} - -/// A Chou-Orlandi receiver error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ReceiverError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::chou_orlandi::ReceiverError), - #[error("{0}")] - StateError(String), - #[error("coin-toss error: {0}")] - CointossError(#[from] mpz_cointoss::CointossError), - #[error("invalid configuration: {0}")] - InvalidConfig(String), -} - -impl From for OTError { - fn from(err: ReceiverError) -> Self { - match err { - ReceiverError::IOError(e) => e.into(), - e => OTError::ReceiverError(Box::new(e)), - } - } -} - -impl From for ReceiverError { - fn from(err: crate::chou_orlandi::receiver::StateError) -> Self { - ReceiverError::StateError(err.to_string()) - } -} diff --git a/crates/mpz-ot/src/chou_orlandi/mod.rs b/crates/mpz-ot/src/chou_orlandi/mod.rs deleted file mode 100644 index df3fda7a..00000000 --- a/crates/mpz-ot/src/chou_orlandi/mod.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! An implementation of the Chou-Orlandi [`CO15`](https://eprint.iacr.org/2015/267.pdf) oblivious transfer protocol. -//! -//! # Examples -//! -//! ``` -//! use mpz_common::executor::test_st_executor; -//! use mpz_ot::{ -//! chou_orlandi::{Receiver, Sender, SenderConfig, ReceiverConfig}, -//! OTReceiver, OTSender, OTSetup -//! }; -//! use mpz_core::Block; -//! -//! # futures::executor::block_on(async { -//! let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); -//! -//! let mut sender = Sender::default(); -//! let mut receiver = Receiver::default(); -//! -//! // Perform the setup. -//! futures::try_join!( -//! sender.setup(&mut ctx_sender), -//! receiver.setup(&mut ctx_receiver) -//! ).unwrap(); -//! -//! // Perform the transfer. -//! let messages = vec![[Block::ZERO, Block::ONES], [Block::ZERO, Block::ONES]]; -//! -//! let (_, output_receiver) = futures::try_join!( -//! sender.send(&mut ctx_sender, &messages), -//! receiver.receive(&mut ctx_receiver, &[true, false]) -//! ).unwrap(); -//! -//! assert_eq!(output_receiver.msgs, vec![Block::ONES, Block::ZERO]); -//! # }); -//! ``` - -mod error; -mod receiver; -mod sender; - -pub use error::{ReceiverError, SenderError}; -pub use receiver::Receiver; -pub use sender::Sender; - -pub use mpz_ot_core::chou_orlandi::{ - msgs, ReceiverConfig, ReceiverConfigBuilder, ReceiverConfigBuilderError, SenderConfig, - SenderConfigBuilder, SenderConfigBuilderError, -}; - -#[cfg(test)] -mod tests { - use futures::TryFutureExt; - use itybity::ToBits; - use mpz_common::executor::test_st_executor; - use mpz_common::Context; - use mpz_core::Block; - use rand::Rng; - use rand_chacha::ChaCha12Rng; - use rand_core::SeedableRng; - - use crate::{CommittedOTReceiver, OTError, OTReceiver, OTSender, OTSetup, VerifiableOTSender}; - - use super::*; - use rstest::*; - - #[fixture] - fn choices() -> Vec { - let mut rng = ChaCha12Rng::seed_from_u64(0); - (0..128).map(|_| rng.gen()).collect() - } - - #[fixture] - fn data() -> Vec<[Block; 2]> { - let mut rng = ChaCha12Rng::seed_from_u64(0); - (0..128) - .map(|_| [rng.gen::<[u8; 16]>().into(), rng.gen::<[u8; 16]>().into()]) - .collect() - } - - fn choose( - data: impl Iterator, - choices: impl Iterator, - ) -> impl Iterator { - data.zip(choices) - .map(|([zero, one], choice)| if choice { one } else { zero }) - } - - async fn setup( - sender_config: SenderConfig, - receiver_config: ReceiverConfig, - sender_ctx: &mut impl Context, - receiver_ctx: &mut impl Context, - ) -> (Sender, Receiver) { - let mut sender = Sender::new(sender_config); - let mut receiver = Receiver::new(receiver_config); - - tokio::try_join!(sender.setup(sender_ctx), receiver.setup(receiver_ctx)).unwrap(); - - (sender, receiver) - } - - #[rstest] - #[tokio::test] - async fn test_chou_orlandi(data: Vec<[Block; 2]>, choices: Vec) { - let (mut sender_ctx, mut receiver_ctx) = test_st_executor(8); - let (mut sender, mut receiver) = setup( - SenderConfig::default(), - ReceiverConfig::default(), - &mut sender_ctx, - &mut receiver_ctx, - ) - .await; - - let (output_sender, output_receiver) = tokio::try_join!( - sender.send(&mut sender_ctx, &data).map_err(OTError::from), - receiver - .receive(&mut receiver_ctx, &choices) - .map_err(OTError::from) - ) - .unwrap(); - - let expected = choose(data.iter().copied(), choices.iter_lsb0()).collect::>(); - - assert_eq!(output_sender.id, output_receiver.id); - assert_eq!(output_receiver.msgs, expected); - } - - #[rstest] - #[tokio::test] - async fn test_chou_orlandi_committed_receiver(data: Vec<[Block; 2]>, choices: Vec) { - let (mut sender_ctx, mut receiver_ctx) = test_st_executor(8); - let (mut sender, mut receiver) = setup( - SenderConfig::builder().receiver_commit().build().unwrap(), - ReceiverConfig::builder().receiver_commit().build().unwrap(), - &mut sender_ctx, - &mut receiver_ctx, - ) - .await; - - tokio::try_join!( - sender.send(&mut sender_ctx, &data), - receiver.receive(&mut receiver_ctx, &choices) - ) - .unwrap(); - - let (verified_choices, _) = tokio::try_join!( - sender.verify_choices(&mut sender_ctx), - receiver.reveal_choices(&mut receiver_ctx) - ) - .unwrap(); - - assert_eq!(verified_choices, choices); - } -} diff --git a/crates/mpz-ot/src/chou_orlandi/receiver.rs b/crates/mpz-ot/src/chou_orlandi/receiver.rs index 91145515..405b9009 100644 --- a/crates/mpz-ot/src/chou_orlandi/receiver.rs +++ b/crates/mpz-ot/src/chou_orlandi/receiver.rs @@ -1,201 +1,158 @@ use async_trait::async_trait; -use itybity::BitIterable; -use mpz_cointoss as cointoss; -use mpz_common::Context; +use mpz_common::{Context, Flush}; use mpz_core::Block; -use mpz_ot_core::chou_orlandi::msgs::SenderPayload; -use mpz_ot_core::chou_orlandi::{ - receiver_state as state, Receiver as ReceiverCore, ReceiverConfig, +use mpz_ot_core::{ + chou_orlandi::{receiver_state as state, Receiver as Core, ReceiverError as CoreError}, + ot::OTReceiver, }; -use enum_try_as_inner::EnumTryAsInner; -use rand::{thread_rng, Rng}; use serio::{stream::IoStreamExt as _, SinkExt as _}; use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; -use crate::{CommittedOTReceiver, OTError, OTReceiver, OTReceiverOutput, OTSetup}; +type Error = ReceiverError; -use super::ReceiverError; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - Initialized { - config: ReceiverConfig, - seed: Option<[u8; 32]>, - }, - Setup(Box>), - Complete, +#[derive(Debug)] +enum State { + Initialized(Core), + Setup(Core), Error, } +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::Error) + } +} + /// Chou-Orlandi receiver. #[derive(Debug)] pub struct Receiver { state: State, - cointoss_sender: Option>, } impl Default for Receiver { fn default() -> Self { Self { - state: State::Initialized { - config: ReceiverConfig::default(), - seed: None, - }, - cointoss_sender: None, + state: State::Initialized(Core::new()), } } } impl Receiver { /// Creates a new receiver. - /// - /// # Arguments - /// - /// * `config` - The receiver's configuration - pub fn new(config: ReceiverConfig) -> Self { - Self { - state: State::Initialized { config, seed: None }, - cointoss_sender: None, - } + pub fn new() -> Self { + Self::default() } /// Creates a new receiver with the provided RNG seed. /// /// # Arguments /// - /// * `config` - The receiver's configuration /// * `seed` - The RNG seed used to generate the receiver's keys. - pub fn new_with_seed(config: ReceiverConfig, seed: [u8; 32]) -> Self { + pub fn new_with_seed(seed: [u8; 32]) -> Self { Self { - state: State::Initialized { - config, - seed: Some(seed), - }, - cointoss_sender: None, + state: State::Initialized(Core::new_with_seed(seed)), } } } -#[async_trait] -impl OTSetup for Receiver { - async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - if self.state.is_setup() { - return Ok(()); - } - - let (config, seed) = std::mem::replace(&mut self.state, State::Error) - .try_into_initialized() - .map_err(ReceiverError::from)?; +impl OTReceiver for Receiver { + type Error = Error; + type Future = >::Future; - // If the receiver is committed, we generate the seed using a cointoss. - let seed = if config.receiver_commit() { - if seed.is_some() { - return Err(ReceiverError::InvalidConfig( - "committed receiver seed must be generated using coin toss".to_string(), - ))?; - } - - let cointoss_seed = thread_rng().gen(); - let (seeds, cointoss_sender) = cointoss::Sender::new(vec![cointoss_seed]) - .commit(ctx) - .await - .map_err(ReceiverError::from)? - .receive(ctx) - .await - .map_err(ReceiverError::from)?; - - self.cointoss_sender = Some(cointoss_sender); - - let seed = seeds[0].to_bytes(); - // Stretch seed to 32 bytes - let mut stretched_seed = [0u8; 32]; - stretched_seed[..16].copy_from_slice(&seed); - stretched_seed[16..].copy_from_slice(&seed); - - stretched_seed - } else { - seed.unwrap_or_else(|| thread_rng().gen()) - }; - - let sender_setup = ctx.io_mut().expect_next().await?; - let receiver = - Backend::spawn(move || ReceiverCore::new_with_seed(config, seed).setup(sender_setup)) - .await; - - self.state = State::Setup(Box::new(receiver)); + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + match &mut self.state { + State::Initialized(receiver) => receiver.alloc(count).map_err(Error::from), + State::Setup(receiver) => receiver.alloc(count).map_err(Error::from), + State::Error => Err(Error::state("can not allocate, receiver is in error state")), + } + } - Ok(()) + fn queue_recv_ot(&mut self, choices: &[bool]) -> Result { + match &mut self.state { + State::Initialized(receiver) => receiver.queue_recv_ot(choices).map_err(Error::from), + State::Setup(receiver) => receiver.queue_recv_ot(choices).map_err(Error::from), + State::Error => Err(Error::state("can not queue ot, receiver is in error state")), + } } } #[async_trait] -impl OTReceiver for Receiver +impl Flush for Receiver where Ctx: Context, - T: BitIterable + Send + Sync + Clone + 'static, { - async fn receive( - &mut self, - ctx: &mut Ctx, - choices: &[T], - ) -> Result, OTError> { - let mut receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_setup() - .map_err(ReceiverError::from)?; - - let choices = choices.to_vec(); - let (mut receiver, receiver_payload) = Backend::spawn(move || { - let payload = receiver.receive_random(&choices); - (receiver, payload) - }) - .await; + type Error = Error; - ctx.io_mut().send(receiver_payload).await?; + fn wants_flush(&self) -> bool { + match &self.state { + State::Initialized(_) => true, + State::Setup(receiver) => receiver.wants_flush(), + State::Error => false, + } + } - let sender_payload: SenderPayload = ctx.io_mut().expect_next().await?; - let id = sender_payload.id; + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let mut receiver = match self.state.take() { + State::Initialized(receiver) => { + let payload = ctx.io_mut().expect_next().await?; + receiver.setup(payload) + } + State::Setup(receiver) => receiver, + State::Error => return Err(Error::state("can not flush, receiver is in error state")), + }; - let (receiver, msgs) = Backend::spawn(move || { - receiver - .receive(sender_payload) - .map(|msgs| (receiver, msgs)) + if !receiver.wants_flush() { + self.state = State::Setup(receiver); + return Ok(()); + } + + let (payload, mut receiver) = Backend::spawn(|| { + let payload = receiver.choose(); + (payload, receiver) }) - .await - .map_err(ReceiverError::from)?; + .await; + + ctx.io_mut().send(payload).await?; + let payload = ctx.io_mut().expect_next().await?; + + receiver.receive(payload)?; self.state = State::Setup(receiver); - Ok(OTReceiverOutput { id, msgs }) + Ok(()) } } -#[async_trait] -impl CommittedOTReceiver for Receiver { - async fn reveal_choices(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - let receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_setup() - .map_err(ReceiverError::from)?; - - let Some(cointoss_sender) = self.cointoss_sender.take() else { - return Err(ReceiverError::InvalidConfig( - "receiver not configured to commit".to_string(), - ) - .into()); - }; +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct ReceiverError(#[from] ErrorRepr); - cointoss_sender - .finalize(ctx) - .await - .map_err(ReceiverError::from)?; +impl ReceiverError { + fn state(err: impl Into) -> Self { + Self(ErrorRepr::State(err.into())) + } +} - let reveal = receiver.reveal_choices().map_err(ReceiverError::from)?; - ctx.io_mut().send(reveal).await?; +#[derive(Debug, thiserror::Error)] +enum ErrorRepr { + #[error("core error: {0}")] + Core(#[source] CoreError), + #[error("state error: {0}")] + State(String), + #[error("io error: {0}")] + Io(#[source] std::io::Error), +} - self.state = State::Complete; +impl From for ReceiverError { + fn from(e: CoreError) -> Self { + Self(ErrorRepr::Core(e)) + } +} - Ok(()) +impl From for ReceiverError { + fn from(e: std::io::Error) -> Self { + Self(ErrorRepr::Io(e)) } } diff --git a/crates/mpz-ot/src/chou_orlandi/sender.rs b/crates/mpz-ot/src/chou_orlandi/sender.rs index 610f891d..3f22cbee 100644 --- a/crates/mpz-ot/src/chou_orlandi/sender.rs +++ b/crates/mpz-ot/src/chou_orlandi/sender.rs @@ -1,168 +1,154 @@ -use crate::{ - chou_orlandi::SenderError, OTError, OTSender, OTSenderOutput, OTSetup, VerifiableOTSender, -}; - use async_trait::async_trait; -use mpz_cointoss as cointoss; -use mpz_common::Context; +use mpz_common::{future::MaybeDone, Context, Flush}; use mpz_core::Block; -use mpz_ot_core::chou_orlandi::{sender_state as state, Sender as SenderCore, SenderConfig}; -use rand::{thread_rng, Rng}; -use serio::{stream::IoStreamExt, SinkExt as _}; +use mpz_ot_core::{ + chou_orlandi::{sender_state as state, Sender as Core, SenderError as CoreError}, + ot::{OTSender, OTSenderOutput}, +}; +use serio::{stream::IoStreamExt, SinkExt}; use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; -use enum_try_as_inner::EnumTryAsInner; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - Initialized(SenderCore), - Setup(SenderCore), - Complete, - Error, -} +type Error = SenderError; /// Chou-Orlandi sender. #[derive(Debug)] pub struct Sender { state: State, - /// The coin toss receiver after revealing one's own seed but before receiving a decommitment - /// from the coin toss sender. - cointoss_receiver: Option>, } -impl Default for Sender { - fn default() -> Self { - Self { - state: State::Initialized(SenderCore::new(SenderConfig::default())), - cointoss_receiver: None, - } +#[derive(Debug)] +enum State { + Initialized(Core), + Setup(Core), + Error, +} + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::Error) } } impl Sender { /// Creates a new Sender - /// - /// # Arguments - /// - /// * `config` - The sender's configuration - pub fn new(config: SenderConfig) -> Self { - Self { - state: State::Initialized(SenderCore::new(config)), - cointoss_receiver: None, - } + pub fn new() -> Self { + Self::default() } /// Creates a new Sender with the provided RNG seed /// /// # Arguments /// - /// * `config` - The sender's configuration /// * `seed` - The RNG seed used to generate the sender's keys - pub fn new_with_seed(config: SenderConfig, seed: [u8; 32]) -> Self { + pub fn new_with_seed(seed: [u8; 32]) -> Self { Self { - state: State::Initialized(SenderCore::new_with_seed(config, seed)), - cointoss_receiver: None, + state: State::Initialized(Core::new_with_seed(seed)), } } } -#[async_trait] -impl OTSetup for Sender { - async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - if self.state.is_setup() { - return Ok(()); - } - - let sender = std::mem::replace(&mut self.state, State::Error) - .try_into_initialized() - .map_err(SenderError::from)?; - - // If the receiver is committed, we run the cointoss protocol - if sender.config().receiver_commit() { - let cointoss_seed = thread_rng().gen(); - self.cointoss_receiver = Some( - cointoss::Receiver::new(vec![cointoss_seed]) - .receive(ctx) - .await - .map_err(SenderError::from)?, - ); +impl Default for Sender { + fn default() -> Self { + Self { + state: State::Initialized(Core::new()), } + } +} - let (msg, sender) = sender.setup(); - - ctx.io_mut().send(msg).await?; +impl OTSender for Sender { + type Error = Error; + type Future = MaybeDone; - self.state = State::Setup(sender); + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + match &mut self.state { + State::Initialized(sender) => sender.alloc(count).map_err(Error::from), + State::Setup(sender) => sender.alloc(count).map_err(Error::from), + State::Error => Err(Error::state("can not allocate, sender is in error state")), + } + } - Ok(()) + fn queue_send_ot(&mut self, msgs: &[[Block; 2]]) -> Result { + match &mut self.state { + State::Initialized(sender) => sender.queue_send_ot(msgs).map_err(Error::from), + State::Setup(sender) => sender.queue_send_ot(msgs).map_err(Error::from), + State::Error => Err(Error::state("can not queue ot, sender is in error state")), + } } } #[async_trait] -impl OTSender for Sender { - async fn send( - &mut self, - ctx: &mut Ctx, - input: &[[Block; 2]], - ) -> Result { - let mut sender = std::mem::replace(&mut self.state, State::Error) - .try_into_setup() - .map_err(SenderError::from)?; - - let receiver_payload = ctx.io_mut().expect_next().await?; - - let input = input.to_vec(); - let (sender, payload) = Backend::spawn(move || { - sender - .send(&input, receiver_payload) - .map(|payload| (sender, payload)) - }) - .await - .map_err(SenderError::from)?; - - let id = payload.id; +impl Flush for Sender +where + Ctx: Context, +{ + type Error = Error; + + fn wants_flush(&self) -> bool { + match &self.state { + State::Initialized(_) => true, + State::Setup(sender) => sender.wants_recv(), + State::Error => false, + } + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let mut sender = match self.state.take() { + State::Initialized(sender) => { + let (setup, sender) = sender.setup(); + ctx.io_mut().send(setup).await?; + sender + } + State::Setup(sender) => sender, + State::Error => return Err(Error::state("can not flush, sender is in error state")), + }; + + if !sender.wants_recv() { + self.state = State::Setup(sender); + return Ok(()); + } + + let payload = ctx.io_mut().expect_next().await?; + + let (payload, sender) = + Backend::spawn(|| sender.send(payload).map(|payload| (payload, sender))).await?; ctx.io_mut().send(payload).await?; self.state = State::Setup(sender); - Ok(OTSenderOutput { id }) + Ok(()) } } -#[async_trait] -impl VerifiableOTSender for Sender { - async fn verify_choices(&mut self, ctx: &mut Ctx) -> Result, OTError> { - let sender = std::mem::replace(&mut self.state, State::Error) - .try_into_setup() - .map_err(SenderError::from)?; - - let Some(cointoss_receiver) = self.cointoss_receiver.take() else { - Err(SenderError::InvalidConfig( - "receiver commitment not enabled".to_string(), - ))? - }; +/// Error for [`Sender`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct SenderError(#[from] ErrorRepr); - let seed = cointoss_receiver - .finalize(ctx) - .await - .map_err(SenderError::from)?; - - let seed = seed[0].to_bytes(); - // Stretch seed to 32 bytes - let mut stretched_seed = [0u8; 32]; - stretched_seed[..16].copy_from_slice(&seed); - stretched_seed[16..].copy_from_slice(&seed); +impl SenderError { + fn state(msg: impl Into) -> Self { + Self(ErrorRepr::State(msg.into())) + } +} - let receiver_reveal = ctx.io_mut().expect_next().await?; - let verified_choices = - Backend::spawn(move || sender.verify_choices(stretched_seed, receiver_reveal)) - .await - .map_err(SenderError::from)?; +#[derive(Debug, thiserror::Error)] +enum ErrorRepr { + #[error("core error: {0}")] + Core(#[from] CoreError), + #[error("state error: {0}")] + State(String), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} - self.state = State::Complete; +impl From for SenderError { + fn from(err: CoreError) -> Self { + SenderError(ErrorRepr::Core(err)) + } +} - Ok(verified_choices) +impl From for SenderError { + fn from(err: std::io::Error) -> Self { + SenderError(ErrorRepr::Io(err)) } } diff --git a/crates/mpz-ot/src/cot.rs b/crates/mpz-ot/src/cot.rs new file mode 100644 index 00000000..1dd33e2e --- /dev/null +++ b/crates/mpz-ot/src/cot.rs @@ -0,0 +1,8 @@ +//! Correlated OT. + +mod derandomize; + +pub use derandomize::{ + DerandCOTReceiver, DerandCOTReceiverError, DerandCOTSender, DerandCOTSenderError, +}; +pub use mpz_ot_core::cot::{COTReceiver, COTReceiverOutput, COTSender, COTSenderOutput}; diff --git a/crates/mpz-ot/src/cot/derandomize.rs b/crates/mpz-ot/src/cot/derandomize.rs new file mode 100644 index 00000000..1e5677e5 --- /dev/null +++ b/crates/mpz-ot/src/cot/derandomize.rs @@ -0,0 +1,14 @@ +mod receiver; +mod sender; + +pub use receiver::{DerandCOTReceiver, DerandCOTReceiverError}; +pub use sender::{DerandCOTSender, DerandCOTSenderError}; + +#[cfg(test)] +mod tests { + + #[tokio::test] + async fn test_derandomize_cot() { + todo!() + } +} diff --git a/crates/mpz-ot/src/cot/derandomize/receiver.rs b/crates/mpz-ot/src/cot/derandomize/receiver.rs new file mode 100644 index 00000000..3a7662d0 --- /dev/null +++ b/crates/mpz-ot/src/cot/derandomize/receiver.rs @@ -0,0 +1,123 @@ +use async_trait::async_trait; +use mpz_common::{Context, ContextError, Flush}; +use mpz_core::Block; +use mpz_ot_core::cot::{DerandCOTReceiver as Core, DerandCOTReceiverError as CoreError}; +use serio::{stream::IoStreamExt, SinkExt}; + +use crate::{cot::COTReceiver, rcot::RCOTReceiver}; + +type Error = DerandCOTReceiverError; + +/// Derandomized COT receiver. +/// +/// This is a COT receiver which derandomizes preprocessed RCOTs. +#[derive(Debug)] +pub struct DerandCOTReceiver { + core: Core, +} + +impl DerandCOTReceiver { + /// Creates a new `DerandCOTReceiver`. + pub fn new(rcot: T) -> Self { + Self { + core: Core::new(rcot), + } + } + + /// Returns the inner RCOT receiver. + pub fn into_inner(self) -> T { + self.core.into_inner() + } +} + +impl COTReceiver for DerandCOTReceiver +where + T: RCOTReceiver, +{ + type Error = Error; + type Future = as COTReceiver>::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.core.alloc(count).map_err(Error::from) + } + + fn available(&self) -> usize { + self.core.available() + } + + fn queue_recv_cot(&mut self, choices: &[bool]) -> Result { + self.core.queue_recv_cot(choices).map_err(Error::from) + } +} + +#[async_trait] +impl Flush for DerandCOTReceiver +where + Ctx: Context, + T: RCOTReceiver + Flush + Send, +{ + type Error = Error; + + fn wants_flush(&self) -> bool { + self.core.wants_adjust() + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.rcot().wants_flush() { + self.core.rcot_mut().flush(ctx).await.map_err(Error::rcot)?; + } + + if self.wants_flush() { + let (derandomize, recv) = self.core.adjust()?; + ctx.io_mut().send(derandomize).await?; + let adjust = ctx.io_mut().expect_next().await?; + recv.receive(adjust)?; + } + + Ok(()) + } +} + +/// Error for [`DerandCOTReceiver`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct DerandCOTReceiverError(#[from] ErrorRepr); + +impl DerandCOTReceiverError { + fn rcot(err: E) -> Self + where + E: Into>, + { + Self(ErrorRepr::Rcot(err.into())) + } +} + +#[derive(Debug, thiserror::Error)] +enum ErrorRepr { + #[error("core error: {0}")] + Core(#[from] CoreError), + #[error("rcot error: {0}")] + Rcot(Box), + #[error("context error: {0}")] + Context(#[from] ContextError), + #[error("io error: {0}")] + Io(#[from] std::io::Error), +} + +impl From for DerandCOTReceiverError { + fn from(err: CoreError) -> Self { + Self(ErrorRepr::Core(err)) + } +} + +impl From for DerandCOTReceiverError { + fn from(err: ContextError) -> Self { + Self(ErrorRepr::Context(err)) + } +} + +impl From for DerandCOTReceiverError { + fn from(err: std::io::Error) -> Self { + Self(ErrorRepr::Io(err)) + } +} diff --git a/crates/mpz-ot/src/cot/derandomize/sender.rs b/crates/mpz-ot/src/cot/derandomize/sender.rs new file mode 100644 index 00000000..644f9ff5 --- /dev/null +++ b/crates/mpz-ot/src/cot/derandomize/sender.rs @@ -0,0 +1,126 @@ +use async_trait::async_trait; +use mpz_common::{Context, ContextError, Flush}; +use mpz_core::Block; +use mpz_ot_core::cot::{DerandCOTSender as Core, DerandCOTSenderError as CoreError}; +use serio::{stream::IoStreamExt, SinkExt}; + +use crate::{cot::COTSender, rcot::RCOTSender}; + +type Error = DerandCOTSenderError; + +/// Derandomized COT sender. +/// +/// This is a COT sender which derandomizes preprocessed RCOTs. +#[derive(Debug)] +pub struct DerandCOTSender { + core: Core, +} + +impl DerandCOTSender { + /// Creates a new `DerandCOTSender`. + pub fn new(rcot: T) -> Self { + Self { + core: Core::new(rcot), + } + } + + /// Returns the inner RCOT sender. + pub fn into_inner(self) -> T { + self.core.into_inner() + } +} + +impl COTSender for DerandCOTSender +where + T: RCOTSender, +{ + type Error = Error; + type Future = as COTSender>::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.core.alloc(count).map_err(Error::from) + } + + fn available(&self) -> usize { + self.core.available() + } + + fn delta(&self) -> Block { + self.core.delta() + } + + fn queue_send_cot(&mut self, keys: &[Block]) -> Result { + self.core.queue_send_cot(keys).map_err(Error::from) + } +} + +#[async_trait] +impl Flush for DerandCOTSender +where + Ctx: Context, + T: RCOTSender + Flush + Send, +{ + type Error = Error; + + fn wants_flush(&self) -> bool { + self.core.wants_adjust() + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.rcot().wants_flush() { + self.core.rcot_mut().flush(ctx).await.map_err(Error::rcot)?; + } + + if self.wants_flush() { + let derandomize = ctx.io_mut().expect_next().await?; + let adjust = self.core.adjust(derandomize)?; + ctx.io_mut().send(adjust).await?; + } + + Ok(()) + } +} + +/// Error for [`DerandCOTSender`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct DerandCOTSenderError(#[from] ErrorRepr); + +impl DerandCOTSenderError { + fn rcot(err: E) -> Self + where + E: Into>, + { + Self(ErrorRepr::Rcot(err.into())) + } +} + +#[derive(Debug, thiserror::Error)] +enum ErrorRepr { + #[error("core error: {0}")] + Core(#[from] CoreError), + #[error("rcot error: {0}")] + Rcot(Box), + #[error("context error: {0}")] + Context(#[from] ContextError), + #[error("io error: {0}")] + Io(#[from] std::io::Error), +} + +impl From for DerandCOTSenderError { + fn from(err: CoreError) -> Self { + Self(ErrorRepr::Core(err)) + } +} + +impl From for DerandCOTSenderError { + fn from(err: ContextError) -> Self { + Self(ErrorRepr::Context(err)) + } +} + +impl From for DerandCOTSenderError { + fn from(err: std::io::Error) -> Self { + Self(ErrorRepr::Io(err)) + } +} diff --git a/crates/mpz-ot/src/ideal.rs b/crates/mpz-ot/src/ideal.rs new file mode 100644 index 00000000..b555adb3 --- /dev/null +++ b/crates/mpz-ot/src/ideal.rs @@ -0,0 +1,6 @@ +//! Ideal OT functionalities. + +pub mod cot; +pub mod ot; +pub mod rcot; +pub mod rot; diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index b0084957..cc260bd9 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -1,223 +1,120 @@ -//! Ideal functionality for correlated oblivious transfer. +//! Ideal functionality for correlated OT. use async_trait::async_trait; - -use mpz_common::{ - ideal::{ideal_f2p, Alice, Bob}, - Allocate, Context, Preprocess, -}; +use mpz_common::Flush; use mpz_core::Block; use mpz_ot_core::{ - ideal::cot::IdealCOT, COTReceiverOutput, COTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, + cot::{COTReceiver, COTSender}, + ideal::cot::{IdealCOT as Core, IdealCOTError as CoreError}, }; -use crate::{COTReceiver, COTSender, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender}; - -fn cot( - f: &mut IdealCOT, - sender_count: usize, - choices: Vec, -) -> (COTSenderOutput, COTReceiverOutput) { - assert_eq!(sender_count, choices.len()); - - f.correlated(choices) +/// Returns a new ideal COT sender and receiver. +pub fn ideal_cot(delta: Block) -> (IdealCOTSender, IdealCOTReceiver) { + let core = Core::new(delta); + ( + IdealCOTSender { core: core.clone() }, + IdealCOTReceiver { core }, + ) } -fn rcot( - f: &mut IdealCOT, - sender_count: usize, - receiver_count: usize, -) -> (RCOTSenderOutput, RCOTReceiverOutput) { - assert_eq!(sender_count, receiver_count); - - f.random_correlated(sender_count) -} - -/// Returns an ideal COT sender and receiver. -pub fn ideal_cot() -> (IdealCOTSender, IdealCOTReceiver) { - let (alice, bob) = ideal_f2p(IdealCOT::default()); - (IdealCOTSender(alice), IdealCOTReceiver(bob)) -} - -/// Returns an ideal random COT sender and receiver. -pub fn ideal_rcot() -> (IdealCOTSender, IdealCOTReceiver) { - let (alice, bob) = ideal_f2p(IdealCOT::default()); - (IdealCOTSender(alice), IdealCOTReceiver(bob)) +/// Ideal COT sender. +pub struct IdealCOTSender { + core: Core, } -/// Ideal COT sender. -#[derive(Debug, Clone)] -pub struct IdealCOTSender(Alice); +impl COTSender for IdealCOTSender { + type Error = IdealCOTError; + type Future = >::Future; -#[async_trait] -impl OTSetup for IdealCOTSender -where - Ctx: Context, -{ - async fn setup(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + COTSender::alloc(&mut self.core, count).map_err(From::from) } -} - -impl Allocate for IdealCOTSender { - fn alloc(&mut self, _count: usize) {} -} -#[async_trait] -impl Preprocess for IdealCOTSender -where - Ctx: Context, -{ - type Error = OTError; + fn available(&self) -> usize { + COTSender::available(&self.core) + } - async fn preprocess(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) + fn delta(&self) -> Block { + COTSender::delta(&self.core) } -} -#[async_trait] -impl COTSender for IdealCOTSender { - async fn send_correlated( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - Ok(self.0.call(ctx, count, cot).await) + fn queue_send_cot(&mut self, msgs: &[Block]) -> Result { + self.core.queue_send_cot(msgs).map_err(From::from) } } #[async_trait] -impl RandomCOTSender for IdealCOTSender { - async fn send_random_correlated( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - Ok(self.0.call(ctx, count, rcot).await) +impl Flush for IdealCOTSender { + type Error = IdealCOTError; + + fn wants_flush(&self) -> bool { + self.core.wants_flush() } -} -/// Ideal COT receiver. -#[derive(Debug, Clone)] -pub struct IdealCOTReceiver(Bob); + async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_flush() { + self.core.flush().map_err(IdealCOTError::from)?; + } -#[async_trait] -impl OTSetup for IdealCOTReceiver -where - Ctx: Context, -{ - async fn setup(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { Ok(()) } } -impl Allocate for IdealCOTReceiver { - fn alloc(&mut self, _count: usize) {} +/// Ideal COT receiver. +pub struct IdealCOTReceiver { + core: Core, } -#[async_trait] -impl Preprocess for IdealCOTReceiver -where - Ctx: Context, -{ - type Error = OTError; +impl COTReceiver for IdealCOTReceiver { + type Error = IdealCOTError; + type Future = >::Future; - async fn preprocess(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + COTReceiver::alloc(&mut self.core, count).map_err(From::from) } -} -#[async_trait] -impl COTReceiver for IdealCOTReceiver { - async fn receive_correlated( - &mut self, - ctx: &mut Ctx, - choices: &[bool], - ) -> Result, OTError> { - Ok(self.0.call(ctx, choices.to_vec(), cot).await) + fn available(&self) -> usize { + COTReceiver::available(&self.core) + } + + fn queue_recv_cot(&mut self, choices: &[bool]) -> Result { + self.core.queue_recv_cot(choices).map_err(From::from) } } #[async_trait] -impl RandomCOTReceiver for IdealCOTReceiver { - async fn receive_random_correlated( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - Ok(self.0.call(ctx, count, rcot).await) +impl Flush for IdealCOTReceiver { + type Error = IdealCOTError; + + fn wants_flush(&self) -> bool { + self.core.wants_flush() + } + + async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_flush() { + self.core.flush().map_err(IdealCOTError::from)?; + } + + Ok(()) } } +/// Ideal COT error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct IdealCOTError(#[from] CoreError); + #[cfg(test)] mod tests { + use rand::{rngs::StdRng, Rng, SeedableRng}; + use super::*; - use mpz_common::executor::test_st_executor; - use mpz_ot_core::test::assert_cot; - use rand::{Rng, SeedableRng}; - use rand_chacha::ChaCha12Rng; + use crate::test::test_cot; #[tokio::test] async fn test_ideal_cot() { - let mut rng = ChaCha12Rng::seed_from_u64(0); - let (mut ctx_a, mut ctx_b) = test_st_executor(8); - let (mut alice, mut bob) = ideal_cot(); - - let delta = alice.0.get_mut().delta(); - - let count = 10; - let choices = (0..count).map(|_| rng.gen()).collect::>(); - - let ( - COTSenderOutput { - id: id_a, - msgs: sender_msgs, - }, - COTReceiverOutput { - id: id_b, - msgs: receiver_msgs, - }, - ) = tokio::try_join!( - alice.send_correlated(&mut ctx_a, count), - bob.receive_correlated(&mut ctx_b, &choices) - ) - .unwrap(); - - assert_eq!(id_a, id_b); - assert_eq!(count, sender_msgs.len()); - assert_eq!(count, receiver_msgs.len()); - assert_cot(delta, &choices, &sender_msgs, &receiver_msgs); - } - - #[tokio::test] - async fn test_ideal_rcot() { - let (mut ctx_a, mut ctx_b) = test_st_executor(8); - let (mut alice, mut bob) = ideal_rcot(); - - let delta = alice.0.get_mut().delta(); - - let count = 10; - - let ( - RCOTSenderOutput { - id: id_a, - msgs: sender_msgs, - }, - RCOTReceiverOutput { - id: id_b, - choices, - msgs: receiver_msgs, - }, - ) = tokio::try_join!( - alice.send_random_correlated(&mut ctx_a, count), - bob.receive_random_correlated(&mut ctx_b, count) - ) - .unwrap(); - - assert_eq!(id_a, id_b); - assert_eq!(count, sender_msgs.len()); - assert_eq!(count, receiver_msgs.len()); - assert_eq!(count, choices.len()); - assert_cot(delta, &choices, &sender_msgs, &receiver_msgs); + let mut rng = StdRng::seed_from_u64(0); + let (sender, receiver) = ideal_cot(rng.gen()); + test_cot(sender, receiver).await; } } diff --git a/crates/mpz-ot/src/ideal/mod.rs b/crates/mpz-ot/src/ideal/mod.rs deleted file mode 100644 index e8f57c57..00000000 --- a/crates/mpz-ot/src/ideal/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -//! Ideal implementations of the OT protocols. - -pub mod cot; -pub mod ot; -pub mod rot; diff --git a/crates/mpz-ot/src/ideal/ot.rs b/crates/mpz-ot/src/ideal/ot.rs index 326438dd..43d46f10 100644 --- a/crates/mpz-ot/src/ideal/ot.rs +++ b/crates/mpz-ot/src/ideal/ot.rs @@ -1,167 +1,93 @@ //! Ideal functionality for chosen-message oblivious transfer. -use std::marker::PhantomData; - use async_trait::async_trait; - -use mpz_common::{ - ideal::{ideal_f2p, Alice, Bob}, - Allocate, Context, Preprocess, +use mpz_common::Flush; +use mpz_core::Block; +use mpz_ot_core::{ + ideal::ot::{IdealOT as Core, IdealOTError as CoreError}, + ot::{OTReceiver, OTSender}, }; -use mpz_ot_core::{ideal::ot::IdealOT, TransferId}; - -use crate::{ - CommittedOTReceiver, CommittedOTSender, OTError, OTReceiver, OTReceiverOutput, OTSender, - OTSenderOutput, OTSetup, VerifiableOTReceiver, VerifiableOTSender, -}; - -fn ot( - f: &mut IdealOT, - sender_msgs: Vec<[T; 2]>, - receiver_choices: Vec, -) -> (OTSenderOutput, OTReceiverOutput) { - assert_eq!(sender_msgs.len(), receiver_choices.len()); - f.chosen(receiver_choices, sender_msgs) -} - -fn verify(f: &mut IdealOT, _: (), _: ()) -> (Vec, ()) { - (f.choices().to_vec(), ()) -} - -/// Returns an ideal OT sender and receiver. -pub fn ideal_ot() -> (IdealOTSender, IdealOTReceiver) { - let (alice, bob) = ideal_f2p(IdealOT::default()); +/// Returns a new ideal OT sender and receiver. +pub fn ideal_ot() -> (IdealOTSender, IdealOTReceiver) { + let core = Core::new(); ( - IdealOTSender(alice, PhantomData), - IdealOTReceiver(bob, PhantomData), + IdealOTSender { core: core.clone() }, + IdealOTReceiver { core }, ) } /// Ideal OT sender. -#[derive(Debug, Clone)] -pub struct IdealOTSender(Alice, PhantomData T>); - -#[async_trait] -impl OTSetup for IdealOTSender -where - Ctx: Context, -{ - async fn setup(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) - } +pub struct IdealOTSender { + core: Core, } -impl Allocate for IdealOTSender { - fn alloc(&mut self, _count: usize) {} -} +impl OTSender for IdealOTSender { + type Error = IdealOTError; + type Future = >::Future; -#[async_trait] -impl Preprocess for IdealOTSender -where - Ctx: Context, -{ - type Error = OTError; - - async fn preprocess(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + OTSender::alloc(&mut self.core, count).map_err(From::from) } -} -#[async_trait] -impl OTSender - for IdealOTSender<[T; 2]> -{ - async fn send(&mut self, ctx: &mut Ctx, msgs: &[[T; 2]]) -> Result { - Ok(self.0.call(ctx, msgs.to_vec(), ot).await) + fn queue_send_ot(&mut self, msgs: &[[Block; 2]]) -> Result { + self.core.queue_send_ot(msgs).map_err(From::from) } } #[async_trait] -impl CommittedOTSender - for IdealOTSender<[T; 2]> -{ - async fn reveal(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) - } -} +impl Flush for IdealOTSender { + type Error = IdealOTError; -#[async_trait] -impl VerifiableOTSender - for IdealOTSender<[T; 2]> -{ - async fn verify_choices(&mut self, ctx: &mut Ctx) -> Result, OTError> { - Ok(self.0.call(ctx, (), verify).await) + fn wants_flush(&self) -> bool { + self.core.wants_flush() } -} -/// Ideal OT receiver. -#[derive(Debug, Clone)] -pub struct IdealOTReceiver(Bob, PhantomData T>); + async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_flush() { + self.core.flush().map_err(IdealOTError::from)?; + } -#[async_trait] -impl OTSetup for IdealOTReceiver -where - Ctx: Context, -{ - async fn setup(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { Ok(()) } } -impl Allocate for IdealOTReceiver { - fn alloc(&mut self, _count: usize) {} +/// Ideal OT receiver. +pub struct IdealOTReceiver { + core: Core, } -#[async_trait] -impl Preprocess for IdealOTReceiver -where - Ctx: Context, -{ - type Error = OTError; +impl OTReceiver for IdealOTReceiver { + type Error = IdealOTError; + type Future = >::Future; - async fn preprocess(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + OTReceiver::alloc(&mut self.core, count).map_err(From::from) } -} -#[async_trait] -impl OTReceiver - for IdealOTReceiver -{ - async fn receive( - &mut self, - ctx: &mut Ctx, - choices: &[bool], - ) -> Result, OTError> { - Ok(self.0.call(ctx, choices.to_vec(), ot).await) + fn queue_recv_ot(&mut self, choices: &[bool]) -> Result { + self.core.queue_recv_ot(choices).map_err(From::from) } } #[async_trait] -impl CommittedOTReceiver - for IdealOTReceiver -{ - async fn reveal_choices(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - self.0.call(ctx, (), verify).await; - Ok(()) - } -} +impl Flush for IdealOTReceiver { + type Error = IdealOTError; -#[async_trait] -impl VerifiableOTReceiver - for IdealOTReceiver -{ - async fn accept_reveal(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) + fn wants_flush(&self) -> bool { + self.core.wants_flush() } - async fn verify( - &mut self, - _ctx: &mut Ctx, - _id: TransferId, - _msgs: &[V], - ) -> Result<(), OTError> { + async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_flush() { + self.core.flush().map_err(IdealOTError::from)?; + } + Ok(()) } } + +/// Ideal OT error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct IdealOTError(#[from] CoreError); diff --git a/crates/mpz-ot/src/ideal/rcot.rs b/crates/mpz-ot/src/ideal/rcot.rs new file mode 100644 index 00000000..f3186dde --- /dev/null +++ b/crates/mpz-ot/src/ideal/rcot.rs @@ -0,0 +1,131 @@ +//! Ideal functionality for random correlated OT. + +use async_trait::async_trait; +use mpz_common::Flush; +use mpz_core::Block; +use mpz_ot_core::{ + ideal::rcot::{IdealRCOT as Core, IdealRCOTError as CoreError}, + rcot::{RCOTReceiver, RCOTReceiverOutput, RCOTSender, RCOTSenderOutput}, +}; + +/// Returns a new ideal RCOT sender and receiver. +pub fn ideal_rcot(seed: Block, delta: Block) -> (IdealRCOTSender, IdealRCOTReceiver) { + let core = Core::new(seed, delta); + ( + IdealRCOTSender { core: core.clone() }, + IdealRCOTReceiver { core }, + ) +} + +/// Ideal RCOT sender. +pub struct IdealRCOTSender { + core: Core, +} + +impl RCOTSender for IdealRCOTSender { + type Error = IdealRCOTError; + type Future = >::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + RCOTSender::alloc(&mut self.core, count).map_err(From::from) + } + + fn available(&self) -> usize { + RCOTSender::available(&self.core) + } + + fn delta(&self) -> Block { + RCOTSender::delta(&self.core) + } + + fn try_send_rcot(&mut self, count: usize) -> Result, Self::Error> { + self.core.try_send_rcot(count).map_err(From::from) + } + + fn queue_send_rcot(&mut self, count: usize) -> Result { + self.core.queue_send_rcot(count).map_err(From::from) + } +} + +#[async_trait] +impl Flush for IdealRCOTSender { + type Error = IdealRCOTError; + + fn wants_flush(&self) -> bool { + self.core.wants_flush() + } + + async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_flush() { + self.core.flush().map_err(IdealRCOTError::from)?; + } + + Ok(()) + } +} + +/// Ideal RCOT receiver. +pub struct IdealRCOTReceiver { + core: Core, +} + +impl RCOTReceiver for IdealRCOTReceiver { + type Error = IdealRCOTError; + type Future = >::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + RCOTReceiver::alloc(&mut self.core, count).map_err(From::from) + } + + fn available(&self) -> usize { + RCOTReceiver::available(&self.core) + } + + fn try_recv_rcot( + &mut self, + count: usize, + ) -> Result, Self::Error> { + self.core.try_recv_rcot(count).map_err(From::from) + } + + fn queue_recv_rcot(&mut self, count: usize) -> Result { + self.core.queue_recv_rcot(count).map_err(From::from) + } +} + +#[async_trait] +impl Flush for IdealRCOTReceiver { + type Error = IdealRCOTError; + + fn wants_flush(&self) -> bool { + self.core.wants_flush() + } + + async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_flush() { + self.core.flush().map_err(IdealRCOTError::from)?; + } + + Ok(()) + } +} + +/// Ideal RCOT error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct IdealRCOTError(#[from] CoreError); + +#[cfg(test)] +mod tests { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use super::*; + use crate::test::test_rcot; + + #[tokio::test] + async fn test_ideal_rcot() { + let mut rng = StdRng::seed_from_u64(0); + let (sender, receiver) = ideal_rcot(rng.gen(), rng.gen()); + test_rcot(sender, receiver).await; + } +} diff --git a/crates/mpz-ot/src/ideal/rot.rs b/crates/mpz-ot/src/ideal/rot.rs index 315ab186..2104b0ee 100644 --- a/crates/mpz-ot/src/ideal/rot.rs +++ b/crates/mpz-ot/src/ideal/rot.rs @@ -1,120 +1,127 @@ -//! Ideal functionality for random oblivious transfer. +//! Ideal functionality for random correlated OT. use async_trait::async_trait; - -use mpz_common::{ - ideal::{ideal_f2p, Alice, Bob}, - Allocate, Context, Preprocess, +use mpz_common::Flush; +use mpz_core::Block; +use mpz_ot_core::{ + ideal::rot::{IdealROT as Core, IdealROTError as CoreError}, + rot::{ROTReceiver, ROTReceiverOutput, ROTSender, ROTSenderOutput}, }; -use mpz_ot_core::{ideal::rot::IdealROT, ROTReceiverOutput, ROTSenderOutput}; -use rand::distributions::{Distribution, Standard}; - -use crate::{OTError, OTSetup, RandomOTReceiver, RandomOTSender}; - -fn rot( - f: &mut IdealROT, - sender_count: usize, - receiver_count: usize, -) -> (ROTSenderOutput<[T; 2]>, ROTReceiverOutput) -where - Standard: Distribution, -{ - assert_eq!(sender_count, receiver_count); - - f.random(sender_count) -} -/// Returns an ideal ROT sender and receiver. -pub fn ideal_rot() -> (IdealROTSender, IdealROTReceiver) { - let (alice, bob) = ideal_f2p(IdealROT::default()); - (IdealROTSender(alice), IdealROTReceiver(bob)) +/// Returns a new ideal ROT sender and receiver. +pub fn ideal_rot(seed: Block) -> (IdealROTSender, IdealROTReceiver) { + let core = Core::new(seed); + ( + IdealROTSender { core: core.clone() }, + IdealROTReceiver { core }, + ) } /// Ideal ROT sender. -#[derive(Debug, Clone)] -pub struct IdealROTSender(Alice); +pub struct IdealROTSender { + core: Core, +} -#[async_trait] -impl OTSetup for IdealROTSender -where - Ctx: Context, -{ - async fn setup(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) +impl ROTSender<[Block; 2]> for IdealROTSender { + type Error = IdealROTError; + type Future = >::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + ROTSender::alloc(&mut self.core, count).map_err(From::from) } -} -impl Allocate for IdealROTSender { - fn alloc(&mut self, _count: usize) {} -} + fn available(&self) -> usize { + ROTSender::available(&self.core) + } -#[async_trait] -impl Preprocess for IdealROTSender -where - Ctx: Context, -{ - type Error = OTError; + fn try_send_rot(&mut self, count: usize) -> Result, Self::Error> { + self.core.try_send_rot(count).map_err(From::from) + } - async fn preprocess(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { - Ok(()) + fn queue_send_rot(&mut self, count: usize) -> Result { + self.core.queue_send_rot(count).map_err(From::from) } } #[async_trait] -impl RandomOTSender for IdealROTSender -where - Standard: Distribution, -{ - async fn send_random( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - Ok(self.0.call(ctx, count, rot).await) +impl Flush for IdealROTSender { + type Error = IdealROTError; + + fn wants_flush(&self) -> bool { + self.core.wants_flush() } -} -/// Ideal ROT receiver. -#[derive(Debug, Clone)] -pub struct IdealROTReceiver(Bob); + async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_flush() { + self.core.flush().map_err(IdealROTError::from)?; + } -#[async_trait] -impl OTSetup for IdealROTReceiver -where - Ctx: Context, -{ - async fn setup(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { Ok(()) } } -impl Allocate for IdealROTReceiver { - fn alloc(&mut self, _count: usize) {} +/// Ideal OT receiver. +pub struct IdealROTReceiver { + core: Core, +} + +impl ROTReceiver for IdealROTReceiver { + type Error = IdealROTError; + type Future = >::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + ROTReceiver::alloc(&mut self.core, count).map_err(From::from) + } + + fn available(&self) -> usize { + ROTReceiver::available(&self.core) + } + + fn try_recv_rot( + &mut self, + count: usize, + ) -> Result, Self::Error> { + self.core.try_recv_rot(count).map_err(From::from) + } + + fn queue_recv_rot(&mut self, count: usize) -> Result { + self.core.queue_recv_rot(count).map_err(From::from) + } } #[async_trait] -impl Preprocess for IdealROTReceiver -where - Ctx: Context, -{ - type Error = OTError; +impl Flush for IdealROTReceiver { + type Error = IdealROTError; + + fn wants_flush(&self) -> bool { + self.core.wants_flush() + } + + async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_flush() { + self.core.flush().map_err(IdealROTError::from)?; + } - async fn preprocess(&mut self, _ctx: &mut Ctx) -> Result<(), OTError> { Ok(()) } } -#[async_trait] -impl RandomOTReceiver - for IdealROTReceiver -where - Standard: Distribution, -{ - async fn receive_random( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - Ok(self.0.call(ctx, count, rot).await) +/// Ideal OT error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct IdealROTError(#[from] CoreError); + +#[cfg(test)] +mod tests { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use super::*; + use crate::test::test_rot; + + #[tokio::test] + async fn test_ideal_rot() { + let mut rng = StdRng::seed_from_u64(0); + let (sender, receiver) = ideal_rot(rng.gen()); + test_rot(sender, receiver).await; } } diff --git a/crates/mpz-ot/src/kos.rs b/crates/mpz-ot/src/kos.rs new file mode 100644 index 00000000..4a419c1e --- /dev/null +++ b/crates/mpz-ot/src/kos.rs @@ -0,0 +1,33 @@ +//! [`KOS15`](https://eprint.iacr.org/2015/546.pdf) oblivious transfer extension protocol. + +mod receiver; +mod sender; + +pub use receiver::Receiver; +pub use sender::Sender; + +pub use mpz_ot_core::kos::{ + msgs, ReceiverConfig, ReceiverConfigBuilder, ReceiverConfigBuilderError, SenderConfig, + SenderConfigBuilder, SenderConfigBuilderError, +}; + +#[cfg(test)] +mod tests { + use mpz_core::Block; + use rand::{rngs::StdRng, SeedableRng}; + + use super::*; + + use crate::{ideal::ot::ideal_ot, test::test_rcot}; + + #[tokio::test] + async fn test_kos_rcot() { + let mut rng = StdRng::seed_from_u64(0); + let (base_sender, base_receiver) = ideal_ot(); + let delta = Block::random(&mut rng); + let sender = Sender::new(SenderConfig::default(), delta, base_receiver); + let receiver = Receiver::new(ReceiverConfig::default(), base_sender); + + test_rcot(sender, receiver).await; + } +} diff --git a/crates/mpz-ot/src/kos/error.rs b/crates/mpz-ot/src/kos/error.rs deleted file mode 100644 index 05361cd5..00000000 --- a/crates/mpz-ot/src/kos/error.rs +++ /dev/null @@ -1,92 +0,0 @@ -use crate::OTError; - -/// A KOS sender error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum SenderError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::kos::SenderError), - #[error(transparent)] - BaseOTError(#[from] crate::OTError), - #[error("coin-toss error: {0}")] - CointossError(#[from] mpz_cointoss::CointossError), - #[error("{0}")] - StateError(String), - #[error("configuration error: {0}")] - ConfigError(String), - #[error("{0}")] - Other(String), -} - -impl From for OTError { - fn from(err: SenderError) -> Self { - match err { - SenderError::IOError(e) => e.into(), - e => OTError::SenderError(Box::new(e)), - } - } -} - -impl From for SenderError { - fn from(err: crate::kos::SenderStateError) -> Self { - SenderError::StateError(err.to_string()) - } -} - -impl From for OTError { - fn from(err: mpz_ot_core::kos::SenderError) -> Self { - SenderError::from(err).into() - } -} - -/// A KOS receiver error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ReceiverError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::kos::ReceiverError), - #[error(transparent)] - BaseOTError(#[from] crate::OTError), - #[error("coin-toss error: {0}")] - CointossError(#[from] mpz_cointoss::CointossError), - #[error("{0}")] - StateError(String), - #[error("configuration error: {0}")] - ConfigError(String), - #[error(transparent)] - VerifyError(#[from] ReceiverVerifyError), - #[error("{0}")] - Other(String), -} - -impl From for OTError { - fn from(err: ReceiverError) -> Self { - match err { - ReceiverError::IOError(e) => e.into(), - e => OTError::ReceiverError(Box::new(e)), - } - } -} - -impl From for ReceiverError { - fn from(err: crate::kos::ReceiverStateError) -> Self { - ReceiverError::StateError(err.to_string()) - } -} - -impl From for OTError { - fn from(err: mpz_ot_core::kos::ReceiverError) -> Self { - ReceiverError::from(err).into() - } -} - -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ReceiverVerifyError { - #[error("delta value is not inconsistent")] - InconsistentDelta, -} diff --git a/crates/mpz-ot/src/kos/mod.rs b/crates/mpz-ot/src/kos/mod.rs deleted file mode 100644 index 0aa55927..00000000 --- a/crates/mpz-ot/src/kos/mod.rs +++ /dev/null @@ -1,258 +0,0 @@ -//! An implementation of the [`KOS15`](https://eprint.iacr.org/2015/546.pdf) oblivious transfer extension protocol. - -mod error; -mod receiver; -mod sender; -mod shared_receiver; -mod shared_sender; - -pub use error::{ReceiverError, ReceiverVerifyError, SenderError}; -pub use receiver::Receiver; -pub use sender::Sender; -pub use shared_receiver::SharedReceiver; -pub use shared_sender::SharedSender; - -pub(crate) use receiver::StateError as ReceiverStateError; -pub(crate) use sender::StateError as SenderStateError; - -pub use mpz_ot_core::kos::{ - msgs, PayloadRecord, ReceiverConfig, ReceiverConfigBuilder, ReceiverConfigBuilderError, - ReceiverKeys, SenderConfig, SenderConfigBuilder, SenderConfigBuilderError, SenderKeys, -}; - -// If we're testing we use a smaller chunk size to make sure the chunking code paths are tested. -cfg_if::cfg_if! { - if #[cfg(test)] { - pub(crate) const EXTEND_CHUNK_SIZE: usize = 1024; - } else { - /// The size of the chunks used to send the extension matrix, 4MB. - pub(crate) const EXTEND_CHUNK_SIZE: usize = 4 * 1024 * 1024; - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rstest::*; - - use futures::TryFutureExt; - use itybity::ToBits; - use mpz_common::{executor::test_st_executor, Context}; - use mpz_core::Block; - use rand::Rng; - use rand_chacha::ChaCha12Rng; - use rand_core::SeedableRng; - - use crate::{ - ideal::ot::{ideal_ot, IdealOTReceiver, IdealOTSender}, - CommittedOTSender, OTError, OTReceiver, OTSender, OTSetup, RandomOTReceiver, - RandomOTSender, VerifiableOTReceiver, - }; - - #[fixture] - fn choices() -> Vec { - let mut rng = ChaCha12Rng::seed_from_u64(0); - (0..128).map(|_| rng.gen()).collect() - } - - #[fixture] - fn data() -> Vec<[Block; 2]> { - let mut rng = ChaCha12Rng::seed_from_u64(0); - (0..128) - .map(|_| [rng.gen::<[u8; 16]>().into(), rng.gen::<[u8; 16]>().into()]) - .collect() - } - - fn choose( - data: impl Iterator, - choices: impl Iterator, - ) -> impl Iterator { - data.zip(choices) - .map(|([zero, one], choice)| if choice { one } else { zero }) - } - - async fn setup( - sender_config: SenderConfig, - receiver_config: ReceiverConfig, - ctx_sender: &mut Ctx, - ctx_receiver: &mut Ctx, - count: usize, - ) -> ( - Sender>, - Receiver>, - ) { - let (base_sender, base_receiver) = ideal_ot(); - - let mut sender = Sender::new(sender_config, base_receiver); - let mut receiver = Receiver::new(receiver_config, base_sender); - - tokio::try_join!(sender.setup(ctx_sender), receiver.setup(ctx_receiver)).unwrap(); - tokio::try_join!( - sender.extend(ctx_sender, count).map_err(OTError::from), - receiver.extend(ctx_receiver, count).map_err(OTError::from) - ) - .unwrap(); - - (sender, receiver) - } - - #[rstest] - #[tokio::test] - async fn test_kos(data: Vec<[Block; 2]>, choices: Vec) { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (mut sender, mut receiver) = setup( - SenderConfig::default(), - ReceiverConfig::default(), - &mut ctx_sender, - &mut ctx_receiver, - data.len(), - ) - .await; - - let (output_sender, output_receiver) = tokio::try_join!( - OTSender::<_, [Block; 2]>::send(&mut sender, &mut ctx_sender, &data) - .map_err(OTError::from), - OTReceiver::<_, bool, Block>::receive(&mut receiver, &mut ctx_receiver, &choices) - .map_err(OTError::from) - ) - .unwrap(); - - let expected = choose(data.iter().copied(), choices.iter_lsb0()).collect::>(); - - assert_eq!(output_sender.id, output_receiver.id); - assert_eq!(output_receiver.msgs, expected); - } - - #[tokio::test] - async fn test_kos_random() { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (mut sender, mut receiver) = setup( - SenderConfig::default(), - ReceiverConfig::default(), - &mut ctx_sender, - &mut ctx_receiver, - 10, - ) - .await; - - let (output_sender, output_receiver) = tokio::try_join!( - RandomOTSender::<_, [Block; 2]>::send_random(&mut sender, &mut ctx_sender, 10), - RandomOTReceiver::<_, bool, Block>::receive_random( - &mut receiver, - &mut ctx_receiver, - 10 - ) - ) - .unwrap(); - - let expected = output_sender - .msgs - .into_iter() - .zip(output_receiver.choices) - .map(|(output, choice)| output[choice as usize]) - .collect::>(); - - assert_eq!(output_sender.id, output_receiver.id); - assert_eq!(output_receiver.msgs, expected); - } - - #[rstest] - #[tokio::test] - async fn test_kos_bytes(data: Vec<[Block; 2]>, choices: Vec) { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (mut sender, mut receiver) = setup( - SenderConfig::default(), - ReceiverConfig::default(), - &mut ctx_sender, - &mut ctx_receiver, - data.len(), - ) - .await; - - let data: Vec<_> = data - .into_iter() - .map(|[a, b]| [a.to_bytes(), b.to_bytes()]) - .collect(); - - let (output_sender, output_receiver) = tokio::try_join!( - OTSender::<_, [[u8; 16]; 2]>::send(&mut sender, &mut ctx_sender, &data) - .map_err(OTError::from), - OTReceiver::<_, bool, [u8; 16]>::receive(&mut receiver, &mut ctx_receiver, &choices) - .map_err(OTError::from) - ) - .unwrap(); - - let expected = choose(data.iter().copied(), choices.iter_lsb0()).collect::>(); - - assert_eq!(output_sender.id, output_receiver.id); - assert_eq!(output_receiver.msgs, expected); - } - - #[rstest] - #[tokio::test] - async fn test_kos_committed_sender(data: Vec<[Block; 2]>, choices: Vec) { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (mut sender, mut receiver) = setup( - SenderConfig::builder().sender_commit().build().unwrap(), - ReceiverConfig::builder().sender_commit().build().unwrap(), - &mut ctx_sender, - &mut ctx_receiver, - data.len(), - ) - .await; - - let (output_sender, output_receiver) = tokio::try_join!( - OTSender::<_, [Block; 2]>::send(&mut sender, &mut ctx_sender, &data) - .map_err(OTError::from), - OTReceiver::<_, bool, Block>::receive(&mut receiver, &mut ctx_receiver, &choices) - .map_err(OTError::from) - ) - .unwrap(); - - let expected = choose(data.iter().copied(), choices.iter_lsb0()).collect::>(); - - assert_eq!(output_sender.id, output_receiver.id); - assert_eq!(output_receiver.msgs, expected); - - tokio::try_join!( - CommittedOTSender::reveal(&mut sender, &mut ctx_sender), - receiver.accept_reveal(&mut ctx_receiver) - ) - .unwrap(); - - receiver - .verify(&mut ctx_receiver, output_receiver.id, &data) - .await - .unwrap(); - } - - #[rstest] - #[tokio::test] - async fn test_shared_kos(data: Vec<[Block; 2]>, choices: Vec) { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (sender, receiver) = setup( - SenderConfig::default(), - ReceiverConfig::default(), - &mut ctx_sender, - &mut ctx_receiver, - data.len(), - ) - .await; - - let mut receiver = SharedReceiver::new(receiver); - let mut sender = SharedSender::new(sender); - - let (output_sender, output_receiver) = tokio::try_join!( - OTSender::<_, [Block; 2]>::send(&mut sender, &mut ctx_sender, &data) - .map_err(OTError::from), - OTReceiver::<_, bool, Block>::receive(&mut receiver, &mut ctx_receiver, &choices) - .map_err(OTError::from) - ) - .unwrap(); - - let expected = choose(data.iter().copied(), choices.iter_lsb0()).collect::>(); - - assert_eq!(output_sender.id, output_receiver.id); - assert_eq!(output_receiver.msgs, expected); - } -} diff --git a/crates/mpz-ot/src/kos/receiver.rs b/crates/mpz-ot/src/kos/receiver.rs index 62fa4c75..039873c9 100644 --- a/crates/mpz-ot/src/kos/receiver.rs +++ b/crates/mpz-ot/src/kos/receiver.rs @@ -1,387 +1,222 @@ -use std::mem; - use async_trait::async_trait; -use futures::TryFutureExt as _; -use itybity::{FromBitIterator, IntoBitIterator}; -use mpz_cointoss as cointoss; -use mpz_common::{try_join, Allocate, Context, Preprocess}; -use mpz_core::{prg::Prg, Block}; +use rand::{thread_rng, Rng}; +use serio::SinkExt as _; + +use mpz_cointoss::{self as cointoss, cointoss_sender}; +use mpz_common::{future::MaybeDone, scoped, Context, ContextError, Flush}; +use mpz_core::Block; use mpz_ot_core::{ - kos::{ - msgs::{SenderPayload, StartExtend}, - pad_ot_count, receiver_state as state, Receiver as ReceiverCore, ReceiverConfig, - ReceiverKeys, CSP, - }, - OTReceiverOutput, ROTReceiverOutput, TransferId, + kos::{receiver_state as state, Receiver as Core, ReceiverConfig, ReceiverError as CoreError}, + ot::OTSender, + rcot::{RCOTReceiver, RCOTReceiverOutput}, }; -use enum_try_as_inner::EnumTryAsInner; -use rand::{ - distributions::{Distribution, Standard}, - thread_rng, Rng, -}; -use rand_core::SeedableRng; -use serio::{stream::IoStreamExt as _, SinkExt as _}; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -use super::{ReceiverError, ReceiverVerifyError, EXTEND_CHUNK_SIZE}; -use crate::{ - OTError, OTReceiver, OTSender, OTSetup, RandomOTReceiver, VerifiableOTReceiver, - VerifiableOTSender, -}; +type Error = ReceiverError; -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - Initialized(Box>), - Extension(Box>), - Verify(ReceiverCore), +#[derive(Debug)] +enum State { + Initialized { + base_ot: BaseOT, + receiver: Core, + }, + Extension(Core), Error, } +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::Error) + } +} + /// KOS receiver. #[derive(Debug)] pub struct Receiver { - state: State, - base: BaseOT, - alloc: usize, - cointoss_receiver: Option>, + state: State, } -impl Receiver -where - BaseOT: Send, -{ - /// Creates a new receiver. +impl Receiver { + /// Creates a new Receiver /// /// # Arguments /// - /// * `config` - The receiver's configuration - pub fn new(config: ReceiverConfig, base: BaseOT) -> Self { + /// * `config` - The Receiver's configuration. + /// * `base_ot` - Base OT. + pub fn new(config: ReceiverConfig, base_ot: BaseOT) -> Self { Self { - state: State::Initialized(Box::new(ReceiverCore::new(config))), - base, - alloc: 0, - cointoss_receiver: None, + state: State::Initialized { + base_ot, + receiver: Core::new(config), + }, } } +} - /// The number of remaining OTs which can be consumed. - pub fn remaining(&self) -> Result { - Ok(self.state.try_as_extension()?.remaining()) - } +impl RCOTReceiver for Receiver { + type Error = Error; + type Future = MaybeDone>; - pub(crate) fn state(&self) -> &State { - &self.state + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + match &mut self.state { + State::Initialized { receiver, .. } => receiver.alloc(count).map_err(Error::from), + State::Extension(receiver) => receiver.alloc(count).map_err(Error::from), + State::Error => Err(Error::state("can not allocate, receiver in error state")), + } } - /// Returns the provided number of keys. - pub(crate) fn take_keys(&mut self, count: usize) -> Result { - self.state - .try_as_extension_mut()? - .keys(count) - .map_err(ReceiverError::from) + fn available(&self) -> usize { + match &self.state { + State::Initialized { .. } | State::Error => 0, + State::Extension(receiver) => receiver.available(), + } } - /// Performs OT extension. - /// - /// # Arguments - /// - /// * `sink` - The sink to send messages to the sender - /// * `stream` - The stream to receive messages from the sender - /// * `count` - The number of OTs to extend - pub async fn extend( + fn try_recv_rcot( &mut self, - ctx: &mut Ctx, count: usize, - ) -> Result<(), ReceiverError> { - let mut ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - let count = pad_ot_count(count); - - // Extend the OTs. - let (mut ext_receiver, extend) = Backend::spawn(move || { - ext_receiver - .extend(count) - .map(|extend| (ext_receiver, extend)) - }) - .await?; - - // Send the extend message and cointoss commitment. - ctx.io_mut().feed(StartExtend { count }).await?; - for extend in extend.into_chunks(EXTEND_CHUNK_SIZE) { - ctx.io_mut().feed(extend).await?; + ) -> Result, Self::Error> { + match &mut self.state { + State::Initialized { receiver, .. } => { + receiver.try_recv_rcot(count).map_err(Error::from) + } + State::Extension(receiver) => receiver.try_recv_rcot(count).map_err(Error::from), + State::Error => Err(Error::state("can not send, receiver in error state")), } - ctx.io_mut().flush().await?; - - // Sample chi_seed with coin-toss. - let seed = thread_rng().gen(); - let chi_seed = cointoss::cointoss_sender(ctx, vec![seed]).await?[0]; - - // Compute consistency check. - let (ext_receiver, check) = Backend::spawn(move || { - ext_receiver - .check(chi_seed) - .map(|check| (ext_receiver, check)) - }) - .await?; - - // Send correlation check value. - ctx.io_mut().send(check).await?; - - self.state = State::Extension(ext_receiver); - - Ok(()) } -} - -impl Receiver -where - BaseOT: Send, -{ - pub(crate) async fn verify_delta( - &mut self, - ctx: &mut Ctx, - ) -> Result<(), ReceiverError> - where - BaseOT: VerifiableOTSender, - { - let receiver = std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - // Finalize coin toss to determine expected delta - let Some(cointoss_receiver) = self.cointoss_receiver.take() else { - return Err(ReceiverError::ConfigError( - "committed sender not configured".to_string(), - ))?; - }; - - let expected_delta = cointoss_receiver - .finalize(ctx) - .await - .map_err(ReceiverError::from)?[0]; - - // Receive delta by verifying the sender's base OT choices. - let choices = self.base.verify_choices(ctx).await?; - let actual_delta = <[u8; 16]>::from_lsb0_iter(choices).into(); - - if expected_delta != actual_delta { - return Err(ReceiverError::from(ReceiverVerifyError::InconsistentDelta)); + fn queue_recv_rcot(&mut self, count: usize) -> Result { + match &mut self.state { + State::Initialized { receiver, .. } => { + receiver.queue_recv_rcot(count).map_err(Error::from) + } + State::Extension(receiver) => receiver.queue_recv_rcot(count).map_err(Error::from), + State::Error => Err(Error::state("can not queue, receiver in error state")), } - - self.state = State::Verify(receiver.start_verification(actual_delta)?); - - Ok(()) } } #[async_trait] -impl OTSetup for Receiver +impl Flush for Receiver where Ctx: Context, - BaseOT: OTSetup + OTSender + Send, + BaseOT: OTSender + Flush + Send, { - async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - if self.state.is_extension() { - return Ok(()); - } + type Error = Error; - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_initialized() - .map_err(ReceiverError::from)?; - - // If the sender is committed, we run a coin toss - if ext_receiver.config().sender_commit() { - let cointoss_seed = thread_rng().gen(); - let (cointoss_receiver, _) = try_join!( - ctx, - cointoss::Receiver::new(vec![cointoss_seed]) - .receive(ctx) - .map_err(ReceiverError::from), - self.base.setup(ctx).map_err(ReceiverError::from) - )??; - - self.cointoss_receiver = Some(cointoss_receiver); - } else { - self.base.setup(ctx).await?; + fn wants_flush(&self) -> bool { + match &self.state { + State::Initialized { .. } => true, + State::Extension(receiver) => receiver.wants_extend(), + State::Error => false, } + } - let seeds: [[Block; 2]; CSP] = std::array::from_fn(|_| thread_rng().gen()); + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let mut receiver = match self.state.take() { + State::Initialized { + mut base_ot, + receiver, + } => { + let (receiver, seeds) = { + let mut rng = thread_rng(); + let seeds = std::array::from_fn(|_| rng.gen()); + (receiver.setup(seeds), seeds) + }; + + _ = base_ot.queue_send_ot(&seeds).map_err(Error::base_ot)?; + base_ot.flush(ctx).await.map_err(Error::base_ot)?; + + receiver + } + State::Extension(receiver) => receiver, + State::Error => return Err(Error::state("can not flush, receiver in error state")), + }; - // Send seeds to sender - self.base.send(ctx, &seeds).await?; + if !receiver.wants_extend() { + self.state = State::Extension(receiver); + return Ok(()); + } - let ext_receiver = ext_receiver.setup(seeds); + let receiver = ctx + .blocking(scoped!(move |ctx| { + while receiver.wants_extend() { + let extend = receiver.extend()?; + ctx.io_mut().send(extend).await?; + } - self.state = State::Extension(Box::new(ext_receiver)); + let seed = thread_rng().gen(); - Ok(()) - } -} + // See issue #176. + let chi_seed = cointoss_sender(ctx, vec![seed]).await?[0]; -impl Allocate for Receiver { - fn alloc(&mut self, count: usize) { - self.alloc += count; - } -} + let receiver_check = receiver.check(chi_seed)?; -#[async_trait] -impl Preprocess for Receiver -where - Ctx: Context, - BaseOT: OTSetup + OTSender + Send, -{ - type Error = OTError; + ctx.io_mut().send(receiver_check).await?; - async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - if self.state.is_initialized() { - self.setup(ctx).await?; - } + Ok::<_, Error>(receiver) + })) + .await??; - let count = mem::take(&mut self.alloc); - if count == 0 { - return Ok(()); - } + self.state = State::Extension(receiver); - self.extend(ctx, count).await.map_err(OTError::from) + Ok(()) } } -#[async_trait] -impl OTReceiver for Receiver -where - Ctx: Context, - BaseOT: Send, -{ - async fn receive( - &mut self, - ctx: &mut Ctx, - choices: &[bool], - ) -> Result, OTError> { - let receiver = self - .state - .try_as_extension_mut() - .map_err(ReceiverError::from)?; - - let mut receiver_keys = receiver.keys(choices.len()).map_err(ReceiverError::from)?; - - let choices = choices.into_lsb0_vec(); - let derandomize = receiver_keys - .derandomize(&choices) - .map_err(ReceiverError::from)?; - - // Send derandomize message - ctx.io_mut().send(derandomize).await?; - - // Receive payload - let payload: SenderPayload = ctx.io_mut().expect_next().await?; - let id = payload.id; - - let received = Backend::spawn(move || { - receiver_keys - .decrypt_blocks(payload) - .map_err(ReceiverError::from) - }) - .await?; - - Ok(OTReceiverOutput { id, msgs: received }) - } -} +/// Error for [`Receiver`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct ReceiverError(#[from] ErrorRepr); -#[async_trait] -impl RandomOTReceiver for Receiver -where - Ctx: Context, - Standard: Distribution, - BaseOT: Send, -{ - async fn receive_random( - &mut self, - _ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - let receiver = self - .state - .try_as_extension_mut() - .map_err(ReceiverError::from)?; +impl ReceiverError { + fn base_ot(err: E) -> Self + where + E: Into>, + { + Self(ErrorRepr::BaseOT(err.into())) + } - let keys = receiver.keys(count).map_err(ReceiverError::from)?; - let id = keys.id(); - let (choices, keys) = keys.take_choices_and_keys(); + fn state(msg: impl Into) -> Self { + Self(ErrorRepr::State(msg.into())) + } +} - let msgs = keys.into_iter().map(|k| Prg::from_seed(k).gen()).collect(); +#[derive(Debug, thiserror::Error)] +enum ErrorRepr { + #[error("core error: {0}")] + Core(#[from] CoreError), + #[error("base OT error: {0}")] + BaseOT(Box), + #[error("cointoss error: {0}")] + Cointoss(#[from] cointoss::CointossError), + #[error("state error: {0}")] + State(String), + #[error("context error: {0}")] + Context(#[from] ContextError), + #[error("io error: {0}")] + Io(#[from] std::io::Error), +} - Ok(ROTReceiverOutput { id, choices, msgs }) +impl From for ReceiverError { + fn from(err: CoreError) -> Self { + Self(ErrorRepr::Core(err)) } } -#[async_trait] -impl OTReceiver for Receiver -where - Ctx: Context, - BaseOT: Send, -{ - async fn receive( - &mut self, - ctx: &mut Ctx, - choices: &[bool], - ) -> Result, OTError> { - let receiver = self - .state - .try_as_extension_mut() - .map_err(ReceiverError::from)?; - - let mut receiver_keys = receiver.keys(choices.len()).map_err(ReceiverError::from)?; - - let choices = choices.into_lsb0_vec(); - let derandomize = receiver_keys - .derandomize(&choices) - .map_err(ReceiverError::from)?; - - // Send derandomize message - ctx.io_mut().send(derandomize).await?; - - // Receive payload - let payload: SenderPayload = ctx.io_mut().expect_next().await?; - let id = payload.id; - - let received = Backend::spawn(move || { - receiver_keys - .decrypt_bytes(payload) - .map_err(ReceiverError::from) - }) - .await?; - - Ok(OTReceiverOutput { id, msgs: received }) +impl From for ReceiverError { + fn from(err: cointoss::CointossError) -> Self { + Self(ErrorRepr::Cointoss(err)) } } -#[async_trait] -impl VerifiableOTReceiver for Receiver -where - Ctx: Context, - BaseOT: VerifiableOTSender + Send, -{ - async fn accept_reveal(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - self.verify_delta(ctx).await.map_err(OTError::from) +impl From for ReceiverError { + fn from(err: ContextError) -> Self { + Self(ErrorRepr::Context(err)) } +} - async fn verify( - &mut self, - _ctx: &mut Ctx, - id: TransferId, - msgs: &[[Block; 2]], - ) -> Result<(), OTError> { - let receiver = self.state.try_as_verify().map_err(ReceiverError::from)?; - - let record = receiver.remove_record(id).map_err(ReceiverError::from)?; - - let msgs = msgs.to_vec(); - Backend::spawn(move || record.verify(&msgs)) - .await - .map_err(ReceiverError::from)?; - - Ok(()) +impl From for ReceiverError { + fn from(err: std::io::Error) -> Self { + Self(ErrorRepr::Io(err)) } } diff --git a/crates/mpz-ot/src/kos/sender.rs b/crates/mpz-ot/src/kos/sender.rs index f0e8b37a..71fbeeab 100644 --- a/crates/mpz-ot/src/kos/sender.rs +++ b/crates/mpz-ot/src/kos/sender.rs @@ -1,407 +1,224 @@ -use std::mem; - use async_trait::async_trait; -use enum_try_as_inner::EnumTryAsInner; -use futures::TryFutureExt; use itybity::IntoBits; -use mpz_cointoss as cointoss; -use mpz_common::{try_join, Allocate, Context, Preprocess}; -use mpz_core::{prg::Prg, Block}; +use mpz_cointoss::{self as cointoss, cointoss_receiver}; +use mpz_common::{future::MaybeDone, scoped, Context, ContextError, Flush}; +use mpz_core::Block; use mpz_ot_core::{ - kos::{ - extension_matrix_size, - msgs::{Extend, StartExtend}, - pad_ot_count, sender_state as state, Sender as SenderCore, SenderConfig, SenderKeys, CSP, - }, - OTSenderOutput, ROTSenderOutput, -}; -use rand::{ - distributions::{Distribution, Standard}, - thread_rng, Rng, + kos::{sender_state as state, Sender as Core, SenderConfig, SenderError as CoreError}, + ot::{OTReceiver, OTReceiverOutput}, + rcot::{RCOTSender, RCOTSenderOutput}, }; -use rand_core::SeedableRng; -use serio::{stream::IoStreamExt as _, SinkExt as _}; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; +use rand::{thread_rng, Rng}; +use serio::stream::IoStreamExt as _; -use crate::{ - kos::SenderError, CommittedOTReceiver, CommittedOTSender, OTError, OTReceiver, OTSender, - OTSetup, RandomOTSender, -}; +type Error = SenderError; -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - Initialized(SenderCore), - Extension(SenderCore), - Complete, +#[derive(Debug)] +enum State { + Initialized { + base_ot: BaseOT, + sender: Core, + }, + Extension(Core), Error, } +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::Error) + } +} + /// KOS sender. #[derive(Debug)] pub struct Sender { - state: State, - base: BaseOT, - alloc: usize, - cointoss_sender: Option>, + state: State, } -impl Sender { +impl Sender { /// Creates a new Sender /// /// # Arguments /// - /// * `config` - The Sender's configuration - pub fn new(config: SenderConfig, base: BaseOT) -> Self { + /// * `config` - The Sender's configuration. + /// * `delta` - Global COT correlation. + /// * `base_ot` - Base OT. + pub fn new(config: SenderConfig, delta: Block, base_ot: BaseOT) -> Self { Self { - state: State::Initialized(SenderCore::new(config)), - base, - alloc: 0, - cointoss_sender: None, + state: State::Initialized { + base_ot, + sender: Core::new(config, delta), + }, } } +} - /// The number of remaining OTs which can be consumed. - pub fn remaining(&self) -> Result { - Ok(self.state.try_as_extension()?.remaining()) - } +impl RCOTSender for Sender { + type Error = Error; + type Future = MaybeDone>; - /// Returns the provided number of keys. - pub(crate) fn take_keys(&mut self, count: usize) -> Result { - self.state - .try_as_extension_mut()? - .keys(count) - .map_err(SenderError::from) + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + match &mut self.state { + State::Initialized { sender, .. } => sender.alloc(count).map_err(Error::from), + State::Extension(sender) => sender.alloc(count).map_err(Error::from), + State::Error => Err(Error::state("can not allocate, sender in error state")), + } } - /// Performs the base OT setup with the provided delta. - /// - /// # Arguments - /// - /// * `sink` - The sink to send messages to the base OT sender - /// * `stream` - The stream to receive messages from the base OT sender - /// * `delta` - The delta value to use for the base OT setup. - pub async fn setup_with_delta( - &mut self, - ctx: &mut Ctx, - delta: Block, - ) -> Result<(), SenderError> - where - BaseOT: OTReceiver, - { - if self.state.try_as_initialized()?.config().sender_commit() { - return Err(SenderError::ConfigError( - "committed sender can not choose delta".to_string(), - )); + fn available(&self) -> usize { + match &self.state { + State::Initialized { .. } | State::Error => 0, + State::Extension(sender) => sender.available(), } - - self._setup_with_delta(ctx, delta).await } - async fn _setup_with_delta( - &mut self, - ctx: &mut Ctx, - delta: Block, - ) -> Result<(), SenderError> - where - BaseOT: OTReceiver, - { - let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let choices = delta.into_lsb0_vec(); - let base_output = self.base.receive(ctx, &choices).await?; - - let seeds: [Block; CSP] = base_output - .msgs - .try_into() - .expect("seeds should be CSP length"); - - let ext_sender = ext_sender.setup(delta, seeds); - - self.state = State::Extension(ext_sender); - - Ok(()) + fn delta(&self) -> Block { + match &self.state { + State::Initialized { sender, .. } => sender.delta(), + State::Extension(sender) => sender.delta(), + State::Error => panic!("sender left in error state"), + } } - /// Performs OT extension. - /// - /// # Arguments - /// - /// * `channel` - The channel to communicate with the receiver. - /// * `count` - The number of OTs to extend. - pub async fn extend( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result<(), SenderError> { - let mut ext_sender = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - let count = pad_ot_count(count); - - let StartExtend { - count: receiver_count, - } = ctx.io_mut().expect_next().await?; - - if count != receiver_count { - return Err(SenderError::ConfigError( - "sender and receiver count mismatch".to_string(), - )); + fn try_send_rcot(&mut self, count: usize) -> Result, Self::Error> { + match &mut self.state { + State::Initialized { sender, .. } => sender.try_send_rcot(count).map_err(Error::from), + State::Extension(sender) => sender.try_send_rcot(count).map_err(Error::from), + State::Error => Err(Error::state("can not send, sender in error state")), } + } - let expected_us = extension_matrix_size(count); - let mut extend = Extend { - us: Vec::with_capacity(expected_us), - }; + fn queue_send_rcot(&mut self, count: usize) -> Result { + match &mut self.state { + State::Initialized { sender, .. } => sender.queue_send_rcot(count).map_err(Error::from), + State::Extension(sender) => sender.queue_send_rcot(count).map_err(Error::from), + State::Error => Err(Error::state("can not queue, sender in error state")), + } + } +} - // Receive extension matrix from the receiver. - while extend.us.len() < expected_us { - let Extend { us: chunk } = ctx.io_mut().expect_next().await?; +#[async_trait] +impl Flush for Sender +where + Ctx: Context, + BaseOT: OTReceiver + Flush + Send, + BaseOT::Future: Send, +{ + type Error = Error; - extend.us.extend(chunk); + fn wants_flush(&self) -> bool { + match &self.state { + State::Initialized { .. } => true, + State::Extension(sender) => sender.wants_extend(), + State::Error => false, } + } - // Extend the OTs. - let mut ext_sender = - Backend::spawn(move || ext_sender.extend(count, extend).map(|_| ext_sender)).await?; + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let mut sender = match self.state.take() { + State::Initialized { + mut base_ot, + sender, + } => { + let choices = sender.delta().into_lsb0_vec(); + let seeds = base_ot.queue_recv_ot(&choices).map_err(Error::base_ot)?; + base_ot.flush(ctx).await.map_err(Error::base_ot)?; - // Sample chi_seed with coin-toss. - let seed: Block = thread_rng().gen(); - let chi_seed = cointoss::cointoss_receiver(ctx, vec![seed]).await?[0]; + let OTReceiverOutput { msgs: seeds, .. } = seeds.await.map_err(Error::base_ot)?; - // Receive the receiver's check. - let receiver_check = ctx.io_mut().expect_next().await?; + let seeds = seeds.try_into().expect("seeds should be 128 long"); - // Check consistency of extension. - let ext_sender = Backend::spawn(move || { - ext_sender - .check(chi_seed, receiver_check) - .map(|_| ext_sender) - }) - .await?; + sender.setup(seeds) + } + State::Extension(sender) => sender, + State::Error => return Err(Error::state("can not flush, sender in error state")), + }; - self.state = State::Extension(ext_sender); + if !sender.wants_extend() { + self.state = State::Extension(sender); + return Ok(()); + } - Ok(()) - } -} + let sender = ctx + .blocking(scoped!(move |ctx| { + while sender.wants_extend() { + let extend = ctx.io_mut().expect_next().await?; + sender.extend(extend)?; + } -impl Sender { - pub(crate) async fn reveal(&mut self, ctx: &mut Ctx) -> Result<(), SenderError> - where - BaseOT: CommittedOTReceiver, - { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + let seed = thread_rng().gen(); - // Reveal coin toss payload - let Some(sender) = self.cointoss_sender.take() else { - return Err(SenderError::ConfigError( - "committed sender not configured".to_string(), - ))?; - }; + // See issue #176. + let chi_seed = cointoss_receiver(ctx, vec![seed]).await?[0]; + + let receiver_check = ctx.io_mut().expect_next().await?; - sender.finalize(ctx).await.map_err(SenderError::from)?; + sender.check(chi_seed, receiver_check)?; - // Reveal base OT choices - self.base.reveal_choices(ctx).await?; + Ok::<_, Error>(sender) + })) + .await??; - // This sender is no longer usable, so mark it as complete. - self.state = State::Complete; + self.state = State::Extension(sender); Ok(()) } } -#[async_trait] -impl OTSetup for Sender -where - Ctx: Context, - BaseOT: OTSetup + OTReceiver + Send + 'static, -{ - async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - if self.state.is_extension() { - return Ok(()); - } - - let sender = std::mem::replace(&mut self.state, State::Error) - .try_into_initialized() - .map_err(SenderError::from)?; - - // If the sender is committed, we sample delta using a coin toss. - let delta = if sender.config().sender_commit() { - let cointoss_seed = thread_rng().gen(); - - // Execute coin-toss protocol and base OT setup concurrently. - let ((seeds, cointoss_sender), _) = try_join!( - ctx, - async { - cointoss::Sender::new(vec![cointoss_seed]) - .commit(ctx) - .await? - .receive(ctx) - .await - .map_err(SenderError::from) - }, - self.base.setup(ctx).map_err(SenderError::from) - )??; - - // Store the sender to finalize the cointoss protocol later. - self.cointoss_sender = Some(cointoss_sender); - - seeds[0] - } else { - self.base.setup(ctx).await?; - Block::random(&mut thread_rng()) - }; - - self.state = State::Initialized(sender); +/// Error for [`Sender`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct SenderError(#[from] ErrorRepr); - self._setup_with_delta(ctx, delta) - .await - .map_err(OTError::from) +impl SenderError { + fn base_ot(err: E) -> Self + where + E: Into>, + { + Self(ErrorRepr::BaseOT(err.into())) } -} -impl Allocate for Sender { - fn alloc(&mut self, count: usize) { - self.alloc += count; + fn state(msg: impl Into) -> Self { + Self(ErrorRepr::State(msg.into())) } } -#[async_trait] -impl Preprocess for Sender -where - Ctx: Context, - BaseOT: OTSetup + OTReceiver + Send + 'static, -{ - type Error = OTError; - - async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - if self.state.is_initialized() { - self.setup(ctx).await?; - } - - let count = mem::take(&mut self.alloc); - if count == 0 { - return Ok(()); - } - - self.extend(ctx, count).await.map_err(OTError::from) - } +#[derive(Debug, thiserror::Error)] +enum ErrorRepr { + #[error("core error: {0}")] + Core(#[from] CoreError), + #[error("base OT error: {0}")] + BaseOT(Box), + #[error("cointoss error: {0}")] + Cointoss(#[from] cointoss::CointossError), + #[error("state error: {0}")] + State(String), + #[error("context error: {0}")] + Context(#[from] ContextError), + #[error("io error: {0}")] + Io(#[from] std::io::Error), } -#[async_trait] -impl OTSender for Sender -where - Ctx: Context, - BaseOT: Send, -{ - async fn send( - &mut self, - ctx: &mut Ctx, - msgs: &[[Block; 2]], - ) -> Result { - let sender = self - .state - .try_as_extension_mut() - .map_err(SenderError::from)?; - - let derandomize = ctx.io_mut().expect_next().await?; - - let mut sender_keys = sender.keys(msgs.len()).map_err(SenderError::from)?; - sender_keys - .derandomize(derandomize) - .map_err(SenderError::from)?; - let payload = sender_keys - .encrypt_blocks(msgs) - .map_err(SenderError::from)?; - let id = payload.id; - - ctx.io_mut() - .send(payload) - .await - .map_err(SenderError::from)?; - - Ok(OTSenderOutput { id }) +impl From for SenderError { + fn from(err: CoreError) -> Self { + Self(ErrorRepr::Core(err)) } } -#[async_trait] -impl OTSender for Sender -where - Ctx: Context, - BaseOT: Send, -{ - async fn send( - &mut self, - ctx: &mut Ctx, - msgs: &[[[u8; N]; 2]], - ) -> Result { - let sender = self - .state - .try_as_extension_mut() - .map_err(SenderError::from)?; - - let derandomize = ctx.io_mut().expect_next().await?; - - let mut sender_keys = sender.keys(msgs.len()).map_err(SenderError::from)?; - sender_keys - .derandomize(derandomize) - .map_err(SenderError::from)?; - let payload = sender_keys.encrypt_bytes(msgs).map_err(SenderError::from)?; - let id = payload.id; - - ctx.io_mut() - .send(payload) - .await - .map_err(SenderError::from)?; - - Ok(OTSenderOutput { id }) +impl From for SenderError { + fn from(err: cointoss::CointossError) -> Self { + Self(ErrorRepr::Cointoss(err)) } } -#[async_trait] -impl RandomOTSender for Sender -where - Ctx: Context, - Standard: Distribution, - BaseOT: Send, -{ - async fn send_random( - &mut self, - _ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - let sender = self - .state - .try_as_extension_mut() - .map_err(SenderError::from)?; - - let keys = sender.keys(count).map_err(SenderError::from)?; - let id = keys.id(); - - let msgs = keys - .take_keys() - .into_iter() - .map(|[k0, k1]| { - let mut prg_0 = Prg::from_seed(k0); - let mut prg_1 = Prg::from_seed(k1); - - [prg_0.gen::(), prg_1.gen::()] - }) - .collect(); - - Ok(ROTSenderOutput { id, msgs }) +impl From for SenderError { + fn from(err: ContextError) -> Self { + Self(ErrorRepr::Context(err)) } } -#[async_trait] -impl CommittedOTSender for Sender -where - Ctx: Context, - BaseOT: CommittedOTReceiver + Send, -{ - async fn reveal(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - self.reveal(ctx).await.map_err(OTError::from) +impl From for SenderError { + fn from(err: std::io::Error) -> Self { + Self(ErrorRepr::Io(err)) } } diff --git a/crates/mpz-ot/src/kos/shared_receiver.rs b/crates/mpz-ot/src/kos/shared_receiver.rs deleted file mode 100644 index df9a47e5..00000000 --- a/crates/mpz-ot/src/kos/shared_receiver.rs +++ /dev/null @@ -1,139 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; -use itybity::IntoBitIterator; -use mpz_common::{sync::AsyncMutex, Allocate, Context, Preprocess}; -use mpz_core::Block; -use mpz_ot_core::{kos::msgs::SenderPayload, OTReceiverOutput, ROTReceiverOutput, TransferId}; -use rand::distributions::{Distribution, Standard}; -use serio::{stream::IoStreamExt, SinkExt}; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -use crate::{ - kos::{Receiver, ReceiverError}, - OTError, OTReceiver, OTSender, OTSetup, RandomOTReceiver, VerifiableOTReceiver, - VerifiableOTSender, -}; - -/// A shared KOS receiver. -#[derive(Debug)] -pub struct SharedReceiver { - inner: Arc>>, -} - -impl Clone for SharedReceiver { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } -} - -impl SharedReceiver { - /// Creates a new shared receiver. - pub fn new(receiver: Receiver) -> Self { - Self { - // KOS receiver is always the leader. - inner: Arc::new(AsyncMutex::new_leader(receiver)), - } - } -} - -impl Allocate for SharedReceiver { - fn alloc(&mut self, count: usize) { - self.inner.blocking_lock_unsync().alloc(count); - } -} - -#[async_trait] -impl Preprocess for SharedReceiver -where - Ctx: Context, - BaseOT: OTSetup + OTSender + Send, -{ - type Error = OTError; - - async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - self.inner.lock(ctx).await?.preprocess(ctx).await - } -} - -#[async_trait] -impl OTReceiver for SharedReceiver -where - Ctx: Context, - BaseOT: Send, -{ - async fn receive( - &mut self, - ctx: &mut Ctx, - choices: &[bool], - ) -> Result, OTError> { - let mut keys = self.inner.lock(ctx).await?.take_keys(choices.len())?; - - let choices = choices.into_lsb0_vec(); - let derandomize = keys.derandomize(&choices).map_err(ReceiverError::from)?; - - // Send derandomize message - ctx.io_mut().send(derandomize).await?; - - // Receive payload - let payload: SenderPayload = ctx.io_mut().expect_next().await?; - let id = payload.id; - - let msgs = - Backend::spawn(move || keys.decrypt_blocks(payload).map_err(ReceiverError::from)) - .await?; - - Ok(OTReceiverOutput { id, msgs }) - } -} - -#[async_trait] -impl RandomOTReceiver for SharedReceiver -where - Ctx: Context, - Standard: Distribution, - BaseOT: Send, -{ - async fn receive_random( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - self.inner.lock(ctx).await?.receive_random(ctx, count).await - } -} - -#[async_trait] -impl VerifiableOTReceiver for SharedReceiver -where - Ctx: Context, - BaseOT: VerifiableOTSender + Send, -{ - async fn accept_reveal(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - self.inner.lock(ctx).await?.accept_reveal(ctx).await - } - - async fn verify( - &mut self, - _ctx: &mut Ctx, - id: TransferId, - msgs: &[[Block; 2]], - ) -> Result<(), OTError> { - let record = { - let inner = self.inner.blocking_lock_unsync(); - - let receiver = inner.state().try_as_verify().map_err(ReceiverError::from)?; - - receiver.remove_record(id).map_err(ReceiverError::from)? - }; - - let msgs = msgs.to_vec(); - Backend::spawn(move || record.verify(&msgs)) - .await - .map_err(ReceiverError::from)?; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/kos/shared_sender.rs b/crates/mpz-ot/src/kos/shared_sender.rs deleted file mode 100644 index 19bc7a97..00000000 --- a/crates/mpz-ot/src/kos/shared_sender.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; - -use mpz_common::{sync::AsyncMutex, Allocate, Context, Preprocess}; -use mpz_core::Block; -use rand::distributions::{Distribution, Standard}; -use serio::{stream::IoStreamExt as _, SinkExt as _}; - -use crate::{ - kos::{Sender, SenderError}, - CommittedOTReceiver, CommittedOTSender, OTError, OTReceiver, OTSender, OTSenderOutput, OTSetup, - ROTSenderOutput, RandomOTSender, -}; - -/// A shared KOS sender. -#[derive(Debug)] -pub struct SharedSender { - inner: Arc>>, -} - -impl Clone for SharedSender { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } -} - -impl SharedSender { - /// Creates a new shared sender. - pub fn new(sender: Sender) -> Self { - Self { - // KOS sender is always the follower. - inner: Arc::new(AsyncMutex::new_follower(sender)), - } - } -} - -impl Allocate for SharedSender { - fn alloc(&mut self, count: usize) { - self.inner.blocking_lock_unsync().alloc(count); - } -} - -#[async_trait] -impl Preprocess for SharedSender -where - Ctx: Context, - BaseOT: OTSetup + OTReceiver + Send + 'static, -{ - type Error = OTError; - - async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - self.inner.lock(ctx).await?.preprocess(ctx).await - } -} - -#[async_trait] -impl OTSender for SharedSender -where - Ctx: Context, - BaseOT: OTReceiver + Send + 'static, -{ - async fn send( - &mut self, - ctx: &mut Ctx, - msgs: &[[Block; 2]], - ) -> Result { - let mut keys = self.inner.lock(ctx).await?.take_keys(msgs.len())?; - - let derandomize = ctx.io_mut().expect_next().await?; - - keys.derandomize(derandomize).map_err(SenderError::from)?; - let payload = keys.encrypt_blocks(msgs).map_err(SenderError::from)?; - let id = payload.id; - - ctx.io_mut() - .send(payload) - .await - .map_err(SenderError::from)?; - - Ok(OTSenderOutput { id }) - } -} - -#[async_trait] -impl RandomOTSender for SharedSender -where - Ctx: Context, - Standard: Distribution, - BaseOT: Send, -{ - async fn send_random( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError> { - self.inner.lock(ctx).await?.send_random(ctx, count).await - } -} - -#[async_trait] -impl CommittedOTSender for SharedSender -where - Ctx: Context, - BaseOT: CommittedOTReceiver + Send + 'static, -{ - async fn reveal(&mut self, ctx: &mut Ctx) -> Result<(), OTError> { - self.inner - .lock(ctx) - .await? - .reveal(ctx) - .await - .map_err(OTError::from) - } -} diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index b9871eab..11887b0c 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -1,4 +1,4 @@ -//! Implementations of oblivious transfer protocols. +//! Oblivious transfer protocols. #![deny( unsafe_code, @@ -10,233 +10,14 @@ )] pub mod chou_orlandi; +pub mod cot; #[cfg(any(test, feature = "ideal"))] pub mod ideal; pub mod kos; +pub mod ot; +pub mod rcot; +pub mod rot; +#[cfg(any(test, feature = "test-utils"))] +pub mod test; -use async_trait::async_trait; - -pub use mpz_ot_core::{ - COTReceiverOutput, COTSenderOutput, OTReceiverOutput, OTSenderOutput, RCOTReceiverOutput, - RCOTSenderOutput, ROTReceiverOutput, ROTSenderOutput, TransferId, -}; - -/// An oblivious transfer error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum OTError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error("context error: {0}")] - Context(#[from] mpz_common::ContextError), - #[error("mutex error: {0}")] - Mutex(#[from] mpz_common::sync::MutexError), - #[error("sender error: {0}")] - SenderError(Box), - #[error("receiver error: {0}")] - ReceiverError(Box), -} - -/// An oblivious transfer protocol that needs to perform a one-time setup. -#[async_trait] -pub trait OTSetup { - /// Runs any one-time setup for the protocol. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), OTError>; -} - -/// An oblivious transfer sender. -#[async_trait] -pub trait OTSender { - /// Obliviously transfers the messages to the receiver. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `msgs` - The messages to obliviously transfer. - async fn send(&mut self, ctx: &mut Ctx, msgs: &[T]) -> Result; -} - -/// A correlated oblivious transfer sender. -#[async_trait] -pub trait COTSender { - /// Obliviously transfers the correlated messages to the receiver. - /// - /// Returns the `0`-bit messages that were obliviously transferred. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `count` - The number of correlated messages to obliviously transfer. - async fn send_correlated( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError>; -} - -/// A random OT sender. -#[async_trait] -pub trait RandomOTSender { - /// Outputs pairs of random messages. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `count` - The number of pairs of random messages to output. - async fn send_random( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError>; -} - -/// A random correlated oblivious transfer sender. -#[async_trait] -pub trait RandomCOTSender { - /// Obliviously transfers the correlated messages to the receiver. - /// - /// Returns the `0`-bit messages that were obliviously transferred. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `count` - The number of correlated messages to obliviously transfer. - async fn send_random_correlated( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError>; -} - -/// An oblivious transfer receiver. -#[async_trait] -pub trait OTReceiver { - /// Obliviously receives data from the sender. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `choices` - The choices made by the receiver. - async fn receive( - &mut self, - ctx: &mut Ctx, - choices: &[T], - ) -> Result, OTError>; -} - -/// A correlated oblivious transfer receiver. -#[async_trait] -pub trait COTReceiver { - /// Obliviously receives correlated messages from the sender. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `choices` - The choices made by the receiver. - async fn receive_correlated( - &mut self, - ctx: &mut Ctx, - choices: &[T], - ) -> Result, OTError>; -} - -/// A random OT receiver. -#[async_trait] -pub trait RandomOTReceiver { - /// Outputs the choice bits and the corresponding messages. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `count` - The number of random messages to receive. - async fn receive_random( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError>; -} - -/// A random correlated oblivious transfer receiver. -#[async_trait] -pub trait RandomCOTReceiver { - /// Obliviously receives correlated messages with random choices. - /// - /// Returns a tuple of the choices and the messages, respectively. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `count` - The number of correlated messages to obliviously receive. - async fn receive_random_correlated( - &mut self, - ctx: &mut Ctx, - count: usize, - ) -> Result, OTError>; -} - -/// An oblivious transfer sender that is committed to its messages and can reveal them -/// to the receiver to verify them. -#[async_trait] -pub trait CommittedOTSender: OTSender { - /// Reveals all messages sent to the receiver. - /// - /// # Warning - /// - /// Obviously, you should be sure you want to do this before calling this function! - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - async fn reveal(&mut self, ctx: &mut Ctx) -> Result<(), OTError>; -} - -/// An oblivious transfer sender that can verify the receiver's choices. -#[async_trait] -pub trait VerifiableOTSender: OTSender { - /// Receives the purported choices made by the receiver and verifies them. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - async fn verify_choices(&mut self, ctx: &mut Ctx) -> Result, OTError>; -} - -/// An oblivious transfer receiver that is committed to its choices and can reveal them -/// to the sender to verify them. -#[async_trait] -pub trait CommittedOTReceiver: OTReceiver { - /// Reveals the choices made by the receiver. - /// - /// # Warning - /// - /// Obviously, you should be sure you want to do this before calling this function! - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - async fn reveal_choices(&mut self, ctx: &mut Ctx) -> Result<(), OTError>; -} - -/// An oblivious transfer receiver that can verify the sender's messages. -#[async_trait] -pub trait VerifiableOTReceiver: OTReceiver { - /// Accepts revealed secrets from the sender which are requried to verify previous messages. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - async fn accept_reveal(&mut self, ctx: &mut Ctx) -> Result<(), OTError>; - - /// Verifies purported messages sent by the sender. - /// - /// # Arguments - /// - /// * `ctx` - The thread context. - /// * `id` - The transfer id of the messages to verify. - /// * `msgs` - The purported messages sent by the sender. - async fn verify(&mut self, ctx: &mut Ctx, id: TransferId, msgs: &[V]) -> Result<(), OTError>; -} +pub use mpz_ot_core::TransferId; diff --git a/crates/mpz-ot/src/ot.rs b/crates/mpz-ot/src/ot.rs new file mode 100644 index 00000000..3fb1bc7d --- /dev/null +++ b/crates/mpz-ot/src/ot.rs @@ -0,0 +1,3 @@ +//! Chosen-message OT. + +pub use mpz_ot_core::ot::{OTReceiver, OTReceiverOutput, OTSender, OTSenderOutput}; diff --git a/crates/mpz-ot/src/rcot.rs b/crates/mpz-ot/src/rcot.rs new file mode 100644 index 00000000..e7d6ffe8 --- /dev/null +++ b/crates/mpz-ot/src/rcot.rs @@ -0,0 +1,3 @@ +//! Random correlated OT. + +pub use mpz_ot_core::rcot::{RCOTReceiver, RCOTReceiverOutput, RCOTSender, RCOTSenderOutput}; \ No newline at end of file diff --git a/crates/mpz-ot/src/rot.rs b/crates/mpz-ot/src/rot.rs new file mode 100644 index 00000000..2ad09bfe --- /dev/null +++ b/crates/mpz-ot/src/rot.rs @@ -0,0 +1,6 @@ +//! Random OT. + +pub mod any; +pub mod randomize; + +pub use mpz_ot_core::rot::{ROTReceiver, ROTReceiverOutput, ROTSender, ROTSenderOutput}; diff --git a/crates/mpz-ot/src/rot/any.rs b/crates/mpz-ot/src/rot/any.rs new file mode 100644 index 00000000..650fb589 --- /dev/null +++ b/crates/mpz-ot/src/rot/any.rs @@ -0,0 +1,33 @@ +//! Adapter for using any type as the message type in a ROT protocol. + +mod receiver; +mod sender; + +pub use receiver::AnyReceiver; +pub use sender::AnySender; + +#[cfg(test)] +mod tests { + use rand::{distributions::Standard, prelude::Distribution, rngs::StdRng, Rng, SeedableRng}; + + use super::*; + use crate::{ideal::rot::ideal_rot, test::test_rot}; + + #[derive(Clone, Copy, PartialEq)] + struct Foo { + foo: [u8; 32], + } + + impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> Foo { + Foo { foo: rng.gen() } + } + } + + #[tokio::test] + async fn test_any_rot() { + let mut rng = StdRng::seed_from_u64(0); + let (sender, receiver) = ideal_rot(rng.gen()); + test_rot::<_, _, Foo>(AnySender::new(sender), AnyReceiver::new(receiver)).await + } +} diff --git a/crates/mpz-ot/src/rot/any/receiver.rs b/crates/mpz-ot/src/rot/any/receiver.rs new file mode 100644 index 00000000..52ed8ffd --- /dev/null +++ b/crates/mpz-ot/src/rot/any/receiver.rs @@ -0,0 +1,67 @@ +use async_trait::async_trait; +use mpz_common::{Context, Flush}; +use mpz_core::Block; +use mpz_ot_core::rot::{AnyReceiver as Core, ROTReceiver, ROTReceiverOutput}; +use rand::{distributions::Standard, prelude::Distribution}; + +/// A ROT receiver which recvs any type implementing `rand` traits. +#[derive(Debug)] +pub struct AnyReceiver { + core: Core, +} + +impl AnyReceiver { + /// Creates a new `AnyReceiver`. + pub fn new(rot: T) -> Self { + Self { + core: Core::new(rot), + } + } + + /// Returns the inner receiver. + pub fn into_inner(self) -> T { + self.core.into_inner() + } +} + +impl ROTReceiver for AnyReceiver +where + T: ROTReceiver, + Standard: Distribution, +{ + type Error = T::Error; + type Future = as ROTReceiver>::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.core.alloc(count) + } + + fn available(&self) -> usize { + self.core.available() + } + + fn try_recv_rot(&mut self, count: usize) -> Result, Self::Error> { + self.core.try_recv_rot(count) + } + + fn queue_recv_rot(&mut self, count: usize) -> Result { + self.core.queue_recv_rot(count) + } +} + +#[async_trait] +impl Flush for AnyReceiver +where + Ctx: Context, + T: Flush + Send, +{ + type Error = T::Error; + + fn wants_flush(&self) -> bool { + self.core.rot().wants_flush() + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + self.core.rot_mut().flush(ctx).await + } +} diff --git a/crates/mpz-ot/src/rot/any/sender.rs b/crates/mpz-ot/src/rot/any/sender.rs new file mode 100644 index 00000000..13cf880e --- /dev/null +++ b/crates/mpz-ot/src/rot/any/sender.rs @@ -0,0 +1,67 @@ +use async_trait::async_trait; +use mpz_common::{Context, Flush}; +use mpz_core::Block; +use mpz_ot_core::rot::{AnySender as Core, ROTSender, ROTSenderOutput}; +use rand::{distributions::Standard, prelude::Distribution}; + +/// A ROT sender which sends any type implementing `rand` traits. +#[derive(Debug)] +pub struct AnySender { + core: Core, +} + +impl AnySender { + /// Creates a new `AnySender`. + pub fn new(rot: T) -> Self { + Self { + core: Core::new(rot), + } + } + + /// Returns the inner sender. + pub fn into_inner(self) -> T { + self.core.into_inner() + } +} + +impl ROTSender<[U; 2]> for AnySender +where + T: ROTSender<[Block; 2]>, + Standard: Distribution, +{ + type Error = T::Error; + type Future = as ROTSender<[U; 2]>>::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.core.alloc(count) + } + + fn available(&self) -> usize { + self.core.available() + } + + fn try_send_rot(&mut self, count: usize) -> Result, Self::Error> { + self.core.try_send_rot(count) + } + + fn queue_send_rot(&mut self, count: usize) -> Result { + self.core.queue_send_rot(count) + } +} + +#[async_trait] +impl Flush for AnySender +where + Ctx: Context, + T: Flush + Send, +{ + type Error = T::Error; + + fn wants_flush(&self) -> bool { + self.core.rot().wants_flush() + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + self.core.rot_mut().flush(ctx).await + } +} diff --git a/crates/mpz-ot/src/rot/randomize.rs b/crates/mpz-ot/src/rot/randomize.rs new file mode 100644 index 00000000..e74046cb --- /dev/null +++ b/crates/mpz-ot/src/rot/randomize.rs @@ -0,0 +1,27 @@ +//! Adapter to convert an RCOT protocol to ROT. + +mod receiver; +mod sender; + +pub use receiver::RandomizeRCOTReceiver; +pub use sender::RandomizeRCOTSender; + +#[cfg(test)] +mod tests { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use crate::{ideal::rcot::ideal_rcot, test::test_rot}; + + use super::*; + + #[tokio::test] + async fn test_randomize_rcot() { + let mut rng = StdRng::seed_from_u64(0); + let (sender, receiver) = ideal_rcot(rng.gen(), rng.gen()); + test_rot( + RandomizeRCOTSender::new(sender), + RandomizeRCOTReceiver::new(receiver), + ) + .await + } +} diff --git a/crates/mpz-ot/src/rot/randomize/receiver.rs b/crates/mpz-ot/src/rot/randomize/receiver.rs new file mode 100644 index 00000000..87d1800d --- /dev/null +++ b/crates/mpz-ot/src/rot/randomize/receiver.rs @@ -0,0 +1,71 @@ +use async_trait::async_trait; +use mpz_common::{Context, Flush}; +use mpz_core::Block; +use mpz_ot_core::{ + rcot::RCOTReceiver, + rot::{ROTReceiver, ROTReceiverOutput, RandomizeRCOTReceiver as Core}, +}; + +/// Randomize RCOT receiver. +#[derive(Debug)] +pub struct RandomizeRCOTReceiver { + core: Core, +} + +impl RandomizeRCOTReceiver { + /// Creates a new receiver. + pub fn new(rcot: T) -> Self { + Self { + core: Core::new(rcot), + } + } + + /// Returns the inner receiver. + pub fn into_inner(self) -> T { + self.core.into_inner() + } +} + +impl ROTReceiver for RandomizeRCOTReceiver +where + T: RCOTReceiver, +{ + type Error = as ROTReceiver>::Error; + type Future = as ROTReceiver>::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.core.alloc(count) + } + + fn available(&self) -> usize { + self.core.available() + } + + fn try_recv_rot( + &mut self, + count: usize, + ) -> Result, Self::Error> { + self.core.try_recv_rot(count) + } + + fn queue_recv_rot(&mut self, count: usize) -> Result { + self.core.queue_recv_rot(count) + } +} + +#[async_trait] +impl Flush for RandomizeRCOTReceiver +where + Ctx: Context, + T: Flush + Send, +{ + type Error = T::Error; + + fn wants_flush(&self) -> bool { + self.core.rcot().wants_flush() + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + self.core.rcot_mut().flush(ctx).await + } +} diff --git a/crates/mpz-ot/src/rot/randomize/sender.rs b/crates/mpz-ot/src/rot/randomize/sender.rs new file mode 100644 index 00000000..9c282355 --- /dev/null +++ b/crates/mpz-ot/src/rot/randomize/sender.rs @@ -0,0 +1,68 @@ +use async_trait::async_trait; +use mpz_common::{Context, Flush}; +use mpz_core::Block; +use mpz_ot_core::{ + rcot::RCOTSender, + rot::{ROTSender, ROTSenderOutput, RandomizeRCOTSender as Core}, +}; + +/// Randomize RCOT sender. +#[derive(Debug)] +pub struct RandomizeRCOTSender { + core: Core, +} + +impl RandomizeRCOTSender { + /// Creates a new sender. + pub fn new(rcot: T) -> Self { + Self { + core: Core::new(rcot), + } + } + + /// Returns the inner sender. + pub fn into_inner(self) -> T { + self.core.into_inner() + } +} + +impl ROTSender<[Block; 2]> for RandomizeRCOTSender +where + T: RCOTSender, +{ + type Error = as ROTSender<[Block; 2]>>::Error; + type Future = as ROTSender<[Block; 2]>>::Future; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.core.alloc(count) + } + + fn available(&self) -> usize { + self.core.available() + } + + fn try_send_rot(&mut self, count: usize) -> Result, Self::Error> { + self.core.try_send_rot(count) + } + + fn queue_send_rot(&mut self, count: usize) -> Result { + self.core.queue_send_rot(count) + } +} + +#[async_trait] +impl Flush for RandomizeRCOTSender +where + Ctx: Context, + T: Flush + Send, +{ + type Error = T::Error; + + fn wants_flush(&self) -> bool { + self.core.rcot().wants_flush() + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + self.core.rcot_mut().flush(ctx).await + } +} diff --git a/crates/mpz-ot/src/test.rs b/crates/mpz-ot/src/test.rs new file mode 100644 index 00000000..d352763b --- /dev/null +++ b/crates/mpz-ot/src/test.rs @@ -0,0 +1,152 @@ +use mpz_common::{ + executor::{test_st_executor, TestSTExecutor}, + Flush, +}; +use mpz_core::Block; +use mpz_ot_core::{ + cot::{COTReceiver, COTSender}, + ot::{OTReceiver, OTSender}, + rcot::{RCOTReceiver, RCOTReceiverOutput, RCOTSender, RCOTSenderOutput}, + rot::{ROTReceiver, ROTReceiverOutput, ROTSender, ROTSenderOutput}, + test::{assert_cot, assert_ot, assert_rot}, +}; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +/// Tests OT functionality. +pub async fn test_ot(mut sender: S, mut receiver: R) +where + S: OTSender + Flush, + R: OTReceiver + Flush, +{ + let (mut sender_ctx, mut receiver_ctx) = test_st_executor(8); + + let mut rng = StdRng::seed_from_u64(0); + let msgs = (0..128).map(|_| [rng.gen(), rng.gen()]).collect::>(); + let choices = (0..128).map(|_| rng.gen()).collect::>(); + + let (output_sender, output_receiver) = futures::join! { + async { + sender.alloc(msgs.len()).unwrap(); + let output = sender.queue_send_ot(&msgs).unwrap(); + sender.flush(&mut sender_ctx).await.unwrap(); + output.await.unwrap() + }, + async { + receiver.alloc(choices.len()).unwrap(); + let output = receiver.queue_recv_ot(&choices).unwrap(); + receiver.flush(&mut receiver_ctx).await.unwrap(); + output.await.unwrap() + } + }; + + assert_eq!(output_sender.id, output_receiver.id); + assert_ot(&choices, &msgs, &output_receiver.msgs); +} + +/// Tests RCOT functionality. +pub async fn test_rcot(mut sender: S, mut receiver: R) +where + S: RCOTSender + Flush, + R: RCOTReceiver + Flush, +{ + let (mut sender_ctx, mut receiver_ctx) = test_st_executor(8); + + let count = 128; + let ( + RCOTSenderOutput { + id: sender_id, + keys, + }, + RCOTReceiverOutput { + id: receiver_id, + choices, + msgs, + }, + ) = futures::join! { + async { + sender.alloc(count).unwrap(); + let output = sender.queue_send_rcot(count).unwrap(); + sender.flush(&mut sender_ctx).await.unwrap(); + output.await.unwrap() + }, + async { + receiver.alloc(count).unwrap(); + let output = receiver.queue_recv_rcot(count).unwrap(); + receiver.flush(&mut receiver_ctx).await.unwrap(); + output.await.unwrap() + } + }; + + assert_eq!(sender_id, receiver_id); + assert_cot(sender.delta(), &choices, &keys, &msgs); +} + +/// Tests COT functionality. +pub async fn test_cot(mut sender: S, mut receiver: R) +where + S: COTSender + Flush, + R: COTReceiver + Flush, +{ + let (mut sender_ctx, mut receiver_ctx) = test_st_executor(8); + + let mut rng = StdRng::seed_from_u64(0); + let keys = (0..128).map(|_| rng.gen()).collect::>(); + let choices = (0..128).map(|_| rng.gen()).collect::>(); + + let (output_sender, output_receiver) = futures::join! { + async { + sender.alloc(keys.len()).unwrap(); + let output = sender.queue_send_cot(&keys).unwrap(); + sender.flush(&mut sender_ctx).await.unwrap(); + output.await.unwrap() + }, + async { + receiver.alloc(choices.len()).unwrap(); + let output = receiver.queue_recv_cot(&choices).unwrap(); + receiver.flush(&mut receiver_ctx).await.unwrap(); + output.await.unwrap() + } + }; + + assert_eq!(output_sender.id, output_receiver.id); + assert_cot(sender.delta(), &choices, &keys, &output_receiver.msgs); +} + +/// Tests ROT functionality. +pub async fn test_rot(mut sender: S, mut receiver: R) +where + S: ROTSender<[T; 2]> + Flush, + R: ROTReceiver + Flush, + T: Copy + PartialEq, +{ + let (mut sender_ctx, mut receiver_ctx) = test_st_executor(8); + + let count = 128; + let ( + ROTSenderOutput { + id: sender_id, + keys, + }, + ROTReceiverOutput { + id: receiver_id, + choices, + msgs, + }, + ) = futures::join! { + async { + sender.alloc(count).unwrap(); + let output = sender.queue_send_rot(count).unwrap(); + sender.flush(&mut sender_ctx).await.unwrap(); + output.await.unwrap() + }, + async { + receiver.alloc(count).unwrap(); + let output = receiver.queue_recv_rot(count).unwrap(); + receiver.flush(&mut receiver_ctx).await.unwrap(); + output.await.unwrap() + } + }; + + assert_eq!(sender_id, receiver_id); + assert_rot(&choices, &keys, &msgs); +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..47a41249 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +imports_granularity = "Crate" +wrap_comments = true