Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threading Refactor #165

Merged
merged 34 commits into from
Jun 25, 2024
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a23a095
feat: mpz-common (#107)
sinui0 Mar 7, 2024
7883f37
refactor: cointoss (#108)
sinui0 Mar 7, 2024
272b83d
refactor: mpz-ot (#109)
sinui0 Mar 7, 2024
a4f80dc
refactor: re-organize crates (#110)
sinui0 Mar 7, 2024
61275e8
Adds an ideal ROT functionality to mpz-ot-core (#102)
th4s Mar 7, 2024
13b0cae
refactor(mpz-ot): Normalize OT and ideal functionalities (#122)
sinui0 May 8, 2024
9891192
feat(mpz-common): add try_/join convenience macros (#126)
sinui0 May 13, 2024
9300c94
fix(mpz-ot): Ideal RCOT (#131)
sinui0 May 13, 2024
67564a4
docs: fix typos (#130)
themighty1 May 15, 2024
b35d392
feat(mpz-common): dummy executor (#132)
sinui0 May 15, 2024
3d523ec
feat(mpz-common): simple counter (#133)
sinui0 May 15, 2024
a10810a
refactor(mpz-garble-core): batched garbling (#140)
sinui0 May 28, 2024
b9e5f59
Add crate `mpz-ole-core` (#135)
th4s May 29, 2024
7292063
feat(mpz-common): multi-threaded executor (#136)
sinui0 May 29, 2024
6617c54
Add IO wrapper for OLE (#138)
th4s May 31, 2024
699acff
feat(mpz-common): Context::blocking (#141)
sinui0 May 31, 2024
c50a145
feat(mpz-common): scoped! macro (#143)
sinui0 May 31, 2024
ab82dbf
test(mpz-common): test mt executor concurrency (#145)
sinui0 Jun 4, 2024
81802ea
Add `mpz-share-conversion-core` (#147)
th4s Jun 5, 2024
b858189
refactor(mpz-garble): fix threading breaking changes (#144)
sinui0 Jun 5, 2024
b0f5a90
refactor(mpz-share-conversion): new impl (#146)
sinui0 Jun 5, 2024
6f1ef18
feat(mpz-common): add type alias for test st executor (#154)
sinui0 Jun 7, 2024
5fc90ff
feat(mpz-common): async sync primitives (#152)
sinui0 Jun 11, 2024
801381c
feat(mpz-ot): impl more OT traits on shared KOS (#153)
sinui0 Jun 11, 2024
6c7dec2
feat(mpz-garble): pre-commit inputs (#149)
sinui0 Jun 12, 2024
9c17b38
refactor: KOS and preprocessing traits (#155)
sinui0 Jun 12, 2024
640bda0
refactor(mpz-ot): add accept_reveal for verifiable ot (#158)
sinui0 Jun 14, 2024
1ec8979
fix(mpz-common): flush io in syncer (#157)
sinui0 Jun 14, 2024
34a0663
fix(mpz-ot): fix shared KOS verifiable ot receiver (#161)
sinui0 Jun 20, 2024
3272954
fix(mpz-garble): add thread id to otp ids (#162)
sinui0 Jun 20, 2024
461d3b4
refactor(mpz-common): new thread future (#163)
sinui0 Jun 24, 2024
836ad67
chore: bump serio and uid-mux (#164)
sinui0 Jun 25, 2024
5ad1bea
chore: move workspace manifest
sinui0 Jun 25, 2024
e25d123
fix: clippy --fix
sinui0 Jun 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor(mpz-garble-core): batched garbling (#140)
* refactor(mpz-garble-core): batched garbling

* Apply suggestions from code review

Co-authored-by: th4s <[email protected]>
Co-authored-by: dan <[email protected]>

* qualify comment

* remove unused msg module

* comments

---------

Co-authored-by: th4s <[email protected]>
Co-authored-by: dan <[email protected]>
3 people committed Jun 25, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit a10810ac233de9fdeae79759268d2f0e4a8bb18d
1 change: 1 addition & 0 deletions crates/mpz-garble-core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ rand_core.workspace = true
rand_chacha.workspace = true
regex = { workspace = true, optional = true }
once_cell.workspace = true
opaque-debug.workspace = true

serde = { workspace = true, features = ["derive"] }
serde_arrays.workspace = true
78 changes: 61 additions & 17 deletions crates/mpz-garble-core/benches/garble.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,83 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use mpz_circuits::circuits::AES128;
use mpz_garble_core::{ChaChaEncoder, Encoder, Generator};
use mpz_garble_core::{ChaChaEncoder, Encoder, Evaluator, Generator};

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("garble_circuits");
let mut gb_group = c.benchmark_group("garble");

let encoder = ChaChaEncoder::new([0u8; 32]);
let inputs = AES128
let full_inputs = AES128
.inputs()
.iter()
.map(|value| encoder.encode_by_type(0, &value.value_type()))
.collect::<Vec<_>>();
group.bench_function("aes128", |b| {

let active_inputs = vec![
full_inputs[0].clone().select([0u8; 16]).unwrap(),
full_inputs[1].clone().select([0u8; 16]).unwrap(),
];

gb_group.bench_function("aes128", |b| {
let mut gen = Generator::default();
b.iter(|| {
let mut gen = Generator::new(AES128.clone(), encoder.delta(), &inputs).unwrap();
let mut gen_iter = gen
.generate(&AES128, encoder.delta(), full_inputs.clone())
.unwrap();

let mut enc_gates = Vec::with_capacity(AES128.and_count());
for gate in gen.by_ref() {
enc_gates.push(gate);
}
let _: Vec<_> = gen_iter.by_ref().collect();

black_box(gen_iter.finish().unwrap())
})
});

gb_group.bench_function("aes128_batched", |b| {
let mut gen = Generator::default();
b.iter(|| {
let mut gen_iter = gen
.generate_batched(&AES128, encoder.delta(), full_inputs.clone())
.unwrap();

let _: Vec<_> = gen_iter.by_ref().collect();

black_box(gen_iter.finish().unwrap())
})
});

gb_group.bench_function("aes128_with_hash", |b| {
let mut gen = Generator::default();
b.iter(|| {
let mut gen_iter = gen
.generate(&AES128, encoder.delta(), full_inputs.clone())
.unwrap();

gen_iter.enable_hasher();

black_box(gen.outputs().unwrap())
let _: Vec<_> = gen_iter.by_ref().collect();

black_box(gen_iter.finish().unwrap())
})
});
group.bench_function("aes128_with_hash", |b| {

drop(gb_group);

let mut ev_group = c.benchmark_group("evaluate");

ev_group.bench_function("aes128", |b| {
let mut gen = Generator::default();
let mut gen_iter = gen
.generate(&AES128, encoder.delta(), full_inputs.clone())
.unwrap();
let gates: Vec<_> = gen_iter.by_ref().collect();

let mut ev = Evaluator::default();
b.iter(|| {
let mut gen =
Generator::new_with_hasher(AES128.clone(), encoder.delta(), &inputs).unwrap();
let mut ev_consumer = ev.evaluate(&AES128, active_inputs.clone()).unwrap();

let mut enc_gates = Vec::with_capacity(AES128.and_count());
for gate in gen.by_ref() {
enc_gates.push(gate);
for gate in &gates {
ev_consumer.next(*gate);
}

black_box(gen.outputs().unwrap())
black_box(ev_consumer.finish().unwrap());
})
});
}
26 changes: 24 additions & 2 deletions crates/mpz-garble-core/src/circuit.rs
Original file line number Diff line number Diff line change
@@ -3,15 +3,15 @@ use std::ops::Index;
use mpz_core::Block;
use serde::{Deserialize, Serialize};

use crate::EncodingCommitment;
use crate::{EncodingCommitment, DEFAULT_BATCH_SIZE};

/// Encrypted gate truth table
///
/// For the half-gate garbling scheme a truth table will typically have 2 rows, except for in
/// privacy-free garbling mode where it will be reduced to 1.
///
/// We do not yet support privacy-free garbling.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct EncryptedGate(#[serde(with = "serde_arrays")] pub(crate) [Block; 2]);

impl EncryptedGate {
@@ -35,6 +35,28 @@ impl Index<usize> for EncryptedGate {
}
}

/// A batch of encrypted gates.
///
/// # Parameters
///
/// - `N`: The size of a batch.
#[derive(Debug, Serialize, Deserialize)]
pub struct EncryptedGateBatch<const N: usize = DEFAULT_BATCH_SIZE>(
#[serde(with = "serde_arrays")] [EncryptedGate; N],
);

impl<const N: usize> EncryptedGateBatch<N> {
/// Creates a new batch of encrypted gates.
pub fn new(batch: [EncryptedGate; N]) -> Self {
Self(batch)
}

/// Returns the inner array.
pub fn into_array(self) -> [EncryptedGate; N] {
self.0
}
}

/// A garbled circuit
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GarbledCircuit {
11 changes: 10 additions & 1 deletion crates/mpz-garble-core/src/encoding/mod.rs
Original file line number Diff line number Diff line change
@@ -272,7 +272,7 @@ impl<const N: usize, S: LabelState> Index<usize> for Labels<N, S> {
}

/// Encoded bit label.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Label(Block);

impl Label {
@@ -350,6 +350,15 @@ impl BitXor<Delta> for &Label {
}
}

impl BitXor<&Delta> for Label {
type Output = Label;

#[inline]
fn bitxor(self, rhs: &Delta) -> Self::Output {
Label(self.0 ^ rhs.0)
}
}

impl AsRef<Block> for Label {
fn as_ref(&self) -> &Block {
&self.0
305 changes: 196 additions & 109 deletions crates/mpz-garble-core/src/evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use std::sync::Arc;
use core::fmt;

use blake3::Hasher;

use crate::{
circuit::EncryptedGate,
encoding::{state, EncodedValue, Label},
EncryptedGateBatch, DEFAULT_BATCH_SIZE,
};
use mpz_circuits::{
types::{BinaryRepr, TypeError},
Circuit, CircuitError, Gate,
};
use mpz_circuits::{types::TypeError, Circuit, CircuitError, Gate};
use mpz_core::{
aes::{FixedKeyAes, FIXED_KEY_AES},
hash::Hash,
@@ -54,66 +58,55 @@ pub(crate) fn and_gate(
Label::new(w_g ^ w_e)
}

/// Core evaluator type for evaluating a garbled circuit.
/// Output of the evaluator.
#[derive(Debug)]
pub struct EvaluatorOutput {
/// Encoded outputs of the circuit.
pub outputs: Vec<EncodedValue<state::Active>>,
/// Hash of the encrypted gates.
pub hash: Option<Hash>,
}

/// Garbled circuit evaluator.
#[derive(Debug)]
pub struct Evaluator {
/// Cipher to use to encrypt the gates
cipher: &'static FixedKeyAes,
/// Circuit to evaluate
circ: Arc<Circuit>,
/// Active label state
active_labels: Vec<Option<Label>>,
/// Current position in the circuit
pos: usize,
/// Current gate id
gid: usize,
/// Whether the evaluator is finished
complete: bool,
/// Hasher to use to hash the encrypted gates
hasher: Option<Hasher>,
/// Buffer for the active labels.
buffer: Vec<Label>,
}

impl Evaluator {
/// Creates a new evaluator for the given circuit.
///
/// # Arguments
///
/// * `circ` - The circuit to evaluate.
/// * `inputs` - The inputs to the circuit.
pub fn new(
circ: Arc<Circuit>,
inputs: &[EncodedValue<state::Active>],
) -> Result<Self, EvaluatorError> {
Self::new_with(circ, inputs, None)
impl Default for Evaluator {
fn default() -> Self {
Self {
buffer: Default::default(),
}
}
}

/// Creates a new evaluator for the given circuit. Evaluator will compute
/// a hash of the encrypted gates while they are evaluated.
impl Evaluator {
/// Returns a consumer over the encrypted gates of a circuit.
///
/// # Arguments
///
/// * `circ` - The circuit to evaluate.
/// * `inputs` - The inputs to the circuit.
pub fn new_with_hasher(
circ: Arc<Circuit>,
inputs: &[EncodedValue<state::Active>],
) -> Result<Self, EvaluatorError> {
Self::new_with(circ, inputs, Some(Hasher::new()))
}

fn new_with(
circ: Arc<Circuit>,
inputs: &[EncodedValue<state::Active>],
hasher: Option<Hasher>,
) -> Result<Self, EvaluatorError> {
/// * `inputs` - The input values to the circuit.
pub fn evaluate<'a>(
&'a mut self,
circ: &'a Circuit,
inputs: Vec<EncodedValue<state::Active>>,
) -> Result<EncryptedGateConsumer<'_, std::slice::Iter<'_, Gate>>, EvaluatorError> {
if inputs.len() != circ.inputs().len() {
return Err(CircuitError::InvalidInputCount(
circ.inputs().len(),
inputs.len(),
))?;
}

let mut active_labels: Vec<Option<Label>> = vec![None; circ.feed_count()];
for (encoded, input) in inputs.iter().zip(circ.inputs()) {
// Expand the buffer to fit the circuit
if circ.feed_count() > self.buffer.len() {
self.buffer.resize(circ.feed_count(), Default::default());
}

for (encoded, input) in inputs.into_iter().zip(circ.inputs()) {
if encoded.value_type() != input.value_type() {
return Err(TypeError::UnexpectedType {
expected: input.value_type(),
@@ -122,111 +115,205 @@ impl Evaluator {
}

for (label, node) in encoded.iter().zip(input.iter()) {
active_labels[node.id()] = Some(*label);
self.buffer[node.id()] = *label;
}
}

let mut ev = Self {
Ok(EncryptedGateConsumer::new(
circ.gates().iter(),
circ.outputs(),
&mut self.buffer,
circ.and_count(),
))
}

/// Returns a consumer over batched encrypted gates of a circuit.
///
/// # Arguments
///
/// * `circ` - The circuit to evaluate.
/// * `inputs` - The input values to the circuit.
pub fn evaluate_batched<'a>(
&'a mut self,
circ: &'a Circuit,
inputs: Vec<EncodedValue<state::Active>>,
) -> Result<EncryptedGateBatchConsumer<'_, std::slice::Iter<'_, Gate>>, EvaluatorError> {
self.evaluate(circ, inputs).map(EncryptedGateBatchConsumer)
}
}

/// Consumer over the encrypted gates of a circuit.
pub struct EncryptedGateConsumer<'a, I: Iterator> {
/// Cipher to use to encrypt the gates.
cipher: &'static FixedKeyAes,
/// Buffer for the active labels.
labels: &'a mut [Label],
/// Iterator over the gates.
gates: I,
/// Circuit outputs.
outputs: &'a [BinaryRepr],
/// Current gate id.
gid: usize,
/// Hasher to use to hash the encrypted gates.
hasher: Option<Hasher>,
/// Number of AND gates evaluated.
counter: usize,
/// Total number of AND gates in the circuit.
and_count: usize,
/// Whether the entire circuit has been garbled.
complete: bool,
}

impl<'a, I: Iterator> fmt::Debug for EncryptedGateConsumer<'a, I> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "EncryptedGateConsumer {{ .. }}")
}
}

impl<'a, I> EncryptedGateConsumer<'a, I>
where
I: Iterator<Item = &'a Gate>,
{
fn new(gates: I, outputs: &'a [BinaryRepr], labels: &'a mut [Label], and_count: usize) -> Self {
Self {
cipher: &(*FIXED_KEY_AES),
circ,
active_labels,
pos: 0,
gates,
outputs,
labels,
gid: 1,
hasher: None,
counter: 0,
and_count,
complete: false,
hasher,
};

// If circuit has no AND gates we can evaluate it immediately for cheap
if ev.circ.and_count() == 0 {
ev.evaluate(std::iter::empty());
}
}

Ok(ev)
/// Enables hashing of the encrypted gates.
pub fn enable_hasher(&mut self) {
self.hasher = Some(Hasher::new());
}

/// Evaluates the next batch of encrypted gates.
/// Returns `true` if the evaluator wants more encrypted gates.
#[inline]
pub fn evaluate<'a>(&mut self, mut encrypted_gates: impl Iterator<Item = &'a EncryptedGate>) {
let labels = &mut self.active_labels;
pub fn wants_gates(&self) -> bool {
self.counter != self.and_count
}

// Process gates until we run out of encrypted gates
while self.pos < self.circ.gates().len() {
match &self.circ.gates()[self.pos] {
Gate::Inv {
x: node_x,
z: node_z,
} => {
let x = labels[node_x.id()].expect("feed should be initialized");
labels[node_z.id()] = Some(x);
}
/// Evaluates the next encrypted gate in the circuit.
#[inline]
pub fn next(&mut self, encrypted_gate: EncryptedGate) {
while let Some(gate) = self.gates.next() {
match gate {
Gate::Xor {
x: node_x,
y: node_y,
z: node_z,
} => {
let x = labels[node_x.id()].expect("feed should be initialized");
let y = labels[node_y.id()].expect("feed should be initialized");
labels[node_z.id()] = Some(x ^ y);
let x = self.labels[node_x.id()];
let y = self.labels[node_y.id()];
self.labels[node_z.id()] = x ^ y;
}
Gate::And {
x: node_x,
y: node_y,
z: node_z,
} => {
if let Some(encrypted_gate) = encrypted_gates.next() {
if let Some(hasher) = &mut self.hasher {
hasher.update(&encrypted_gate.to_bytes());
}

let x = labels[node_x.id()].expect("feed should be initialized");
let y = labels[node_y.id()].expect("feed should be initialized");
let z = and_gate(self.cipher, &x, &y, encrypted_gate, self.gid);
labels[node_z.id()] = Some(z);
self.gid += 2;
} else {
// We ran out of encrypted gates, so we return until we get more
let x = self.labels[node_x.id()];
let y = self.labels[node_y.id()];
let z = and_gate(self.cipher, &x, &y, &encrypted_gate, self.gid);
self.labels[node_z.id()] = z;

self.gid += 2;
self.counter += 1;

if let Some(hasher) = &mut self.hasher {
hasher.update(&encrypted_gate.to_bytes());
}

// If we have more AND gates to evaluate, return.
if self.wants_gates() {
return;
}
}
Gate::Inv {
x: node_x,
z: node_z,
} => {
let x = self.labels[node_x.id()];
self.labels[node_z.id()] = x;
}
}
self.pos += 1;
}

self.complete = true;
}

/// Returns whether the evaluator has finished evaluating the circuit.
pub fn is_complete(&self) -> bool {
self.complete
}

/// Returns the active encoded outputs of the circuit.
pub fn outputs(&self) -> Result<Vec<EncodedValue<state::Active>>, EvaluatorError> {
if !self.is_complete() {
/// Returns the encoded outputs of the circuit.
pub fn finish(mut self) -> Result<EvaluatorOutput, EvaluatorError> {
if self.wants_gates() {
return Err(EvaluatorError::NotFinished);
}

Ok(self
.circ
.outputs()
// If there were 0 AND gates in the circuit, we need to evaluate the "free" gates now.
if !self.complete {
self.next(Default::default());
}

let outputs = self
.outputs
.iter()
.map(|output| {
let labels: Vec<Label> = output
.iter()
.map(|node| self.active_labels[node.id()].expect("feed should be initialized"))
.collect();
let labels: Vec<Label> = output.iter().map(|node| self.labels[node.id()]).collect();

EncodedValue::<state::Active>::from_labels(output.value_type(), &labels)
.expect("encoding should be correct")
})
.collect())
}
.collect();

/// Returns the hash of the encrypted gates.
pub fn hash(&self) -> Option<Hash> {
self.hasher.as_ref().map(|hasher| {
let hash: [u8; 32] = hasher.finalize().into();
Hash::from(hash)
Ok(EvaluatorOutput {
outputs,
hash: self.hasher.as_ref().map(|hasher| {
let hash: [u8; 32] = hasher.finalize().into();
Hash::from(hash)
}),
})
}
}

/// Consumer returned by [`Evaluator::evaluate_batched`].
#[derive(Debug)]
pub struct EncryptedGateBatchConsumer<'a, I: Iterator, const N: usize = DEFAULT_BATCH_SIZE>(
EncryptedGateConsumer<'a, I>,
);

impl<'a, I, const N: usize> EncryptedGateBatchConsumer<'a, I, N>
where
I: Iterator<Item = &'a Gate>,
{
/// Enables hashing of the encrypted gates.
pub fn enable_hasher(&mut self) {
self.0.enable_hasher()
}

/// Returns `true` if the evaluator wants more encrypted gates.
pub fn wants_gates(&self) -> bool {
self.0.wants_gates()
}

/// Evaluates the next batch of gates in the circuit.
#[inline]
pub fn next(&mut self, batch: EncryptedGateBatch<N>) {
for encrypted_gate in batch.into_array() {
self.0.next(encrypted_gate);
if !self.0.wants_gates() {
// Skipping any remaining gates which may have been used to pad the last batch.
return;
}
}
}

/// Returns the encoded outputs of the circuit, and the hash of the encrypted gates if present.
pub fn finish(self) -> Result<EvaluatorOutput, EvaluatorError> {
self.0.finish()
}
}
368 changes: 263 additions & 105 deletions crates/mpz-garble-core/src/generator.rs

Large diffs are not rendered by default.

166 changes: 123 additions & 43 deletions crates/mpz-garble-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -6,7 +6,9 @@
//!
//! ```
//! use mpz_circuits::circuits::AES128;
//! use mpz_garble_core::{Generator, Evaluator, ChaChaEncoder, Encoder};
//! use mpz_garble_core::{
//! Generator, Evaluator, ChaChaEncoder, Encoder, GeneratorOutput, EvaluatorOutput
//! };
//!
//!
//! let encoder = ChaChaEncoder::new([0u8; 32]);
@@ -19,30 +21,22 @@
//! let active_key = encoded_key.select(*key).unwrap();
//! let active_plaintext = encoded_plaintext.select(*plaintext).unwrap();
//!
//! let mut gen =
//! Generator::new(
//! AES128.clone(),
//! encoder.delta(),
//! &[encoded_key, encoded_plaintext]
//! ).unwrap();
//! let mut gen = Generator::default();
//! let mut ev = Evaluator::default();
//!
//! let mut ev =
//! Evaluator::new(
//! AES128.clone(),
//! &[active_key, active_plaintext]
//! ).unwrap();
//! let mut gen_iter = gen
//! .generate_batched(&AES128, encoder.delta(), vec![encoded_key, encoded_plaintext]).unwrap();
//! let mut ev_consumer = ev.evaluate_batched(&AES128, vec![active_key, active_plaintext]).unwrap();
//!
//! const BATCH_SIZE: usize = 1000;
//! while !(gen.is_complete() && ev.is_complete()) {
//! let batch: Vec<_> = gen.by_ref().take(BATCH_SIZE).collect();
//! ev.evaluate(batch.iter());
//! for batch in gen_iter.by_ref() {
//! ev_consumer.next(batch);
//! }
//!
//! let encoded_outputs = gen.outputs().unwrap();
//! let GeneratorOutput { outputs: encoded_outputs, .. } = gen_iter.finish().unwrap();
//! let encoded_ciphertext = encoded_outputs[0].clone();
//! let ciphertext_decoding = encoded_ciphertext.decoding();
//!
//! let active_outputs = ev.outputs().unwrap();
//! let EvaluatorOutput { outputs: active_outputs, .. } = ev_consumer.finish().unwrap();
//! let active_ciphertext = active_outputs[0].clone();
//! let ciphertext: [u8; 16] =
//! active_ciphertext.decode(&ciphertext_decoding).unwrap().try_into().unwrap();
@@ -57,23 +51,41 @@ pub(crate) mod circuit;
pub mod encoding;
mod evaluator;
mod generator;
pub mod msg;

pub use circuit::{EncryptedGate, GarbledCircuit};
pub use circuit::{EncryptedGate, EncryptedGateBatch, GarbledCircuit};
pub use encoding::{
state as encoding_state, ChaChaEncoder, Decoding, Delta, Encode, EncodedValue, Encoder,
EncodingCommitment, EqualityCheck, Label, ValueError,
};
pub use evaluator::{Evaluator, EvaluatorError};
pub use generator::{Generator, GeneratorError};
pub use evaluator::{
EncryptedGateBatchConsumer, EncryptedGateConsumer, Evaluator, EvaluatorError, EvaluatorOutput,
};
pub use generator::{
EncryptedGateBatchIter, EncryptedGateIter, Generator, GeneratorError, GeneratorOutput,
};

const KB: usize = 1024;
const BYTES_PER_GATE: usize = 32;

/// Maximum size of a batch in bytes.
const MAX_BATCH_SIZE: usize = 4 * KB;

/// Default amount of encrypted gates per batch.
///
/// Batches are stack allocated, so we will limit the size to `MAX_BATCH_SIZE`.
///
/// Additionally, because the size of each batch is static, if a circuit is smaller than a batch
/// we will be wasting some bandwidth sending empty bytes. This puts an upper limit on that
/// waste.
pub(crate) const DEFAULT_BATCH_SIZE: usize = MAX_BATCH_SIZE / BYTES_PER_GATE;

#[cfg(test)]
mod tests {
use aes::{
cipher::{BlockEncrypt, KeyInit},
Aes128,
};
use mpz_circuits::{circuits::AES128, types::Value};
use mpz_circuits::{circuits::AES128, types::Value, CircuitBuilder};
use mpz_core::aes::FIXED_KEY_AES;
use rand::SeedableRng;
use rand_chacha::ChaCha12Rng;
@@ -109,7 +121,6 @@ mod tests {

let key = [69u8; 16];
let msg = [42u8; 16];
const BATCH_SIZE: usize = 1000;

let expected: [u8; 16] = {
let cipher = Aes128::new_from_slice(&key).unwrap();
@@ -129,39 +140,108 @@ mod tests {
full_inputs[1].clone().select(msg).unwrap(),
];

let mut gen =
Generator::new_with_hasher(AES128.clone(), encoder.delta(), &full_inputs).unwrap();
let mut ev = Evaluator::new_with_hasher(AES128.clone(), &active_inputs).unwrap();

while !(gen.is_complete() && ev.is_complete()) {
let mut batch = Vec::with_capacity(BATCH_SIZE);
for enc_gate in gen.by_ref() {
batch.push(enc_gate);
if batch.len() == BATCH_SIZE {
break;
}
}
ev.evaluate(batch.iter());
}
let mut gen = Generator::default();
let mut ev = Evaluator::default();

let mut gen_iter = gen
.generate_batched(&AES128, encoder.delta(), full_inputs)
.unwrap();
let mut ev_consumer = ev.evaluate_batched(&AES128, active_inputs).unwrap();

let full_outputs = gen.outputs().unwrap();
let active_outputs = ev.outputs().unwrap();
gen_iter.enable_hasher();
ev_consumer.enable_hasher();

let gen_digest = gen.hash().unwrap();
let ev_digest = ev.hash().unwrap();
for batch in gen_iter.by_ref() {
ev_consumer.next(batch);
}

assert_eq!(gen_digest, ev_digest);
let GeneratorOutput {
outputs: full_outputs,
hash: gen_hash,
} = gen_iter.finish().unwrap();
let EvaluatorOutput {
outputs: active_outputs,
hash: ev_hash,
} = ev_consumer.finish().unwrap();

let outputs: Vec<Value> = active_outputs
.iter()
.zip(full_outputs)
.map(|(active_output, full_output)| {
full_output.commit().verify(&active_output).unwrap();
active_output.decode(&full_output.decoding()).unwrap()
})
.collect();

let actual: [u8; 16] = outputs[0].clone().try_into().unwrap();

assert_eq!(actual, expected);
assert_eq!(gen_hash, ev_hash);
}

// Tests garbling a circuit with no AND gates
#[test]
fn test_garble_no_and() {
let encoder = ChaChaEncoder::new([0; 32]);

let builder = CircuitBuilder::new();
let a = builder.add_input::<u8>();
let b = builder.add_input::<u8>();
let c = a ^ b;
builder.add_output(c);
let circ = builder.build().unwrap();
assert_eq!(circ.and_count(), 0);

let mut gen = Generator::default();
let mut ev = Evaluator::default();

let a = 1u8;
let b = 2u8;

let full_inputs: Vec<EncodedValue<encoding_state::Full>> = circ
.inputs()
.iter()
.map(|input| encoder.encode_by_type(0, &input.value_type()))
.collect();

let active_inputs: Vec<EncodedValue<encoding_state::Active>> = vec![
full_inputs[0].clone().select(a).unwrap(),
full_inputs[1].clone().select(b).unwrap(),
];

let mut gen_iter = gen
.generate_batched(&circ, encoder.delta(), full_inputs)
.unwrap();
let mut ev_consumer = ev.evaluate_batched(&circ, active_inputs).unwrap();

gen_iter.enable_hasher();
ev_consumer.enable_hasher();

for batch in gen_iter.by_ref() {
ev_consumer.next(batch);
}

let GeneratorOutput {
outputs: full_outputs,
hash: gen_hash,
} = gen_iter.finish().unwrap();
let EvaluatorOutput {
outputs: active_outputs,
hash: ev_hash,
} = ev_consumer.finish().unwrap();

let outputs: Vec<Value> = active_outputs
.iter()
.zip(full_outputs)
.map(|(active_output, full_output)| {
full_output.commit().verify(&active_output).unwrap();
active_output.decode(&full_output.decoding()).unwrap()
})
.collect();

let actual: u8 = outputs[0].clone().try_into().unwrap();

assert_eq!(actual, a ^ b);
assert_eq!(gen_hash, ev_hash);
}
}
29 changes: 0 additions & 29 deletions crates/mpz-garble-core/src/msg.rs

This file was deleted.