Skip to content

Commit

Permalink
refactor: chunk KOS extend message (#98)
Browse files Browse the repository at this point in the history
* refactor: chunk KOS extend message

* clippy

* fix chunking

* Update ot/mpz-ot-core/src/kos/msgs.rs

Co-authored-by: dan <[email protected]>

* clippy

* pad before rounding

* update comment

* require extend count to be a multiple of 64

---------

Co-authored-by: dan <[email protected]>
  • Loading branch information
sinui0 and themighty1 authored Feb 13, 2024
1 parent d6848d9 commit 850636f
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 36 deletions.
4 changes: 4 additions & 0 deletions ot/mpz-ot-core/src/kos/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
pub enum SenderError {
#[error("invalid state: expected {0}")]
InvalidState(String),
#[error("invalid count, must be a multiple of 64: {0}")]
InvalidCount(usize),
#[error("count mismatch: expected {0}, got {1}")]
CountMismatch(usize, usize),
#[error("id mismatch: expected {0}, got {1}")]
Expand All @@ -22,6 +24,8 @@ pub enum SenderError {
pub enum ReceiverError {
#[error("invalid state: expected {0}")]
InvalidState(String),
#[error("invalid count, must be a multiple of 64: {0}")]
InvalidCount(usize),
#[error("count mismatch: expected {0}, got {1}")]
CountMismatch(usize, usize),
#[error("id mismatch: expected {0}, got {1}")]
Expand Down
14 changes: 14 additions & 0 deletions ot/mpz-ot-core/src/kos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ pub(crate) type RngSeed = <Rng as SeedableRng>::Seed;
/// AES-128 CTR used for encryption.
pub(crate) type Aes128Ctr = ctr::Ctr64LE<aes::Aes128>;

/// Pads the number of OTs to accomodate for the KOS extension check and
/// the extension matrix transpose optimization.
pub fn pad_ot_count(mut count: usize) -> usize {
// Add OTs for the KOS extension check.
count += CSP + SSP;
// Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization).
(count + 63) & !63
}

/// Returns the size in bytes of the extension matrix for a given number of OTs.
pub fn extension_matrix_size(count: usize) -> usize {
count * CSP / 8
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
40 changes: 38 additions & 2 deletions ot/mpz-ot-core/src/kos/msgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::msgs::Derandomize;
#[allow(missing_docs)]
pub enum Message<BaseMsg> {
BaseMsg(BaseMsg),
StartExtend(StartExtend),
Extend(Extend),
Check(Check),
Derandomize(Derandomize),
Expand All @@ -33,15 +34,50 @@ impl<BaseMsg> From<MessageError<BaseMsg>> for std::io::Error {
}
}

/// Extension message sent by the receiver.
/// Extension message sent by the receiver to agree upon the number of OTs to set up.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Extend {
pub struct StartExtend {
/// The number of OTs to set up.
pub count: usize,
}

/// Extension message sent by the receiver.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Extend {
/// The receiver's extension vectors.
pub us: Vec<u8>,
}

impl Extend {
/// Returns an iterator over the chunks of the message.
pub fn into_chunks(self, chunk_size: usize) -> ExtendChunks {
ExtendChunks {
chunk_size,
us: self.us.into_iter(),
}
}
}

/// Iterator over the chunks of an extension message.
pub struct ExtendChunks {
chunk_size: usize,
us: <Vec<u8> as IntoIterator>::IntoIter,
}

impl Iterator for ExtendChunks {
type Item = Extend;

fn next(&mut self) -> Option<Self::Item> {
if self.us.len() == 0 {
None
} else {
Some(Extend {
us: self.us.by_ref().take(self.chunk_size).collect::<Vec<_>>(),
})
}
}
}

/// Values for the correlation check sent by the receiver.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
Expand Down
11 changes: 7 additions & 4 deletions ot/mpz-ot-core/src/kos/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ impl Receiver<state::Extension> {

/// Perform the IKNP OT extension.
///
/// The provided count _must_ be a multiple of 64, otherwise an error will be returned.
///
/// # Sacrificial OTs
///
/// Performing the consistency check sacrifices 256 OTs, so be sure to
Expand All @@ -132,16 +134,17 @@ impl Receiver<state::Extension> {
///
/// # Arguments
///
/// * `count` - The number of OTs to extend.
/// * `count` - The number of OTs to extend (must be a multiple of 64).
pub fn extend(&mut self, count: usize) -> Result<Extend, ReceiverError> {
if self.state.extended {
return Err(ReceiverError::InvalidState(
"extending more than once is currently disabled".to_string(),
));
}

// Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization).
let count = (count + 63) & !63;
if count % 64 != 0 {
return Err(ReceiverError::InvalidCount(count));
}

const NROWS: usize = CSP;
let row_width = count / 8;
Expand Down Expand Up @@ -196,7 +199,7 @@ impl Receiver<state::Extension> {
);
self.state.unchecked_choices.extend(choices);

Ok(Extend { count, us })
Ok(Extend { us })
}

/// Performs the correlation check for all outstanding OTS.
Expand Down
24 changes: 10 additions & 14 deletions ot/mpz-ot-core/src/kos/sender.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{
kos::{
extension_matrix_size,
msgs::{Check, Ciphertexts, Extend, SenderPayload},
Aes128Ctr, Rng, RngSeed, SenderConfig, SenderError, CSP, SSP,
},
Expand Down Expand Up @@ -96,6 +97,8 @@ impl Sender<state::Extension> {

/// Perform the IKNP OT extension.
///
/// The provided count _must_ be a multiple of 64, otherwise an error will be returned.
///
/// # Sacrificial OTs
///
/// Performing the consistency check sacrifices 256 OTs, so be sure to extend enough to
Expand All @@ -111,32 +114,25 @@ impl Sender<state::Extension> {
///
/// # Arguments
///
/// * `count` - The number of additional OTs to extend
/// * `extend` - The receiver's setup message
/// * `count` - The number of additional OTs to extend (must be a multiple of 64).
/// * `extend` - The receiver's setup message.
pub fn extend(&mut self, count: usize, extend: Extend) -> Result<(), SenderError> {
if self.state.extended {
return Err(SenderError::InvalidState(
"extending more than once is currently disabled".to_string(),
));
}

// Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization).
let count = (count + 63) & !63;
if count % 64 != 0 {
return Err(SenderError::InvalidCount(count));
}

const NROWS: usize = CSP;
let row_width = count / 8;

let Extend {
us,
count: receiver_count,
} = extend;

// Make sure the number of OTs to extend matches the receiver's setup message.
if receiver_count != count {
return Err(SenderError::CountMismatch(receiver_count, count));
}
let Extend { us } = extend;

if us.len() != NROWS * row_width {
if us.len() != extension_matrix_size(count) {
return Err(SenderError::InvalidExtend);
}

Expand Down
1 change: 1 addition & 0 deletions ot/mpz-ot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ itybity.workspace = true
enum-try-as-inner.workspace = true
opaque-debug.workspace = true
serde = { workspace = true, optional = true }
cfg-if.workspace = true

[dev-dependencies]
rstest = { workspace = true }
Expand Down
10 changes: 10 additions & 0 deletions ot/mpz-ot/src/kos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ pub use mpz_ot_core::kos::{
};
use utils_aio::{sink::IoSink, stream::IoStream};

// If we're testing we use a smaller chunk size to make sure the chunking code paths are tested.
cfg_if::cfg_if! {
if #[cfg(test)] {
pub(crate) const EXTEND_CHUNK_SIZE: usize = 1024;
} else {
/// The size of the chunks used to send the extension matrix, 4MB.
pub(crate) const EXTEND_CHUNK_SIZE: usize = 4 * 1024 * 1024;
}
}

/// Converts a sink of KOS messages into a sink of base OT messages.
pub(crate) fn into_base_sink<'a, Si: IoSink<msgs::Message<T>> + Send + Unpin, T: Send + 'a>(
sink: &'a mut Si,
Expand Down
19 changes: 14 additions & 5 deletions ot/mpz-ot/src/kos/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use futures::SinkExt;
use itybity::{FromBitIterator, IntoBitIterator};
use mpz_core::{cointoss, prg::Prg, Block, ProtocolMessage};
use mpz_ot_core::kos::{
msgs::Message, receiver_state as state, Receiver as ReceiverCore, ReceiverConfig, CSP, SSP,
msgs::{Message, StartExtend},
pad_ot_count, receiver_state as state, Receiver as ReceiverCore, ReceiverConfig, CSP,
};

use enum_try_as_inner::EnumTryAsInner;
Expand All @@ -15,7 +16,9 @@ use utils_aio::{
stream::{ExpectStreamExt, IoStream},
};

use super::{into_base_sink, into_base_stream, ReceiverError, ReceiverVerifyError};
use super::{
into_base_sink, into_base_stream, ReceiverError, ReceiverVerifyError, EXTEND_CHUNK_SIZE,
};
use crate::{
OTError, OTReceiver, OTSender, OTSetup, RandomOTReceiver, VerifiableOTReceiver,
VerifiableOTSender,
Expand Down Expand Up @@ -90,9 +93,11 @@ where
let mut ext_receiver =
std::mem::replace(&mut self.state, State::Error).try_into_extension()?;

// Extend the OTs, adding padding for the consistency check.
let count = pad_ot_count(count);

// Extend the OTs.
let (mut ext_receiver, extend) = Backend::spawn(move || {
let extend = ext_receiver.extend(count + CSP + SSP);
let extend = ext_receiver.extend(count);

(ext_receiver, extend)
})
Expand All @@ -105,7 +110,11 @@ where
let (cointoss_sender, cointoss_commitment) = cointoss::Sender::new(vec![seed]).send();

// Send the extend message and cointoss commitment
sink.feed(Message::Extend(extend)).await?;
sink.feed(Message::StartExtend(StartExtend { count }))
.await?;
for extend in extend.into_chunks(EXTEND_CHUNK_SIZE) {
sink.feed(Message::Extend(extend)).await?;
}
sink.feed(Message::CointossCommit(cointoss_commitment))
.await?;
sink.flush().await?;
Expand Down
45 changes: 34 additions & 11 deletions ot/mpz-ot/src/kos/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use futures_util::SinkExt;
use itybity::IntoBits;
use mpz_core::{cointoss, prg::Prg, Block, ProtocolMessage};
use mpz_ot_core::kos::{
msgs::Message, sender_state as state, Sender as SenderCore, SenderConfig, CSP, SSP,
extension_matrix_size,
msgs::{Extend, Message, StartExtend},
pad_ot_count, sender_state as state, Sender as SenderCore, SenderConfig, CSP,
};
use rand::{thread_rng, Rng};
use rand_core::{RngCore, SeedableRng};
Expand Down Expand Up @@ -138,23 +140,44 @@ where
let mut ext_sender =
std::mem::replace(&mut self.state, State::Error).try_into_extension()?;

// Receive extend message from the receiver.
let extend = stream
let count = pad_ot_count(count);

let StartExtend {
count: receiver_count,
} = stream
.expect_next()
.await?
.try_into_extend()
.try_into_start_extend()
.map_err(SenderError::from)?;

if count != receiver_count {
return Err(SenderError::ConfigError(
"sender and receiver count mismatch".to_string(),
));
}

let expected_us = extension_matrix_size(count);
let mut extend = Extend {
us: Vec::with_capacity(expected_us),
};

// Receive extension matrix from the receiver.
while extend.us.len() < expected_us {
let Extend { us: chunk } = stream
.expect_next()
.await?
.try_into_extend()
.map_err(SenderError::from)?;

extend.us.extend(chunk);
}

// Receive coin toss commitments from the receiver.
let commitment = stream.expect_next().await?.try_into_cointoss_commit()?;

// Extend the OTs, adding padding for the consistency check.
let mut ext_sender = Backend::spawn(move || {
ext_sender
.extend(count + CSP + SSP, extend)
.map(|_| ext_sender)
})
.await?;
// Extend the OTs.
let mut ext_sender =
Backend::spawn(move || ext_sender.extend(count, extend).map(|_| ext_sender)).await?;

// Execute coin toss protocol for consistency check.
let seed: Block = thread_rng().gen();
Expand Down

0 comments on commit 850636f

Please sign in to comment.