Skip to content

Commit

Permalink
Merge pull request #44 from Yoii-Inc/feat/mpc-boolean-field
Browse files Browse the repository at this point in the history
Feat/mpc boolean field
  • Loading branch information
taskooh authored May 22, 2024
2 parents 06fedff + 66badf6 commit 0cc1c9f
Show file tree
Hide file tree
Showing 6 changed files with 422 additions and 232 deletions.
88 changes: 59 additions & 29 deletions mpc-algebra/examples/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use ark_ff::{One, Zero};
use ark_poly::reveal;
use ark_std::{end_timer, start_timer};
use log::debug;
use mpc_algebra::boolean_field::MpcBooleanField;
use mpc_algebra::{
share, AdditiveFieldShare, BitAdd, BitDecomposition, BitwiseLessThan, EqualityZero, LessThan, LogicalOperations, MpcField, Reveal, UniformBitRand
};
Expand All @@ -27,6 +28,8 @@ struct Opt {
type F = ark_bls12_377::Fr;
type S = AdditiveFieldShare<F>;
type MF = MpcField<F, S>;
type MBF = MpcBooleanField<F, S>;


fn test_add() {
// init communication protocol
Expand Down Expand Up @@ -64,7 +67,7 @@ fn test_div() {
}

fn test_sum() {
let a = vec![
let a = [
MF::from_public(F::from(1u64)),
MF::from_public(F::from(2u64)),
MF::from_public(F::from(3u64)),
Expand All @@ -80,7 +83,7 @@ fn test_bit_rand() {
let mut counter = [0, 0, 0];

for _ in 0..1000 {
let a = MF::bit_rand(&mut rng).reveal();
let a = MBF::bit_rand(&mut rng).reveal();

if a.is_zero() {
counter[0] += 1;
Expand All @@ -99,7 +102,7 @@ fn test_rand_number_bitwise() {
let mut rng = thread_rng();

for _ in 0..10 {
let (a, b) = MF::rand_number_bitwise(&mut rng);
let (a, b) = MBF::rand_number_bitwise(&mut rng);

let revealed_a = a.iter().map(|x| x.reveal()).collect::<Vec<_>>();
let revealed_b = b.reveal();
Expand Down Expand Up @@ -128,10 +131,10 @@ fn test_bitwise_lt() {

for _ in 0..10 {
let a = (0..modulus_size)
.map(|_| MF::bit_rand(rng))
.map(|_| MBF::bit_rand(rng))
.collect::<Vec<_>>();
let b = (0..modulus_size)
.map(|_| MF::bit_rand(rng))
.map(|_| MBF::bit_rand(rng))
.collect::<Vec<_>>();

let a_bigint =
Expand Down Expand Up @@ -186,9 +189,9 @@ fn test_less_than() {
fn test_and() {
let mut rng = ark_std::test_rng();

let a00 = vec![MF::zero(), MF::zero()];
let a10 = vec![MF::one(), MF::zero()];
let a11 = vec![MF::one(), MF::one()];
let a00 = vec![MBF::pub_false(),MBF::pub_true()];
let a10 = vec![MBF::pub_true(), MBF::pub_false()];
let a11 = vec![MBF::pub_true(), MBF::pub_true()];

assert_eq!(a00.kary_and().reveal(), F::zero());
assert_eq!(a10.kary_and().reveal(), F::zero());
Expand All @@ -197,7 +200,7 @@ fn test_and() {
let mut counter = [0, 0];

for _ in 0..100 {
let a = (0..3).map(|_| MF::bit_rand(&mut rng)).collect::<Vec<_>>();
let a = (0..3).map(|_| MBF::bit_rand(&mut rng)).collect::<Vec<_>>();

let res = a.kary_and();

Expand All @@ -214,9 +217,9 @@ fn test_and() {
fn test_or() {
let mut rng = thread_rng();

let a00 = vec![MF::zero(), MF::zero()];
let a10 = vec![MF::one(), MF::zero()];
let a11 = vec![MF::one(), MF::one()];
let a00 = vec![MBF::pub_false(), MBF::pub_false()];
let a10 = vec![MBF::pub_true(), MBF::pub_false()];
let a11 = vec![MBF::pub_true(), MBF::pub_true()];

assert_eq!(a00.kary_or().reveal(), F::zero());
assert_eq!(a10.kary_or().reveal(), F::one());
Expand All @@ -225,7 +228,7 @@ fn test_or() {
let mut counter = [0, 0];

for _ in 0..100 {
let a = (0..3).map(|_| MF::bit_rand(&mut rng)).collect::<Vec<_>>();
let a = (0..3).map(|_| MBF::bit_rand(&mut rng)).collect::<Vec<_>>();

let res = a.kary_or();

Expand All @@ -239,6 +242,28 @@ fn test_or() {
println!("OR counter is {:?}", counter);
}

fn test_xor() {
let mut rng = ark_std::test_rng();
let mut counter = [0, 0];

for _ in 0..100 {
let a = MBF::bit_rand(&mut rng);
let b = MBF::bit_rand(&mut rng);

let res = a ^ b;

println!("unbounded and is {:?}", res.reveal());
assert_eq!(res.reveal().is_one(),a.reveal().is_one() ^ b.reveal().is_one());
if res.reveal().is_zero() {
counter[0] += 1;
} else if res.reveal().is_one() {
counter[1] += 1;
}
}
println!("AND counter is {:?}", counter);
}


fn test_equality_zero() {
let mut rng = ark_std::test_rng();

Expand Down Expand Up @@ -268,27 +293,30 @@ fn test_equality_zero() {

fn test_carries() {
// a = 0101 = 5, b = 1100= 12
let mut a = vec![MF::from_add_shared(F::from(0u64)); 4];
let mut b = vec![MF::from_add_shared(F::from(0u64)); 4];
a[0] += MF::from_public(F::from(1u64));
a[2] += MF::one();
b[2] += MF::one();
b[3] += MF::one();
let mut a = vec![MBF::from_add_shared(F::zero()); 4];
let mut b = vec![MBF::from_add_shared(F::zero()); 4];
// TODO: improve how to initialize
a[0] = a[0] | MBF::pub_true();
a[2] = a[2] | MBF::pub_true();
b[2] = b[2] | MBF::pub_true();
b[3] = b[3] | MBF::pub_true();

// TODO: better way to initialize

let c = a.carries(&b);

// expected carries: 1100
assert_eq!(c.reveal(), vec![F::zero(), F::zero(), F::one(), F::one()]);

// a = 010011 = 19, b = 101010= 42
let mut a = vec![MF::from_add_shared(F::from(0u64)); 6];
let mut b = vec![MF::from_add_shared(F::from(0u64)); 6];
a[0] += MF::one();
a[1] += MF::one();
a[4] += MF::one();
b[1] += MF::one();
b[3] += MF::one();
b[5] += MF::one();
let mut a = vec![MBF::from_add_shared(F::from(0u64)); 6];
let mut b = vec![MBF::from_add_shared(F::from(0u64)); 6];
a[0] = a[0] | MBF::pub_true();
a[1] = a[1] | MBF::pub_true();
a[4] = a[4] | MBF::pub_true();
b[1] = b[1] | MBF::pub_true();
b[3] = b[3] | MBF::pub_true();
b[5] = b[5] | MBF::pub_true();

let c = a.carries(&b);

Expand All @@ -309,8 +337,8 @@ fn test_carries() {
fn test_bit_add() {
let rng = &mut thread_rng();

let (rand_a, a) = MF::rand_number_bitwise(rng);
let (rand_b, b) = MF::rand_number_bitwise(rng);
let (rand_a, a) = MBF::rand_number_bitwise(rng);
let (rand_b, b) = MBF::rand_number_bitwise(rng);

let c_vec = rand_a.bit_add(&rand_b);

Expand Down Expand Up @@ -372,6 +400,8 @@ fn main() {
println!("Test and passed");
test_or();
println!("Test or passed");
test_xor();
println!("Test xor passed");
test_equality_zero();
println!("Test equality_zero passed");

Expand Down
13 changes: 8 additions & 5 deletions mpc-algebra/src/mpc_primitives.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use rand::Rng;

pub trait UniformBitRand: Sized {
type BaseField;

fn bit_rand<R: Rng + ?Sized>(rng: &mut R) -> Self;
// little-endian
fn rand_number_bitwise<R: Rng + ?Sized>(rng: &mut R) -> (Vec<Self>, Self);
fn rand_number_bitwise<R: Rng + ?Sized>(rng: &mut R) -> (Vec<Self>, Self::BaseField);
}

pub trait BitwiseLessThan {
Expand All @@ -12,7 +14,7 @@ pub trait BitwiseLessThan {
fn is_smaller_than_le(&self, other: &Self) -> Self::Output;
}

pub trait LessThan : UniformBitRand {
pub trait LessThan {
type Output;

fn is_smaller_or_equal_than_mod_minus_one_div_two(&self) -> Self::Output;
Expand All @@ -28,13 +30,14 @@ pub trait LogicalOperations {
}

pub trait EqualityZero {
fn is_zero_shared(&self) -> Self;
type Output;
fn is_zero_shared(&self) -> Self::Output;
}

pub trait BitDecomposition {
type Output;
type BooleanField;

fn bit_decomposition(&self) -> Self::Output;
fn bit_decomposition(&self) -> Vec<Self::BooleanField>;
}

pub trait BitAdd {
Expand Down
27 changes: 11 additions & 16 deletions mpc-algebra/src/r1cs_helper/mpc_fp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@ use std::borrow::Borrow;
use ark_ff::{Field, PrimeField, SquareRootField};
use ark_r1cs_std::{
alloc::{AllocVar, AllocationMode},
fields::{
fp::{AllocatedFp, FpVar},
FieldOpsBounds, FieldVar,
},
fields::FieldOpsBounds,
impl_ops,
select::CondSelectGadget,
R1CSVar, ToConstraintFieldGadget,
R1CSVar,
};
use ark_r1cs_std::{impl_bounded_ops, Assignment};
use ark_relations::{
Expand Down Expand Up @@ -159,7 +155,7 @@ impl<F: PrimeField, S: FieldShare<F>> MpcAllocatedFp<F, S> {
#[tracing::instrument(target = "r1cs")]
pub fn add(&self, other: &Self) -> Self {
let value = match (self.value, other.value) {
(Some(val1), Some(val2)) => Some(val1 + &val2),
(Some(val1), Some(val2)) => Some(val1 + val2),
(..) => None,
};

Expand All @@ -176,7 +172,7 @@ impl<F: PrimeField, S: FieldShare<F>> MpcAllocatedFp<F, S> {
#[tracing::instrument(target = "r1cs")]
pub fn sub(&self, other: &Self) -> Self {
let value = match (self.value, other.value) {
(Some(val1), Some(val2)) => Some(val1 - &val2),
(Some(val1), Some(val2)) => Some(val1 - val2),
(..) => None,
};

Expand All @@ -193,7 +189,7 @@ impl<F: PrimeField, S: FieldShare<F>> MpcAllocatedFp<F, S> {
#[tracing::instrument(target = "r1cs")]
pub fn mul(&self, other: &Self) -> Self {
let product = MpcAllocatedFp::new_witness(self.cs.clone(), || {
Ok(self.value.get()? * &other.value.get()?)
Ok(self.value.get()? * other.value.get()?)
})
.unwrap();
self.cs
Expand Down Expand Up @@ -463,11 +459,11 @@ impl<F: PrimeField + SquareRootField, S: FieldShare<F>> MpcAllocatedFp<F, S> {
let is_zero_value = self.value.get()?.is_zero_shared();

let is_not_zero =
MpcBoolean::new_witness(self.cs.clone(), || Ok(MpcField::one() - is_zero_value))?;
MpcBoolean::new_witness(self.cs.clone(), || Ok((!is_zero_value).field()))?;

let multiplier = self
.cs
.new_witness_variable(|| (self.value.get()? + is_zero_value).inverse().get())?;
.new_witness_variable(|| (self.value.get()? + is_zero_value.field()).inverse().get())?;

self.cs
.enforce_constraint(lc!() + self.variable, lc!() + multiplier, is_not_zero.lc())?;
Expand Down Expand Up @@ -554,8 +550,7 @@ impl<F: PrimeField + SquareRootField, S: FieldShare<F>> MpcToBitsGadget<F, S>
#[tracing::instrument(target = "r1cs")]
fn to_non_unique_bits_le(&self) -> Result<Vec<MpcBoolean<F, S>>, SynthesisError> {
let cs = self.cs.clone();
use ark_ff::BitIteratorBE;
let mut bits = if let Some(value) = self.value {
let bits = if let Some(value) = self.value {
// let field_char = BitIteratorBE::new(F::characteristic());
// let bits: Vec<_> = BitIteratorBE::new(value.into_repr())
// .zip(field_char)
Expand All @@ -574,7 +569,7 @@ impl<F: PrimeField + SquareRootField, S: FieldShare<F>> MpcToBitsGadget<F, S>

let bits: Vec<_> = bits
.into_iter()
.map(|b| MpcBoolean::new_witness(cs.clone(), || b.get()))
.map(|b| MpcBoolean::new_witness(cs.clone(), || b.get().map(|b| b.field())))
.collect::<Result<_, _>>()?;

let mut lc = LinearCombination::zero();
Expand Down Expand Up @@ -690,9 +685,9 @@ impl<F: PrimeField, S: FieldShare<F>> MpcTwoBitLookupGadget<F, S> for MpcAllocat
})?;
let one = Variable::One;
b.cs().enforce_constraint(
lc!() + b[1].lc() * (c[3] - &c[2] - &c[1] + &c[0]) + (c[1] - &c[0], one),
lc!() + b[1].lc() * (c[3] - c[2] - c[1] + c[0]) + (c[1] - c[0], one),
lc!() + b[0].lc(),
lc!() + result.variable - (c[0], one) + b[1].lc() * (c[0] - &c[2]),
lc!() + result.variable - (c[0], one) + b[1].lc() * (c[0] - c[2]),
)?;

Ok(result)
Expand Down
1 change: 1 addition & 0 deletions mpc-algebra/src/wire.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod field;
pub mod boolean_field;
pub use field::*;
pub mod group;
pub use group::*;
Expand Down
Loading

0 comments on commit 0cc1c9f

Please sign in to comment.