Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(mpz-ot): Normalize OT and ideal functionalities #122

Merged
merged 8 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/mpz-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
default = ["sync"]
sync = []
test-utils = []
ideal = []

[dependencies]
mpz-core.workspace = true
Expand Down
191 changes: 191 additions & 0 deletions crates/mpz-common/src/ideal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
//! Ideal functionality utilities.

use futures::channel::oneshot;
use std::{
any::Any,
collections::HashMap,
sync::{Arc, Mutex, MutexGuard},
};

use crate::{Context, ThreadId};

type BoxAny = Box<dyn Any + Send + 'static>;

#[derive(Debug, Default)]
struct Buffer {
alice: HashMap<ThreadId, (BoxAny, oneshot::Sender<BoxAny>)>,
bob: HashMap<ThreadId, (BoxAny, oneshot::Sender<BoxAny>)>,
}

/// The ideal functionality from the perspective of Alice.
#[derive(Debug)]
pub struct Alice<F> {
f: Arc<Mutex<F>>,
th4s marked this conversation as resolved.
Show resolved Hide resolved
buffer: Arc<Mutex<Buffer>>,
}

impl<F> Clone for Alice<F> {
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
buffer: self.buffer.clone(),
}
}
}

impl<F> Alice<F> {
/// Returns a lock to the ideal functionality.
pub fn get_mut(&mut self) -> MutexGuard<'_, F> {
self.f.lock().unwrap()
}

/// Calls the ideal functionality.
pub async fn call<Ctx, C, IA, IB, OA, OB>(&mut self, ctx: &mut Ctx, input: IA, call: C) -> OA
where
Ctx: Context,
C: FnOnce(&mut F, IA, IB) -> (OA, OB),
IA: Send + 'static,
IB: Send + 'static,
OA: Send + 'static,
OB: Send + 'static,
{
let receiver = {
let mut buffer = self.buffer.lock().unwrap();
if let Some((input_bob, ret_bob)) = buffer.bob.remove(ctx.id()) {
let input_bob = *input_bob
.downcast()
.expect("alice received correct input type for bob");

let (output_alice, output_bob) =
call(&mut self.f.lock().unwrap(), input, input_bob);

_ = ret_bob.send(Box::new(output_bob));

return output_alice;
}

let (sender, receiver) = oneshot::channel();
buffer
.alice
.insert(ctx.id().clone(), (Box::new(input), sender));
receiver
};

let output_alice = receiver.await.expect("bob did not drop the channel");
*output_alice
.downcast()
.expect("bob sent correct output type for alice")
}
}

/// The ideal functionality from the perspective of Bob.
#[derive(Debug)]
pub struct Bob<F> {
f: Arc<Mutex<F>>,
buffer: Arc<Mutex<Buffer>>,
}

impl<F> Clone for Bob<F> {
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
buffer: self.buffer.clone(),
}
}
}

impl<F> Bob<F> {
/// Returns a lock to the ideal functionality.
pub fn get_mut(&mut self) -> MutexGuard<'_, F> {
self.f.lock().unwrap()
}

/// Calls the ideal functionality.
pub async fn call<Ctx, C, IA, IB, OA, OB>(&mut self, ctx: &mut Ctx, input: IB, call: C) -> OB
where
Ctx: Context,
C: FnOnce(&mut F, IA, IB) -> (OA, OB),
IA: Send + 'static,
IB: Send + 'static,
OA: Send + 'static,
OB: Send + 'static,
{
let receiver = {
let mut buffer = self.buffer.lock().unwrap();
if let Some((input_alice, ret_alice)) = buffer.alice.remove(ctx.id()) {
let input_alice = *input_alice
.downcast()
.expect("bob received correct input type for alice");

let (output_alice, output_bob) =
call(&mut self.f.lock().unwrap(), input_alice, input);

_ = ret_alice.send(Box::new(output_alice));

return output_bob;
}

let (sender, receiver) = oneshot::channel();
buffer
.bob
.insert(ctx.id().clone(), (Box::new(input), sender));
receiver
};

let output_bob = receiver.await.expect("alice did not drop the channel");
*output_bob
.downcast()
.expect("alice sent correct output type for bob")
}
}

