diff --git a/mpc-algebra/examples/algebra.rs b/mpc-algebra/examples/algebra.rs index 3a36c75d..c089797b 100644 --- a/mpc-algebra/examples/algebra.rs +++ b/mpc-algebra/examples/algebra.rs @@ -3,10 +3,11 @@ use std::path::PathBuf; use ark_ff::{BigInteger, BigInteger256, Field, FpParameters, PrimeField, UniformRand}; use ark_ff::{One, Zero}; use ark_poly::reveal; +use ark_std::{end_timer, start_timer}; use log::debug; use mpc_algebra::{ AdditiveFieldShare, BitAdd, BitDecomposition, BitwiseLessThan, EqualityZero, - IntervalTestHalfModulus, LogicalOperations, MpcField, Reveal, UniformBitRand, + LessThan, LogicalOperations, MpcField, Reveal, UniformBitRand, }; use mpc_net::{MpcMultiNet as Net, MpcNet}; @@ -178,6 +179,20 @@ fn test_interval_test_half_modulus() { } } +fn test_less_than() { + let mut rng = ark_std::test_rng(); + + for _ in 0..5 { + let timer = start_timer!(|| "less_than test"); + let a = MF::bit_rand(&mut rng); + let b = MF::bit_rand(&mut rng); + + let res = a.less_than(&b); + assert_eq!(res.reveal().is_one(), a.reveal() < b.reveal()); + end_timer!(timer) + } +} + fn test_and() { let mut rng = ark_std::test_rng(); @@ -352,6 +367,8 @@ fn main() { test_bit_rand(); println!("Test bit_rand passed"); + test_less_than(); + println!("Test less_than passed"); test_interval_test_half_modulus(); println!("Test interval_test_half_modulus passed"); test_rand_number_bitwise(); diff --git a/mpc-algebra/src/mpc_primitives.rs b/mpc-algebra/src/mpc_primitives.rs index 6b63370c..775e5648 100644 --- a/mpc-algebra/src/mpc_primitives.rs +++ b/mpc-algebra/src/mpc_primitives.rs @@ -13,10 +13,11 @@ pub trait BitwiseLessThan { fn bitwise_lt(&self, other: &Self) -> Self::Output; } -pub trait IntervalTestHalfModulus { +pub trait LessThan : UniformBitRand { type Output; - + fn interval_test_half_modulus(&self) -> Self::Output; + fn less_than(&self, other: &Self) -> Self::Output; } pub trait LogicalOperations { diff --git a/mpc-algebra/src/wire/field.rs b/mpc-algebra/src/wire/field.rs index 4b9cfa74..5dbfd1a2 100644 --- a/mpc-algebra/src/wire/field.rs +++ b/mpc-algebra/src/wire/field.rs @@ -3,6 +3,7 @@ use derivative::Derivative; use mpc_trait::MpcWire; use num_bigint::BigUint; use rand::Rng; +use core::panic; use std::fmt::{self, Debug, Display}; use std::io::{self, Read, Write}; use std::iter::{Product, Sum}; @@ -22,7 +23,7 @@ use ark_serialize::{ // use crate::channel::MpcSerNet; use crate::share::field::FieldShare; -use crate::{BeaverSource, BitAdd, BitDecomposition, BitwiseLessThan, IntervalTestHalfModulus, LogicalOperations, Reveal}; +use crate::{BeaverSource, BitAdd, BitDecomposition, BitwiseLessThan, LessThan, LogicalOperations, Reveal}; use crate::{EqualityZero, UniformBitRand}; use mpc_net::{MpcMultiNet as Net, MpcNet}; @@ -301,9 +302,8 @@ impl> BitwiseLessThan for Vec> { } -impl> IntervalTestHalfModulus for MpcField { +impl> LessThan for MpcField { type Output = Self; - // check if shared value a is in the interval [0, modulus/2) fn interval_test_half_modulus(&self) -> Self::Output { // define double self as x @@ -334,6 +334,15 @@ impl> IntervalTestHalfModulus // return 1 - lsb_x one - lsb_x } + + fn less_than(&self, other: &Self) -> Self::Output { + // [z]=[b−a

p/2] + // ([z]∧[x])∨([z]∧[y])∨(¬[z]∧[x]∧[y])=[z(x+y)+(1−2z)xy]. + let z = (*other-self).interval_test_half_modulus(); + let x = self.interval_test_half_modulus(); + let y = Self::one() - other.interval_test_half_modulus(); + z*(x+y)+(Self::one()-Self::from_public(F::from(2u8))*z)*x*y + } } impl> LogicalOperations for Vec> {