diff --git a/examples/bin_test_marlin.rs b/examples/bin_test_marlin.rs index c258c4fd..410b5876 100644 --- a/examples/bin_test_marlin.rs +++ b/examples/bin_test_marlin.rs @@ -22,5 +22,6 @@ fn main() { marlin::mpc_test_prove_and_verify_pedersen(1); marlin::test_equality_zero(1); marlin::test_bit_decomposition(1); - marlin::test_enforce_smaller_eq_than(5); + marlin::test_enforce_smaller_eq_than(3); + marlin::test_smaller_than(5); } diff --git a/mpc-algebra/src/mpc_primitives.rs b/mpc-algebra/src/mpc_primitives.rs index 3c6240b9..59974e04 100644 --- a/mpc-algebra/src/mpc_primitives.rs +++ b/mpc-algebra/src/mpc_primitives.rs @@ -8,6 +8,9 @@ pub trait UniformBitRand: Sized { fn bit_rand(rng: &mut R) -> Self; // little-endian fn rand_number_bitwise(rng: &mut R) -> (Vec, Self::BaseField); + fn rand_number_bitwise_less_than_half_modulus( + rng: &mut R, + ) -> (Vec, Self::BaseField); } pub trait BitwiseLessThan { diff --git a/mpc-algebra/src/wire/boolean_field.rs b/mpc-algebra/src/wire/boolean_field.rs index 2fef5de8..7f116bf9 100644 --- a/mpc-algebra/src/wire/boolean_field.rs +++ b/mpc-algebra/src/wire/boolean_field.rs @@ -196,7 +196,48 @@ impl> UniformBitRand for MpcBo } }; - // bits to field elemetn (little endian) + // bits to field element (little endian) + let num = valid_bits + .iter() + .map(|b| b.field()) + .rev() + .fold(Self::BaseField::zero(), |acc, x| { + acc * Self::BaseField::from_public(F::from(2u8)) + x + }); + + (valid_bits, num) + } + + fn rand_number_bitwise_less_than_half_modulus( + rng: &mut R, + ) -> (Vec, Self::BaseField) { + let modulus_size = F::Params::MODULUS_BITS as usize; + + let mut half_modulus_bits = F::Params::MODULUS_MINUS_ONE_DIV_TWO + .to_bits_le() + .iter() + .map(|&b| Self::from(b)) + .collect::>(); + + half_modulus_bits = half_modulus_bits[..modulus_size].to_vec(); + + let valid_bits = loop { + let bits = (0..modulus_size) + .map(|_| Self::bit_rand(rng)) + .collect::>(); + + if bits + .clone() + .is_smaller_than_le(&half_modulus_bits) + .field() + .reveal() + .is_one() + { + break bits; + } + }; + + // bits to field element (little endian) let num = valid_bits .iter() .map(|b| b.field()) diff --git a/src/circuits.rs b/src/circuits.rs index bbb62393..7b6b991c 100644 --- a/src/circuits.rs +++ b/src/circuits.rs @@ -7,3 +7,4 @@ pub use werewolf::*; pub mod bit_decomposition; pub mod enforce_smaller_or_eq_than; pub mod equality_zero; +pub mod smaller_than; diff --git a/src/circuits/smaller_than.rs b/src/circuits/smaller_than.rs new file mode 100644 index 00000000..f92745e8 --- /dev/null +++ b/src/circuits/smaller_than.rs @@ -0,0 +1,48 @@ +use std::cmp::Ordering; + +use ark_ff::One; +use ark_ff::PrimeField; +use ark_r1cs_std::alloc::AllocVar; +use ark_r1cs_std::boolean::Boolean; +use ark_r1cs_std::eq::EqGadget; +use ark_r1cs_std::fields::fp::FpVar; +use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; +use mpc_algebra::malicious_majority::MpcField; +use mpc_algebra::{MpcBoolean, MpcEqGadget, MpcFpVar}; + +type Fr = ark_bls12_377::Fr; +type MFr = MpcField; + +pub struct SmallerThanCircuit { + pub a: F, + pub b: F, + pub res: F, + pub cmp: Ordering, + pub check_eq: bool, +} + +impl ConstraintSynthesizer for SmallerThanCircuit { + fn generate_constraints(self, cs: ConstraintSystemRef) -> Result<(), SynthesisError> { + let a_var = MpcFpVar::new_witness(cs.clone(), || Ok(self.a))?; + let b_var = MpcFpVar::new_witness(cs.clone(), || Ok(self.b))?; + let res_var = MpcBoolean::new_witness(cs.clone(), || Ok(self.res))?; + let res2 = MpcFpVar::is_cmp(&a_var, &b_var, self.cmp, self.check_eq).unwrap(); + + res_var.enforce_equal(&res2); + + Ok(()) + } +} + +impl ConstraintSynthesizer for SmallerThanCircuit { + fn generate_constraints(self, cs: ConstraintSystemRef) -> Result<(), SynthesisError> { + let a_var = FpVar::new_witness(cs.clone(), || Ok(self.a))?; + let b_var = FpVar::new_witness(cs.clone(), || Ok(self.b))?; + let res_var = Boolean::new_witness(cs.clone(), || Ok(self.res.is_one()))?; + let res2 = FpVar::is_cmp(&a_var, &b_var, self.cmp, self.check_eq).unwrap(); + + res_var.enforce_equal(&res2); + + Ok(()) + } +} diff --git a/src/marlin.rs b/src/marlin.rs index c8e5a83a..fd3c952f 100644 --- a/src/marlin.rs +++ b/src/marlin.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use ark_crypto_primitives::CommitmentScheme; use ark_ec::twisted_edwards_extended::GroupAffine; use ark_ff::{BigInteger, PrimeField}; @@ -11,7 +13,7 @@ use blake2::Blake2s; use itertools::Itertools; // use mpc_algebra::honest_but_curious::*; use mpc_algebra::{ - malicious_majority::*, BooleanWire, MpcBooleanField, SpdzFieldShare, UniformBitRand, + malicious_majority::*, BooleanWire, LessThan, MpcBooleanField, SpdzFieldShare, UniformBitRand, }; use mpc_algebra::{FromLocal, Reveal}; use mpc_net::{MpcMultiNet, MpcNet}; @@ -19,6 +21,7 @@ use mpc_net::{MpcMultiNet, MpcNet}; use ark_std::{One, Zero}; use crate::circuits::enforce_smaller_or_eq_than::SmallerEqThanCircuit; +use crate::circuits::smaller_than::{self, SmallerThanCircuit}; use crate::{ circuits::{ bit_decomposition::BitDecompositionCircuit, circuit::MyCircuit, @@ -313,7 +316,6 @@ pub fn test_bit_decomposition(n_iters: usize) { } } -// Test pub fn test_enforce_smaller_eq_than(n_iters: usize) { let rng = &mut test_rng(); @@ -351,3 +353,46 @@ pub fn test_enforce_smaller_eq_than(n_iters: usize) { } } } + +pub fn test_smaller_than(n_iters: usize) { + let rng = &mut test_rng(); + let (_, local_a_rand) = + MpcBooleanField::>::rand_number_bitwise_less_than_half_modulus(rng); + let (_, local_b_rand) = + MpcBooleanField::>::rand_number_bitwise_less_than_half_modulus(rng); + let local_res = local_a_rand.is_smaller_than(&local_b_rand); + + let local_circuit = SmallerThanCircuit { + a: local_a_rand.reveal(), + b: local_b_rand.reveal(), + res: local_res.reveal(), + cmp: Ordering::Less, + check_eq: true, + }; + let (mpc_index_pk, index_vk) = setup_and_index(local_circuit); + for _ in 0..n_iters { + let (_, a_rand) = + MpcBooleanField::>::rand_number_bitwise_less_than_half_modulus( + rng, + ); + let (_, b_rand) = + MpcBooleanField::>::rand_number_bitwise_less_than_half_modulus( + rng, + ); + let res = a_rand.is_smaller_than(&b_rand); + let mpc_circuit = SmallerThanCircuit { + a: a_rand, + b: b_rand, + res: res.field(), + cmp: Ordering::Less, + check_eq: true, + }; + let inputs = vec![]; + assert!(prove_and_verify( + &mpc_index_pk, + &index_vk, + mpc_circuit, + inputs + )); + } +}