/// Creates an ideal functionality, returning the perspectives of Alice and Bob.
pub fn ideal_f2p<F>(f: F) -> (Alice<F>, Bob<F>) {
let f = Arc::new(Mutex::new(f));
let buffer = Arc::new(Mutex::new(Buffer::default()));

(
Alice {
f: f.clone(),
buffer: buffer.clone(),
},
Bob { f, buffer },
)
}

#[cfg(test)]
mod test {
use crate::executor::test_st_executor;

use super::*;

#[test]
fn test_ideal() {
let (mut alice, mut bob) = ideal_f2p(());
let (mut ctx_a, mut ctx_b) = test_st_executor(8);

let (output_a, output_b) = futures::executor::block_on(async {
futures::join!(
alice.call(&mut ctx_a, 1u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
bob.call(&mut ctx_b, 2u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
)
});

assert_eq!(output_a, 3);
assert_eq!(output_b, 3);
}

#[test]
#[should_panic]
fn test_ideal_wrong_input_type() {
let (mut alice, mut bob) = ideal_f2p(());
let (mut ctx_a, mut ctx_b) = test_st_executor(8);

futures::executor::block_on(async {
futures::join!(
alice.call(&mut ctx_a, 1u16, |&mut (), a: u16, b: u16| (a + b, a + b)),
bob.call(&mut ctx_b, 2u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
)
});
}
}
2 changes: 2 additions & 0 deletions crates/mpz-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
mod context;
pub mod executor;
mod id;
#[cfg(any(test, feature = "ideal"))]
pub mod ideal;
#[cfg(feature = "sync")]
pub mod sync;

Expand Down
6 changes: 6 additions & 0 deletions crates/mpz-ot-core/src/chou_orlandi/error.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use crate::TransferId;

/// Errors that can occur when using the CO15 sender.
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum SenderError {
#[error("invalid state: expected {0}")]
InvalidState(String),
#[error("id mismatch: expected {0}, got {1}")]
IdMismatch(TransferId, TransferId),
#[error("count mismatch: sender expected {0} but receiver sent {1}")]
CountMismatch(usize, usize),
#[error(transparent)]
Expand All @@ -16,6 +20,8 @@ pub enum SenderError {
pub enum ReceiverError {
#[error("invalid state: expected {0}")]
InvalidState(String),
#[error("id mismatch: expected {0}, got {1}")]
IdMismatch(TransferId, TransferId),
#[error("count mismatch: receiver expected {0} but sender sent {1}")]
CountMismatch(usize, usize),
}
Expand Down
6 changes: 6 additions & 0 deletions crates/mpz-ot-core/src/chou_orlandi/msgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use curve25519_dalek::RistrettoPoint;
use mpz_core::Block;
use serde::{Deserialize, Serialize};

use crate::TransferId;

/// Sender setup message.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SenderSetup {
Expand All @@ -14,13 +16,17 @@ pub struct SenderSetup {
/// Sender payload message.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SenderPayload {
/// The transfer ID.
pub id: TransferId,
/// The sender's ciphertexts
pub payload: Vec<[Block; 2]>,
}

/// Receiver payload message.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ReceiverPayload {
/// The transfer ID.
pub id: TransferId,
/// The receiver's blinded choices.
pub blinded_choices: Vec<RistrettoPoint>,
}
Expand Down
21 changes: 18 additions & 3 deletions crates/mpz-ot-core/src/chou_orlandi/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::chou_orlandi::{
msgs::{ReceiverPayload, ReceiverReveal, SenderPayload, SenderSetup},
ReceiverConfig, ReceiverError,
};
use crate::TransferId;

use itybity::{BitIterable, FromBitIterator, ToBits};
use mpz_core::Block;
Expand Down Expand Up @@ -89,6 +90,7 @@ impl Receiver {
state: state::Setup {
rng,
sender_base_table: RistrettoBasepointTable::create(&sender_setup.public_key),
transfer_id: TransferId::default(),
counter: 0,
choice_log: Vec::default(),
decryption_keys: Vec::default(),
Expand Down Expand Up @@ -129,7 +131,10 @@ impl Receiver<state::Setup> {
choice_log.extend(choices.iter_lsb0());
}

ReceiverPayload { blinded_choices }
ReceiverPayload {
id: self.state.transfer_id,
blinded_choices,
}
}

/// Receives the encrypted payload from the Sender, returning the plaintext messages corresponding
Expand All @@ -140,10 +145,18 @@ impl Receiver<state::Setup> {
/// * `payload` - The encrypted payload from the Sender
pub fn receive(&mut self, payload: SenderPayload) -> Result<Vec<Block>, ReceiverError> {
let state::Setup {
decryption_keys, ..
transfer_id: current_id,
decryption_keys,
..
} = &mut self.state;

let SenderPayload { payload } = payload;
let SenderPayload { id, payload } = payload;

// Check that the transfer id matches
let expected_id = current_id.next();
if id != expected_id {
return Err(ReceiverError::IdMismatch(expected_id, id));
}

// Check that the number of ciphertexts does not exceed the number of pending keys
if payload.len() > decryption_keys.len() {
Expand Down Expand Up @@ -267,6 +280,8 @@ pub mod state {
pub(super) rng: ChaCha20Rng,
/// Sender's public key (precomputed table)
pub(super) sender_base_table: RistrettoBasepointTable,
/// Current transfer id.
pub(super) transfer_id: TransferId,
/// Counts how many decryption keys we've computed so far
pub(super) counter: usize,
/// Log of the receiver's choice bits
Expand Down
32 changes: 25 additions & 7 deletions crates/mpz-ot-core/src/chou_orlandi/sender.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::chou_orlandi::{
hash_point,
msgs::{ReceiverPayload, ReceiverReveal, SenderPayload, SenderSetup},
Receiver, ReceiverConfig, SenderConfig, SenderError, SenderVerifyError,
use crate::{
chou_orlandi::{
hash_point,
msgs::{ReceiverPayload, ReceiverReveal, SenderPayload, SenderSetup},
Receiver, ReceiverConfig, SenderConfig, SenderError, SenderVerifyError,
},
TransferId,
};

use itybity::IntoBitIterator;
Expand Down Expand Up @@ -101,6 +104,7 @@ impl Sender {
state: state::Setup {
private_key,
public_key,
transfer_id: TransferId::default(),
counter: 0,
},
tape: self.tape,
Expand All @@ -124,11 +128,21 @@ impl Sender<state::Setup> {
let state::Setup {
private_key,
public_key,
transfer_id: current_id,
counter,
..
} = &mut self.state;

let ReceiverPayload { blinded_choices } = receiver_payload;
let ReceiverPayload {
id,
blinded_choices,
} = receiver_payload;

// Check that the transfer id matches
let expected_id = current_id.next();
if id != expected_id {
return Err(SenderError::IdMismatch(expected_id, id));
}

// Check that the number of inputs matches the number of choices
if inputs.len() != blinded_choices.len() {
Expand All @@ -154,7 +168,7 @@ impl Sender<state::Setup> {
payload[1] = input[1] ^ payload[1];
}

Ok(SenderPayload { payload })
Ok(SenderPayload { id, payload })
}

/// Returns the Receiver choices after verifying them against the tape.
Expand Down Expand Up @@ -199,7 +213,9 @@ impl Sender<state::Setup> {

let mut receiver = receiver.setup(SenderSetup { public_key });

let ReceiverPayload { blinded_choices } = receiver.receive_random(&choices);
let ReceiverPayload {
blinded_choices, ..
} = receiver.receive_random(&choices);

// Check that the simulated receiver's choices match the ones recorded in the tape
if blinded_choices != tape.receiver_choices {
Expand Down Expand Up @@ -296,6 +312,8 @@ pub mod state {
pub(super) private_key: Scalar,
// The public_key is `A == g^a` in [ref1]
pub(super) public_key: RistrettoPoint,
/// Current transfer id.
pub(super) transfer_id: TransferId,
/// Number of OTs sent so far
pub(super) counter: usize,
}
Expand Down
Loading