Skip to content

Commit

Permalink
refactor(mpz-ot): Normalize OT and ideal functionalities (#122)
Browse files Browse the repository at this point in the history
* add transfer id

* update co15 and kos15

* add Output type

* feat(mpz-common): ideal functionality utils

* refactor ideal functionalities and traits

* pr feedback

* impl ideal rot

* Update crates/mpz-ot/src/ideal/rot.rs

Co-authored-by: th4s <[email protected]>

---------

Co-authored-by: th4s <[email protected]>
  • Loading branch information
sinui0 and th4s authored May 8, 2024
1 parent 73441ff commit 42c7fe9
Show file tree
Hide file tree
Showing 43 changed files with 1,459 additions and 1,107 deletions.
1 change: 1 addition & 0 deletions crates/mpz-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
default = ["sync"]
sync = []
test-utils = []
ideal = []

[dependencies]
mpz-core.workspace = true
Expand Down
191 changes: 191 additions & 0 deletions crates/mpz-common/src/ideal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
//! Ideal functionality utilities.
use futures::channel::oneshot;
use std::{
any::Any,
collections::HashMap,
sync::{Arc, Mutex, MutexGuard},
};

use crate::{Context, ThreadId};

type BoxAny = Box<dyn Any + Send + 'static>;

#[derive(Debug, Default)]
struct Buffer {
alice: HashMap<ThreadId, (BoxAny, oneshot::Sender<BoxAny>)>,
bob: HashMap<ThreadId, (BoxAny, oneshot::Sender<BoxAny>)>,
}

/// The ideal functionality from the perspective of Alice.
#[derive(Debug)]
pub struct Alice<F> {
f: Arc<Mutex<F>>,
buffer: Arc<Mutex<Buffer>>,
}

impl<F> Clone for Alice<F> {
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
buffer: self.buffer.clone(),
}
}
}

impl<F> Alice<F> {
/// Returns a lock to the ideal functionality.
pub fn get_mut(&mut self) -> MutexGuard<'_, F> {
self.f.lock().unwrap()
}

/// Calls the ideal functionality.
pub async fn call<Ctx, C, IA, IB, OA, OB>(&mut self, ctx: &mut Ctx, input: IA, call: C) -> OA
where
Ctx: Context,
C: FnOnce(&mut F, IA, IB) -> (OA, OB),
IA: Send + 'static,
IB: Send + 'static,
OA: Send + 'static,
OB: Send + 'static,
{
let receiver = {
let mut buffer = self.buffer.lock().unwrap();
if let Some((input_bob, ret_bob)) = buffer.bob.remove(ctx.id()) {
let input_bob = *input_bob
.downcast()
.expect("alice received correct input type for bob");

let (output_alice, output_bob) =
call(&mut self.f.lock().unwrap(), input, input_bob);

_ = ret_bob.send(Box::new(output_bob));

return output_alice;
}

let (sender, receiver) = oneshot::channel();
buffer
.alice
.insert(ctx.id().clone(), (Box::new(input), sender));
receiver
};

let output_alice = receiver.await.expect("bob did not drop the channel");
*output_alice
.downcast()
.expect("bob sent correct output type for alice")
}
}

/// The ideal functionality from the perspective of Bob.
#[derive(Debug)]
pub struct Bob<F> {
f: Arc<Mutex<F>>,
buffer: Arc<Mutex<Buffer>>,
}

impl<F> Clone for Bob<F> {
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
buffer: self.buffer.clone(),
}
}
}

impl<F> Bob<F> {
/// Returns a lock to the ideal functionality.
pub fn get_mut(&mut self) -> MutexGuard<'_, F> {
self.f.lock().unwrap()
}

/// Calls the ideal functionality.
pub async fn call<Ctx, C, IA, IB, OA, OB>(&mut self, ctx: &mut Ctx, input: IB, call: C) -> OB
where
Ctx: Context,
C: FnOnce(&mut F, IA, IB) -> (OA, OB),
IA: Send + 'static,
IB: Send + 'static,
OA: Send + 'static,
OB: Send + 'static,
{
let receiver = {
let mut buffer = self.buffer.lock().unwrap();
if let Some((input_alice, ret_alice)) = buffer.alice.remove(ctx.id()) {
let input_alice = *input_alice
.downcast()
.expect("bob received correct input type for alice");

let (output_alice, output_bob) =
call(&mut self.f.lock().unwrap(), input_alice, input);

_ = ret_alice.send(Box::new(output_alice));

return output_bob;
}

let (sender, receiver) = oneshot::channel();
buffer
.bob
.insert(ctx.id().clone(), (Box::new(input), sender));
receiver
};

let output_bob = receiver.await.expect("alice did not drop the channel");
*output_bob
.downcast()
.expect("alice sent correct output type for bob")
}
}

/// Creates an ideal functionality, returning the perspectives of Alice and Bob.
pub fn ideal_f2p<F>(f: F) -> (Alice<F>, Bob<F>) {
let f = Arc::new(Mutex::new(f));
let buffer = Arc::new(Mutex::new(Buffer::default()));

(
Alice {
f: f.clone(),
buffer: buffer.clone(),
},
Bob { f, buffer },
)
}

#[cfg(test)]
mod test {
use crate::executor::test_st_executor;

use super::*;

#[test]
fn test_ideal() {
let (mut alice, mut bob) = ideal_f2p(());
let (mut ctx_a, mut ctx_b) = test_st_executor(8);

let (output_a, output_b) = futures::executor::block_on(async {
futures::join!(
alice.call(&mut ctx_a, 1u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
bob.call(&mut ctx_b, 2u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
)
});

assert_eq!(output_a, 3);
assert_eq!(output_b, 3);
}

#[test]
#[should_panic]
fn test_ideal_wrong_input_type() {
let (mut alice, mut bob) = ideal_f2p(());
let (mut ctx_a, mut ctx_b) = test_st_executor(8);

futures::executor::block_on(async {
futures::join!(
alice.call(&mut ctx_a, 1u16, |&mut (), a: u16, b: u16| (a + b, a + b)),
bob.call(&mut ctx_b, 2u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
)
});
}
}
2 changes: 2 additions & 0 deletions crates/mpz-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
mod context;
pub mod executor;
mod id;
#[cfg(any(test, feature = "ideal"))]
pub mod ideal;
#[cfg(feature = "sync")]
pub mod sync;

Expand Down
6 changes: 6 additions & 0 deletions crates/mpz-ot-core/src/chou_orlandi/error.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use crate::TransferId;

/// Errors that can occur when using the CO15 sender.
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum SenderError {
#[error("invalid state: expected {0}")]
InvalidState(String),
#[error("id mismatch: expected {0}, got {1}")]
IdMismatch(TransferId, TransferId),
#[error("count mismatch: sender expected {0} but receiver sent {1}")]
CountMismatch(usize, usize),
#[error(transparent)]
Expand All @@ -16,6 +20,8 @@ pub enum SenderError {
pub enum ReceiverError {
#[error("invalid state: expected {0}")]
InvalidState(String),
#[error("id mismatch: expected {0}, got {1}")]
IdMismatch(TransferId, TransferId),
#[error("count mismatch: receiver expected {0} but sender sent {1}")]
CountMismatch(usize, usize),
}
Expand Down
6 changes: 6 additions & 0 deletions crates/mpz-ot-core/src/chou_orlandi/msgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use curve25519_dalek::RistrettoPoint;
use mpz_core::Block;
use serde::{Deserialize, Serialize};

use crate::TransferId;

/// Sender setup message.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SenderSetup {
Expand All @@ -14,13 +16,17 @@ pub struct SenderSetup {
/// Sender payload message.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SenderPayload {
/// The transfer ID.
pub id: TransferId,
/// The sender's ciphertexts
pub payload: Vec<[Block; 2]>,
}

/// Receiver payload message.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ReceiverPayload {
/// The transfer ID.
pub id: TransferId,
/// The receiver's blinded choices.
pub blinded_choices: Vec<RistrettoPoint>,
}
Expand Down
21 changes: 18 additions & 3 deletions crates/mpz-ot-core/src/chou_orlandi/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::chou_orlandi::{
msgs::{ReceiverPayload, ReceiverReveal, SenderPayload, SenderSetup},
ReceiverConfig, ReceiverError,
};
use crate::TransferId;

use itybity::{BitIterable, FromBitIterator, ToBits};
use mpz_core::Block;
Expand Down Expand Up @@ -89,6 +90,7 @@ impl Receiver {
state: state::Setup {
rng,
sender_base_table: RistrettoBasepointTable::create(&sender_setup.public_key),
transfer_id: TransferId::default(),
counter: 0,
choice_log: Vec::default(),
decryption_keys: Vec::default(),
Expand Down Expand Up @@ -129,7 +131,10 @@ impl Receiver<state::Setup> {
choice_log.extend(choices.iter_lsb0());
}

ReceiverPayload { blinded_choices }
ReceiverPayload {
id: self.state.transfer_id,
blinded_choices,
}
}

/// Receives the encrypted payload from the Sender, returning the plaintext messages corresponding
Expand All @@ -140,10 +145,18 @@ impl Receiver<state::Setup> {
/// * `payload` - The encrypted payload from the Sender
pub fn receive(&mut self, payload: SenderPayload) -> Result<Vec<Block>, ReceiverError> {
let state::Setup {
decryption_keys, ..
transfer_id: current_id,
decryption_keys,
..
} = &mut self.state;

let SenderPayload { payload } = payload;
let SenderPayload { id, payload } = payload;

// Check that the transfer id matches
let expected_id = current_id.next();
if id != expected_id {
return Err(ReceiverError::IdMismatch(expected_id, id));
}

// Check that the number of ciphertexts does not exceed the number of pending keys
if payload.len() > decryption_keys.len() {
Expand Down Expand Up @@ -267,6 +280,8 @@ pub mod state {
pub(super) rng: ChaCha20Rng,
/// Sender's public key (precomputed table)
pub(super) sender_base_table: RistrettoBasepointTable,
/// Current transfer id.
pub(super) transfer_id: TransferId,
/// Counts how many decryption keys we've computed so far
pub(super) counter: usize,
/// Log of the receiver's choice bits
Expand Down
32 changes: 25 additions & 7 deletions crates/mpz-ot-core/src/chou_orlandi/sender.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::chou_orlandi::{
hash_point,
msgs::{ReceiverPayload, ReceiverReveal, SenderPayload, SenderSetup},
Receiver, ReceiverConfig, SenderConfig, SenderError, SenderVerifyError,
use crate::{
chou_orlandi::{
hash_point,
msgs::{ReceiverPayload, ReceiverReveal, SenderPayload, SenderSetup},
Receiver, ReceiverConfig, SenderConfig, SenderError, SenderVerifyError,
},
TransferId,
};

use itybity::IntoBitIterator;
Expand Down Expand Up @@ -101,6 +104,7 @@ impl Sender {
state: state::Setup {
private_key,
public_key,
transfer_id: TransferId::default(),
counter: 0,
},
tape: self.tape,
Expand All @@ -124,11 +128,21 @@ impl Sender<state::Setup> {
let state::Setup {
private_key,
public_key,
transfer_id: current_id,
counter,
..
} = &mut self.state;

let ReceiverPayload { blinded_choices } = receiver_payload;
let ReceiverPayload {
id,
blinded_choices,
} = receiver_payload;

// Check that the transfer id matches
let expected_id = current_id.next();
if id != expected_id {
return Err(SenderError::IdMismatch(expected_id, id));
}

// Check that the number of inputs matches the number of choices
if inputs.len() != blinded_choices.len() {
Expand All @@ -154,7 +168,7 @@ impl Sender<state::Setup> {
payload[1] = input[1] ^ payload[1];
}

Ok(SenderPayload { payload })
Ok(SenderPayload { id, payload })
}

/// Returns the Receiver choices after verifying them against the tape.
Expand Down Expand Up @@ -199,7 +213,9 @@ impl Sender<state::Setup> {

let mut receiver = receiver.setup(SenderSetup { public_key });

let ReceiverPayload { blinded_choices } = receiver.receive_random(&choices);
let ReceiverPayload {
blinded_choices, ..
} = receiver.receive_random(&choices);

// Check that the simulated receiver's choices match the ones recorded in the tape
if blinded_choices != tape.receiver_choices {
Expand Down Expand Up @@ -296,6 +312,8 @@ pub mod state {
pub(super) private_key: Scalar,
// The public_key is `A == g^a` in [ref1]
pub(super) public_key: RistrettoPoint,
/// Current transfer id.
pub(super) transfer_id: TransferId,
/// Number of OTs sent so far
pub(super) counter: usize,
}
Expand Down
Loading

0 comments on commit 42c7fe9

Please sign in to comment.